tiny-llm-demo

tiny-llm-demo - small plain-Python LLM learning demos.
git clone git://git.beep.wimdupont.com/tiny-llm-demo.git
Log | Files | Refs | README | LICENSE

tiny_lm.py (10765B)


      1 #!/usr/bin/env python3
      2 """
      3 Tiny character-level neural language model in plain Python.
      4 
      5 This is intentionally small and readable. It is not a transformer and not a real
      6 LLM. It exists to show the basic mechanics of next-token training:
      7 
      8 - tokenize text
      9 - embed tokens
     10 - run a forward pass
     11 - compute softmax loss
     12 - backpropagate gradients
     13 - update weights
     14 - sample new text
     15 """
     16 
     17 from __future__ import annotations
     18 
     19 import argparse
     20 import math
     21 import random
     22 from dataclasses import dataclass
     23 from pathlib import Path
     24 
     25 
     26 def zeros(length: int) -> list[float]:
     27     return [0.0 for _ in range(length)]
     28 
     29 
     30 def random_vector(length: int, scale: float, rng: random.Random) -> list[float]:
     31     return [rng.uniform(-scale, scale) for _ in range(length)]
     32 
     33 
     34 def random_matrix(rows: int, cols: int, scale: float, rng: random.Random) -> list[list[float]]:
     35     return [random_vector(cols, scale, rng) for _ in range(rows)]
     36 
     37 
     38 def softmax(logits: list[float]) -> list[float]:
     39     peak = max(logits)
     40     exps = [math.exp(value - peak) for value in logits]
     41     total = sum(exps)
     42     return [value / total for value in exps]
     43 
     44 
     45 @dataclass
     46 class Batch:
     47     context_ids: list[int]
     48     target_id: int
     49 
     50 
     51 class CharTokenizer:
     52     def __init__(self, text: str) -> None:
     53         chars = sorted(set(text))
     54         self.stoi = {char: index for index, char in enumerate(chars)}
     55         self.itos = {index: char for char, index in self.stoi.items()}
     56 
     57     @property
     58     def vocab_size(self) -> int:
     59         return len(self.stoi)
     60 
     61     def encode(self, text: str) -> list[int]:
     62         return [self.stoi[char] for char in text]
     63 
     64     def decode(self, token_ids: list[int]) -> str:
     65         return "".join(self.itos[token_id] for token_id in token_ids)
     66 
     67 
     68 class TinyLanguageModel:
     69     def __init__(
     70         self,
     71         vocab_size: int,
     72         context_size: int,
     73         embed_dim: int,
     74         hidden_dim: int,
     75         seed: int,
     76     ) -> None:
     77         rng = random.Random(seed)
     78         self.vocab_size = vocab_size
     79         self.context_size = context_size
     80         self.embed_dim = embed_dim
     81         self.hidden_dim = hidden_dim
     82         self.input_dim = context_size * embed_dim
     83 
     84         self.token_embed = random_matrix(vocab_size, embed_dim, 0.08, rng)
     85         self.w1 = random_matrix(self.input_dim, hidden_dim, 0.08, rng)
     86         self.b1 = zeros(hidden_dim)
     87         self.w2 = random_matrix(hidden_dim, vocab_size, 0.08, rng)
     88         self.b2 = zeros(vocab_size)
     89 
     90     def forward(self, context_ids: list[int]) -> tuple[list[float], dict[str, list[float] | list[int]]]:
     91         combined = zeros(self.input_dim)
     92         for position, token_id in enumerate(context_ids):
     93             token_vector = self.token_embed[token_id]
     94             for dim in range(self.embed_dim):
     95                 combined[position * self.embed_dim + dim] = token_vector[dim]
     96 
     97         hidden_pre = zeros(self.hidden_dim)
     98         for hidden_idx in range(self.hidden_dim):
     99             total = self.b1[hidden_idx]
    100             for dim in range(self.input_dim):
    101                 total += combined[dim] * self.w1[dim][hidden_idx]
    102             hidden_pre[hidden_idx] = total
    103 
    104         hidden = [math.tanh(value) for value in hidden_pre]
    105 
    106         logits = zeros(self.vocab_size)
    107         for vocab_idx in range(self.vocab_size):
    108             total = self.b2[vocab_idx]
    109             for hidden_idx in range(self.hidden_dim):
    110                 total += hidden[hidden_idx] * self.w2[hidden_idx][vocab_idx]
    111             logits[vocab_idx] = total
    112 
    113         cache = {
    114             "context_ids": context_ids,
    115             "combined": combined,
    116             "hidden": hidden,
    117         }
    118         return logits, cache
    119 
    120     def train_step(self, batch: Batch, learning_rate: float) -> float:
    121         logits, cache = self.forward(batch.context_ids)
    122         probs = softmax(logits)
    123         loss = -math.log(probs[batch.target_id] + 1e-12)
    124 
    125         dlogits = probs[:]
    126         dlogits[batch.target_id] -= 1.0
    127 
    128         hidden = cache["hidden"]
    129         combined = cache["combined"]
    130 
    131         dw2 = [zeros(self.vocab_size) for _ in range(self.hidden_dim)]
    132         db2 = dlogits[:]
    133         for hidden_idx in range(self.hidden_dim):
    134             for vocab_idx in range(self.vocab_size):
    135                 dw2[hidden_idx][vocab_idx] = hidden[hidden_idx] * dlogits[vocab_idx]
    136 
    137         dhidden = zeros(self.hidden_dim)
    138         for hidden_idx in range(self.hidden_dim):
    139             total = 0.0
    140             for vocab_idx in range(self.vocab_size):
    141                 total += self.w2[hidden_idx][vocab_idx] * dlogits[vocab_idx]
    142             dhidden[hidden_idx] = total
    143 
    144         dhidden_pre = zeros(self.hidden_dim)
    145         for hidden_idx in range(self.hidden_dim):
    146             dhidden_pre[hidden_idx] = dhidden[hidden_idx] * (1.0 - hidden[hidden_idx] ** 2)
    147 
    148         dw1 = [zeros(self.hidden_dim) for _ in range(self.input_dim)]
    149         db1 = dhidden_pre[:]
    150         for dim in range(self.input_dim):
    151             for hidden_idx in range(self.hidden_dim):
    152                 dw1[dim][hidden_idx] = combined[dim] * dhidden_pre[hidden_idx]
    153 
    154         dcombined = zeros(self.input_dim)
    155         for dim in range(self.input_dim):
    156             total = 0.0
    157             for hidden_idx in range(self.hidden_dim):
    158                 total += self.w1[dim][hidden_idx] * dhidden_pre[hidden_idx]
    159             dcombined[dim] = total
    160 
    161         for position, token_id in enumerate(batch.context_ids):
    162             for dim in range(self.embed_dim):
    163                 gradient = dcombined[position * self.embed_dim + dim]
    164                 self.token_embed[token_id][dim] -= learning_rate * gradient
    165 
    166         for dim in range(self.input_dim):
    167             for hidden_idx in range(self.hidden_dim):
    168                 self.w1[dim][hidden_idx] -= learning_rate * dw1[dim][hidden_idx]
    169         for hidden_idx in range(self.hidden_dim):
    170             self.b1[hidden_idx] -= learning_rate * db1[hidden_idx]
    171 
    172         for hidden_idx in range(self.hidden_dim):
    173             for vocab_idx in range(self.vocab_size):
    174                 self.w2[hidden_idx][vocab_idx] -= learning_rate * dw2[hidden_idx][vocab_idx]
    175         for vocab_idx in range(self.vocab_size):
    176             self.b2[vocab_idx] -= learning_rate * db2[vocab_idx]
    177 
    178         return loss
    179 
    180     def sample_next_token(self, context_ids: list[int], rng: random.Random, temperature: float) -> int:
    181         logits, _ = self.forward(context_ids)
    182         scaled = [value / max(temperature, 1e-6) for value in logits]
    183         probs = softmax(scaled)
    184 
    185         threshold = rng.random()
    186         running = 0.0
    187         for index, prob in enumerate(probs):
    188             running += prob
    189             if threshold <= running:
    190                 return index
    191         return len(probs) - 1
    192 
    193 
    194 def build_batches(token_ids: list[int], context_size: int) -> list[Batch]:
    195     batches: list[Batch] = []
    196     for index in range(len(token_ids) - context_size):
    197         batches.append(
    198             Batch(
    199                 context_ids=token_ids[index : index + context_size],
    200                 target_id=token_ids[index + context_size],
    201             )
    202         )
    203     return batches
    204 
    205 
    206 def generate_text(
    207     model: TinyLanguageModel,
    208     tokenizer: CharTokenizer,
    209     prompt: str,
    210     sample_length: int,
    211     rng: random.Random,
    212     temperature: float,
    213 ) -> str:
    214     if not prompt:
    215         prompt = "language "
    216 
    217     if any(char not in tokenizer.stoi for char in prompt):
    218         known = "".join(sorted(tokenizer.stoi.keys()))
    219         raise ValueError(f"Prompt contains characters outside the training set. Known chars: {known!r}")
    220 
    221     token_ids = tokenizer.encode(prompt)
    222     if len(token_ids) < model.context_size:
    223         token_ids = [token_ids[0]] * (model.context_size - len(token_ids)) + token_ids
    224 
    225     generated = token_ids[:]
    226     for _ in range(sample_length):
    227         context_ids = generated[-model.context_size :]
    228         next_token = model.sample_next_token(context_ids, rng, temperature)
    229         generated.append(next_token)
    230 
    231     return tokenizer.decode(generated)
    232 
    233 
    234 def load_corpus(path: Path) -> str:
    235     return path.read_text(encoding="utf-8")
    236 
    237 
    238 def main() -> None:
    239     parser = argparse.ArgumentParser(description="Train a tiny next-character language model.")
    240     parser.add_argument("--steps", type=int, default=3000, help="Number of SGD steps.")
    241     parser.add_argument("--report-every", type=int, default=500, help="Progress report interval.")
    242     parser.add_argument("--learning-rate", type=float, default=0.08, help="SGD learning rate.")
    243     parser.add_argument("--context-size", type=int, default=12, help="Context window length.")
    244     parser.add_argument("--embed-dim", type=int, default=24, help="Embedding size.")
    245     parser.add_argument("--hidden-dim", type=int, default=48, help="Hidden layer size.")
    246     parser.add_argument("--sample-length", type=int, default=220, help="Characters to generate.")
    247     parser.add_argument("--temperature", type=float, default=0.9, help="Sampling temperature.")
    248     parser.add_argument("--seed", type=int, default=7, help="Random seed.")
    249     parser.add_argument("--prompt", default="language models ", help="Initial prompt for generation.")
    250     parser.add_argument(
    251         "--repeat-corpus",
    252         type=int,
    253         default=8,
    254         help="Repeat the corpus this many times to make the tiny demo learn faster.",
    255     )
    256     args = parser.parse_args()
    257 
    258     project_dir = Path(__file__).resolve().parent
    259     corpus_text = load_corpus(project_dir / "corpus.txt")
    260     corpus_text = corpus_text * max(args.repeat_corpus, 1)
    261 
    262     tokenizer = CharTokenizer(corpus_text)
    263     token_ids = tokenizer.encode(corpus_text)
    264     batches = build_batches(token_ids, args.context_size)
    265     if not batches:
    266         raise ValueError("Corpus is too small for the requested context size.")
    267 
    268     model = TinyLanguageModel(
    269         vocab_size=tokenizer.vocab_size,
    270         context_size=args.context_size,
    271         embed_dim=args.embed_dim,
    272         hidden_dim=args.hidden_dim,
    273         seed=args.seed,
    274     )
    275 
    276     rng = random.Random(args.seed)
    277     running_loss = 0.0
    278 
    279     print(f"vocab size: {tokenizer.vocab_size}")
    280     print(f"training samples: {len(batches)}")
    281     print("training...")
    282 
    283     for step in range(1, args.steps + 1):
    284         batch = batches[rng.randrange(len(batches))]
    285         loss = model.train_step(batch, args.learning_rate)
    286         running_loss += loss
    287 
    288         if step % args.report_every == 0 or step == 1:
    289             avg_loss = running_loss / min(step, args.report_every)
    290             print(f"step {step:5d} | avg loss {avg_loss:.4f}")
    291             running_loss = 0.0
    292 
    293     print("\n--- sample ---\n")
    294     sample = generate_text(
    295         model=model,
    296         tokenizer=tokenizer,
    297         prompt=args.prompt,
    298         sample_length=args.sample_length,
    299         rng=rng,
    300         temperature=args.temperature,
    301     )
    302     print(sample)
    303 
    304 
    305 if __name__ == "__main__":
    306     main()