#!/usr/bin/env python3 """ Instant fp16 inference for openai/gpt-oss-120b No bfloat16 assertion, no 8-bit, no re-quantization. """ import os import torch from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList # Make sure to set your model output directory and make sure it has 755 permissions. MODEL_ID = "openai/gpt-oss-120b" OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./fp16/gpt-oss-120b-fp16") # 1. silence tokenizer warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" # 2. grab tokenizer tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) # 3. load model in fp16 # Make sure to change these max_memory settings. max_memory = {0: "17GiB", 1: "17GiB", 2: "17GiB", 3: "17GiB", 4: "17GiB", 5: "17GiB", 6: "17GiB", 7: "17GiB", "cpu": "196GiB"} model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16, device_map="sequential", low_cpu_mem_usage=True, max_memory=max_memory, offload_folder="./offload_cache", trust_remote_code=True, ) # Ensure any lingering BF16 params/buffers are cast to FP16 (handles lazy/offloaded shards) def _cast_module_floating_to_dtype_(module: torch.nn.Module, dtype: torch.dtype) -> None: for submodule in module.modules(): # Parameters for name, param in list(submodule._parameters.items()): if param is None: continue if getattr(param, "device", None) is not None and str(param.device) == "meta": continue if param.is_floating_point() and param.dtype != dtype: # In-place dtype change to preserve Parameter object/hooks param.data = param.data.to(dtype) # Buffers for name, buf in list(submodule._buffers.items()): if buf is None: continue if getattr(buf, "device", None) is not None and str(buf.device) == "meta": continue if torch.is_floating_point(buf) and buf.dtype != dtype: submodule._buffers[name] = buf.to(dtype) def _cast_with_progress(module: torch.nn.Module, dtype: torch.dtype) -> None: """Cast all resident floating tensors to dtype with a progress bar.""" candidates = [] for sub in module.modules(): for name, p in list(sub._parameters.items()): if p is None: continue if getattr(p, "device", None) is not None and str(p.device) == "meta": continue if p.is_floating_point() and p.dtype != dtype: candidates.append((sub, name, "param")) for name, b in list(sub._buffers.items()): if b is None: continue if getattr(b, "device", None) is not None and str(b.device) == "meta": continue if torch.is_floating_point(b) and b.dtype != dtype: candidates.append((sub, name, "buf")) if not candidates: return with tqdm(total=len(candidates), desc="Casting to FP16", unit="tensor", leave=False) as pbar: for sub, name, kind in candidates: if kind == "param": p = sub._parameters.get(name) if p is not None and p.is_floating_point() and p.dtype != dtype: p.data = p.data.to(dtype) else: b = sub._buffers.get(name) if b is not None and torch.is_floating_point(b) and b.dtype != dtype: sub._buffers[name] = b.to(dtype) pbar.update(1) _cast_with_progress(model, torch.float16) # Register a lightweight pre-forward hook to convert any lazily materialized/offloaded # tensors (loaded during forward by accelerate) to FP16 before use. def _pre_forward_fp16(module, inputs): _cast_module_floating_to_dtype_(module, torch.float16) return None for _m in model.modules(): _m.register_forward_pre_hook(_pre_forward_fp16) # 4. kill the bfloat16 assert from transformers.models.gpt_bigcode import modeling_gpt_bigcode modeling_gpt_bigcode.GPTBigCodeModel._check_hidden_states_dtype = lambda *_, **__: None # 5. inference to verify functionality if __name__ == "__main__": prompt = "Explain quantum supremacy in one paragraph." inputs = tok(prompt, return_tensors="pt").to(model.device) class TqdmTokenBar(StoppingCriteria): def __init__(self, total: int): self.pbar = tqdm(total=total, desc="Generating", unit="tok", leave=False) self.last = 0 def __call__(self, input_ids, scores, **kwargs): cur = input_ids.shape[1] if cur > self.last: self.pbar.update(cur - self.last) self.last = cur return False total_new = 80 with torch.cuda.amp.autocast(dtype=torch.bfloat16): out = model.generate( **inputs, max_new_tokens=total_new, do_sample=True, temperature=0.7, pad_token_id=tok.eos_token_id, stopping_criteria=StoppingCriteriaList([TqdmTokenBar(total_new)]) ) print(tok.decode(out[0], skip_special_tokens=True)) # Save FP16 checkpoint to disk (directory is created if missing) os.makedirs(OUTPUT_DIR, exist_ok=True) model.save_pretrained(OUTPUT_DIR, safe_serialization=True) tok.save_pretrained(OUTPUT_DIR) print(f"Saved FP16 model and tokenizer to: {OUTPUT_DIR}")