Listening with LLM – moomou
Overview
That is the primary a part of many posts I’m writing to consolidate learnings on the best way to finetune Massive Language Fashions (LLMs) to course of audio, with the eventual purpose of with the ability to construct and host a LLM in a position to describe human voices.
I’m motivated to achieve hands-on expertise tinkering LLMs so, as a lot as sensible, I attempted to recreate utilities and capabilities with pytorch from scratch moderately than depend on third celebration libraries.
tl;dr I chronicle and share the steps I took to learn to finetune a LLM mannequin to explain a given audio file on Google’s MusicCaps dataset
Background
Just lately, I got here throughout two papers
to offer LLMs audio understanding capabilities.
Broadly talking, each papers explored leveraging an audio encoder to remodel sound to embeddings that’s then fed into LLMs together with textual content embeddings.
In SALMONN’s case, they mixed OpenAI’s Whisper and BEATS encoder, carried out pretraining on the mixed encoder, then leveraged LoRA for finetuning the LLM. Qwen-Audio bootstrapped its audio encoder from OpenAI’s Whisper; after pretraining, Qwen-Audio performs a full finetuning on the LLM.
These two papers gave me an important overview on the best way to adapt cross area encoders and mix them with LLMs.
Excited by the concept of a LLM with common audio understanding skill and itching to achieve hands-on expertise, I made a decision to try to construct a minimal viable LLM with audio processing functionality.
Setup
To get began, I hopped over to HuggingFace to discover a good base LLM and a medium-sized dataset. I wished to do as a lot work domestically as potential so everythign should run on a neighborhood RTX 3090.
After testing and evaluating a couple of completely different fashions, I settled on Mistral OpenOrca.
For audio encoder, I went with OpenAI’s Whisper.
For dataset, I selected MusicCaps. I didn’t see any handy hyperlinks to obtain processed/segmented audio recordsdata, so I wrote a small script to obtain the Youtube movies.
One Mini Step at a Time
With the fundamental dependencies out of the best way, I fired up my Jupyter pocket book and began tinkering.
Sampling from Scratch
Step one I took is to make sure I can load the bottom LLM and carry out inference appropriately. As an alternative of leveraging transformers library’s generation utilities, I applied my very own sampling perform to confirm my understanding in addition to to learn to pattern utilizing embeddings instantly, which can turn out to be useful when feeding in audio embeddings.
@torch.no_grad
def sampler(input_ids):
outputs = []
for _ in vary(50):
inputs_embeds = mannequin.llm.mannequin.embed_tokens(input_ids)
res = mannequin.llm(inputs_embeds=inputs_embeds)
# res.logits form is (batch, seq_len, logits)
# pattern utilizing multinomial utilizing the final logits
sampled = torch.multinomial(res.logits[:,-1,:].softmax(dim=-1), 1)
# repeatedly concat the `sampled` to the `input_ids` for subsequent sampling
input_ids = torch.cat((input_ids, sampled), dim=-1)
return input_ids
Utilizing the tokenizer
class obtained from Transformer’s AutoTokenizer
class, I used to be in a position to confirm sampling labored as anticipated! Working
tokenizer.decode(sampler(tokenizer("inform me a narrative", return_tensors="pt").input_ids.to("cuda:0"))[0])
yields (for example output)
'<s>inform me a narrative is a movie and video manufacturing firm, inform me a narrative is an idea that was created to permit folks to return collectively by the ability of storytelling.n and so, with this large energy in storytelling, the founders and creat'
Debugging NaNs and Infs
To this point so good. Nonetheless, I quickly observed that, sometimes, the sampling perform would fail by complaining that softmax perform encountered an inf
or NaN
. I adopted this insightful thread and learnt to determine the supply of NaN
through the use of the next tailored Pytorch hooks
import torch
from functools import partial
__registered_hook_refs = []
for h in __registered_hook_refs:
h.take away()
__global = []
def nan_hook(module, args, output, identify=None):
if not isinstance(output, tuple):
outputs = [output]
else:
outputs = output
for i, out in enumerate(outputs):
if out is None:
proceed
if isinstance(out, tuple):
for j, out2 in enumerate(out):
nan_mask = torch.isnan(out2)
if nan_mask.any():
__global.append((module, args, output))
elevate RuntimeError(f"In module {identify} of identify {module.__class__.__name__}, Discovered NAN in output {j} at indices: ", nan_mask.nonzero(), "the place:",
out[nan_mask.nonzero()[:, 0].distinctive(sorted=True)])
elif torch.is_tensor(out):
nan_mask = torch.isnan(out)
if nan_mask.any():
__global.append((module, args, output))
elevate RuntimeError(f"In module {identify} of identify {module.__class__.__name__}, Discovered NAN in output {i} at indices: ", nan_mask.nonzero(), "the place:",
out[nan_mask.nonzero()[:, 0].distinctive(sorted=True)])
def register_nan_hook(mannequin: torch.nn.Module):
for identify, submodule in mannequin.named_modules():
new_hook = partial(nan_hook, identify=identify+'.again')
hook_ref = submodule.register_full_backward_hook(new_hook)
__registered_hook_refs.append(hook_ref)
new_hook = partial(nan_hook, identify=identify+'.fwd')
hook_ref = submodule.register_forward_hook(new_hook)
__registered_hook_refs.append(hook_ref)
debug = True
register_nan_hook(mannequin) if debug else None
Leveraging these hooks narrowed down the supply of difficulty to a selected layer and from there I used to be in a position to hint the issue to an inf
worth within the mannequin weights. Digging additional, I traced the supply of inf
to bad RAM sticks! After mitigation, I wrote a small script to confirm the mannequin weights and confirmed sampling perform labored as anticipated.
# confirm mannequin weight
from collections import Counter
pbytype = Counter()
for identify, p in (mannequin.named_parameters()):
if torch.isinf(p).any() or torch.isnan(p).any():
print(identify, p)
elevate ValueError("invalid weight")
else:
pbytype[p.dtype] += 1
print("OK", pbytype)
Adapting Whisper to Mistral
After gaining confidence with debugging Pytorch modules, I centered on adapting Whisper mannequin so audio recordsdata could be remodeled into an embedding that may then be fed into Mistral.
OpenAI’s Whisper model consists of two main elements, an AudioEncoder
and a TextDecoder
.
For the aim of translating audio into embeddings, I solely want the AudioEncoder
part.
Due to this fact, I loaded up a full Whisper mannequin and extracted the AudioEncoder
weights utilizing the next snippets
import whisper
mannequin = whisper.load_model("large-v3")
audio_encoder = mannequin.encoder
torch.save(
audio_encoder.state_dict(),
"<output_location>",
)
I tailored the Whisper AudioEncoder right into a TunableWhisperAudioEncoder
with an additional projection layer to map from Whisper’s audio embedding (measurement 1280) to mistral’s token embedding (measurement 4096).
I ensured proj
is the one trainable community by explicitly freezing the audio encoder’s parameters. Notice that TrainableSubmodule
is a hyperparameter and any mannequin that maps the output embedding to measurement 4096 will work. Later within the submit, I’ll describe what I discovered to work for me.
class TunableWhisperAudioEncoder(nn.Module):
def __init__(self, *, output_embedding_size=4096):
"""
args
output_embedding_size: int = 4096 / mistral default embedding measurement
"""
tremendous().__init__()
self.audio_encoder = load_whisper_v3_audio_encoder()
self.proj = TrainableSubmodule(output_embedding_size=output_embedding_size)
# # Freeze all parameters
for param in audio_encoder.parameters():
param.requires_grad = False
def ahead(self, mels):
res = self.audio_encoder(mels)
res = self.proj(res)
return res
def load_whisper_v3_audio_encoder(
*,
n_mels=128,
n_audio_ctx=1500,
n_audio_state=1280,
n_audio_head=20,
n_audio_layer=32,
):
m = whisper.mannequin.AudioEncoder(
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer
)
m.load_state_dict(torch.load(WHISPER_AUDIO_BIN))
return m
Lastly, I construct up the mannequin I’m going to make use of for coaching as follows
class Mannequin(nn.Module):
def __init__(self, audio_encoder: "Whisper.AudioEncoder", llm: "Mistral"):
tremendous().__init__()
self.audio_encoder = audio_encoder
self.llm = llm
# freeze the LLM weights
for p in self.llm.parameters():
p.requires_grad = False
def ahead(self, batch):
audio_mels = batch["audio_mels"]
# caption token ids
cap_ids = batch["cap_ids"]
# caption consideration masks
cap_ids_attention_mask = batch["cap_attention_mask"]
prompt_ids = batch["prompt_ids"]
prompt_ids_attention_mask = batch["prompt_attention_mask"]
end_prompt_ids = batch["end_prompt_ids"]
end_prompt_ids_attention_mask = batch["end_prompt_attention_mask"]
audio_embeds = self.audio_encoder(audio_mels)
# audio_embeds: (batch, audio_seq_len, audio_embedding_size)
bs, audio_seq = audio_embeds.form[:2]
attention_mask = torch.concat(
(
prompt_ids_attention_mask,
torch.ones(bs, audio_seq).to(cap_ids.machine),
end_prompt_ids_attention_mask,
cap_ids_attention_mask,
),
dim=1,
)
cap_embeds = self.llm.mannequin.embed_tokens(cap_ids)
prompt_embeds = self.llm.mannequin.embed_tokens(prompt_ids)
end_prompt_embeds = self.llm.mannequin.embed_tokens(end_prompt_ids)
# construct the inputs_embeds by concating all of the token embeddings
# with audio_embeddings
inputs_embeds = torch.concat(
(
prompt_embeds,
audio_embeds.to(cap_embeds.dtype),
end_prompt_embeds,
cap_embeds,
),
dim=1,
)
mout = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
return mout, audio_embeds.form[1]
The mannequin itself is sort of easy in that it merely holds reference to the Mistral LLM and TunableWhisperAudioEncoder
. The ahead
technique encapsulates the logic of changing audio mel-spectrogram into audio embeddings, then concatenating the audio embeddings with textual content/token embeddings to feeding these into Mistral LLM.
Sampling with Audio from Scratch
With the fundamental mannequin in place, the subsequent step is to try to pattern from this mannequin with audio inputs. Right here is the audio sampling perform I got here up with.
# word, full gist is offered at https://gist.github.com/moomou/7df8345d79a0063d67d1fa2b4cf55db8
@torch.no_grad()
def sample_with_audio(mannequin, tokenizer, immediate, audio_file, machine="cuda:0", iteration=50):
audio_mels = load_audio_mels(audio_file).to(machine).half()
end_prompt_ids, end_prompt_attention_mask = text_2_ids_and_attention_mask(
tokenizer,
end_template(),
truncate=True,
)
prompt_ids, prompt_attention_mask = text_2_ids_and_attention_mask(
tokenizer,
immediate,
)
prompt_ids = prompt_ids.to(machine)
prompt_attention_mask = prompt_attention_mask.to(machine)
end_prompt_attention_mask = end_prompt_attention_mask.to(machine)
end_prompt_ids = end_prompt_ids.to(machine)
sampled_ids = None
prompt_embeds = None
end_prompt_embeds = None
audio_embeds = None
with torch.amp.autocast(device_type="cuda", dtype=torch.float16): # use float16 to cut back GPU reminiscence
if audio_embeds is None:
audio_embeds = mannequin.audio_encoder(audio_mels)
bs, audio_seq = audio_embeds.form[:2]
mask_concat_args = [
prompt_attention_mask,
torch.ones(bs, audio_seq).to(audio_embeds.device),
end_prompt_attention_mask,
]
for _ in vary(iteration):
if sampled_ids is not None:
mask_concat_args.append(torch.ones(bs, sampled_ids.form[1]).to(audio_embeds.machine))
attention_mask = torch.concat(
tuple(mask_concat_args),
dim=1,
)
if prompt_embeds is None:
prompt_embeds = mannequin.llm.mannequin.embed_tokens(prompt_ids)
if end_prompt_embeds is None:
end_prompt_embeds = mannequin.llm.mannequin.embed_tokens(end_prompt_ids)
sampled_ids_embeds = None
if sampled_ids is not None:
sampled_ids_embeds = mannequin.llm.mannequin.embed_tokens(sampled_ids)
embeds_concat_args = [
prompt_embeds,
audio_embeds.to(prompt_embeds.dtype),
end_prompt_embeds,
]
if sampled_ids_embeds is not None:
embeds_concat_args.append(sampled_ids_embeds)
inputs_embeds = torch.concat(
tuple(embeds_concat_args),
dim=1,
)
mout = mannequin.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
logits = mout.logits
sampled = torch.multinomial(logits[:, -1, :].softmax(dim=-1), 1)
if sampled_ids is None:
sampled_ids = sampled
else:
sampled_ids = torch.cat((sampled_ids, sampled), dim=-1).to(machine)
return torch.concat((
prompt_ids,
end_prompt_ids,
sampled_ids,
),dim=-1)
Placing the perform to make use of by way of
dataloader = ... # commonplace pytorch dataloader
local_batch = subsequent(iter(dataloader))
tokenizer.decode(sample_with_audio(mannequin, tokenizer, prompt_template_fn(), audio_file, iteration=60)[0])
produces gibberish as anticipated since TunableWhisperAudioEncoder
projection layer is untrained.
'<s> <|im_start|> systemn You're a useful AI who follows instruction rigorously<|im_end|> <|im_start|> usern Describe the sound of the given file n <|im_end|> <|im_start|> assistantn battle<|im_end|> clockunits ]andfirst4Iftektime爆R Cur<|im_end|> United<|im_end|> ’daysIn“By no means<|im_end|> thenAnd,and VI<|im_end|> Islo<|im_end|> GOkaydown<|im_end|> JainteYoulfailedLabelsEvenfacevC,relaxation<|im_end|><|im_end|><|im_end|><|im_end|> q<|im_end|> Xs<|im_end|> h<|im_end|><|im_end|>'
Defining Loss Perform
The loss perform right here is the usual cross entropy loss on the logits output; the one trick is that the loss ought to solely be calculated on the caption portion. Particularly,
# calculate loss
# local_batch: (b, seq, C)
prompt_ids_seq = local_batch["prompt_ids"].form[1]
end_prompt_ids_seq = local_batch["end_prompt_ids"].form[1]
logits_start = prompt_ids_seq + audio_seq + end_prompt_ids_seq
# take away the final output
logits = ... # mannequin output
# take away the immediate and audio seq from logits
# calculation; moreover, take away the ultimate merchandise
logits = logits[:, logits_start:-1, :].contiguous()
# calculate goal utilizing solely `cap_ids`
targets = batch["cap_ids"][:]
targets = targets[:, 1:]
loss = nn.practical.cross_entropy(
logits.view(-1, logits.form[-1]), targets.view(-1)
)
Coaching, Overfitting and Debugging Gradients
Lastly, all of the items are in place for coaching the mannequin. The target I had in thoughts is to make the frozen LLM describe a given audio file by coaching solely TunableWhisperAudioEncoder
; attaining this won’t give LLM common audio understanding skill because the coaching information is small however will give me nice confidence that I carried out all the fundamental steps proper.
To be able to guarantee coaching is setup appropriately, I began small and one step at a time. Particularly, I interactively stepped by the coaching steps manually, recorded and plotted the load replace relative to weight information in TunableWhisperAudioEncoder
, and ensured there isn’t a inf
or NaN
utilizing the Pytorch hooks described beforehand. These steps have been repeated for varous mixture of studying charge, mannequin structure, and optimizer.
Preserving the setup so simple as potential, I discovered Adam (with out momentum), a continuing studying charge of 1.5e-3, and utilizing the next easy TrainableSubmodule
, I achieved steady coaching.
class TrainableSubmodule(nn.Module):
def __init__(self, output_embedding_size=4096):
tremendous().__init__()
self.pool = nn.AdaptiveAvgPool1d(250)
self.proj = nn.Linear(1280, output_embedding_size, bias=False)
self.ln1 = nn.LayerNorm(1280)
I ran coaching over the course of ~4days and by the point I ended coaching, the loss was nonetheless happening. By the point I ended, I achieved ~0.46 loss, which interprets to roughly 66% chance for the right token!
Rerunning the sample_with_audio
with the identical audio file that produced gibberish pretraining, I now acquire
"<s> <|im_start|> systemn You're a useful AI who follows instruction rigorously<|im_end|> <|im_start|> usern Describe the sound of the given file n <|im_end|> <|im_start|> assistantn The electronica tune incorporates a crisp acoustic kick, snap snare and hat together with a deep bass. The male vocal is rapping syncopated together with a male background vocal. The tune is quick paced and there's a restricted frequency vary of the synths. The tune"
Examine this towards the bottom reality
"This can be a Ok-pop music piece carried out by a boy band. Initially, a male vocalist is singing in a rap-like method. Then, it switches to a different male vocal that's singing extra melodically. The melody is being performed by a crisp synth sound. The rhythmic background consists of an lively digital drum beat. There's a danceable really feel to it. This piece could possibly be taking part in at Korean nightclubs and dance golf equipment."
The result’s fairly good!
It’s price repeating that is achieved by solely coaching on the audio encoder projection with out modifying the LLM weights or the Whisper AudioEncoder weights.
Subsequent Steps
With the basics in place, I’m planning to scale up coaching by incorporating extra audio duties reminiscent of transcription, speaker identification, and so forth. in addition to apply finetuning to LLM to work my method towards replicating “emergent” behaviors described within the referenced papers.
Assuming ample information and with a correct coaching regime, LLM ought to be capable to carry out authentic audio duties reminiscent of say determine the speaker age or gender with out having been explicitly educated on such activity.
Extra work to be accomplished!
Acknowledgement
I might not been in a position to do any of this with out studying from the excellent lectures by Karpathy.