Finetuning LLMs on a Single GPU Utilizing Gradient Accumulation
Key takeaway
Discover ways to use gradient accumulation to coach fashions with giant batch sizes so as to work round {hardware} limitations when GPU reminiscence is a priority.
Beforehand, I shared an article using multi-GPU training strategies to speed up the finetuning of large language models. A number of of those methods embrace mechanisms similar to mannequin or tensor sharding that distributes the mannequin weights and computations throughout completely different gadgets to work round GPU reminiscence limitations.
Nevertheless, many people don’t have entry to multi-GPU assets. This text subsequently demonstrates an incredible workaround to coach fashions with bigger batch sizes when GPU reminiscence is a priority: gradient accumulation.
Let’s Finetune BLOOM for Classification
Let’s suppose we’re fascinated by adopting a latest pretrained giant language mannequin for a downstream activity similar to textual content classification. We’re going to work with BLOOM, which is an open-source different to GPT-3. Specifically, we’re going to use a model of BLOOM that “solely” has 560 million parameters — it ought to match into the RAM of typical GPUs with out issues (for reference, the free tier of Google Colab has a GPU with 15 Gb of RAM.)
As soon as we begin, nonetheless, we stumble upon issues: our reminiscence explodes throughout coaching or finetuning; we discover that the one strategy to prepare this mannequin is utilizing a batch dimension of 1.
# pip set up torch lightning matplotlib pandas torchmetrics watermark transformers datasets -U
import os
import os.path as op
import time
from datasets import load_dataset
from lightning import Material
import torch
from torch.utils.knowledge import DataLoader
import torchmetrics
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from watermark import watermark
from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset
from local_dataset_utilities import IMDBDataset
def tokenize_text(batch):
return tokenizer(batch["text"], truncation=True, padding=True, max_length=1024)
def prepare(num_epochs, mannequin, optimizer, train_loader, val_loader, material):
for epoch in vary(num_epochs):
train_acc = torchmetrics.Accuracy(
activity="multiclass", num_classes=2).to(material.system)
for batch_idx, batch in enumerate(train_loader):
mannequin.prepare()
### FORWARD AND BACK PROP
outputs = mannequin(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["label"]
)
material.backward(outputs["loss"])
### UPDATE MODEL PARAMETERS
optimizer.step()
optimizer.zero_grad()
### LOGGING
if not batch_idx % 300:
print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} "
f"| Batch {batch_idx:04d}/{len(train_loader):04d} "
f"| Loss: {outputs['loss']:.4f}")
mannequin.eval()
with torch.no_grad():
predicted_labels = torch.argmax(outputs["logits"], 1)
train_acc.replace(predicted_labels, batch["label"])
### MORE LOGGING
mannequin.eval()
with torch.no_grad():
val_acc = torchmetrics.Accuracy(activity="multiclass", num_classes=2).to(material.system)
for batch in val_loader:
outputs = mannequin(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["label"]
)
predicted_labels = torch.argmax(outputs["logits"], 1)
val_acc.replace(predicted_labels, batch["label"])
print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} "
f"| Prepare acc.: {train_acc.compute()*100:.2f}% "
f"| Val acc.: {val_acc.compute()*100:.2f}%"
)
train_acc.reset(), val_acc.reset()
if __name__ == "__main__":
print(watermark(packages="torch,lightning,transformers", python=True))
print("Torch CUDA obtainable?", torch.cuda.is_available())
system = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(123)
# torch.use_deterministic_algorithms(True)
##########################
### 1 Loading the Dataset
##########################
download_dataset()
df = load_dataset_into_to_dataframe()
if not (op.exists("prepare.csv") and op.exists("val.csv") and op.exists("take a look at.csv")):
partition_dataset(df)
imdb_dataset = load_dataset(
"csv",
data_files={
"prepare": "prepare.csv",
"validation": "val.csv",
"take a look at": "take a look at.csv",
},
)
#########################################
### 2 Tokenization and Numericalization
#########################################
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m", max_length=1024)
print("Tokenizer enter max size:", tokenizer.model_max_length, flush=True)
print("Tokenizer vocabulary dimension:", tokenizer.vocab_size, flush=True)
print("Tokenizing ...", flush=True)
imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None)
del imdb_dataset
imdb_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])
os.environ["TOKENIZERS_PARALLELISM"] = "false"
#########################################
### 3 Set Up DataLoaders
#########################################
train_dataset = IMDBDataset(imdb_tokenized, partition_key="prepare")
val_dataset = IMDBDataset(imdb_tokenized, partition_key="validation")
test_dataset = IMDBDataset(imdb_tokenized, partition_key="take a look at")
train_loader = DataLoader(
dataset=train_dataset,
batch_size=1,
shuffle=True,
num_workers=4,
drop_last=True,
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=1,
num_workers=4,
drop_last=True,
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=1,
num_workers=2,
drop_last=True,
)
#########################################
### 4 Initializing the Mannequin
#########################################
material = Material(accelerator="cuda", gadgets=1, precision="16-mixed")
material.launch()
mannequin = AutoModelForSequenceClassification.from_pretrained(
"bigscience/bloom-560m", num_labels=2)
optimizer = torch.optim.Adam(mannequin.parameters(), lr=5e-5)
mannequin, optimizer = material.setup(mannequin, optimizer)
train_loader, val_loader, test_loader = material.setup_dataloaders(
train_loader, val_loader, test_loader)
#########################################
### 5 Finetuning
#########################################
begin = time.time()
prepare(
num_epochs=1,
mannequin=mannequin,
optimizer=optimizer,
train_loader=train_loader,
val_loader=val_loader,
material=material,
)
finish = time.time()
elapsed = end-start
print(f"Time elapsed {elapsed/60:.2f} min")
with torch.no_grad():
mannequin.eval()
test_acc = torchmetrics.Accuracy(
activity="multiclass", num_classes=2).to(material.system)
for batch in test_loader:
outputs = mannequin(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["label"]
)
predicted_labels = torch.argmax(outputs["logits"], 1)
test_acc.replace(predicted_labels, batch["label"])
print(f"Take a look at accuracy {test_acc.compute()*100:.2f}%")
I’m utilizing Lightning Fabric as a result of it permits me to flexibly change the variety of GPUs and multi-GPU coaching technique when working this code on completely different {hardware}. It additionally lets me allow mixed-precision coaching by solely adjusting the precision flag. On this case, mixed-precision coaching can triple the coaching pace and cut back reminiscence necessities by roughly 25%.
The principle code proven above is executed within the if __name__ == "__main__"
context, which is advisable when working Python scripts for multi-GPU coaching with PyTorch — though we’re solely utilizing a single GPU, it’s a finest observe that we undertake. Then, the next three code sections inside the if __name__ == "__main__"
, care for the information loading:
# 1 Loading the Dataset
# 2 Tokenization and Numericalization
# 3 Setting Up DataLoaders
In part # 4 Initializing the Mannequin
, we initialize the mannequin. Then, in part # 5 Finetuning
, we name the prepare perform, which is the place issues get fascinating. Within the prepare(...)
perform, we implement our normal PyTorch loop. An annotated model of the core coaching loop is proven beneath.
The issue with batch sizes of 1 is that the gradient updates will likely be extraordinarily noisy, as we are able to see primarily based on the fluctuating coaching loss and poor take a look at set efficiency beneath once we prepare the mannequin:
...
torch : 2.0.0
lightning : 2.0.0
transformers: 4.27.2
Torch CUDA obtainable? True
...
Epoch: 0001/0001 | Batch 23700/35000 | Loss: 0.0969
Epoch: 0001/0001 | Batch 24000/35000 | Loss: 1.9902
Epoch: 0001/0001 | Batch 24300/35000 | Loss: 0.0395
Epoch: 0001/0001 | Batch 24600/35000 | Loss: 0.2546
Epoch: 0001/0001 | Batch 24900/35000 | Loss: 0.1128
Epoch: 0001/0001 | Batch 25200/35000 | Loss: 0.2661
Epoch: 0001/0001 | Batch 25500/35000 | Loss: 0.0044
Epoch: 0001/0001 | Batch 25800/35000 | Loss: 0.0067
Epoch: 0001/0001 | Batch 26100/35000 | Loss: 0.0468
Epoch: 0001/0001 | Batch 26400/35000 | Loss: 1.7139
Epoch: 0001/0001 | Batch 26700/35000 | Loss: 0.9570
Epoch: 0001/0001 | Batch 27000/35000 | Loss: 0.1857
Epoch: 0001/0001 | Batch 27300/35000 | Loss: 0.0090
Epoch: 0001/0001 | Batch 27600/35000 | Loss: 0.9790
Epoch: 0001/0001 | Batch 27900/35000 | Loss: 0.0503
Epoch: 0001/0001 | Batch 28200/35000 | Loss: 0.2625
Epoch: 0001/0001 | Batch 28500/35000 | Loss: 0.1010
Epoch: 0001/0001 | Batch 28800/35000 | Loss: 0.0035
Epoch: 0001/0001 | Batch 29100/35000 | Loss: 0.0009
Epoch: 0001/0001 | Batch 29400/35000 | Loss: 0.0234
Epoch: 0001/0001 | Batch 29700/35000 | Loss: 0.8394
Epoch: 0001/0001 | Batch 30000/35000 | Loss: 0.9497
Epoch: 0001/0001 | Batch 30300/35000 | Loss: 0.1437
Epoch: 0001/0001 | Batch 30600/35000 | Loss: 0.1317
Epoch: 0001/0001 | Batch 30900/35000 | Loss: 0.0112
Epoch: 0001/0001 | Batch 31200/35000 | Loss: 0.0073
Epoch: 0001/0001 | Batch 31500/35000 | Loss: 0.7393
Epoch: 0001/0001 | Batch 31800/35000 | Loss: 0.0512
Epoch: 0001/0001 | Batch 32100/35000 | Loss: 0.1337
Epoch: 0001/0001 | Batch 32400/35000 | Loss: 1.1875
Epoch: 0001/0001 | Batch 32700/35000 | Loss: 0.2727
Epoch: 0001/0001 | Batch 33000/35000 | Loss: 0.1545
Epoch: 0001/0001 | Batch 33300/35000 | Loss: 0.0022
Epoch: 0001/0001 | Batch 33600/35000 | Loss: 0.2681
Epoch: 0001/0001 | Batch 33900/35000 | Loss: 0.2467
Epoch: 0001/0001 | Batch 34200/35000 | Loss: 0.0620
Epoch: 0001/0001 | Batch 34500/35000 | Loss: 2.5039
Epoch: 0001/0001 | Batch 34800/35000 | Loss: 0.0131
Epoch: 0001/0001 | Prepare acc.: 75.11% | Val acc.: 78.62%
Time elapsed 69.97 min
Take a look at accuracy 78.53%
Since we don’t have a number of GPUs obtainable for tensor sharding, what can we do to coach the mannequin with bigger batch sizes?
One workaround is gradient accumulation, the place we modify the aforementioned coaching loop.
What’s gradient accumulation?
Gradient accumulation is a strategy to nearly improve the batch dimension throughout coaching, which may be very helpful when the obtainable GPU reminiscence is inadequate to accommodate the specified batch dimension. In gradient accumulation, gradients are computed for smaller batches and amassed (often summed or averaged) over a number of iterations as a substitute of updating the mannequin weights after each batch. As soon as the amassed gradients attain the goal “digital” batch dimension, the mannequin weights are up to date with the amassed gradients.
For instance this, contemplate the up to date PyTorch coaching loop beneath. (The full script is available here on GitHub.)
If we set accumulation_steps
to 2, then zero_grad()
and optimizer.step()
will solely be referred to as each second epoch. Consequently, working the modified coaching loop with accumulation_steps=2
could have the identical impact as doubling the batch dimension.
For instance, if we wish to use a batch dimension of 256 however can solely match a batch dimension of 64 into GPU reminiscence, we are able to carry out gradient accumulation over 4 batches of dimension 64. (After processing all 4 batches, we could have the amassed gradients equal to a single batch of dimension 256.) This permits us to successfully emulate a bigger batch dimension with out requiring bigger GPU reminiscence or tensor sharding throughout completely different gadgets.
Whereas gradient accumulation can assist us prepare fashions with bigger batch sizes, it doesn’t cut back the entire computation required. Actually, it may well generally result in a barely slower coaching course of, as the burden updates are carried out much less ceaselessly. Nonetheless, it permits us to work round limitations the place we’ve very small batch sizes that result in noisy updates.
For instance, let’s now run the code from above, the place we’ve a batch dimension of 1, with 16 accumulation steps to simulate a batch dimension of 16. You possibly can obtain the code right here.
The output is as follows:
...
torch : 2.0.0
lightning : 2.0.0
transformers: 4.27.2
Torch CUDA obtainable? True
...
Epoch: 0001/0001 | Batch 23700/35000 | Loss: 0.0168
Epoch: 0001/0001 | Batch 24000/35000 | Loss: 0.0006
Epoch: 0001/0001 | Batch 24300/35000 | Loss: 0.0152
Epoch: 0001/0001 | Batch 24600/35000 | Loss: 0.0003
Epoch: 0001/0001 | Batch 24900/35000 | Loss: 0.0623
Epoch: 0001/0001 | Batch 25200/35000 | Loss: 0.0010
Epoch: 0001/0001 | Batch 25500/35000 | Loss: 0.0001
Epoch: 0001/0001 | Batch 25800/35000 | Loss: 0.0047
Epoch: 0001/0001 | Batch 26100/35000 | Loss: 0.0004
Epoch: 0001/0001 | Batch 26400/35000 | Loss: 0.1016
Epoch: 0001/0001 | Batch 26700/35000 | Loss: 0.0021
Epoch: 0001/0001 | Batch 27000/35000 | Loss: 0.0015
Epoch: 0001/0001 | Batch 27300/35000 | Loss: 0.0008
Epoch: 0001/0001 | Batch 27600/35000 | Loss: 0.0060
Epoch: 0001/0001 | Batch 27900/35000 | Loss: 0.0001
Epoch: 0001/0001 | Batch 28200/35000 | Loss: 0.0426
Epoch: 0001/0001 | Batch 28500/35000 | Loss: 0.0012
Epoch: 0001/0001 | Batch 28800/35000 | Loss: 0.0025
Epoch: 0001/0001 | Batch 29100/35000 | Loss: 0.0025
Epoch: 0001/0001 | Batch 29400/35000 | Loss: 0.0000
Epoch: 0001/0001 | Batch 29700/35000 | Loss: 0.0495
Epoch: 0001/0001 | Batch 30000/35000 | Loss: 0.0164
Epoch: 0001/0001 | Batch 30300/35000 | Loss: 0.0067
Epoch: 0001/0001 | Batch 30600/35000 | Loss: 0.0037
Epoch: 0001/0001 | Batch 30900/35000 | Loss: 0.0005
Epoch: 0001/0001 | Batch 31200/35000 | Loss: 0.0013
Epoch: 0001/0001 | Batch 31500/35000 | Loss: 0.0112
Epoch: 0001/0001 | Batch 31800/35000 | Loss: 0.0053
Epoch: 0001/0001 | Batch 32100/35000 | Loss: 0.0012
Epoch: 0001/0001 | Batch 32400/35000 | Loss: 0.1365
Epoch: 0001/0001 | Batch 32700/35000 | Loss: 0.0210
Epoch: 0001/0001 | Batch 33000/35000 | Loss: 0.0374
Epoch: 0001/0001 | Batch 33300/35000 | Loss: 0.0007
Epoch: 0001/0001 | Batch 33600/35000 | Loss: 0.0341
Epoch: 0001/0001 | Batch 33900/35000 | Loss: 0.0259
Epoch: 0001/0001 | Batch 34200/35000 | Loss: 0.0005
Epoch: 0001/0001 | Batch 34500/35000 | Loss: 0.4792
Epoch: 0001/0001 | Batch 34800/35000 | Loss: 0.0003
Epoch: 0001/0001 | Prepare acc.: 78.67% | Val acc.: 87.28%
Time elapsed 51.37 min
Take a look at accuracy 87.37%
As we are able to see, primarily based on the outcomes above, the loss fluctuates lower than earlier than. As well as, the take a look at set efficiency elevated by 10%! We’re solely iterating by means of the coaching set as soon as, so every coaching instance is simply encountered a single time. Coaching the mannequin for a number of epochs can additional enhance the predictive efficiency, however I’ll depart this as an train so that you can check out (and let me know the way it goes on Discord!).
You will have additionally seen that this code additionally executed sooner than the code we used beforehand with a batch dimension of 1. If we improve the digital batch dimension to eight utilizing gradient accumulation, we nonetheless have the identical variety of ahead passes. Nevertheless, since we solely replace the mannequin each eighth epoch, we’ve fewer backward passes, which lets us iterate by means of the examples in a single epoch sooner.
Conclusion
Gradient accumulation is a method that simulates a bigger batch dimension by accumulating gradients from a number of small batches earlier than performing a weight replace. This system may be useful in eventualities the place the obtainable reminiscence is restricted, and the batch dimension that may slot in reminiscence is small.
Nevertheless, contemplate a situation in which you’ll be able to run the batch dimension within the first place, that means the obtainable reminiscence is giant sufficient to accommodate the specified batch dimension. In that case, gradient accumulation is probably not crucial. Actually, working a bigger batch dimension may be extra environment friendly as a result of it permits for extra parallelism and reduces the variety of weight updates required to coach the mannequin.
In abstract, gradient accumulation generally is a helpful approach for lowering the affect of noise in small batch sizes on the accuracy of gradient updates. It’s a easy but efficient approach that lets us work round {hardware} limitations.
For reference, all code accompanying this weblog submit is obtainable here on GitHub.