File size: 5,512 Bytes
c005154
 
 
 
 
 
 
 
 
 
 
c10949d
c005154
 
 
 
 
 
 
 
 
 
c10949d
 
c005154
 
 
c10949d
c005154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c10949d
c005154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#!/usr/bin/env python3
"""
Instant fp16 inference for openai/gpt-oss-120b
No bfloat16 assertion, no 8-bit, no re-quantization.
"""
import os
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList

# Make sure to set your model output directory and make sure it has 755 permissions.
MODEL_ID = "openai/gpt-oss-120b"
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./fp16/gpt-oss-120b-fp16")

# 1. silence tokenizer warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# 2. grab tokenizer
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

# 3. load model in fp16
# Make sure to change these max_memory settings.
max_memory = {0: "17GiB", 1: "17GiB", 2: "17GiB", 3: "17GiB", 4: "17GiB", 5: "17GiB", 6: "17GiB", 7: "17GiB", "cpu": "196GiB"}
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="sequential",
    low_cpu_mem_usage=True,
    max_memory=max_memory,
    offload_folder="./offload_cache",
    trust_remote_code=True,
)

# Ensure any lingering BF16 params/buffers are cast to FP16 (handles lazy/offloaded shards)
def _cast_module_floating_to_dtype_(module: torch.nn.Module, dtype: torch.dtype) -> None:
    for submodule in module.modules():
        # Parameters
        for name, param in list(submodule._parameters.items()):
            if param is None:
                continue
            if getattr(param, "device", None) is not None and str(param.device) == "meta":
                continue
            if param.is_floating_point() and param.dtype != dtype:
                # In-place dtype change to preserve Parameter object/hooks
                param.data = param.data.to(dtype)
        # Buffers
        for name, buf in list(submodule._buffers.items()):
            if buf is None:
                continue
            if getattr(buf, "device", None) is not None and str(buf.device) == "meta":
                continue
            if torch.is_floating_point(buf) and buf.dtype != dtype:
                submodule._buffers[name] = buf.to(dtype)

def _cast_with_progress(module: torch.nn.Module, dtype: torch.dtype) -> None:
    """Cast all resident floating tensors to dtype with a progress bar."""
    candidates = []
    for sub in module.modules():
        for name, p in list(sub._parameters.items()):
            if p is None:
                continue
            if getattr(p, "device", None) is not None and str(p.device) == "meta":
                continue
            if p.is_floating_point() and p.dtype != dtype:
                candidates.append((sub, name, "param"))
        for name, b in list(sub._buffers.items()):
            if b is None:
                continue
            if getattr(b, "device", None) is not None and str(b.device) == "meta":
                continue
            if torch.is_floating_point(b) and b.dtype != dtype:
                candidates.append((sub, name, "buf"))
    if not candidates:
        return
    with tqdm(total=len(candidates), desc="Casting to FP16", unit="tensor", leave=False) as pbar:
        for sub, name, kind in candidates:
            if kind == "param":
                p = sub._parameters.get(name)
                if p is not None and p.is_floating_point() and p.dtype != dtype:
                    p.data = p.data.to(dtype)
            else:
                b = sub._buffers.get(name)
                if b is not None and torch.is_floating_point(b) and b.dtype != dtype:
                    sub._buffers[name] = b.to(dtype)
            pbar.update(1)

_cast_with_progress(model, torch.float16)

# Register a lightweight pre-forward hook to convert any lazily materialized/offloaded
# tensors (loaded during forward by accelerate) to FP16 before use.
def _pre_forward_fp16(module, inputs):
    _cast_module_floating_to_dtype_(module, torch.float16)
    return None

for _m in model.modules():
    _m.register_forward_pre_hook(_pre_forward_fp16)

# 4. kill the bfloat16 assert
from transformers.models.gpt_bigcode import modeling_gpt_bigcode
modeling_gpt_bigcode.GPTBigCodeModel._check_hidden_states_dtype = lambda *_, **__: None

# 5. inference to verify functionality
if __name__ == "__main__":
    prompt = "Explain quantum supremacy in one paragraph."
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    class TqdmTokenBar(StoppingCriteria):
        def __init__(self, total: int):
            self.pbar = tqdm(total=total, desc="Generating", unit="tok", leave=False)
            self.last = 0
        def __call__(self, input_ids, scores, **kwargs):
            cur = input_ids.shape[1]
            if cur > self.last:
                self.pbar.update(cur - self.last)
                self.last = cur
            return False

    total_new = 80
    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        out = model.generate(
            **inputs,
            max_new_tokens=total_new,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tok.eos_token_id,
            stopping_criteria=StoppingCriteriaList([TqdmTokenBar(total_new)])
        )
    print(tok.decode(out[0], skip_special_tokens=True))

    # Save FP16 checkpoint to disk (directory is created if missing)
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    model.save_pretrained(OUTPUT_DIR, safe_serialization=True)
    tok.save_pretrained(OUTPUT_DIR)
    print(f"Saved FP16 model and tokenizer to: {OUTPUT_DIR}")