Now Reading
Listening with LLM – moomou

Listening with LLM – moomou

2024-01-13 10:09:12

Overview

That is the primary a part of many posts I’m writing to consolidate learnings on the best way to finetune Massive Language Fashions (LLMs) to course of audio, with the eventual purpose of with the ability to construct and host a LLM in a position to describe human voices.

I’m motivated to achieve hands-on expertise tinkering LLMs so, as a lot as sensible, I attempted to recreate utilities and capabilities with pytorch from scratch moderately than depend on third celebration libraries.

tl;dr I chronicle and share the steps I took to learn to finetune a LLM mannequin to explain a given audio file on Google’s MusicCaps dataset

Background

Just lately, I got here throughout two papers

to offer LLMs audio understanding capabilities.

Broadly talking, each papers explored leveraging an audio encoder to remodel sound to embeddings that’s then fed into LLMs together with textual content embeddings.

In SALMONN’s case, they mixed OpenAI’s Whisper and BEATS encoder, carried out pretraining on the mixed encoder, then leveraged LoRA for finetuning the LLM. Qwen-Audio bootstrapped its audio encoder from OpenAI’s Whisper; after pretraining, Qwen-Audio performs a full finetuning on the LLM.

These two papers gave me an important overview on the best way to adapt cross area encoders and mix them with LLMs.
Excited by the concept of a LLM with common audio understanding skill and itching to achieve hands-on expertise, I made a decision to try to construct a minimal viable LLM with audio processing functionality.

Setup

To get began, I hopped over to HuggingFace to discover a good base LLM and a medium-sized dataset. I wished to do as a lot work domestically as potential so everythign should run on a neighborhood RTX 3090.

After testing and evaluating a couple of completely different fashions, I settled on Mistral OpenOrca.

For audio encoder, I went with OpenAI’s Whisper.

For dataset, I selected MusicCaps. I didn’t see any handy hyperlinks to obtain processed/segmented audio recordsdata, so I wrote a small script to obtain the Youtube movies.

One Mini Step at a Time

With the fundamental dependencies out of the best way, I fired up my Jupyter pocket book and began tinkering.

Sampling from Scratch

Step one I took is to make sure I can load the bottom LLM and carry out inference appropriately. As an alternative of leveraging transformers library’s generation utilities, I applied my very own sampling perform to confirm my understanding in addition to to learn to pattern utilizing embeddings instantly, which can turn out to be useful when feeding in audio embeddings.

@torch.no_grad
def sampler(input_ids):
    outputs = []

    for _ in vary(50):
        inputs_embeds = mannequin.llm.mannequin.embed_tokens(input_ids)
        res = mannequin.llm(inputs_embeds=inputs_embeds)
        # res.logits form is (batch, seq_len, logits)
        # pattern utilizing multinomial utilizing the final logits 
        sampled = torch.multinomial(res.logits[:,-1,:].softmax(dim=-1), 1)
        # repeatedly concat the `sampled` to the `input_ids` for subsequent sampling
        input_ids = torch.cat((input_ids, sampled), dim=-1)

    return input_ids

Utilizing the tokenizer class obtained from Transformer’s AutoTokenizer class, I used to be in a position to confirm sampling labored as anticipated! Working

tokenizer.decode(sampler(tokenizer("inform me a narrative", return_tensors="pt").input_ids.to("cuda:0"))[0])

yields (for example output)

'<s>inform me a narrative is a movie and video manufacturing firm, inform me a narrative is an idea that was created to permit folks to return collectively by the ability of storytelling.n and so, with this large energy in storytelling, the founders and creat'

Debugging NaNs and Infs

To this point so good. Nonetheless, I quickly observed that, sometimes, the sampling perform would fail by complaining that softmax perform encountered an inf or NaN. I adopted this insightful thread and learnt to determine the supply of NaN through the use of the next tailored Pytorch hooks

import torch
from functools import partial

__registered_hook_refs = []

for h in __registered_hook_refs:
    h.take away()

__global = []

def nan_hook(module, args, output, identify=None):
    if not isinstance(output, tuple):
        outputs = [output]
    else:
        outputs = output

    for i, out in enumerate(outputs):
        if out is None:
            proceed
        if isinstance(out, tuple):
            for j, out2 in enumerate(out):
                nan_mask = torch.isnan(out2)
                if nan_mask.any():
                    __global.append((module, args, output))
                    elevate RuntimeError(f"In module {identify} of identify {module.__class__.__name__}, Discovered NAN in output {j} at indices: ", nan_mask.nonzero(), "the place:",
                               out[nan_mask.nonzero()[:, 0].distinctive(sorted=True)])
                               
        elif torch.is_tensor(out):
            nan_mask = torch.isnan(out)
            
            if nan_mask.any():
                __global.append((module, args, output))
                elevate RuntimeError(f"In module {identify} of identify {module.__class__.__name__}, Discovered NAN in output {i} at indices: ", nan_mask.nonzero(), "the place:",
                                   out[nan_mask.nonzero()[:, 0].distinctive(sorted=True)])

def register_nan_hook(mannequin: torch.nn.Module):
    for identify, submodule in mannequin.named_modules():
        new_hook = partial(nan_hook, identify=identify+'.again')
        hook_ref = submodule.register_full_backward_hook(new_hook)
        __registered_hook_refs.append(hook_ref)
        new_hook = partial(nan_hook, identify=identify+'.fwd')
        hook_ref = submodule.register_forward_hook(new_hook)
        __registered_hook_refs.append(hook_ref)

debug = True
register_nan_hook(mannequin) if debug else None

Leveraging these hooks narrowed down the supply of difficulty to a selected layer and from there I used to be in a position to hint the issue to an inf worth within the mannequin weights. Digging additional, I traced the supply of inf to bad RAM sticks! After mitigation, I wrote a small script to confirm the mannequin weights and confirmed sampling perform labored as anticipated.

# confirm mannequin weight
from collections import Counter
pbytype = Counter()
for identify, p in (mannequin.named_parameters()):
    if torch.isinf(p).any() or torch.isnan(p).any():
        print(identify, p)
        elevate ValueError("invalid weight")
    else:
        pbytype[p.dtype] += 1
print("OK", pbytype)

Adapting Whisper to Mistral

After gaining confidence with debugging Pytorch modules, I centered on adapting Whisper mannequin so audio recordsdata could be remodeled into an embedding that may then be fed into Mistral.

OpenAI’s Whisper model consists of two main elements, an AudioEncoder and a TextDecoder.
For the aim of translating audio into embeddings, I solely want the AudioEncoder part.

Due to this fact, I loaded up a full Whisper mannequin and extracted the AudioEncoder weights utilizing the next snippets

import whisper

mannequin = whisper.load_model("large-v3")
audio_encoder = mannequin.encoder
torch.save(
    audio_encoder.state_dict(),
    "<output_location>",
)

I tailored the Whisper AudioEncoder right into a TunableWhisperAudioEncoder with an additional projection layer to map from Whisper’s audio embedding (measurement 1280) to mistral’s token embedding (measurement 4096).

I ensured proj is the one trainable community by explicitly freezing the audio encoder’s parameters. Notice that TrainableSubmodule is a hyperparameter and any mannequin that maps the output embedding to measurement 4096 will work. Later within the submit, I’ll describe what I discovered to work for me.

class TunableWhisperAudioEncoder(nn.Module):
    def __init__(self, *, output_embedding_size=4096):
        """
        args
            output_embedding_size: int = 4096 / mistral default embedding measurement
        """
        tremendous().__init__()

        self.audio_encoder = load_whisper_v3_audio_encoder()
        self.proj = TrainableSubmodule(output_embedding_size=output_embedding_size)

        # # Freeze all parameters
        for param in audio_encoder.parameters():
            param.requires_grad = False

    def ahead(self, mels):
        res = self.audio_encoder(mels)
        res = self.proj(res)
        return res

def load_whisper_v3_audio_encoder(
    *,
    n_mels=128,
    n_audio_ctx=1500,
    n_audio_state=1280,
    n_audio_head=20,
    n_audio_layer=32,
):
    m = whisper.mannequin.AudioEncoder(
        n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer
    )
    m.load_state_dict(torch.load(WHISPER_AUDIO_BIN))
    return m

Lastly, I construct up the mannequin I’m going to make use of for coaching as follows

class Mannequin(nn.Module):
    def __init__(self, audio_encoder: "Whisper.AudioEncoder", llm: "Mistral"):
        tremendous().__init__()

        self.audio_encoder = audio_encoder
        self.llm = llm

        # freeze the LLM weights
        for p in self.llm.parameters():
            p.requires_grad = False

    def ahead(self, batch):
        audio_mels = batch["audio_mels"]
        # caption token ids
        cap_ids = batch["cap_ids"]
        # caption consideration masks
        cap_ids_attention_mask = batch["cap_attention_mask"]
        prompt_ids = batch["prompt_ids"]
        prompt_ids_attention_mask = batch["prompt_attention_mask"]
        end_prompt_ids = batch["end_prompt_ids"]
        end_prompt_ids_attention_mask = batch["end_prompt_attention_mask"]

        audio_embeds = self.audio_encoder(audio_mels)
        # audio_embeds: (batch, audio_seq_len, audio_embedding_size)
        bs, audio_seq = audio_embeds.form[:2]

        attention_mask = torch.concat(
            (
                prompt_ids_attention_mask,
                torch.ones(bs, audio_seq).to(cap_ids.machine),
                end_prompt_ids_attention_mask,
                cap_ids_attention_mask,
            ),
            dim=1,
        )

        cap_embeds = self.llm.mannequin.embed_tokens(cap_ids)
        prompt_embeds = self.llm.mannequin.embed_tokens(prompt_ids)
        end_prompt_embeds = self.llm.mannequin.embed_tokens(end_prompt_ids)

        # construct the inputs_embeds by concating all of the token embeddings
        # with audio_embeddings
        inputs_embeds = torch.concat(
            (
                prompt_embeds,
                audio_embeds.to(cap_embeds.dtype),
                end_prompt_embeds,
                cap_embeds,
            ),
            dim=1,
        )

        mout = self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
        )

        return mout, audio_embeds.form[1]

The mannequin itself is sort of easy in that it merely holds reference to the Mistral LLM and TunableWhisperAudioEncoder. The ahead technique encapsulates the logic of changing audio mel-spectrogram into audio embeddings, then concatenating the audio embeddings with textual content/token embeddings to feeding these into Mistral LLM.

Sampling with Audio from Scratch

With the fundamental mannequin in place, the subsequent step is to try to pattern from this mannequin with audio inputs. Right here is the audio sampling perform I got here up with.

# word, full gist is offered at https://gist.github.com/moomou/7df8345d79a0063d67d1fa2b4cf55db8

@torch.no_grad()
def sample_with_audio(mannequin, tokenizer, immediate, audio_file, machine="cuda:0", iteration=50):
    audio_mels = load_audio_mels(audio_file).to(machine).half()
    end_prompt_ids, end_prompt_attention_mask = text_2_ids_and_attention_mask(
        tokenizer,
        end_template(),
        truncate=True,
    )
    prompt_ids, prompt_attention_mask = text_2_ids_and_attention_mask(
        tokenizer,
        immediate,
    )

    prompt_ids = prompt_ids.to(machine)
    prompt_attention_mask = prompt_attention_mask.to(machine)
    end_prompt_attention_mask = end_prompt_attention_mask.to(machine)
    end_prompt_ids = end_prompt_ids.to(machine)
    sampled_ids = None

    prompt_embeds = None
    end_prompt_embeds = None
    audio_embeds = None

    with torch.amp.autocast(device_type="cuda", dtype=torch.float16): # use float16 to cut back GPU reminiscence
        if audio_embeds is None:
            audio_embeds = mannequin.audio_encoder(audio_mels)
        bs, audio_seq = audio_embeds.form[:2]
        
        mask_concat_args = [
            prompt_attention_mask,
            torch.ones(bs, audio_seq).to(audio_embeds.device),
            end_prompt_attention_mask,
        ]

        for _ in vary(iteration):
            if sampled_ids is not None:
                mask_concat_args.append(torch.ones(bs, sampled_ids.form[1]).to(audio_embeds.machine))
                
            attention_mask = torch.concat(
                tuple(mask_concat_args),
                dim=1,
            )

            if prompt_embeds is None:
                prompt_embeds = mannequin.llm.mannequin.embed_tokens(prompt_ids)
            if end_prompt_embeds is None:
                end_prompt_embeds = mannequin.llm.mannequin.embed_tokens(end_prompt_ids)
                
            sampled_ids_embeds = None
            if sampled_ids is not None:
                sampled_ids_embeds = mannequin.llm.mannequin.embed_tokens(sampled_ids)
                
            embeds_concat_args = [
                prompt_embeds,
                audio_embeds.to(prompt_embeds.dtype),
                end_prompt_embeds,
            ]
            if sampled_ids_embeds is not None:
                embeds_concat_args.append(sampled_ids_embeds)
                
            inputs_embeds = torch.concat(
                tuple(embeds_concat_args),
                dim=1,
            )
    
            mout = mannequin.llm(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
            )
    
            logits = mout.logits
            sampled = torch.multinomial(logits[:, -1, :].softmax(dim=-1), 1)
            
            if sampled_ids is None:
                sampled_ids = sampled
            else:
                sampled_ids = torch.cat((sampled_ids, sampled), dim=-1).to(machine)

    return torch.concat((
        prompt_ids, 
        end_prompt_ids,
        sampled_ids,
    ),dim=-1)

Placing the perform to make use of by way of

dataloader = ... # commonplace pytorch dataloader
local_batch = subsequent(iter(dataloader))
tokenizer.decode(sample_with_audio(mannequin, tokenizer, prompt_template_fn(), audio_file, iteration=60)[0])

produces gibberish as anticipated since TunableWhisperAudioEncoder projection layer is untrained.

'<s> <|im_start|>  systemn    You're a useful AI who follows instruction rigorously<|im_end|> <|im_start|>  usern    Describe the sound of the given file n    <|im_end|> <|im_start|>  assistantn     battle<|im_end|> clockunits ]andfirst4Iftektime爆R Cur<|im_end|> United<|im_end|> ’daysIn“By no means<|im_end|> thenAnd,and VI<|im_end|> Islo<|im_end|> GOkaydown<|im_end|> JainteYoulfailedLabelsEvenfacevC,relaxation<|im_end|><|im_end|><|im_end|><|im_end|> q<|im_end|> Xs<|im_end|> h<|im_end|><|im_end|>'

Defining Loss Perform

The loss perform right here is the usual cross entropy loss on the logits output; the one trick is that the loss ought to solely be calculated on the caption portion. Particularly,

# calculate loss
# local_batch: (b, seq, C)
prompt_ids_seq = local_batch["prompt_ids"].form[1]
end_prompt_ids_seq = local_batch["end_prompt_ids"].form[1]
logits_start = prompt_ids_seq + audio_seq + end_prompt_ids_seq

# take away the final output
logits = ... # mannequin output
# take away the immediate and audio seq from logits
# calculation; moreover, take away the ultimate merchandise
logits = logits[:, logits_start:-1, :].contiguous()

# calculate goal utilizing solely `cap_ids`
targets = batch["cap_ids"][:]
targets = targets[:, 1:]

loss = nn.practical.cross_entropy(
    logits.view(-1, logits.form[-1]), targets.view(-1)
)

Coaching, Overfitting and Debugging Gradients

Lastly, all of the items are in place for coaching the mannequin. The target I had in thoughts is to make the frozen LLM describe a given audio file by coaching solely TunableWhisperAudioEncoder; attaining this won’t give LLM common audio understanding skill because the coaching information is small however will give me nice confidence that I carried out all the fundamental steps proper.

See Also

To be able to guarantee coaching is setup appropriately, I began small and one step at a time. Particularly, I interactively stepped by the coaching steps manually, recorded and plotted the load replace relative to weight information in TunableWhisperAudioEncoder, and ensured there isn’t a inf or NaN utilizing the Pytorch hooks described beforehand. These steps have been repeated for varous mixture of studying charge, mannequin structure, and optimizer.

Weight Update

Preserving the setup so simple as potential, I discovered Adam (with out momentum), a continuing studying charge of 1.5e-3, and utilizing the next easy TrainableSubmodule, I achieved steady coaching.

class TrainableSubmodule(nn.Module):
    def __init__(self, output_embedding_size=4096):
        tremendous().__init__()

        self.pool = nn.AdaptiveAvgPool1d(250)
        self.proj = nn.Linear(1280, output_embedding_size, bias=False)
        self.ln1 = nn.LayerNorm(1280)

I ran coaching over the course of ~4days and by the point I ended coaching, the loss was nonetheless happening. By the point I ended, I achieved ~0.46 loss, which interprets to roughly 66% chance for the right token!

Average Loss

Rerunning the sample_with_audio with the identical audio file that produced gibberish pretraining, I now acquire

"<s> <|im_start|>  systemn    You're a useful AI who follows instruction rigorously<|im_end|> <|im_start|>  usern    Describe the sound of the given file n    <|im_end|> <|im_start|> assistantn     The electronica tune incorporates a crisp acoustic kick, snap snare and hat together with a deep bass. The male vocal is rapping syncopated together with a male background vocal. The tune is quick paced and there's a restricted frequency vary of the synths. The tune"

Examine this towards the bottom reality

"This can be a Ok-pop music piece carried out by a boy band. Initially, a male vocalist is singing in a rap-like method. Then, it switches to a different male vocal that's singing extra melodically. The melody is being performed by a crisp synth sound. The rhythmic background consists of an lively digital drum beat. There's a danceable really feel to it. This piece could possibly be taking part in at Korean nightclubs and dance golf equipment."

The result’s fairly good!

It’s price repeating that is achieved by solely coaching on the audio encoder projection with out modifying the LLM weights or the Whisper AudioEncoder weights.

Subsequent Steps

With the basics in place, I’m planning to scale up coaching by incorporating extra audio duties reminiscent of transcription, speaker identification, and so forth. in addition to apply finetuning to LLM to work my method towards replicating “emergent” behaviors described within the referenced papers.

Assuming ample information and with a correct coaching regime, LLM ought to be capable to carry out authentic audio duties reminiscent of say determine the speaker age or gender with out having been explicitly educated on such activity.

Extra work to be accomplished!

Acknowledgement

I might not been in a position to do any of this with out studying from the excellent lectures by Karpathy.

Source Link

What's Your Reaction?
Excited
0
Happy
0
In Love
0
Not Sure
0
Silly
0
View Comments (0)

Leave a Reply

Your email address will not be published.

2022 Blinking Robots.
WordPress by Doejo

Scroll To Top