tiny-llm-demo

tiny-llm-demo - small plain-Python LLM learning demos.
git clone git://git.beep.wimdupont.com/tiny-llm-demo.git
Log | Files | Refs | README | LICENSE

tiny_transformer_lm.py (14490B)


      1 #!/usr/bin/env python3
      2 """
      3 Tiny character-level language model with a minimal causal self-attention block.
      4 
      5 This is still far from a real transformer, but it shows the key structural leap
      6 from the plain MLP demo:
      7 
      8 - each context position gets its own representation
      9 - the final position forms a query
     10 - all earlier positions provide keys and values
     11 - attention weights decide what matters dynamically
     12 
     13 The model predicts the next character from the attended summary of the context.
     14 """
     15 
     16 from __future__ import annotations
     17 
     18 import argparse
     19 import math
     20 import random
     21 from dataclasses import dataclass
     22 from pathlib import Path
     23 
     24 
     25 def zeros(length: int) -> list[float]:
     26     return [0.0 for _ in range(length)]
     27 
     28 
     29 def random_vector(length: int, scale: float, rng: random.Random) -> list[float]:
     30     return [rng.uniform(-scale, scale) for _ in range(length)]
     31 
     32 
     33 def random_matrix(rows: int, cols: int, scale: float, rng: random.Random) -> list[list[float]]:
     34     return [random_vector(cols, scale, rng) for _ in range(rows)]
     35 
     36 
     37 def softmax(logits: list[float]) -> list[float]:
     38     peak = max(logits)
     39     exps = [math.exp(value - peak) for value in logits]
     40     total = sum(exps)
     41     return [value / total for value in exps]
     42 
     43 
     44 @dataclass
     45 class Batch:
     46     context_ids: list[int]
     47     target_id: int
     48 
     49 
     50 class CharTokenizer:
     51     def __init__(self, text: str) -> None:
     52         chars = sorted(set(text))
     53         self.stoi = {char: index for index, char in enumerate(chars)}
     54         self.itos = {index: char for char, index in self.stoi.items()}
     55 
     56     @property
     57     def vocab_size(self) -> int:
     58         return len(self.stoi)
     59 
     60     def encode(self, text: str) -> list[int]:
     61         return [self.stoi[char] for char in text]
     62 
     63     def decode(self, token_ids: list[int]) -> str:
     64         return "".join(self.itos[token_id] for token_id in token_ids)
     65 
     66 
     67 class TinyAttentionLanguageModel:
     68     def __init__(self, vocab_size: int, context_size: int, embed_dim: int, seed: int) -> None:
     69         rng = random.Random(seed)
     70         self.vocab_size = vocab_size
     71         self.context_size = context_size
     72         self.embed_dim = embed_dim
     73         self.scale = math.sqrt(embed_dim)
     74 
     75         self.token_embed = random_matrix(vocab_size, embed_dim, 0.08, rng)
     76         self.pos_embed = random_matrix(context_size, embed_dim, 0.08, rng)
     77 
     78         self.wq = random_matrix(embed_dim, embed_dim, 0.08, rng)
     79         self.wk = random_matrix(embed_dim, embed_dim, 0.08, rng)
     80         self.wv = random_matrix(embed_dim, embed_dim, 0.08, rng)
     81         self.wo = random_matrix(embed_dim, vocab_size, 0.08, rng)
     82         self.bo = zeros(vocab_size)
     83 
     84     def matvec(self, vector: list[float], matrix: list[list[float]]) -> list[float]:
     85         cols = len(matrix[0])
     86         out = zeros(cols)
     87         for col in range(cols):
     88             total = 0.0
     89             for row in range(len(vector)):
     90                 total += vector[row] * matrix[row][col]
     91             out[col] = total
     92         return out
     93 
     94     def forward(self, context_ids: list[int]) -> tuple[list[float], dict[str, object]]:
     95         x_vectors: list[list[float]] = []
     96         q_vectors: list[list[float]] = []
     97         k_vectors: list[list[float]] = []
     98         v_vectors: list[list[float]] = []
     99 
    100         for position, token_id in enumerate(context_ids):
    101             x = [
    102                 self.token_embed[token_id][dim] + self.pos_embed[position][dim]
    103                 for dim in range(self.embed_dim)
    104             ]
    105             x_vectors.append(x)
    106             q_vectors.append(self.matvec(x, self.wq))
    107             k_vectors.append(self.matvec(x, self.wk))
    108             v_vectors.append(self.matvec(x, self.wv))
    109 
    110         last_index = len(context_ids) - 1
    111         q_last = q_vectors[last_index]
    112 
    113         scores = zeros(len(context_ids))
    114         for index in range(len(context_ids)):
    115             dot = 0.0
    116             for dim in range(self.embed_dim):
    117                 dot += q_last[dim] * k_vectors[index][dim]
    118             scores[index] = dot / self.scale
    119 
    120         attn = softmax(scores)
    121 
    122         attended = zeros(self.embed_dim)
    123         for index in range(len(context_ids)):
    124             for dim in range(self.embed_dim):
    125                 attended[dim] += attn[index] * v_vectors[index][dim]
    126 
    127         logits = zeros(self.vocab_size)
    128         for vocab_idx in range(self.vocab_size):
    129             total = self.bo[vocab_idx]
    130             for dim in range(self.embed_dim):
    131                 total += attended[dim] * self.wo[dim][vocab_idx]
    132             logits[vocab_idx] = total
    133 
    134         cache: dict[str, object] = {
    135             "context_ids": context_ids,
    136             "x_vectors": x_vectors,
    137             "q_vectors": q_vectors,
    138             "k_vectors": k_vectors,
    139             "v_vectors": v_vectors,
    140             "scores": scores,
    141             "attn": attn,
    142             "attended": attended,
    143         }
    144         return logits, cache
    145 
    146     def train_step(self, batch: Batch, learning_rate: float) -> float:
    147         logits, cache = self.forward(batch.context_ids)
    148         probs = softmax(logits)
    149         loss = -math.log(probs[batch.target_id] + 1e-12)
    150 
    151         dlogits = probs[:]
    152         dlogits[batch.target_id] -= 1.0
    153 
    154         attended = cache["attended"]
    155         attn = cache["attn"]
    156         x_vectors = cache["x_vectors"]
    157         q_vectors = cache["q_vectors"]
    158         k_vectors = cache["k_vectors"]
    159         v_vectors = cache["v_vectors"]
    160 
    161         dwo = [zeros(self.vocab_size) for _ in range(self.embed_dim)]
    162         dbo = dlogits[:]
    163         for dim in range(self.embed_dim):
    164             for vocab_idx in range(self.vocab_size):
    165                 dwo[dim][vocab_idx] = attended[dim] * dlogits[vocab_idx]
    166 
    167         dattended = zeros(self.embed_dim)
    168         for dim in range(self.embed_dim):
    169             total = 0.0
    170             for vocab_idx in range(self.vocab_size):
    171                 total += self.wo[dim][vocab_idx] * dlogits[vocab_idx]
    172             dattended[dim] = total
    173 
    174         dattn = zeros(self.context_size)
    175         dv_vectors = [zeros(self.embed_dim) for _ in range(self.context_size)]
    176         for pos in range(self.context_size):
    177             dot = 0.0
    178             for dim in range(self.embed_dim):
    179                 dot += dattended[dim] * v_vectors[pos][dim]
    180                 dv_vectors[pos][dim] += attn[pos] * dattended[dim]
    181             dattn[pos] = dot
    182 
    183         weighted = 0.0
    184         for pos in range(self.context_size):
    185             weighted += attn[pos] * dattn[pos]
    186 
    187         dscores = zeros(self.context_size)
    188         for pos in range(self.context_size):
    189             dscores[pos] = attn[pos] * (dattn[pos] - weighted)
    190 
    191         dq_last = zeros(self.embed_dim)
    192         dk_vectors = [zeros(self.embed_dim) for _ in range(self.context_size)]
    193         for pos in range(self.context_size):
    194             for dim in range(self.embed_dim):
    195                 dq_last[dim] += dscores[pos] * k_vectors[pos][dim] / self.scale
    196                 dk_vectors[pos][dim] += dscores[pos] * q_vectors[self.context_size - 1][dim] / self.scale
    197 
    198         dwq = [zeros(self.embed_dim) for _ in range(self.embed_dim)]
    199         dwk = [zeros(self.embed_dim) for _ in range(self.embed_dim)]
    200         dwv = [zeros(self.embed_dim) for _ in range(self.embed_dim)]
    201         dx_vectors = [zeros(self.embed_dim) for _ in range(self.context_size)]
    202 
    203         last_x = x_vectors[self.context_size - 1]
    204         for in_dim in range(self.embed_dim):
    205             for out_dim in range(self.embed_dim):
    206                 dwq[in_dim][out_dim] += last_x[in_dim] * dq_last[out_dim]
    207                 dx_vectors[self.context_size - 1][in_dim] += self.wq[in_dim][out_dim] * dq_last[out_dim]
    208 
    209         for pos in range(self.context_size):
    210             x = x_vectors[pos]
    211             for in_dim in range(self.embed_dim):
    212                 for out_dim in range(self.embed_dim):
    213                     dwk[in_dim][out_dim] += x[in_dim] * dk_vectors[pos][out_dim]
    214                     dwv[in_dim][out_dim] += x[in_dim] * dv_vectors[pos][out_dim]
    215                     dx_vectors[pos][in_dim] += self.wk[in_dim][out_dim] * dk_vectors[pos][out_dim]
    216                     dx_vectors[pos][in_dim] += self.wv[in_dim][out_dim] * dv_vectors[pos][out_dim]
    217 
    218         for pos, token_id in enumerate(batch.context_ids):
    219             for dim in range(self.embed_dim):
    220                 gradient = dx_vectors[pos][dim]
    221                 self.token_embed[token_id][dim] -= learning_rate * gradient
    222                 self.pos_embed[pos][dim] -= learning_rate * gradient
    223 
    224         for in_dim in range(self.embed_dim):
    225             for out_dim in range(self.embed_dim):
    226                 self.wq[in_dim][out_dim] -= learning_rate * dwq[in_dim][out_dim]
    227                 self.wk[in_dim][out_dim] -= learning_rate * dwk[in_dim][out_dim]
    228                 self.wv[in_dim][out_dim] -= learning_rate * dwv[in_dim][out_dim]
    229 
    230         for dim in range(self.embed_dim):
    231             for vocab_idx in range(self.vocab_size):
    232                 self.wo[dim][vocab_idx] -= learning_rate * dwo[dim][vocab_idx]
    233         for vocab_idx in range(self.vocab_size):
    234             self.bo[vocab_idx] -= learning_rate * dbo[vocab_idx]
    235 
    236         return loss
    237 
    238     def sample_next_token(self, context_ids: list[int], rng: random.Random, temperature: float) -> int:
    239         logits, _ = self.forward(context_ids)
    240         scaled = [value / max(temperature, 1e-6) for value in logits]
    241         probs = softmax(scaled)
    242 
    243         threshold = rng.random()
    244         running = 0.0
    245         for index, prob in enumerate(probs):
    246             running += prob
    247             if threshold <= running:
    248                 return index
    249         return len(probs) - 1
    250 
    251     def attention_weights(self, context_ids: list[int]) -> list[float]:
    252         _, cache = self.forward(context_ids)
    253         return list(cache["attn"])
    254 
    255 
    256 def build_batches(token_ids: list[int], context_size: int) -> list[Batch]:
    257     batches: list[Batch] = []
    258     for index in range(len(token_ids) - context_size):
    259         batches.append(
    260             Batch(
    261                 context_ids=token_ids[index : index + context_size],
    262                 target_id=token_ids[index + context_size],
    263             )
    264         )
    265     return batches
    266 
    267 
    268 def generate_text(
    269     model: TinyAttentionLanguageModel,
    270     tokenizer: CharTokenizer,
    271     prompt: str,
    272     sample_length: int,
    273     rng: random.Random,
    274     temperature: float,
    275 ) -> str:
    276     if not prompt:
    277         prompt = "language "
    278 
    279     if any(char not in tokenizer.stoi for char in prompt):
    280         known = "".join(sorted(tokenizer.stoi.keys()))
    281         raise ValueError(f"Prompt contains characters outside the training set. Known chars: {known!r}")
    282 
    283     token_ids = tokenizer.encode(prompt)
    284     if len(token_ids) < model.context_size:
    285         token_ids = [token_ids[0]] * (model.context_size - len(token_ids)) + token_ids
    286 
    287     generated = token_ids[:]
    288     for _ in range(sample_length):
    289         context_ids = generated[-model.context_size :]
    290         next_token = model.sample_next_token(context_ids, rng, temperature)
    291         generated.append(next_token)
    292 
    293     return tokenizer.decode(generated)
    294 
    295 
    296 def load_corpus(path: Path) -> str:
    297     return path.read_text(encoding="utf-8")
    298 
    299 
    300 def main() -> None:
    301     parser = argparse.ArgumentParser(description="Train a tiny self-attention language model.")
    302     parser.add_argument("--steps", type=int, default=2500, help="Number of SGD steps.")
    303     parser.add_argument("--report-every", type=int, default=500, help="Progress report interval.")
    304     parser.add_argument("--learning-rate", type=float, default=0.03, help="SGD learning rate.")
    305     parser.add_argument("--context-size", type=int, default=8, help="Context window length.")
    306     parser.add_argument("--embed-dim", type=int, default=16, help="Embedding size.")
    307     parser.add_argument("--sample-length", type=int, default=220, help="Characters to generate.")
    308     parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature.")
    309     parser.add_argument("--seed", type=int, default=7, help="Random seed.")
    310     parser.add_argument("--prompt", default="language models ", help="Initial prompt for generation.")
    311     parser.add_argument(
    312         "--repeat-corpus",
    313         type=int,
    314         default=12,
    315         help="Repeat the corpus this many times to make the tiny demo learn faster.",
    316     )
    317     parser.add_argument(
    318         "--show-attention",
    319         action="store_true",
    320         help="Print attention weights for the final prompt context after training.",
    321     )
    322     args = parser.parse_args()
    323 
    324     project_dir = Path(__file__).resolve().parent
    325     corpus_text = load_corpus(project_dir / "corpus.txt")
    326     corpus_text = corpus_text * max(args.repeat_corpus, 1)
    327 
    328     tokenizer = CharTokenizer(corpus_text)
    329     token_ids = tokenizer.encode(corpus_text)
    330     batches = build_batches(token_ids, args.context_size)
    331     if not batches:
    332         raise ValueError("Corpus is too small for the requested context size.")
    333 
    334     model = TinyAttentionLanguageModel(
    335         vocab_size=tokenizer.vocab_size,
    336         context_size=args.context_size,
    337         embed_dim=args.embed_dim,
    338         seed=args.seed,
    339     )
    340 
    341     rng = random.Random(args.seed)
    342     running_loss = 0.0
    343 
    344     print(f"vocab size: {tokenizer.vocab_size}")
    345     print(f"training samples: {len(batches)}")
    346     print("training...")
    347 
    348     for step in range(1, args.steps + 1):
    349         batch = batches[rng.randrange(len(batches))]
    350         loss = model.train_step(batch, args.learning_rate)
    351         running_loss += loss
    352 
    353         if step % args.report_every == 0 or step == 1:
    354             avg_loss = running_loss / min(step, args.report_every)
    355             print(f"step {step:5d} | avg loss {avg_loss:.4f}")
    356             running_loss = 0.0
    357 
    358     print("\n--- sample ---\n")
    359     sample = generate_text(
    360         model=model,
    361         tokenizer=tokenizer,
    362         prompt=args.prompt,
    363         sample_length=args.sample_length,
    364         rng=rng,
    365         temperature=args.temperature,
    366     )
    367     print(sample)
    368 
    369     if args.show_attention:
    370         prompt_ids = tokenizer.encode(args.prompt)
    371         if len(prompt_ids) < model.context_size:
    372             prompt_ids = [prompt_ids[0]] * (model.context_size - len(prompt_ids)) + prompt_ids
    373         context_ids = prompt_ids[-model.context_size :]
    374         weights = model.attention_weights(context_ids)
    375 
    376         print("\n--- attention on final prompt context ---\n")
    377         context_text = tokenizer.decode(context_ids)
    378         for index, weight in enumerate(weights):
    379             char = context_text[index]
    380             label = "\\n" if char == "\n" else char
    381             print(f"pos {index:2d} | char {label!r} | weight {weight:.4f}")
    382 
    383 
    384 if __name__ == "__main__":
    385     main()