File size: 5,512 Bytes
c005154 c10949d c005154 c10949d c005154 c10949d c005154 c10949d c005154 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
#!/usr/bin/env python3
"""
Instant fp16 inference for openai/gpt-oss-120b
No bfloat16 assertion, no 8-bit, no re-quantization.
"""
import os
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
# Make sure to set your model output directory and make sure it has 755 permissions.
MODEL_ID = "openai/gpt-oss-120b"
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./fp16/gpt-oss-120b-fp16")
# 1. silence tokenizer warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 2. grab tokenizer
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
# 3. load model in fp16
# Make sure to change these max_memory settings.
max_memory = {0: "17GiB", 1: "17GiB", 2: "17GiB", 3: "17GiB", 4: "17GiB", 5: "17GiB", 6: "17GiB", 7: "17GiB", "cpu": "196GiB"}
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="sequential",
low_cpu_mem_usage=True,
max_memory=max_memory,
offload_folder="./offload_cache",
trust_remote_code=True,
)
# Ensure any lingering BF16 params/buffers are cast to FP16 (handles lazy/offloaded shards)
def _cast_module_floating_to_dtype_(module: torch.nn.Module, dtype: torch.dtype) -> None:
for submodule in module.modules():
# Parameters
for name, param in list(submodule._parameters.items()):
if param is None:
continue
if getattr(param, "device", None) is not None and str(param.device) == "meta":
continue
if param.is_floating_point() and param.dtype != dtype:
# In-place dtype change to preserve Parameter object/hooks
param.data = param.data.to(dtype)
# Buffers
for name, buf in list(submodule._buffers.items()):
if buf is None:
continue
if getattr(buf, "device", None) is not None and str(buf.device) == "meta":
continue
if torch.is_floating_point(buf) and buf.dtype != dtype:
submodule._buffers[name] = buf.to(dtype)
def _cast_with_progress(module: torch.nn.Module, dtype: torch.dtype) -> None:
"""Cast all resident floating tensors to dtype with a progress bar."""
candidates = []
for sub in module.modules():
for name, p in list(sub._parameters.items()):
if p is None:
continue
if getattr(p, "device", None) is not None and str(p.device) == "meta":
continue
if p.is_floating_point() and p.dtype != dtype:
candidates.append((sub, name, "param"))
for name, b in list(sub._buffers.items()):
if b is None:
continue
if getattr(b, "device", None) is not None and str(b.device) == "meta":
continue
if torch.is_floating_point(b) and b.dtype != dtype:
candidates.append((sub, name, "buf"))
if not candidates:
return
with tqdm(total=len(candidates), desc="Casting to FP16", unit="tensor", leave=False) as pbar:
for sub, name, kind in candidates:
if kind == "param":
p = sub._parameters.get(name)
if p is not None and p.is_floating_point() and p.dtype != dtype:
p.data = p.data.to(dtype)
else:
b = sub._buffers.get(name)
if b is not None and torch.is_floating_point(b) and b.dtype != dtype:
sub._buffers[name] = b.to(dtype)
pbar.update(1)
_cast_with_progress(model, torch.float16)
# Register a lightweight pre-forward hook to convert any lazily materialized/offloaded
# tensors (loaded during forward by accelerate) to FP16 before use.
def _pre_forward_fp16(module, inputs):
_cast_module_floating_to_dtype_(module, torch.float16)
return None
for _m in model.modules():
_m.register_forward_pre_hook(_pre_forward_fp16)
# 4. kill the bfloat16 assert
from transformers.models.gpt_bigcode import modeling_gpt_bigcode
modeling_gpt_bigcode.GPTBigCodeModel._check_hidden_states_dtype = lambda *_, **__: None
# 5. inference to verify functionality
if __name__ == "__main__":
prompt = "Explain quantum supremacy in one paragraph."
inputs = tok(prompt, return_tensors="pt").to(model.device)
class TqdmTokenBar(StoppingCriteria):
def __init__(self, total: int):
self.pbar = tqdm(total=total, desc="Generating", unit="tok", leave=False)
self.last = 0
def __call__(self, input_ids, scores, **kwargs):
cur = input_ids.shape[1]
if cur > self.last:
self.pbar.update(cur - self.last)
self.last = cur
return False
total_new = 80
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
out = model.generate(
**inputs,
max_new_tokens=total_new,
do_sample=True,
temperature=0.7,
pad_token_id=tok.eos_token_id,
stopping_criteria=StoppingCriteriaList([TqdmTokenBar(total_new)])
)
print(tok.decode(out[0], skip_special_tokens=True))
# Save FP16 checkpoint to disk (directory is created if missing)
os.makedirs(OUTPUT_DIR, exist_ok=True)
model.save_pretrained(OUTPUT_DIR, safe_serialization=True)
tok.save_pretrained(OUTPUT_DIR)
print(f"Saved FP16 model and tokenizer to: {OUTPUT_DIR}") |