|
"""
|
|
Sample from a trained model (London GPT)
|
|
"""
|
|
import os
|
|
import pickle
|
|
import torch
|
|
from contextlib import nullcontext
|
|
from model import GPTConfig, GPT
|
|
from tokenizers import ByteLevelBPETokenizer
|
|
|
|
|
|
tok_folder = "tokenizer_london"
|
|
vocab_path = os.path.join(tok_folder, "vocab.json")
|
|
merges_path = os.path.join(tok_folder, "merges.txt")
|
|
if not (os.path.isfile(vocab_path) and os.path.isfile(merges_path)):
|
|
raise FileNotFoundError(f"Cannot find tokenizer files in {tok_folder}: {vocab_path}, {merges_path}")
|
|
|
|
tokenizer = ByteLevelBPETokenizer(vocab_path, merges_path)
|
|
encode = lambda s: tokenizer.encode(s).ids
|
|
decode = lambda ids: tokenizer.decode(ids)
|
|
|
|
|
|
|
|
import sys
|
|
import argparse
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--out_dir", default="out_london")
|
|
parser.add_argument("--device", default="cpu")
|
|
parser.add_argument("--start", default="\n")
|
|
parser.add_argument("--num_samples", type=int, default=10)
|
|
parser.add_argument("--max_new_tokens", type=int, default=500)
|
|
parser.add_argument("--temperature", type=float, default=0.8)
|
|
parser.add_argument("--top_k", type=int, default=200)
|
|
parser.add_argument("--seed", type=int, default=1337)
|
|
parser.add_argument("--compile", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
out_dir = args.out_dir
|
|
start = args.start
|
|
num_samples = args.num_samples
|
|
max_new_tokens = args.max_new_tokens
|
|
temperature = args.temperature
|
|
top_k = args.top_k
|
|
seed = args.seed
|
|
compile_flag = args.compile
|
|
device_str = args.device
|
|
|
|
|
|
|
|
torch.manual_seed(seed)
|
|
device = torch.device(device_str)
|
|
ctx = nullcontext() if device.type == "cpu" else torch.amp.autocast(device_type=device.type, dtype=torch.float32)
|
|
|
|
|
|
ckpt_path = os.path.join(out_dir, "ckpt.pt")
|
|
if not os.path.isfile(ckpt_path):
|
|
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
|
ckpt = torch.load(ckpt_path, map_location=device)
|
|
gptconf = GPTConfig(**ckpt["model_args"])
|
|
model = GPT(gptconf)
|
|
sd = ckpt["model"]
|
|
|
|
for k in list(sd.keys()):
|
|
if k.startswith("_orig_mod."):
|
|
sd[k[len("_orig_mod."):]] = sd.pop(k)
|
|
model.load_state_dict(sd)
|
|
model.eval().to(device)
|
|
if compile_flag:
|
|
model = torch.compile(model)
|
|
|
|
|
|
|
|
if start.startswith("FILE:"):
|
|
with open(start[5:], "r", encoding="utf-8") as f:
|
|
start = f.read()
|
|
ids = encode(start)
|
|
x = torch.tensor([ids], dtype=torch.long, device=device)
|
|
|
|
|
|
with torch.no_grad(), ctx:
|
|
for i in range(num_samples):
|
|
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
|
print(decode(y[0].tolist()))
|
|
print("---------------")
|
|
|
|
|