# Grokking

*by*Phil Tadros

All of the code may be discovered here.

The “Grokking Paper” is likely one of the most head-scratching papers to return out within the neural community area. It explores the phenomenon of a regime change, whereby the mannequin seems, by all indications, to have overfit the info, and that it is solely being exacerbated as coaching progresses. Validation loss is rising, and validation accuracy is at a standstill. In the meantime, 100% coaching accuracy was hit ages in the past. However then, swiftly, as if a divine entity itself sprinkled fairy mud on the neural community, validation loss begins to lower and validation accuracy will increase. After some time, each accuracies are at 100%. In essence, the neural community is transitioning from a regime of memorization to generalization.

### Experiment Setup

The datasets are of the shape:

the place $circ$ may be an arbitrary binary operation with a constant modulus base. For the experiments this can be 97. For some operations this implies the desk has $97^2 = 9409$

$a$, $b$, and $c$ are placeholders however symbols are used to symbolize every quantity.

```
julia> alphabet = string.(vcat('a':'z', 'A':'Z', Char.(1024:2048)))
1077-factor Vector{String}:
"a"
"b"
"c"
"d"
⋮
"߽"
"߾"
"߿"
"ࠀ"
julia>
julia> modn = 97
97
julia>
julia> nums = gather(-modn:modn)
195-factor Vector{Int64}:
-97
-96
-95
-94
⋮
94
95
96
97
julia> num2tok = Dict{Int,String}()
Dict{Int64, String}()
julia> for (i, n) in enumerate(nums)
num2tok[n] = alphabet[i]
finish
julia> num2tok
Dict{Int64, String} with 195 entries:
56 => "ѥ"
35 => "ѐ"
60 => "ѩ"
-5 => "Ш"
67 => "Ѱ"
73 => "Ѷ"
-66 => "F"
-71 => "A"
⋮ => ⋮
julia>
julia> tok2num = Dict(values(num2tok) .=> keys(num2tok))
Dict{String, Int64} with 195 entries:
"Z" => -46
"ф" => 23
"Л" => -18
"C" => -69
"т" => 21
"r" => -80
"л" => 14
"ѱ" => 68
⋮ => ⋮
```

And so the neural community won’t ever see an precise quantity simply the image that represents it.

Your complete code for producing a dataset:

```
perform create_dataset_binop_with_mod(f::Perform, modn::Int)
nums = gather(-modn:modn)
num2tok = Dict{Int,String}()
for (i, n) in enumerate(nums)
num2tok[n] = alphabet[i]
finish
tok2num = Dict(values(num2tok) .=> keys(num2tok))
toks = gather(values(num2tok))
push!(toks, "=")
push!(toks, "∘")
tok2idx = Dict(c => i for (i, c) in enumerate(toks))
idx2tok = Dict(i => c for (i, c) in enumerate(toks))
information = Vector{Int}[]
for a in 1:modn
for b in 1:modn
c = f(a, b)
s = "$(num2tok[a])∘$(num2tok[b])=$(num2tok[c])"
enc = [tok2idx[string(c)] for c in s]
push!(information, enc)
finish
finish
Random.shuffle!(information)
X = zeros(Int, (size(information[1]) - 1, size(information)))
y = zeros(Int, size(information))
for (i, enc) in enumerate(information)
X[:, i] = enc[1:end-1]
y[i] = enc[end]
finish
return (X, Flux.onehotbatch(y, 1:size(tok2idx))), tok2idx, idx2tok
finish
```

That is an important part:

```
c = f(a, b)
s = "$(num2tok[a])∘$(num2tok[b])=$(num2tok[c])"
```

The community solely sees “SymbolA∘SymbolB=SymbolC”.

The community itself is a two layer transformer with 128 dimensions cut up over 4 heads.

```
vocabsize = dimension(trainY, 1)
blocksize = dimension(trainX, 1)
dembed = 128
nheads = 4
nlayers = 2
circ = Circuit(vocabsize, blocksize, dembed; nheads, nlayers) |> gpu;
decide = Flux.setup(AdamW(3e-4), circ);
```

The vocabulary dimension is the 195 symbols plus 2 additional for $circ$ and $=$.

The block dimension is 4 (all tokens earlier than SymbolC).

### Information Break up

Memorization -> generalization is established because of the dataset cut up. For the experiments we’ll do it is a 50/50 cut up. Which means half of the info the neural community may have by no means seen. It can not, by definition memorize it. The one manner for it to appropriately label the examples within the validation set is to determine the underlying perform being carried out. To generalize.

```
X, Y = information;
trainfrac = 0.5;
N = dimension(X, 2);
n = Int(spherical(N * trainfrac));
trainX, trainY = X[:, 1:n], Y[:, 1:n];
valX, valY = X[:, n+1:N], Y[:, n+1:N];
trainX = trainX |> gpu;
trainY = trainY |> gpu;
valX = valX |> gpu;
valY = valY |> gpu;
train_batchsize = min(512, dimension(trainX, 2))
val_batchsize = min(512, dimension(valX, 2))
traindata = Flux.DataLoader((trainX, trainY), batchsize = train_batchsize, shuffle = true);
valdata = Flux.DataLoader((valX, valY), batchsize = val_batchsize);
```

Within the unique paper they run experiments over quite a lot of features and cut up sizes. I’ve picked 4 features from the paper I assumed could be worthwhile reproducing.

Every of those had a 50/50 cut up besides the final one which I additionally tried with a 95/5 cut up, as they did within the paper. It didn’t generalize within the paper and it didn’t generalize for me 🙁

The outcomes …

### Run It

```
information, _, _ = create_dataset_binop_with_mod((a, b) -> (a + b) % modn, modn)
...
run = Run()
evalevery = 10
train_model!(
circ,
decide,
traindata;
nepochs = 10_000,
evaliters = 10,
evalevery = evalevery,
valdata = valdata,
seq2val = true,
early_stop = () -> start
accuracy_metric(circ, valdata; seq2val = true) >= 0.99
finish,
run = run,
)
```

Nothing out of the odd right here. `evaliters`

defines the interval which the loss and accuracies for the practice and validation information must be captured. Each 10 epochs.

`seq2val`

means we solely care in regards to the loss for the final token, fairly than `seq2seq`

, which might be the loss for token prediction in your complete sequence.

```
f = if seq2val
(m, x, y) -> Flux.Losses.crossentropy(softmax(m(x), dims = 1)[:, end, :], y)
else
(m, x, y) -> Flux.Losses.crossentropy(softmax(m(x), dims = 1), y)
finish
```

In Julia arrays are column ordered so the batch and sequence dimensions could be the ultimate two –

`(..., ..., sequence_dim, batch_dim)`

. That is the other of PyTorch.

In the event you learn my notes within the `grokking.jl`

you may see I initially had seq2seq loss, primarily as a result of I used to be working with seq2seq information earlier than but in addition due to I used to be too lazy to alter the loss metric. It does damage the mannequin on this case as a result of it is not helpful for it predict any of the tokens apart from the ultimate one.

Given the primary three tokens you will be unable to foretell the fourth token. It could possibly be any of the quantity symbols (excludes $=$ and $circ$) will equal likelihood!

Anyway – again to modular addition.

Plots, plots, plots.

### Addition

Whole Optimization steps are outlined as epochs * num_batches * evalevery. So addition generalizes fairly fast however the regime change from memorization to generalization is obvious.

### Subtraction

Generalization takes for much longer than addition.

### Uneven Perform

I finished this early at round 95% validation accuracy simply because it was taking so longbut. Validation loss and accuracy had been happening your complete period after the regime change.

### Laborious Perform

Even with a 95/5 cut up we get nowhere.

### Additional: Subtraction then Finedtuned Uneven

I took the subtraction mannequin as soon as it achieved excessive generalization after which tried to finetune it on the uneven dataset. It didn’t work.

### Worth Counts

I am undecided how helpful that is these visualizations present a correlation between symmetry and the power to generalize.

Even the laborious perform does have symmetry. Hmmmmm.

### Subsequent Steps

These neural networks are pretty small so dissecting them could possibly be worthwhile.