haykgrigorian's picture
Upload 12 files
3696887 verified
"""
Sample from a trained model (London GPT)
"""
import os
import pickle
import torch
from contextlib import nullcontext
from model import GPTConfig, GPT
from tokenizers import ByteLevelBPETokenizer
# ─── tokenizer setup ────────────────────────────────────────────────
tok_folder = "tokenizer_london"
vocab_path = os.path.join(tok_folder, "vocab.json")
merges_path = os.path.join(tok_folder, "merges.txt")
if not (os.path.isfile(vocab_path) and os.path.isfile(merges_path)):
raise FileNotFoundError(f"Cannot find tokenizer files in {tok_folder}: {vocab_path}, {merges_path}")
tokenizer = ByteLevelBPETokenizer(vocab_path, merges_path)
encode = lambda s: tokenizer.encode(s).ids
decode = lambda ids: tokenizer.decode(ids)
# ────────────────────────────────────────────────────────────────────
# ─── experiment settings (you can override via CLI) ────────────────
import sys
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--out_dir", default="out_london")
parser.add_argument("--device", default="cpu")
parser.add_argument("--start", default="\n")
parser.add_argument("--num_samples", type=int, default=10)
parser.add_argument("--max_new_tokens", type=int, default=500)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_k", type=int, default=200)
parser.add_argument("--seed", type=int, default=1337)
parser.add_argument("--compile", action="store_true")
args = parser.parse_args()
out_dir = args.out_dir
start = args.start
num_samples = args.num_samples
max_new_tokens = args.max_new_tokens
temperature = args.temperature
top_k = args.top_k
seed = args.seed
compile_flag = args.compile
device_str = args.device
# ────────────────────────────────────────────────────────────────────
# reproducibility & device
torch.manual_seed(seed)
device = torch.device(device_str)
ctx = nullcontext() if device.type == "cpu" else torch.amp.autocast(device_type=device.type, dtype=torch.float32)
# ─── load model checkpoint ──────────────────────────────────────────
ckpt_path = os.path.join(out_dir, "ckpt.pt")
if not os.path.isfile(ckpt_path):
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**ckpt["model_args"])
model = GPT(gptconf)
sd = ckpt["model"]
# strip prefix if present
for k in list(sd.keys()):
if k.startswith("_orig_mod."):
sd[k[len("_orig_mod."):]] = sd.pop(k)
model.load_state_dict(sd)
model.eval().to(device)
if compile_flag:
model = torch.compile(model)
# ────────────────────────────────────────────────────────────────────
# prepare prompt tensor
if start.startswith("FILE:"):
with open(start[5:], "r", encoding="utf-8") as f:
start = f.read()
ids = encode(start)
x = torch.tensor([ids], dtype=torch.long, device=device)
# ─── generation ─────────────────────────────────────────────────────
with torch.no_grad(), ctx:
for i in range(num_samples):
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist()))
print("---------------")
# ────────────────────────────────────────────────────────────────────