tiny_modern_lm.py (10304B)
1 #!/usr/bin/env python3 2 """ 3 Tiny forward-only demo of a more modern transformer-style language model block. 4 5 This file is meant to show architecture, not serious training. It includes: 6 7 - token embeddings 8 - positional embeddings 9 - layer normalization 10 - causal multi-head self-attention 11 - residual connections 12 - feed-forward network 13 - output projection to logits 14 15 Why forward-only? 16 Because implementing a faithful transformer block *and* all of its backpropagation 17 in plain standard-library Python would swamp the core ideas in gradient code. 18 The earlier demos show training mechanics. This file shows the modern block shape. 19 """ 20 21 from __future__ import annotations 22 23 import argparse 24 import math 25 import random 26 from pathlib import Path 27 28 29 def zeros(length: int) -> list[float]: 30 return [0.0 for _ in range(length)] 31 32 33 def random_vector(length: int, scale: float, rng: random.Random) -> list[float]: 34 return [rng.uniform(-scale, scale) for _ in range(length)] 35 36 37 def random_matrix(rows: int, cols: int, scale: float, rng: random.Random) -> list[list[float]]: 38 return [random_vector(cols, scale, rng) for _ in range(rows)] 39 40 41 def add(a: list[float], b: list[float]) -> list[float]: 42 return [x + y for x, y in zip(a, b)] 43 44 45 def softmax(logits: list[float]) -> list[float]: 46 peak = max(logits) 47 exps = [math.exp(value - peak) for value in logits] 48 total = sum(exps) 49 return [value / total for value in exps] 50 51 52 def gelu(values: list[float]) -> list[float]: 53 out = [] 54 for x in values: 55 out.append(0.5 * x * (1.0 + math.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * x ** 3)))) 56 return out 57 58 59 def matvec(vector: list[float], matrix: list[list[float]]) -> list[float]: 60 cols = len(matrix[0]) 61 out = zeros(cols) 62 for col in range(cols): 63 total = 0.0 64 for row in range(len(vector)): 65 total += vector[row] * matrix[row][col] 66 out[col] = total 67 return out 68 69 70 def layer_norm(vector: list[float], gamma: list[float], beta: list[float], eps: float = 1e-5) -> list[float]: 71 mean = sum(vector) / len(vector) 72 variance = sum((value - mean) ** 2 for value in vector) / len(vector) 73 denom = math.sqrt(variance + eps) 74 return [((value - mean) / denom) * gamma[i] + beta[i] for i, value in enumerate(vector)] 75 76 77 class CharTokenizer: 78 def __init__(self, text: str) -> None: 79 chars = sorted(set(text)) 80 self.stoi = {char: index for index, char in enumerate(chars)} 81 self.itos = {index: char for char, index in self.stoi.items()} 82 83 @property 84 def vocab_size(self) -> int: 85 return len(self.stoi) 86 87 def encode(self, text: str) -> list[int]: 88 return [self.stoi[char] for char in text] 89 90 def decode(self, token_ids: list[int]) -> str: 91 return "".join(self.itos[token_id] for token_id in token_ids) 92 93 94 class TinyModernTransformerLM: 95 def __init__(self, vocab_size: int, context_size: int, embed_dim: int, num_heads: int, seed: int) -> None: 96 if embed_dim % num_heads != 0: 97 raise ValueError("embed_dim must be divisible by num_heads") 98 99 rng = random.Random(seed) 100 self.vocab_size = vocab_size 101 self.context_size = context_size 102 self.embed_dim = embed_dim 103 self.num_heads = num_heads 104 self.head_dim = embed_dim // num_heads 105 106 self.token_embed = random_matrix(vocab_size, embed_dim, 0.08, rng) 107 self.pos_embed = random_matrix(context_size, embed_dim, 0.08, rng) 108 109 self.ln1_gamma = [1.0] * embed_dim 110 self.ln1_beta = [0.0] * embed_dim 111 self.ln2_gamma = [1.0] * embed_dim 112 self.ln2_beta = [0.0] * embed_dim 113 114 self.wq = random_matrix(embed_dim, embed_dim, 0.08, rng) 115 self.wk = random_matrix(embed_dim, embed_dim, 0.08, rng) 116 self.wv = random_matrix(embed_dim, embed_dim, 0.08, rng) 117 self.wo = random_matrix(embed_dim, embed_dim, 0.08, rng) 118 119 ff_dim = embed_dim * 4 120 self.w1 = random_matrix(embed_dim, ff_dim, 0.08, rng) 121 self.b1 = zeros(ff_dim) 122 self.w2 = random_matrix(ff_dim, embed_dim, 0.08, rng) 123 self.b2 = zeros(embed_dim) 124 125 self.lm_head = random_matrix(embed_dim, vocab_size, 0.08, rng) 126 self.lm_bias = zeros(vocab_size) 127 128 def causal_attention(self, x_vectors: list[list[float]]) -> tuple[list[list[float]], list[list[list[float]]]]: 129 q_all = [matvec(x, self.wq) for x in x_vectors] 130 k_all = [matvec(x, self.wk) for x in x_vectors] 131 v_all = [matvec(x, self.wv) for x in x_vectors] 132 scale = math.sqrt(self.head_dim) 133 134 head_weights: list[list[list[float]]] = [] 135 head_outputs: list[list[list[float]]] = [] 136 137 for head in range(self.num_heads): 138 start = head * self.head_dim 139 stop = start + self.head_dim 140 weights_for_head: list[list[float]] = [] 141 outputs_for_head: list[list[float]] = [] 142 143 for target_pos in range(self.context_size): 144 scores = [] 145 for source_pos in range(self.context_size): 146 if source_pos > target_pos: 147 scores.append(-1e9) 148 continue 149 dot = 0.0 150 for dim in range(start, stop): 151 dot += q_all[target_pos][dim] * k_all[source_pos][dim] 152 scores.append(dot / scale) 153 154 attn = softmax(scores) 155 weights_for_head.append(attn) 156 157 out = zeros(self.head_dim) 158 for source_pos in range(self.context_size): 159 for local_dim, dim in enumerate(range(start, stop)): 160 out[local_dim] += attn[source_pos] * v_all[source_pos][dim] 161 outputs_for_head.append(out) 162 163 head_weights.append(weights_for_head) 164 head_outputs.append(outputs_for_head) 165 166 combined_outputs: list[list[float]] = [] 167 for pos in range(self.context_size): 168 merged = [] 169 for head in range(self.num_heads): 170 merged.extend(head_outputs[head][pos]) 171 combined_outputs.append(matvec(merged, self.wo)) 172 173 return combined_outputs, head_weights 174 175 def feed_forward(self, vector: list[float]) -> list[float]: 176 hidden = matvec(vector, self.w1) 177 hidden = [hidden[i] + self.b1[i] for i in range(len(hidden))] 178 hidden = gelu(hidden) 179 out = matvec(hidden, self.w2) 180 return [out[i] + self.b2[i] for i in range(len(out))] 181 182 def forward(self, context_ids: list[int]) -> tuple[list[float], dict[str, object]]: 183 x = [] 184 for pos, token_id in enumerate(context_ids): 185 x.append(add(self.token_embed[token_id], self.pos_embed[pos])) 186 187 ln1 = [layer_norm(vec, self.ln1_gamma, self.ln1_beta) for vec in x] 188 attn_out, head_weights = self.causal_attention(ln1) 189 x = [add(x[pos], attn_out[pos]) for pos in range(self.context_size)] 190 191 ln2 = [layer_norm(vec, self.ln2_gamma, self.ln2_beta) for vec in x] 192 ff_out = [self.feed_forward(vec) for vec in ln2] 193 x = [add(x[pos], ff_out[pos]) for pos in range(self.context_size)] 194 195 final_state = x[-1] 196 logits = matvec(final_state, self.lm_head) 197 logits = [logits[i] + self.lm_bias[i] for i in range(self.vocab_size)] 198 199 cache = { 200 "head_weights": head_weights, 201 "final_state": final_state, 202 } 203 return logits, cache 204 205 206 def top_k_indices(values: list[float], k: int) -> list[int]: 207 return sorted(range(len(values)), key=lambda index: values[index], reverse=True)[:k] 208 209 210 def load_corpus(path: Path) -> str: 211 return path.read_text(encoding="utf-8") 212 213 214 def main() -> None: 215 parser = argparse.ArgumentParser(description="Inspect a tiny modern transformer-style LM block.") 216 parser.add_argument("--context-size", type=int, default=12, help="Context window length.") 217 parser.add_argument("--embed-dim", type=int, default=24, help="Embedding size.") 218 parser.add_argument("--num-heads", type=int, default=3, help="Number of attention heads.") 219 parser.add_argument("--seed", type=int, default=7, help="Random seed.") 220 parser.add_argument("--prompt", default="language mod", help="Prompt to inspect.") 221 parser.add_argument("--top-k", type=int, default=8, help="How many next-token candidates to print.") 222 args = parser.parse_args() 223 224 project_dir = Path(__file__).resolve().parent 225 corpus_text = load_corpus(project_dir / "corpus.txt") 226 tokenizer = CharTokenizer(corpus_text) 227 228 if any(char not in tokenizer.stoi for char in args.prompt): 229 known = "".join(sorted(tokenizer.stoi.keys())) 230 raise ValueError(f"Prompt contains characters outside the training set. Known chars: {known!r}") 231 232 token_ids = tokenizer.encode(args.prompt) 233 if len(token_ids) < args.context_size: 234 token_ids = [token_ids[0]] * (args.context_size - len(token_ids)) + token_ids 235 context_ids = token_ids[-args.context_size :] 236 237 model = TinyModernTransformerLM( 238 vocab_size=tokenizer.vocab_size, 239 context_size=args.context_size, 240 embed_dim=args.embed_dim, 241 num_heads=args.num_heads, 242 seed=args.seed, 243 ) 244 245 logits, cache = model.forward(context_ids) 246 probs = softmax(logits) 247 top = top_k_indices(probs, args.top_k) 248 249 print("tiny modern transformer-style LM block") 250 print(f"context: {tokenizer.decode(context_ids)!r}") 251 print(f"vocab size: {tokenizer.vocab_size}") 252 print(f"embed dim: {args.embed_dim}") 253 print(f"heads: {args.num_heads}") 254 255 print("\n--- top next-token candidates ---\n") 256 for index in top: 257 char = tokenizer.itos[index] 258 label = "\\n" if char == "\n" else char 259 print(f"{label!r}: {probs[index]:.4f}") 260 261 head_weights = cache["head_weights"] 262 print("\n--- attention from final position ---\n") 263 context_text = tokenizer.decode(context_ids) 264 for head_index, weights_for_head in enumerate(head_weights): 265 final_weights = weights_for_head[-1] 266 print(f"head {head_index}:") 267 for pos, weight in enumerate(final_weights): 268 char = context_text[pos] 269 label = "\\n" if char == "\n" else char 270 print(f" pos {pos:2d} | char {label!r} | weight {weight:.4f}") 271 272 273 if __name__ == "__main__": 274 main()