tiny_lm.py (10765B)
1 #!/usr/bin/env python3 2 """ 3 Tiny character-level neural language model in plain Python. 4 5 This is intentionally small and readable. It is not a transformer and not a real 6 LLM. It exists to show the basic mechanics of next-token training: 7 8 - tokenize text 9 - embed tokens 10 - run a forward pass 11 - compute softmax loss 12 - backpropagate gradients 13 - update weights 14 - sample new text 15 """ 16 17 from __future__ import annotations 18 19 import argparse 20 import math 21 import random 22 from dataclasses import dataclass 23 from pathlib import Path 24 25 26 def zeros(length: int) -> list[float]: 27 return [0.0 for _ in range(length)] 28 29 30 def random_vector(length: int, scale: float, rng: random.Random) -> list[float]: 31 return [rng.uniform(-scale, scale) for _ in range(length)] 32 33 34 def random_matrix(rows: int, cols: int, scale: float, rng: random.Random) -> list[list[float]]: 35 return [random_vector(cols, scale, rng) for _ in range(rows)] 36 37 38 def softmax(logits: list[float]) -> list[float]: 39 peak = max(logits) 40 exps = [math.exp(value - peak) for value in logits] 41 total = sum(exps) 42 return [value / total for value in exps] 43 44 45 @dataclass 46 class Batch: 47 context_ids: list[int] 48 target_id: int 49 50 51 class CharTokenizer: 52 def __init__(self, text: str) -> None: 53 chars = sorted(set(text)) 54 self.stoi = {char: index for index, char in enumerate(chars)} 55 self.itos = {index: char for char, index in self.stoi.items()} 56 57 @property 58 def vocab_size(self) -> int: 59 return len(self.stoi) 60 61 def encode(self, text: str) -> list[int]: 62 return [self.stoi[char] for char in text] 63 64 def decode(self, token_ids: list[int]) -> str: 65 return "".join(self.itos[token_id] for token_id in token_ids) 66 67 68 class TinyLanguageModel: 69 def __init__( 70 self, 71 vocab_size: int, 72 context_size: int, 73 embed_dim: int, 74 hidden_dim: int, 75 seed: int, 76 ) -> None: 77 rng = random.Random(seed) 78 self.vocab_size = vocab_size 79 self.context_size = context_size 80 self.embed_dim = embed_dim 81 self.hidden_dim = hidden_dim 82 self.input_dim = context_size * embed_dim 83 84 self.token_embed = random_matrix(vocab_size, embed_dim, 0.08, rng) 85 self.w1 = random_matrix(self.input_dim, hidden_dim, 0.08, rng) 86 self.b1 = zeros(hidden_dim) 87 self.w2 = random_matrix(hidden_dim, vocab_size, 0.08, rng) 88 self.b2 = zeros(vocab_size) 89 90 def forward(self, context_ids: list[int]) -> tuple[list[float], dict[str, list[float] | list[int]]]: 91 combined = zeros(self.input_dim) 92 for position, token_id in enumerate(context_ids): 93 token_vector = self.token_embed[token_id] 94 for dim in range(self.embed_dim): 95 combined[position * self.embed_dim + dim] = token_vector[dim] 96 97 hidden_pre = zeros(self.hidden_dim) 98 for hidden_idx in range(self.hidden_dim): 99 total = self.b1[hidden_idx] 100 for dim in range(self.input_dim): 101 total += combined[dim] * self.w1[dim][hidden_idx] 102 hidden_pre[hidden_idx] = total 103 104 hidden = [math.tanh(value) for value in hidden_pre] 105 106 logits = zeros(self.vocab_size) 107 for vocab_idx in range(self.vocab_size): 108 total = self.b2[vocab_idx] 109 for hidden_idx in range(self.hidden_dim): 110 total += hidden[hidden_idx] * self.w2[hidden_idx][vocab_idx] 111 logits[vocab_idx] = total 112 113 cache = { 114 "context_ids": context_ids, 115 "combined": combined, 116 "hidden": hidden, 117 } 118 return logits, cache 119 120 def train_step(self, batch: Batch, learning_rate: float) -> float: 121 logits, cache = self.forward(batch.context_ids) 122 probs = softmax(logits) 123 loss = -math.log(probs[batch.target_id] + 1e-12) 124 125 dlogits = probs[:] 126 dlogits[batch.target_id] -= 1.0 127 128 hidden = cache["hidden"] 129 combined = cache["combined"] 130 131 dw2 = [zeros(self.vocab_size) for _ in range(self.hidden_dim)] 132 db2 = dlogits[:] 133 for hidden_idx in range(self.hidden_dim): 134 for vocab_idx in range(self.vocab_size): 135 dw2[hidden_idx][vocab_idx] = hidden[hidden_idx] * dlogits[vocab_idx] 136 137 dhidden = zeros(self.hidden_dim) 138 for hidden_idx in range(self.hidden_dim): 139 total = 0.0 140 for vocab_idx in range(self.vocab_size): 141 total += self.w2[hidden_idx][vocab_idx] * dlogits[vocab_idx] 142 dhidden[hidden_idx] = total 143 144 dhidden_pre = zeros(self.hidden_dim) 145 for hidden_idx in range(self.hidden_dim): 146 dhidden_pre[hidden_idx] = dhidden[hidden_idx] * (1.0 - hidden[hidden_idx] ** 2) 147 148 dw1 = [zeros(self.hidden_dim) for _ in range(self.input_dim)] 149 db1 = dhidden_pre[:] 150 for dim in range(self.input_dim): 151 for hidden_idx in range(self.hidden_dim): 152 dw1[dim][hidden_idx] = combined[dim] * dhidden_pre[hidden_idx] 153 154 dcombined = zeros(self.input_dim) 155 for dim in range(self.input_dim): 156 total = 0.0 157 for hidden_idx in range(self.hidden_dim): 158 total += self.w1[dim][hidden_idx] * dhidden_pre[hidden_idx] 159 dcombined[dim] = total 160 161 for position, token_id in enumerate(batch.context_ids): 162 for dim in range(self.embed_dim): 163 gradient = dcombined[position * self.embed_dim + dim] 164 self.token_embed[token_id][dim] -= learning_rate * gradient 165 166 for dim in range(self.input_dim): 167 for hidden_idx in range(self.hidden_dim): 168 self.w1[dim][hidden_idx] -= learning_rate * dw1[dim][hidden_idx] 169 for hidden_idx in range(self.hidden_dim): 170 self.b1[hidden_idx] -= learning_rate * db1[hidden_idx] 171 172 for hidden_idx in range(self.hidden_dim): 173 for vocab_idx in range(self.vocab_size): 174 self.w2[hidden_idx][vocab_idx] -= learning_rate * dw2[hidden_idx][vocab_idx] 175 for vocab_idx in range(self.vocab_size): 176 self.b2[vocab_idx] -= learning_rate * db2[vocab_idx] 177 178 return loss 179 180 def sample_next_token(self, context_ids: list[int], rng: random.Random, temperature: float) -> int: 181 logits, _ = self.forward(context_ids) 182 scaled = [value / max(temperature, 1e-6) for value in logits] 183 probs = softmax(scaled) 184 185 threshold = rng.random() 186 running = 0.0 187 for index, prob in enumerate(probs): 188 running += prob 189 if threshold <= running: 190 return index 191 return len(probs) - 1 192 193 194 def build_batches(token_ids: list[int], context_size: int) -> list[Batch]: 195 batches: list[Batch] = [] 196 for index in range(len(token_ids) - context_size): 197 batches.append( 198 Batch( 199 context_ids=token_ids[index : index + context_size], 200 target_id=token_ids[index + context_size], 201 ) 202 ) 203 return batches 204 205 206 def generate_text( 207 model: TinyLanguageModel, 208 tokenizer: CharTokenizer, 209 prompt: str, 210 sample_length: int, 211 rng: random.Random, 212 temperature: float, 213 ) -> str: 214 if not prompt: 215 prompt = "language " 216 217 if any(char not in tokenizer.stoi for char in prompt): 218 known = "".join(sorted(tokenizer.stoi.keys())) 219 raise ValueError(f"Prompt contains characters outside the training set. Known chars: {known!r}") 220 221 token_ids = tokenizer.encode(prompt) 222 if len(token_ids) < model.context_size: 223 token_ids = [token_ids[0]] * (model.context_size - len(token_ids)) + token_ids 224 225 generated = token_ids[:] 226 for _ in range(sample_length): 227 context_ids = generated[-model.context_size :] 228 next_token = model.sample_next_token(context_ids, rng, temperature) 229 generated.append(next_token) 230 231 return tokenizer.decode(generated) 232 233 234 def load_corpus(path: Path) -> str: 235 return path.read_text(encoding="utf-8") 236 237 238 def main() -> None: 239 parser = argparse.ArgumentParser(description="Train a tiny next-character language model.") 240 parser.add_argument("--steps", type=int, default=3000, help="Number of SGD steps.") 241 parser.add_argument("--report-every", type=int, default=500, help="Progress report interval.") 242 parser.add_argument("--learning-rate", type=float, default=0.08, help="SGD learning rate.") 243 parser.add_argument("--context-size", type=int, default=12, help="Context window length.") 244 parser.add_argument("--embed-dim", type=int, default=24, help="Embedding size.") 245 parser.add_argument("--hidden-dim", type=int, default=48, help="Hidden layer size.") 246 parser.add_argument("--sample-length", type=int, default=220, help="Characters to generate.") 247 parser.add_argument("--temperature", type=float, default=0.9, help="Sampling temperature.") 248 parser.add_argument("--seed", type=int, default=7, help="Random seed.") 249 parser.add_argument("--prompt", default="language models ", help="Initial prompt for generation.") 250 parser.add_argument( 251 "--repeat-corpus", 252 type=int, 253 default=8, 254 help="Repeat the corpus this many times to make the tiny demo learn faster.", 255 ) 256 args = parser.parse_args() 257 258 project_dir = Path(__file__).resolve().parent 259 corpus_text = load_corpus(project_dir / "corpus.txt") 260 corpus_text = corpus_text * max(args.repeat_corpus, 1) 261 262 tokenizer = CharTokenizer(corpus_text) 263 token_ids = tokenizer.encode(corpus_text) 264 batches = build_batches(token_ids, args.context_size) 265 if not batches: 266 raise ValueError("Corpus is too small for the requested context size.") 267 268 model = TinyLanguageModel( 269 vocab_size=tokenizer.vocab_size, 270 context_size=args.context_size, 271 embed_dim=args.embed_dim, 272 hidden_dim=args.hidden_dim, 273 seed=args.seed, 274 ) 275 276 rng = random.Random(args.seed) 277 running_loss = 0.0 278 279 print(f"vocab size: {tokenizer.vocab_size}") 280 print(f"training samples: {len(batches)}") 281 print("training...") 282 283 for step in range(1, args.steps + 1): 284 batch = batches[rng.randrange(len(batches))] 285 loss = model.train_step(batch, args.learning_rate) 286 running_loss += loss 287 288 if step % args.report_every == 0 or step == 1: 289 avg_loss = running_loss / min(step, args.report_every) 290 print(f"step {step:5d} | avg loss {avg_loss:.4f}") 291 running_loss = 0.0 292 293 print("\n--- sample ---\n") 294 sample = generate_text( 295 model=model, 296 tokenizer=tokenizer, 297 prompt=args.prompt, 298 sample_length=args.sample_length, 299 rng=rng, 300 temperature=args.temperature, 301 ) 302 print(sample) 303 304 305 if __name__ == "__main__": 306 main()