tiny-llm-demo

tiny-llm-demo - small plain-Python LLM learning demos.
git clone git://git.beep.wimdupont.com/tiny-llm-demo.git
Log | Files | Refs | README | LICENSE

tiny_modern_lm.py (10304B)


      1 #!/usr/bin/env python3
      2 """
      3 Tiny forward-only demo of a more modern transformer-style language model block.
      4 
      5 This file is meant to show architecture, not serious training. It includes:
      6 
      7 - token embeddings
      8 - positional embeddings
      9 - layer normalization
     10 - causal multi-head self-attention
     11 - residual connections
     12 - feed-forward network
     13 - output projection to logits
     14 
     15 Why forward-only?
     16 Because implementing a faithful transformer block *and* all of its backpropagation
     17 in plain standard-library Python would swamp the core ideas in gradient code.
     18 The earlier demos show training mechanics. This file shows the modern block shape.
     19 """
     20 
     21 from __future__ import annotations
     22 
     23 import argparse
     24 import math
     25 import random
     26 from pathlib import Path
     27 
     28 
     29 def zeros(length: int) -> list[float]:
     30     return [0.0 for _ in range(length)]
     31 
     32 
     33 def random_vector(length: int, scale: float, rng: random.Random) -> list[float]:
     34     return [rng.uniform(-scale, scale) for _ in range(length)]
     35 
     36 
     37 def random_matrix(rows: int, cols: int, scale: float, rng: random.Random) -> list[list[float]]:
     38     return [random_vector(cols, scale, rng) for _ in range(rows)]
     39 
     40 
     41 def add(a: list[float], b: list[float]) -> list[float]:
     42     return [x + y for x, y in zip(a, b)]
     43 
     44 
     45 def softmax(logits: list[float]) -> list[float]:
     46     peak = max(logits)
     47     exps = [math.exp(value - peak) for value in logits]
     48     total = sum(exps)
     49     return [value / total for value in exps]
     50 
     51 
     52 def gelu(values: list[float]) -> list[float]:
     53     out = []
     54     for x in values:
     55         out.append(0.5 * x * (1.0 + math.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * x ** 3))))
     56     return out
     57 
     58 
     59 def matvec(vector: list[float], matrix: list[list[float]]) -> list[float]:
     60     cols = len(matrix[0])
     61     out = zeros(cols)
     62     for col in range(cols):
     63         total = 0.0
     64         for row in range(len(vector)):
     65             total += vector[row] * matrix[row][col]
     66         out[col] = total
     67     return out
     68 
     69 
     70 def layer_norm(vector: list[float], gamma: list[float], beta: list[float], eps: float = 1e-5) -> list[float]:
     71     mean = sum(vector) / len(vector)
     72     variance = sum((value - mean) ** 2 for value in vector) / len(vector)
     73     denom = math.sqrt(variance + eps)
     74     return [((value - mean) / denom) * gamma[i] + beta[i] for i, value in enumerate(vector)]
     75 
     76 
     77 class CharTokenizer:
     78     def __init__(self, text: str) -> None:
     79         chars = sorted(set(text))
     80         self.stoi = {char: index for index, char in enumerate(chars)}
     81         self.itos = {index: char for char, index in self.stoi.items()}
     82 
     83     @property
     84     def vocab_size(self) -> int:
     85         return len(self.stoi)
     86 
     87     def encode(self, text: str) -> list[int]:
     88         return [self.stoi[char] for char in text]
     89 
     90     def decode(self, token_ids: list[int]) -> str:
     91         return "".join(self.itos[token_id] for token_id in token_ids)
     92 
     93 
     94 class TinyModernTransformerLM:
     95     def __init__(self, vocab_size: int, context_size: int, embed_dim: int, num_heads: int, seed: int) -> None:
     96         if embed_dim % num_heads != 0:
     97             raise ValueError("embed_dim must be divisible by num_heads")
     98 
     99         rng = random.Random(seed)
    100         self.vocab_size = vocab_size
    101         self.context_size = context_size
    102         self.embed_dim = embed_dim
    103         self.num_heads = num_heads
    104         self.head_dim = embed_dim // num_heads
    105 
    106         self.token_embed = random_matrix(vocab_size, embed_dim, 0.08, rng)
    107         self.pos_embed = random_matrix(context_size, embed_dim, 0.08, rng)
    108 
    109         self.ln1_gamma = [1.0] * embed_dim
    110         self.ln1_beta = [0.0] * embed_dim
    111         self.ln2_gamma = [1.0] * embed_dim
    112         self.ln2_beta = [0.0] * embed_dim
    113 
    114         self.wq = random_matrix(embed_dim, embed_dim, 0.08, rng)
    115         self.wk = random_matrix(embed_dim, embed_dim, 0.08, rng)
    116         self.wv = random_matrix(embed_dim, embed_dim, 0.08, rng)
    117         self.wo = random_matrix(embed_dim, embed_dim, 0.08, rng)
    118 
    119         ff_dim = embed_dim * 4
    120         self.w1 = random_matrix(embed_dim, ff_dim, 0.08, rng)
    121         self.b1 = zeros(ff_dim)
    122         self.w2 = random_matrix(ff_dim, embed_dim, 0.08, rng)
    123         self.b2 = zeros(embed_dim)
    124 
    125         self.lm_head = random_matrix(embed_dim, vocab_size, 0.08, rng)
    126         self.lm_bias = zeros(vocab_size)
    127 
    128     def causal_attention(self, x_vectors: list[list[float]]) -> tuple[list[list[float]], list[list[list[float]]]]:
    129         q_all = [matvec(x, self.wq) for x in x_vectors]
    130         k_all = [matvec(x, self.wk) for x in x_vectors]
    131         v_all = [matvec(x, self.wv) for x in x_vectors]
    132         scale = math.sqrt(self.head_dim)
    133 
    134         head_weights: list[list[list[float]]] = []
    135         head_outputs: list[list[list[float]]] = []
    136 
    137         for head in range(self.num_heads):
    138             start = head * self.head_dim
    139             stop = start + self.head_dim
    140             weights_for_head: list[list[float]] = []
    141             outputs_for_head: list[list[float]] = []
    142 
    143             for target_pos in range(self.context_size):
    144                 scores = []
    145                 for source_pos in range(self.context_size):
    146                     if source_pos > target_pos:
    147                         scores.append(-1e9)
    148                         continue
    149                     dot = 0.0
    150                     for dim in range(start, stop):
    151                         dot += q_all[target_pos][dim] * k_all[source_pos][dim]
    152                     scores.append(dot / scale)
    153 
    154                 attn = softmax(scores)
    155                 weights_for_head.append(attn)
    156 
    157                 out = zeros(self.head_dim)
    158                 for source_pos in range(self.context_size):
    159                     for local_dim, dim in enumerate(range(start, stop)):
    160                         out[local_dim] += attn[source_pos] * v_all[source_pos][dim]
    161                 outputs_for_head.append(out)
    162 
    163             head_weights.append(weights_for_head)
    164             head_outputs.append(outputs_for_head)
    165 
    166         combined_outputs: list[list[float]] = []
    167         for pos in range(self.context_size):
    168             merged = []
    169             for head in range(self.num_heads):
    170                 merged.extend(head_outputs[head][pos])
    171             combined_outputs.append(matvec(merged, self.wo))
    172 
    173         return combined_outputs, head_weights
    174 
    175     def feed_forward(self, vector: list[float]) -> list[float]:
    176         hidden = matvec(vector, self.w1)
    177         hidden = [hidden[i] + self.b1[i] for i in range(len(hidden))]
    178         hidden = gelu(hidden)
    179         out = matvec(hidden, self.w2)
    180         return [out[i] + self.b2[i] for i in range(len(out))]
    181 
    182     def forward(self, context_ids: list[int]) -> tuple[list[float], dict[str, object]]:
    183         x = []
    184         for pos, token_id in enumerate(context_ids):
    185             x.append(add(self.token_embed[token_id], self.pos_embed[pos]))
    186 
    187         ln1 = [layer_norm(vec, self.ln1_gamma, self.ln1_beta) for vec in x]
    188         attn_out, head_weights = self.causal_attention(ln1)
    189         x = [add(x[pos], attn_out[pos]) for pos in range(self.context_size)]
    190 
    191         ln2 = [layer_norm(vec, self.ln2_gamma, self.ln2_beta) for vec in x]
    192         ff_out = [self.feed_forward(vec) for vec in ln2]
    193         x = [add(x[pos], ff_out[pos]) for pos in range(self.context_size)]
    194 
    195         final_state = x[-1]
    196         logits = matvec(final_state, self.lm_head)
    197         logits = [logits[i] + self.lm_bias[i] for i in range(self.vocab_size)]
    198 
    199         cache = {
    200             "head_weights": head_weights,
    201             "final_state": final_state,
    202         }
    203         return logits, cache
    204 
    205 
    206 def top_k_indices(values: list[float], k: int) -> list[int]:
    207     return sorted(range(len(values)), key=lambda index: values[index], reverse=True)[:k]
    208 
    209 
    210 def load_corpus(path: Path) -> str:
    211     return path.read_text(encoding="utf-8")
    212 
    213 
    214 def main() -> None:
    215     parser = argparse.ArgumentParser(description="Inspect a tiny modern transformer-style LM block.")
    216     parser.add_argument("--context-size", type=int, default=12, help="Context window length.")
    217     parser.add_argument("--embed-dim", type=int, default=24, help="Embedding size.")
    218     parser.add_argument("--num-heads", type=int, default=3, help="Number of attention heads.")
    219     parser.add_argument("--seed", type=int, default=7, help="Random seed.")
    220     parser.add_argument("--prompt", default="language mod", help="Prompt to inspect.")
    221     parser.add_argument("--top-k", type=int, default=8, help="How many next-token candidates to print.")
    222     args = parser.parse_args()
    223 
    224     project_dir = Path(__file__).resolve().parent
    225     corpus_text = load_corpus(project_dir / "corpus.txt")
    226     tokenizer = CharTokenizer(corpus_text)
    227 
    228     if any(char not in tokenizer.stoi for char in args.prompt):
    229         known = "".join(sorted(tokenizer.stoi.keys()))
    230         raise ValueError(f"Prompt contains characters outside the training set. Known chars: {known!r}")
    231 
    232     token_ids = tokenizer.encode(args.prompt)
    233     if len(token_ids) < args.context_size:
    234         token_ids = [token_ids[0]] * (args.context_size - len(token_ids)) + token_ids
    235     context_ids = token_ids[-args.context_size :]
    236 
    237     model = TinyModernTransformerLM(
    238         vocab_size=tokenizer.vocab_size,
    239         context_size=args.context_size,
    240         embed_dim=args.embed_dim,
    241         num_heads=args.num_heads,
    242         seed=args.seed,
    243     )
    244 
    245     logits, cache = model.forward(context_ids)
    246     probs = softmax(logits)
    247     top = top_k_indices(probs, args.top_k)
    248 
    249     print("tiny modern transformer-style LM block")
    250     print(f"context: {tokenizer.decode(context_ids)!r}")
    251     print(f"vocab size: {tokenizer.vocab_size}")
    252     print(f"embed dim: {args.embed_dim}")
    253     print(f"heads: {args.num_heads}")
    254 
    255     print("\n--- top next-token candidates ---\n")
    256     for index in top:
    257         char = tokenizer.itos[index]
    258         label = "\\n" if char == "\n" else char
    259         print(f"{label!r}: {probs[index]:.4f}")
    260 
    261     head_weights = cache["head_weights"]
    262     print("\n--- attention from final position ---\n")
    263     context_text = tokenizer.decode(context_ids)
    264     for head_index, weights_for_head in enumerate(head_weights):
    265         final_weights = weights_for_head[-1]
    266         print(f"head {head_index}:")
    267         for pos, weight in enumerate(final_weights):
    268             char = context_text[pos]
    269             label = "\\n" if char == "\n" else char
    270             print(f"  pos {pos:2d} | char {label!r} | weight {weight:.4f}")
    271 
    272 
    273 if __name__ == "__main__":
    274     main()