tiny_transformer_lm.py (14490B)
1 #!/usr/bin/env python3 2 """ 3 Tiny character-level language model with a minimal causal self-attention block. 4 5 This is still far from a real transformer, but it shows the key structural leap 6 from the plain MLP demo: 7 8 - each context position gets its own representation 9 - the final position forms a query 10 - all earlier positions provide keys and values 11 - attention weights decide what matters dynamically 12 13 The model predicts the next character from the attended summary of the context. 14 """ 15 16 from __future__ import annotations 17 18 import argparse 19 import math 20 import random 21 from dataclasses import dataclass 22 from pathlib import Path 23 24 25 def zeros(length: int) -> list[float]: 26 return [0.0 for _ in range(length)] 27 28 29 def random_vector(length: int, scale: float, rng: random.Random) -> list[float]: 30 return [rng.uniform(-scale, scale) for _ in range(length)] 31 32 33 def random_matrix(rows: int, cols: int, scale: float, rng: random.Random) -> list[list[float]]: 34 return [random_vector(cols, scale, rng) for _ in range(rows)] 35 36 37 def softmax(logits: list[float]) -> list[float]: 38 peak = max(logits) 39 exps = [math.exp(value - peak) for value in logits] 40 total = sum(exps) 41 return [value / total for value in exps] 42 43 44 @dataclass 45 class Batch: 46 context_ids: list[int] 47 target_id: int 48 49 50 class CharTokenizer: 51 def __init__(self, text: str) -> None: 52 chars = sorted(set(text)) 53 self.stoi = {char: index for index, char in enumerate(chars)} 54 self.itos = {index: char for char, index in self.stoi.items()} 55 56 @property 57 def vocab_size(self) -> int: 58 return len(self.stoi) 59 60 def encode(self, text: str) -> list[int]: 61 return [self.stoi[char] for char in text] 62 63 def decode(self, token_ids: list[int]) -> str: 64 return "".join(self.itos[token_id] for token_id in token_ids) 65 66 67 class TinyAttentionLanguageModel: 68 def __init__(self, vocab_size: int, context_size: int, embed_dim: int, seed: int) -> None: 69 rng = random.Random(seed) 70 self.vocab_size = vocab_size 71 self.context_size = context_size 72 self.embed_dim = embed_dim 73 self.scale = math.sqrt(embed_dim) 74 75 self.token_embed = random_matrix(vocab_size, embed_dim, 0.08, rng) 76 self.pos_embed = random_matrix(context_size, embed_dim, 0.08, rng) 77 78 self.wq = random_matrix(embed_dim, embed_dim, 0.08, rng) 79 self.wk = random_matrix(embed_dim, embed_dim, 0.08, rng) 80 self.wv = random_matrix(embed_dim, embed_dim, 0.08, rng) 81 self.wo = random_matrix(embed_dim, vocab_size, 0.08, rng) 82 self.bo = zeros(vocab_size) 83 84 def matvec(self, vector: list[float], matrix: list[list[float]]) -> list[float]: 85 cols = len(matrix[0]) 86 out = zeros(cols) 87 for col in range(cols): 88 total = 0.0 89 for row in range(len(vector)): 90 total += vector[row] * matrix[row][col] 91 out[col] = total 92 return out 93 94 def forward(self, context_ids: list[int]) -> tuple[list[float], dict[str, object]]: 95 x_vectors: list[list[float]] = [] 96 q_vectors: list[list[float]] = [] 97 k_vectors: list[list[float]] = [] 98 v_vectors: list[list[float]] = [] 99 100 for position, token_id in enumerate(context_ids): 101 x = [ 102 self.token_embed[token_id][dim] + self.pos_embed[position][dim] 103 for dim in range(self.embed_dim) 104 ] 105 x_vectors.append(x) 106 q_vectors.append(self.matvec(x, self.wq)) 107 k_vectors.append(self.matvec(x, self.wk)) 108 v_vectors.append(self.matvec(x, self.wv)) 109 110 last_index = len(context_ids) - 1 111 q_last = q_vectors[last_index] 112 113 scores = zeros(len(context_ids)) 114 for index in range(len(context_ids)): 115 dot = 0.0 116 for dim in range(self.embed_dim): 117 dot += q_last[dim] * k_vectors[index][dim] 118 scores[index] = dot / self.scale 119 120 attn = softmax(scores) 121 122 attended = zeros(self.embed_dim) 123 for index in range(len(context_ids)): 124 for dim in range(self.embed_dim): 125 attended[dim] += attn[index] * v_vectors[index][dim] 126 127 logits = zeros(self.vocab_size) 128 for vocab_idx in range(self.vocab_size): 129 total = self.bo[vocab_idx] 130 for dim in range(self.embed_dim): 131 total += attended[dim] * self.wo[dim][vocab_idx] 132 logits[vocab_idx] = total 133 134 cache: dict[str, object] = { 135 "context_ids": context_ids, 136 "x_vectors": x_vectors, 137 "q_vectors": q_vectors, 138 "k_vectors": k_vectors, 139 "v_vectors": v_vectors, 140 "scores": scores, 141 "attn": attn, 142 "attended": attended, 143 } 144 return logits, cache 145 146 def train_step(self, batch: Batch, learning_rate: float) -> float: 147 logits, cache = self.forward(batch.context_ids) 148 probs = softmax(logits) 149 loss = -math.log(probs[batch.target_id] + 1e-12) 150 151 dlogits = probs[:] 152 dlogits[batch.target_id] -= 1.0 153 154 attended = cache["attended"] 155 attn = cache["attn"] 156 x_vectors = cache["x_vectors"] 157 q_vectors = cache["q_vectors"] 158 k_vectors = cache["k_vectors"] 159 v_vectors = cache["v_vectors"] 160 161 dwo = [zeros(self.vocab_size) for _ in range(self.embed_dim)] 162 dbo = dlogits[:] 163 for dim in range(self.embed_dim): 164 for vocab_idx in range(self.vocab_size): 165 dwo[dim][vocab_idx] = attended[dim] * dlogits[vocab_idx] 166 167 dattended = zeros(self.embed_dim) 168 for dim in range(self.embed_dim): 169 total = 0.0 170 for vocab_idx in range(self.vocab_size): 171 total += self.wo[dim][vocab_idx] * dlogits[vocab_idx] 172 dattended[dim] = total 173 174 dattn = zeros(self.context_size) 175 dv_vectors = [zeros(self.embed_dim) for _ in range(self.context_size)] 176 for pos in range(self.context_size): 177 dot = 0.0 178 for dim in range(self.embed_dim): 179 dot += dattended[dim] * v_vectors[pos][dim] 180 dv_vectors[pos][dim] += attn[pos] * dattended[dim] 181 dattn[pos] = dot 182 183 weighted = 0.0 184 for pos in range(self.context_size): 185 weighted += attn[pos] * dattn[pos] 186 187 dscores = zeros(self.context_size) 188 for pos in range(self.context_size): 189 dscores[pos] = attn[pos] * (dattn[pos] - weighted) 190 191 dq_last = zeros(self.embed_dim) 192 dk_vectors = [zeros(self.embed_dim) for _ in range(self.context_size)] 193 for pos in range(self.context_size): 194 for dim in range(self.embed_dim): 195 dq_last[dim] += dscores[pos] * k_vectors[pos][dim] / self.scale 196 dk_vectors[pos][dim] += dscores[pos] * q_vectors[self.context_size - 1][dim] / self.scale 197 198 dwq = [zeros(self.embed_dim) for _ in range(self.embed_dim)] 199 dwk = [zeros(self.embed_dim) for _ in range(self.embed_dim)] 200 dwv = [zeros(self.embed_dim) for _ in range(self.embed_dim)] 201 dx_vectors = [zeros(self.embed_dim) for _ in range(self.context_size)] 202 203 last_x = x_vectors[self.context_size - 1] 204 for in_dim in range(self.embed_dim): 205 for out_dim in range(self.embed_dim): 206 dwq[in_dim][out_dim] += last_x[in_dim] * dq_last[out_dim] 207 dx_vectors[self.context_size - 1][in_dim] += self.wq[in_dim][out_dim] * dq_last[out_dim] 208 209 for pos in range(self.context_size): 210 x = x_vectors[pos] 211 for in_dim in range(self.embed_dim): 212 for out_dim in range(self.embed_dim): 213 dwk[in_dim][out_dim] += x[in_dim] * dk_vectors[pos][out_dim] 214 dwv[in_dim][out_dim] += x[in_dim] * dv_vectors[pos][out_dim] 215 dx_vectors[pos][in_dim] += self.wk[in_dim][out_dim] * dk_vectors[pos][out_dim] 216 dx_vectors[pos][in_dim] += self.wv[in_dim][out_dim] * dv_vectors[pos][out_dim] 217 218 for pos, token_id in enumerate(batch.context_ids): 219 for dim in range(self.embed_dim): 220 gradient = dx_vectors[pos][dim] 221 self.token_embed[token_id][dim] -= learning_rate * gradient 222 self.pos_embed[pos][dim] -= learning_rate * gradient 223 224 for in_dim in range(self.embed_dim): 225 for out_dim in range(self.embed_dim): 226 self.wq[in_dim][out_dim] -= learning_rate * dwq[in_dim][out_dim] 227 self.wk[in_dim][out_dim] -= learning_rate * dwk[in_dim][out_dim] 228 self.wv[in_dim][out_dim] -= learning_rate * dwv[in_dim][out_dim] 229 230 for dim in range(self.embed_dim): 231 for vocab_idx in range(self.vocab_size): 232 self.wo[dim][vocab_idx] -= learning_rate * dwo[dim][vocab_idx] 233 for vocab_idx in range(self.vocab_size): 234 self.bo[vocab_idx] -= learning_rate * dbo[vocab_idx] 235 236 return loss 237 238 def sample_next_token(self, context_ids: list[int], rng: random.Random, temperature: float) -> int: 239 logits, _ = self.forward(context_ids) 240 scaled = [value / max(temperature, 1e-6) for value in logits] 241 probs = softmax(scaled) 242 243 threshold = rng.random() 244 running = 0.0 245 for index, prob in enumerate(probs): 246 running += prob 247 if threshold <= running: 248 return index 249 return len(probs) - 1 250 251 def attention_weights(self, context_ids: list[int]) -> list[float]: 252 _, cache = self.forward(context_ids) 253 return list(cache["attn"]) 254 255 256 def build_batches(token_ids: list[int], context_size: int) -> list[Batch]: 257 batches: list[Batch] = [] 258 for index in range(len(token_ids) - context_size): 259 batches.append( 260 Batch( 261 context_ids=token_ids[index : index + context_size], 262 target_id=token_ids[index + context_size], 263 ) 264 ) 265 return batches 266 267 268 def generate_text( 269 model: TinyAttentionLanguageModel, 270 tokenizer: CharTokenizer, 271 prompt: str, 272 sample_length: int, 273 rng: random.Random, 274 temperature: float, 275 ) -> str: 276 if not prompt: 277 prompt = "language " 278 279 if any(char not in tokenizer.stoi for char in prompt): 280 known = "".join(sorted(tokenizer.stoi.keys())) 281 raise ValueError(f"Prompt contains characters outside the training set. Known chars: {known!r}") 282 283 token_ids = tokenizer.encode(prompt) 284 if len(token_ids) < model.context_size: 285 token_ids = [token_ids[0]] * (model.context_size - len(token_ids)) + token_ids 286 287 generated = token_ids[:] 288 for _ in range(sample_length): 289 context_ids = generated[-model.context_size :] 290 next_token = model.sample_next_token(context_ids, rng, temperature) 291 generated.append(next_token) 292 293 return tokenizer.decode(generated) 294 295 296 def load_corpus(path: Path) -> str: 297 return path.read_text(encoding="utf-8") 298 299 300 def main() -> None: 301 parser = argparse.ArgumentParser(description="Train a tiny self-attention language model.") 302 parser.add_argument("--steps", type=int, default=2500, help="Number of SGD steps.") 303 parser.add_argument("--report-every", type=int, default=500, help="Progress report interval.") 304 parser.add_argument("--learning-rate", type=float, default=0.03, help="SGD learning rate.") 305 parser.add_argument("--context-size", type=int, default=8, help="Context window length.") 306 parser.add_argument("--embed-dim", type=int, default=16, help="Embedding size.") 307 parser.add_argument("--sample-length", type=int, default=220, help="Characters to generate.") 308 parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature.") 309 parser.add_argument("--seed", type=int, default=7, help="Random seed.") 310 parser.add_argument("--prompt", default="language models ", help="Initial prompt for generation.") 311 parser.add_argument( 312 "--repeat-corpus", 313 type=int, 314 default=12, 315 help="Repeat the corpus this many times to make the tiny demo learn faster.", 316 ) 317 parser.add_argument( 318 "--show-attention", 319 action="store_true", 320 help="Print attention weights for the final prompt context after training.", 321 ) 322 args = parser.parse_args() 323 324 project_dir = Path(__file__).resolve().parent 325 corpus_text = load_corpus(project_dir / "corpus.txt") 326 corpus_text = corpus_text * max(args.repeat_corpus, 1) 327 328 tokenizer = CharTokenizer(corpus_text) 329 token_ids = tokenizer.encode(corpus_text) 330 batches = build_batches(token_ids, args.context_size) 331 if not batches: 332 raise ValueError("Corpus is too small for the requested context size.") 333 334 model = TinyAttentionLanguageModel( 335 vocab_size=tokenizer.vocab_size, 336 context_size=args.context_size, 337 embed_dim=args.embed_dim, 338 seed=args.seed, 339 ) 340 341 rng = random.Random(args.seed) 342 running_loss = 0.0 343 344 print(f"vocab size: {tokenizer.vocab_size}") 345 print(f"training samples: {len(batches)}") 346 print("training...") 347 348 for step in range(1, args.steps + 1): 349 batch = batches[rng.randrange(len(batches))] 350 loss = model.train_step(batch, args.learning_rate) 351 running_loss += loss 352 353 if step % args.report_every == 0 or step == 1: 354 avg_loss = running_loss / min(step, args.report_every) 355 print(f"step {step:5d} | avg loss {avg_loss:.4f}") 356 running_loss = 0.0 357 358 print("\n--- sample ---\n") 359 sample = generate_text( 360 model=model, 361 tokenizer=tokenizer, 362 prompt=args.prompt, 363 sample_length=args.sample_length, 364 rng=rng, 365 temperature=args.temperature, 366 ) 367 print(sample) 368 369 if args.show_attention: 370 prompt_ids = tokenizer.encode(args.prompt) 371 if len(prompt_ids) < model.context_size: 372 prompt_ids = [prompt_ids[0]] * (model.context_size - len(prompt_ids)) + prompt_ids 373 context_ids = prompt_ids[-model.context_size :] 374 weights = model.attention_weights(context_ids) 375 376 print("\n--- attention on final prompt context ---\n") 377 context_text = tokenizer.decode(context_ids) 378 for index, weight in enumerate(weights): 379 char = context_text[index] 380 label = "\\n" if char == "\n" else char 381 print(f"pos {index:2d} | char {label!r} | weight {weight:.4f}") 382 383 384 if __name__ == "__main__": 385 main()