|
|
|
""" |
|
Instant fp16 inference for openai/gpt-oss-120b |
|
No bfloat16 assertion, no 8-bit, no re-quantization. |
|
""" |
|
import os |
|
import torch |
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList |
|
|
|
|
|
MODEL_ID = "openai/gpt-oss-120b" |
|
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./fp16/gpt-oss-120b-fp16") |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) |
|
|
|
|
|
|
|
max_memory = {0: "17GiB", 1: "17GiB", 2: "17GiB", 3: "17GiB", 4: "17GiB", 5: "17GiB", 6: "17GiB", 7: "17GiB", "cpu": "196GiB"} |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch.float16, |
|
device_map="sequential", |
|
low_cpu_mem_usage=True, |
|
max_memory=max_memory, |
|
offload_folder="./offload_cache", |
|
trust_remote_code=True, |
|
) |
|
|
|
|
|
def _cast_module_floating_to_dtype_(module: torch.nn.Module, dtype: torch.dtype) -> None: |
|
for submodule in module.modules(): |
|
|
|
for name, param in list(submodule._parameters.items()): |
|
if param is None: |
|
continue |
|
if getattr(param, "device", None) is not None and str(param.device) == "meta": |
|
continue |
|
if param.is_floating_point() and param.dtype != dtype: |
|
|
|
param.data = param.data.to(dtype) |
|
|
|
for name, buf in list(submodule._buffers.items()): |
|
if buf is None: |
|
continue |
|
if getattr(buf, "device", None) is not None and str(buf.device) == "meta": |
|
continue |
|
if torch.is_floating_point(buf) and buf.dtype != dtype: |
|
submodule._buffers[name] = buf.to(dtype) |
|
|
|
def _cast_with_progress(module: torch.nn.Module, dtype: torch.dtype) -> None: |
|
"""Cast all resident floating tensors to dtype with a progress bar.""" |
|
candidates = [] |
|
for sub in module.modules(): |
|
for name, p in list(sub._parameters.items()): |
|
if p is None: |
|
continue |
|
if getattr(p, "device", None) is not None and str(p.device) == "meta": |
|
continue |
|
if p.is_floating_point() and p.dtype != dtype: |
|
candidates.append((sub, name, "param")) |
|
for name, b in list(sub._buffers.items()): |
|
if b is None: |
|
continue |
|
if getattr(b, "device", None) is not None and str(b.device) == "meta": |
|
continue |
|
if torch.is_floating_point(b) and b.dtype != dtype: |
|
candidates.append((sub, name, "buf")) |
|
if not candidates: |
|
return |
|
with tqdm(total=len(candidates), desc="Casting to FP16", unit="tensor", leave=False) as pbar: |
|
for sub, name, kind in candidates: |
|
if kind == "param": |
|
p = sub._parameters.get(name) |
|
if p is not None and p.is_floating_point() and p.dtype != dtype: |
|
p.data = p.data.to(dtype) |
|
else: |
|
b = sub._buffers.get(name) |
|
if b is not None and torch.is_floating_point(b) and b.dtype != dtype: |
|
sub._buffers[name] = b.to(dtype) |
|
pbar.update(1) |
|
|
|
_cast_with_progress(model, torch.float16) |
|
|
|
|
|
|
|
def _pre_forward_fp16(module, inputs): |
|
_cast_module_floating_to_dtype_(module, torch.float16) |
|
return None |
|
|
|
for _m in model.modules(): |
|
_m.register_forward_pre_hook(_pre_forward_fp16) |
|
|
|
|
|
from transformers.models.gpt_bigcode import modeling_gpt_bigcode |
|
modeling_gpt_bigcode.GPTBigCodeModel._check_hidden_states_dtype = lambda *_, **__: None |
|
|
|
|
|
if __name__ == "__main__": |
|
prompt = "Explain quantum supremacy in one paragraph." |
|
inputs = tok(prompt, return_tensors="pt").to(model.device) |
|
class TqdmTokenBar(StoppingCriteria): |
|
def __init__(self, total: int): |
|
self.pbar = tqdm(total=total, desc="Generating", unit="tok", leave=False) |
|
self.last = 0 |
|
def __call__(self, input_ids, scores, **kwargs): |
|
cur = input_ids.shape[1] |
|
if cur > self.last: |
|
self.pbar.update(cur - self.last) |
|
self.last = cur |
|
return False |
|
|
|
total_new = 80 |
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
out = model.generate( |
|
**inputs, |
|
max_new_tokens=total_new, |
|
do_sample=True, |
|
temperature=0.7, |
|
pad_token_id=tok.eos_token_id, |
|
stopping_criteria=StoppingCriteriaList([TqdmTokenBar(total_new)]) |
|
) |
|
print(tok.decode(out[0], skip_special_tokens=True)) |
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
model.save_pretrained(OUTPUT_DIR, safe_serialization=True) |
|
tok.save_pretrained(OUTPUT_DIR) |
|
print(f"Saved FP16 model and tokenizer to: {OUTPUT_DIR}") |