twhitworth's picture
fp16
c10949d verified
#!/usr/bin/env python3
"""
Instant fp16 inference for openai/gpt-oss-120b
No bfloat16 assertion, no 8-bit, no re-quantization.
"""
import os
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
# Make sure to set your model output directory and make sure it has 755 permissions.
MODEL_ID = "openai/gpt-oss-120b"
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./fp16/gpt-oss-120b-fp16")
# 1. silence tokenizer warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 2. grab tokenizer
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
# 3. load model in fp16
# Make sure to change these max_memory settings.
max_memory = {0: "17GiB", 1: "17GiB", 2: "17GiB", 3: "17GiB", 4: "17GiB", 5: "17GiB", 6: "17GiB", 7: "17GiB", "cpu": "196GiB"}
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="sequential",
low_cpu_mem_usage=True,
max_memory=max_memory,
offload_folder="./offload_cache",
trust_remote_code=True,
)
# Ensure any lingering BF16 params/buffers are cast to FP16 (handles lazy/offloaded shards)
def _cast_module_floating_to_dtype_(module: torch.nn.Module, dtype: torch.dtype) -> None:
for submodule in module.modules():
# Parameters
for name, param in list(submodule._parameters.items()):
if param is None:
continue
if getattr(param, "device", None) is not None and str(param.device) == "meta":
continue
if param.is_floating_point() and param.dtype != dtype:
# In-place dtype change to preserve Parameter object/hooks
param.data = param.data.to(dtype)
# Buffers
for name, buf in list(submodule._buffers.items()):
if buf is None:
continue
if getattr(buf, "device", None) is not None and str(buf.device) == "meta":
continue
if torch.is_floating_point(buf) and buf.dtype != dtype:
submodule._buffers[name] = buf.to(dtype)
def _cast_with_progress(module: torch.nn.Module, dtype: torch.dtype) -> None:
"""Cast all resident floating tensors to dtype with a progress bar."""
candidates = []
for sub in module.modules():
for name, p in list(sub._parameters.items()):
if p is None:
continue
if getattr(p, "device", None) is not None and str(p.device) == "meta":
continue
if p.is_floating_point() and p.dtype != dtype:
candidates.append((sub, name, "param"))
for name, b in list(sub._buffers.items()):
if b is None:
continue
if getattr(b, "device", None) is not None and str(b.device) == "meta":
continue
if torch.is_floating_point(b) and b.dtype != dtype:
candidates.append((sub, name, "buf"))
if not candidates:
return
with tqdm(total=len(candidates), desc="Casting to FP16", unit="tensor", leave=False) as pbar:
for sub, name, kind in candidates:
if kind == "param":
p = sub._parameters.get(name)
if p is not None and p.is_floating_point() and p.dtype != dtype:
p.data = p.data.to(dtype)
else:
b = sub._buffers.get(name)
if b is not None and torch.is_floating_point(b) and b.dtype != dtype:
sub._buffers[name] = b.to(dtype)
pbar.update(1)
_cast_with_progress(model, torch.float16)
# Register a lightweight pre-forward hook to convert any lazily materialized/offloaded
# tensors (loaded during forward by accelerate) to FP16 before use.
def _pre_forward_fp16(module, inputs):
_cast_module_floating_to_dtype_(module, torch.float16)
return None
for _m in model.modules():
_m.register_forward_pre_hook(_pre_forward_fp16)
# 4. kill the bfloat16 assert
from transformers.models.gpt_bigcode import modeling_gpt_bigcode
modeling_gpt_bigcode.GPTBigCodeModel._check_hidden_states_dtype = lambda *_, **__: None
# 5. inference to verify functionality
if __name__ == "__main__":
prompt = "Explain quantum supremacy in one paragraph."
inputs = tok(prompt, return_tensors="pt").to(model.device)
class TqdmTokenBar(StoppingCriteria):
def __init__(self, total: int):
self.pbar = tqdm(total=total, desc="Generating", unit="tok", leave=False)
self.last = 0
def __call__(self, input_ids, scores, **kwargs):
cur = input_ids.shape[1]
if cur > self.last:
self.pbar.update(cur - self.last)
self.last = cur
return False
total_new = 80
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
out = model.generate(
**inputs,
max_new_tokens=total_new,
do_sample=True,
temperature=0.7,
pad_token_id=tok.eos_token_id,
stopping_criteria=StoppingCriteriaList([TqdmTokenBar(total_new)])
)
print(tok.decode(out[0], skip_special_tokens=True))
# Save FP16 checkpoint to disk (directory is created if missing)
os.makedirs(OUTPUT_DIR, exist_ok=True)
model.save_pretrained(OUTPUT_DIR, safe_serialization=True)
tok.save_pretrained(OUTPUT_DIR)
print(f"Saved FP16 model and tokenizer to: {OUTPUT_DIR}")