Act I: Data collection and pretraining loop

Vocabulary

  • Parquet files
  • DDP (DistributedDataParallel)
  • Causal language modeling (as opposed to masked language modeling)

Basics / Parquet

First: look at dataloader.py.

from collections import deque

import torch
import pyarrow.parquet as pq

from nanochat.common import get_dist_info
from nanochat.dataset import list_parquet_files
from nanochat.tokenizer import get_tokenizer

Understand these dependencies.

It may be helpful to go through the common, dataset, and tokenizer files.
Dataset — we are shown how we download a dataset for language modelling into Parquet file shards.

Parquet is really interesting and useful, here is some helpful info:
file → row groups → column chunks → pages ← this is the hierarchy for how Parquet files work.

Row groups are basically chunks of the pretraining data, and each of these groups can be parallelized across GPUs. Column chunks are concomitant containers for metadata, like min/max (these are statistics on what input ids are contained in the page).

These are other statistics, courtesy of ChatGPT:

  • min / max
    Lexicographic for strings, numeric for numbers.
    Used for predicate pushdown (“can this block possibly match?”).
  • null_count
    How many values are null in this page / row group.
  • distinct_count (optional, newer)
    Approximate or exact count of distinct values.
  • data_page_offset / sizes
    Byte offsets and compressed/uncompressed sizes (used for skipping, not filtering).
  • encoding metadata
    RLE / dictionary / delta encodings used.
  • dictionary presence
    Whether a dictionary page exists and its size.
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="download fineweb-edu 100bt dataset shards")
    parser.add_argument("-n", "--num-files", type=int, default=-1, help="number of shards to download")
    parser.add_argument("-w", "--num-workers", type=int, default=4, help="number of parallel workers")
    args = parser.parse_args()

    num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
    ids_to_download = list(range(num))
    print(f"downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
    print(f"target directory: {DATA_DIR}")
    print()
    with Pool(processes=args.num_workers) as pool:
        results = pool.map(download_single_file, ids_to_download)

    # report results
    successful = sum(1 for success in results if success)
    print(f"done! downloaded: {successful}/{len(ids_to_download)} shards to {DATA_DIR}")

Pool is a standard library multiprocessing function, spawning different worker threads.
Line 124 results is just a completion handler.

Tokenizer

HuggingFaceTokenizer:

  1. Contractions: 's 't 'm 'd 'll 've 're
  2. Words: optional punct + letters
  3. Numbers: 1–2 digits
  4. Punctuation: symbols (with optional space/newlines)
  5. Newlines: whitespace + newline
  6. Trailing spaces: whitespace at end of text
  7. Other whitespace: everything else

The goal of the tokenizer is to tokenize. Woohoo.

You do have to train a tokenizer—you do this with a lot of tokens. Training with a tokenizer is combining different candidate tokens greedily until you achieve a target vocab size.

You take all of the tokens (starting with byte tokens) in the corpus and then dedup them into a frequency hashmap.

  • Collect billions of characters of text
  • Split into chunks with regex pattern
  • Start with 256 byte tokens
  • Repeatedly find most common adjacent pair and merge them into a new token
  • Repeat ~65,000 times to get final vocab
  • Result: learned vocabulary that efficiently compresses text

There are two options in the repo: HuggingFaceTokenizer and RustBPETokenizer:

  1. HuggingFace tokenizer that can do both training and inference but is really confusing
  2. Our own RustBPE tokenizer for training and tiktoken for efficient inference

Tricky section in tokenizer.py line 106 which requires some context. Prepending and appending special tokens is necessary. You can’t like… “learn” these special tokens. You kind of have to hard code the scripts to insert them. The model will not “discover” or invent special tokens on its own.

def _encode_one(self, text, prepend=None, append=None):
    # encode a single string
    # prepend/append can be either a string of a special token or a token id directly.
    assert isinstance(text, str)
    ids = []
    if prepend is not None:
        prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
        ids.append(prepend_id)
    ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids)
    if append is not None:
        append_id = append if isinstance(append, int) else self.encode_special(append)
        ids.append(append_id)
    return ids

The tokenizer.py is basically just an interface which defines the behavior which the tokenizer is going to repeatedly use.

A natural question: why do we have this “render conversation” function?
Firstly, we use the language of “rendering” because we’re injecting this chat structure into our tokenization. Consider if you’ve ever looked at the chat template structure of LLaMA… <|end_of_assistant_turn|> etc. Think of it like how you render HTML into the hierarchical form which you recognize.

More on special tokens: llama 3.1 model cards and prompt formats.

This render conversation function is going to be more important once we get to the chat SFT part.

At this point, we should all have a pretty solid understanding of how the tokenization will work and literally what our data will look like. Now we can get into the fun part of pretraining.

Model architecture

Now we need to literally familiarize ourselves with what’s going on in the model. If you’ve coded LLaMA from scratch or anything like that, this will be pretty familiar.

"""
gpt model (rewrite, a lot simpler)
notable features:
- rotary embeddings (and no positional embeddings)
- qk norm
- untied weights for token embedding and lm_head
- relu^2 activation in mlp
- norm after token embedding
- no learnable params in rmsnorm
- no bias in linear layers
- group-query attention (gqa) support for more efficient inference
"""

Rotary embeddings (RoPE)

Explanation on a rotary embedding courtesy of ChatGPT:

A rotary embedding (usually “RoPE”, rotary positional embedding) is a way to add position information to a transformer by rotating the query/key vectors in attention, instead of adding a learned/sinusoidal position vector to the token embeddings.

def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4  # multihead attention
    d = x.shape[3] // 2
    x1, x2 = x[..., :d], x[..., d:]  # split up last time into two halves
    y1 = x1 * cos + x2 * sin  # rotate pairs of dims
    y2 = x1 * (-sin) + x2 * cos
    out = torch.cat([y1, y2], 3)  # re-assemble
    out = out.to(x.dtype)  # ensure input/output dtypes match
    return out

A natural question: why are we passing in cos and sin? Great question. Turns out that when you actually go to the initialization of the GPT, you precompute the rotary embeddings.

An aside: we can definitely go through the math of RoPE more rigorously (Vincent knows a lot about it). I only know the high level stuff so we can read the paper together, optionally.

A brief primer on RoPE:
The intuition behind this is that the rotary positions encode a sense of relative position by injecting a linear operator into the dot product.

setup
- hidden dimension: d (even)
- position index: p ∈ {0, 1, 2, ...}
- frequency base: θ (typically 10,000)

define per-dimension angular frequencies:

    ω_i = θ^{-2i / d},   i = 0, 1, ..., d/2 - 1

rotation angles

for position p:

    φ_{p,i} = p · ω_i

core rope formula (this is the key)

take a vector x ∈ ℝ^d and group it into pairs: (x_{2i}, x_{2i+1})

apply a 2d rotation to each pair:

    (x'_{2i}  )   ( cos φ_{p,i}  -sin φ_{p,i} ) (x_{2i}  )
    (x'_{2i+1}) = ( sin φ_{p,i}   cos φ_{p,i} ) (x_{2i+1})

that's rope.

We basically compute for each position a specific rotation angle omega (one distinct frequency for each dimension) and take the 2D rotation matrix.

Please recall that (Q^\top K) is the “logit” of attention, i.e. the attention score between some token with some embedding at position (p) and token/embedding pair at position (q). This is the “raw compatibility score.”

Good visualization here: https://www.youtube.com/watch?v=GQPOtyITy54.

because one rotation scale cannot encode all distances.
different pairs = different spatial resolutions.

that's the core reason. now the precise version.

1. what goes wrong with a single rotation?

suppose every pair rotated with the same frequency ω:

    r(p) = rotation by pω

then attention depends on:

    qᵀ r(q - p) k = qᵀ rot((q - p) ω) k

but rotation is periodic:

    (q - p) ω ≡ (q - p) ω + 2π

so distances Δ and Δ + 2π / ω are indistinguishable.

→ the model aliases long-range positions catastrophically.

Response to “why do we need dim/2 rotations, one for each dimensional pair?”

Karpathy also includes tunable parameters in order to change the number of QKV matrices. This is because in practice there are three different ways of doing your attention (courtesy of Chat):

  1. Multi-Query Attention (MQA)
    n_kv_head = 1
    All queries share one K and V.
    Maximum speed, some quality loss.
  2. Grouped Query Attention (GQA)
    1 < n_kv_head < n_head
    Queries grouped, share some K/V.
    Balanced speed/quality.
  3. Multi-Head Attention (MHA)
    n_kv_head = n_head
    Every query has its own K/V.
    Highest quality, slowest.

But basically if you have a KV stored then it’s way faster to read from as you can just keep more in memory.

This syntax was a bit confusing to me: what ended up being the case is that the second argument is the stipulation of the “trailing dimension.”

Importantly, an RMS norm doesn’t actually normalize 0–1. It just normalizes by the root mean square. This is in contrast to LayerNorm, which basically zero-centers the feature as well. So some helpful thing is that it normalizes by size, but in some sense preserves the bias/mean offset.

Obvious thing check: an nn.Linear layer is just a fully connected layer which outputs a projection into the output space.

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict({
            "wte": nn.Embedding(config.vocab_size, config.n_embd),
            "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
        })
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # To support meta device initialization, we init the rotary embeddings here, but it's fake
        # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
        # so let's just over-compute them, but assert fail if we ever reach that amount.
        # In the future we can dynamically grow the cache, for now it's fine.
        self.rotary_seq_len = config.sequence_len * 10  # 10X over-compute should be enough, TODO make nicer?
        head_dim = config.n_embd // config.n_head
        cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
        self.register_buffer("cos", cos, persistent=False)  # persistent=False means it's not saved to the checkpoint
        self.register_buffer("sin", sin, persistent=False)

sequence_len is used as the maximum length which the model will see during training—at a high level, the whole corpus is turned into a continuous stream of batches (batch_size, seq_len) which are fed into the model.

We also notice in the screenshot how we have rotary embeddings precomputed, and we precompute 10x more. This is because we need rotary embeddings at inference time, and it’s possible for us to generate more than 1024 tokens, so we need to precompute more of them.

def init_weights(self):
    self.apply(self._init_weights)
    # zero out classifier weights
    torch.nn.init.zeros_(self.lm_head.weight)
    # zero out c_proj weights in all blocks
    for block in self.transformer.h:
        torch.nn.init.zeros_(block.mlp.c_proj.weight)
        torch.nn.init.zeros_(block.attn.c_proj.weight)
    # init the rotary embeddings
    head_dim = self.config.n_embd // self.config.n_head
    cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
    self.cos, self.sin = cos, sin
    # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory
    if self.transformer.wte.weight.device.type == "cuda":
        self.transformer.wte.to(dtype=torch.bfloat16)

The next functions to discuss are the flop estimations, the optimizer setups, the forward pass, and the generate function.

Important aside: wte = word to embedding; h = hidden.

So when you actually use DDP, you have three parameters which are necessary to know:

  • RANK – global process ID
  • LOCAL_RANK – GPU index on the current machine
  • WORLD_SIZE – total number of processes
def estimate_flops(self):
    """ return the estimated flops per token for the model. ref: https://arxiv.org/abs/2204.02311 """
    nparams = sum(p.numel() for p in self.parameters())
    nparams_embedding = self.transformer.wte.weight.numel()
    l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.block_size
    num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
    return num_flops_per_token

This estimation follows from the arxiv paper which is linked there:
We have a multiplication by six because in a forward pass, we have 2 FLOPs per parameter per token, because we have a multiply and add operation occurring. In the backwards pass, we have ~4 FLOPs / param / token.

So you might ask, where do these estimations come from? For a more complete discussion, one can visit the PaLM paper linked.

A basic but important tenet: matrix multiplication takes 2 FLOPs: a multiply and accumulate.
So for each backwards pass, you need to calculate the gradient with respect to the weights, to update the parameters, and you also need to calculate the gradient with respect to the actual input itself. Recall your intro machine learning class, doing backprop through an MLP:

# layer 2 backward
dL_dx1 = backward_layer2(dL_dx2)  # this produces gradient for layer 1

# layer 1 backward
dL_dx0 = backward_layer1(dL_dx1)  # uses dL_dx1 from layer 2!

Another important bit of information this calculation reveals is the difference between matrix layers and embedding layers. Naturally, embedding something is not computationally expensive, because it’s just a lookup, really.

Matrix layers are what you think of generally: nn.Linear… embedding layers are nn.Embedding.

Recall that different optimizers here like SGD, AdamW, Muon control the update rules.
We can take a brief detour to explain Muon and exactly why it’s helpful in saving us some time.

Muon is a different optimizer which is best optimized for big matrices.
The key here is instead of making small weight-wise updates, we instead orthogonalize the weight update matrix.

So one of the main savings comes from not having to store the exponential moving average, which requires storing the second-moment/variance (v_t = \text{EMA}(g_t^2)).

And then why are orthogonal weight updates better in general, for the other form of savings? Because turning the raw gradient update into its polar (orthogonal) factor keeps its direction but removes uneven singular-value scaling, so the step isn’t dominated by a few “large” modes and is better conditioned / stabler across directions.

This concludes our aside into Muon.

As we get into the forward function, we can observe that there’s a lot going on in the first ten or so lines with using the rotary embeddings and KV cache in our computation.

Notice that we allow an existing KV cache in the params. We start by calculating the positional embeddings at the top, taking into account if we already have some KV cache, necessarily implying that we are in the middle of a sequence.

As we push the input sequence through each of the layers, we store in our KV cache a layerwise stored dictionary of all the KV cache attention scores for each of the self-attention queries we make.

kv_cache = {
    "layer_0": [all k, v pairs from previous tokens for layer 0],
    "layer_1": [all k, v pairs from previous tokens for layer 1],
    "layer_2": [all k, v pairs from previous tokens for layer 2],
    "layer_3": [all k, v pairs from previous tokens for layer 3],
    # ...
}

This is all pretty intuitive, but it’s important to know exactly what’s going on under the hood.

Next you’ll notice the “softcap” on the logits. This seems like something you would expect to already be built in—you limit the logits so that the softmax basically doesn’t destroy the other weights. Imagine softmax on 100, 2, 1 vs 15, 2, 1. Obviously with the 15, 2, 1, you kind of get to cap the confidence of the model.

Another note: the LM head is just a projection from the last layer of the hidden state back into the vocabulary space.

Act II: The base training loop and evaluation

The first set of 50 or so lines is just the user config settings.

One critical piece of information is how gradient accumulation works—the high level idea is, you split your work up into minibatches which you can parallelize, and then add together all the gradient updates each of the minibatches say you should do. You can simulate having a bigger batch size by basically breaking down your batch into minibatches and then, after adding all updates together, doing the weight update.

# all gradients computed on SAME model weights
grad1 = compute_grad(model, batch1)  # model at step n
grad2 = compute_grad(model, batch2)  # model still at step n
grad3 = compute_grad(model, batch3)  # model still at step n
grad_avg = (grad1 + grad2 + grad3) / 3
model.update(grad_avg)  # 1 update: step n → step n+1

vs.

# each gradient computed on DIFFERENT model weights
grad1 = compute_grad(model, batch1)  # model at step n
model.update(grad1)  # step n → n+1
grad2 = compute_grad(model, batch2)  # model at step n+1 (changed!)
model.update(grad2)  # step n+1 → n+2
grad3 = compute_grad(model, batch3)  # model at step n+2 (changed again!)
model.update(grad3)  # step n+2 → n+3
# --------------------------------------------------------------
# single training step
# evaluate the gradient
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
    with autocast_ctx:
        loss = model(x, y)
        train_loss = loss.detach()  # for logging
        loss = loss / grad_accum_steps  # each .backward() is a grad sum => normalize loss here
        loss.backward()
        x, y, dataloader_state_dict = next(train_loader)  # prefetch the next batch while the GPU is busy with forward
# gradient clipping
grad_clip_enabled = grad_clip > 0.0
if grad_clip_enabled:
    grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
    grad_norm = grad_norm_tensor.item()  # GPU tensor -> CPU float (note: cpu-gpu sync point)
# step the optimizers
lrm = get_lr_multiplier(step)
for opt in optimizers:
    for group in opt.param_groups:
        group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
    group["momentum"] = muon_momentum
for opt in optimizers:
    opt.step()
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
dt = t1 - t0
# --------------------------------------------------------------

So in each of these, we are basically synchronizing the cpu with the gpu at the beginning and end. This loop occurs once per optimizer step.

Initially, this was confusing for me—because in my pea brain, a forward/backward is generally the same as an optimizer step. But notice here, we say that for micro_step in range(grad_accum_steps):, we do this common work of calculating the loss, doing the backward pass, etc.

Under the hood: each parameter in PyTorch has a .grad attribute. We don’t manually stash gradient matrices anywhere because param.grad gets modified inplace during backprop.

Once we’ve accumulated gradients over micro-steps, it’s only when we actually call optimizer.step() that the param.data for all these weights get updated.

A note in the implementation: we actually use Muon and AdamW for different parts of the model—this is due to the aforementioned note that Muon works well for dense matrices and AdamW for parameters that don’t fit that regime.

To be honest, that’s basically the important part of the training loop. Besides the wandb logging setup, which is not super important, this is kind of the crux of what’s happening there.

Next up: base_eval.py, loss_eval.py, core_eval.py.

The next bit of work on base_eval is just about taking the CORE tasks.

autocast_ctx is a helper which allows you to do mixed precision training / evaluation really easily.

Centering results: you always want to be comparing to a baseline, so the random guessing baseline is selected—this means we center so that 0 is random guessing and 1 is perfect accuracy.

A common thing that gets calculated is bpb, which stands for bits per byte. This sounds kind of ridiculous—doesn’t a byte just have 8 bits? In reality, bpb is measuring bits in the Shannon information theory sense, and figuring out how many of these bits of information you need in order to predict every byte of the text you’re trying to reconstruct.

Now that we’ve finished up loss_eval.py, we can move on to core_eval.py.

Notice that there’s a lot of templating going on, since we’re feeding in variable values to create prompts. We generally use the Jinja2 library to get this done.

Basically, to get these benchmarks, we have two versions of the questions which we want to feed to the language model—a “without completion” and a “with completion” version. The with-completion version basically has the answer. The first couple prompts are using Jinja to render the actual prompts for the models.

def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
    """Render complete prompts for a multiple choice question"""
    template_str = """
""".strip()

    template = Template(template_str)
    fewshot_examples = fewshot_examples or []
    context = {
        'fewshot_examples': fewshot_examples,
        'continuation_delimiter': continuation_delimiter,
        'item': item
    }

    prompts = [template.render(choice=choice, **context) for choice in item['choices']]
    return prompts

You’ll notice that for rendering a multiple-choice question, we actually create multiple different continuations, one per option:

tokens = [
    [BOS, "What", "is", "2", "+", "2", "?", " ", "A", ")", " ", "3"],  # question + choice a
    [BOS, "What", "is", "2", "+", "2", "?", " ", "B", ")", " ", "4"],  # question + choice b
    [BOS, "What", "is", "2", "+", "2", "?", " ", "C", ")", " ", "5"],  # question + choice c
    [BOS, "What", "is", "2", "+", "2", "?", " ", "D", ")", " ", "6"],  # question + choice d
]
# 4 sequences (one per answer choice)
def stack_sequences(tokens, pad_token_id):
    """Stack up a list of token sequences, pad to longest on the right"""
    bsz, seq_len = len(tokens), max(len(x) for x in tokens)
    input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
    for i, x in enumerate(tokens):
        input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
    return input_ids

This helps clarify what happens once we get to sequence stacking: we have multiple different continuations we probably want to parallelize across different gpus. To turn this into a rectangular tensor, we have to pad the end of each of these sequences to the same length.

def evaluate_task(model, tokenizer, data, device, task_meta):
    """
    This function is responsible for evaluating one task across many examples.
    It also handles dispatch to all processes if the script is run with torchrun.
    """
    rank = dist.get_rank() if dist.is_initialized() else 0
    world_size = dist.get_world_size() if dist.is_initialized() else 1
    correct = torch.zeros(len(data), dtype=torch.float32, device=device)
    # stride the examples to each rank
    for idx in range(rank, len(data), world_size):
        is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
        correct[idx] = float(is_correct)
    # sync results across all the processes if running distributed
    if world_size > 1:
        dist.barrier()
        dist.all_reduce(correct, op=dist.ReduceOp.SUM)
    # compute the mean
    mean_correct = correct.mean().item()
    return mean_correct
@torch.no_grad()
def forward_model(model, input_ids):
    """
    Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
    The last column of losses is set to nan because we don't have autoregressive targets there.
    """
    batch_size, seq_len = input_ids.size()
    outputs = model(input_ids)
    # Roll the tensor to the left by one position to get the (autoregressive) target ids
    target_ids = torch.roll(input_ids, shifts=-1, dims=1)
    # Calculate cross entropy at all positions
    losses = torch.nn.functional.cross_entropy(
        outputs.view(batch_size * seq_len, -1),
        target_ids.view(batch_size * seq_len),
        reduction='none'
    ).view(batch_size, seq_len)
    # Set the last column to be nan because there is no autoregressive loss there
    losses[:, -1] = float('nan')
    # Get the argmax predictions at each position
    predictions = outputs.argmax(dim=-1)
    return losses, predictions

The forward_model is evaluation. The outputs are a (batch_size, sequence_length, vocab_size) type tensor. We unroll the input ids as we go, and evaluate the loss for each of these predictions.

Notably, nothing is being “generated” per se—we just have the predictions.


This concludes the first part of this nanochat part. I hope that some of the annotations that are made here helps to provide a better understanding of the codebase.

Sincerely, Laerdon