File size: 4,258 Bytes
3696887
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""

Sample from a trained model (London GPT)

"""
import os
import pickle
import torch
from contextlib import nullcontext
from model import GPTConfig, GPT
from tokenizers import ByteLevelBPETokenizer

# ─── tokenizer setup ────────────────────────────────────────────────
tok_folder = "tokenizer_london"
vocab_path = os.path.join(tok_folder, "vocab.json")
merges_path = os.path.join(tok_folder, "merges.txt")
if not (os.path.isfile(vocab_path) and os.path.isfile(merges_path)):
    raise FileNotFoundError(f"Cannot find tokenizer files in {tok_folder}: {vocab_path}, {merges_path}")

tokenizer = ByteLevelBPETokenizer(vocab_path, merges_path)
encode = lambda s: tokenizer.encode(s).ids
decode = lambda ids: tokenizer.decode(ids)
# ────────────────────────────────────────────────────────────────────

# ─── experiment settings (you can override via CLI) ────────────────
import sys
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--out_dir",        default="out_london")
parser.add_argument("--device",         default="cpu")
parser.add_argument("--start",          default="\n")
parser.add_argument("--num_samples",    type=int, default=10)
parser.add_argument("--max_new_tokens", type=int, default=500)
parser.add_argument("--temperature",    type=float, default=0.8)
parser.add_argument("--top_k",          type=int, default=200)
parser.add_argument("--seed",           type=int, default=1337)
parser.add_argument("--compile",        action="store_true")
args = parser.parse_args()

out_dir        = args.out_dir
start          = args.start
num_samples    = args.num_samples
max_new_tokens = args.max_new_tokens
temperature    = args.temperature
top_k          = args.top_k
seed           = args.seed
compile_flag   = args.compile
device_str     = args.device
# ────────────────────────────────────────────────────────────────────

# reproducibility & device
torch.manual_seed(seed)
device = torch.device(device_str)
ctx = nullcontext() if device.type == "cpu" else torch.amp.autocast(device_type=device.type, dtype=torch.float32)

# ─── load model checkpoint ──────────────────────────────────────────
ckpt_path = os.path.join(out_dir, "ckpt.pt")
if not os.path.isfile(ckpt_path):
    raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**ckpt["model_args"])
model   = GPT(gptconf)
sd      = ckpt["model"]
# strip prefix if present
for k in list(sd.keys()):
    if k.startswith("_orig_mod."):
        sd[k[len("_orig_mod."):]] = sd.pop(k)
model.load_state_dict(sd)
model.eval().to(device)
if compile_flag:
    model = torch.compile(model)
# ────────────────────────────────────────────────────────────────────

# prepare prompt tensor
if start.startswith("FILE:"):
    with open(start[5:], "r", encoding="utf-8") as f:
        start = f.read()
ids = encode(start)
x   = torch.tensor([ids], dtype=torch.long, device=device)

# ─── generation ─────────────────────────────────────────────────────
with torch.no_grad(), ctx:
    for i in range(num_samples):
        y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        print(decode(y[0].tolist()))
        print("---------------")
# ────────────────────────────────────────────────────────────────────