Update code/inference.py
Browse files- code/inference.py +35 -62
code/inference.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
import torch
|
4 |
-
|
|
|
5 |
import logging
|
6 |
|
7 |
logger = logging.getLogger(__name__)
|
@@ -11,72 +12,55 @@ logger = logging.getLogger(__name__)
|
|
11 |
# Can specify GPU device with:
|
12 |
# CUDA_VISIBLE_DEVICES="1" python script.py
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
"""Load the model for inference"""
|
16 |
-
model_id = os.getenv("HF_MODEL_ID")
|
17 |
-
|
18 |
-
# Set specific GPU device if available
|
19 |
-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
20 |
-
if device.type == 'cuda':
|
21 |
-
torch.cuda.set_device(device)
|
22 |
-
torch.cuda.empty_cache()
|
23 |
-
logger.info(f"Using device: {device}")
|
24 |
-
|
25 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# Load tokenizer
|
27 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
28 |
|
29 |
-
# Load model
|
30 |
-
|
31 |
model_id,
|
32 |
-
num_labels=2,
|
33 |
torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
|
34 |
-
trust_remote_code=True
|
35 |
-
device_map=None
|
36 |
)
|
37 |
|
38 |
-
#
|
|
|
|
|
|
|
39 |
model = model.to(device)
|
40 |
|
41 |
-
#
|
42 |
if device.type == 'cuda':
|
43 |
-
torch.
|
44 |
-
|
45 |
-
for param in model.parameters():
|
46 |
-
param.data = param.data.to(device)
|
47 |
-
for buffer in model.buffers():
|
48 |
-
buffer.data = buffer.data.to(device)
|
49 |
|
50 |
# Ensure model is in eval mode
|
51 |
model.eval()
|
52 |
-
|
53 |
-
# Set memory optimizations
|
54 |
-
if device.type == 'cuda':
|
55 |
-
torch.backends.cudnn.benchmark = True
|
56 |
|
57 |
logger.info(f"Model loaded successfully on {device}")
|
58 |
-
logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'N/A'}")
|
59 |
-
logger.info(f"Default tensor type: {torch.get_default_tensor_type()}")
|
60 |
-
|
61 |
-
# Verify all model components are on correct device
|
62 |
-
def verify_module_devices(module, prefix=''):
|
63 |
-
issues = []
|
64 |
-
for name, child in module.named_children():
|
65 |
-
child_prefix = f"{prefix}.{name}" if prefix else name
|
66 |
-
if hasattr(child, 'device'):
|
67 |
-
if child.device != device:
|
68 |
-
issues.append(f"{child_prefix} on {child.device}")
|
69 |
-
for param_name, param in child.named_parameters(recurse=False):
|
70 |
-
if param.device != device:
|
71 |
-
issues.append(f"{child_prefix}.{param_name} on {param.device}")
|
72 |
-
issues.extend(verify_module_devices(child, child_prefix))
|
73 |
-
return issues
|
74 |
-
|
75 |
-
device_issues = verify_module_devices(model)
|
76 |
-
if device_issues:
|
77 |
-
logger.warning("Found model components on wrong device:")
|
78 |
-
for issue in device_issues:
|
79 |
-
logger.warning(issue)
|
80 |
|
81 |
return {
|
82 |
"model": model,
|
@@ -97,10 +81,6 @@ def predict_fn(data, model_dict):
|
|
97 |
|
98 |
logger.info(f"Model is on device: {device}")
|
99 |
|
100 |
-
# Set default tensor type for any new tensors
|
101 |
-
if device.type == 'cuda':
|
102 |
-
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
103 |
-
|
104 |
# Parse input
|
105 |
if isinstance(data, str):
|
106 |
input_text = data
|
@@ -120,7 +100,7 @@ def predict_fn(data, model_dict):
|
|
120 |
return_tensors='pt'
|
121 |
)
|
122 |
|
123 |
-
# Move inputs to
|
124 |
if device.type == 'cuda':
|
125 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
126 |
|
@@ -150,30 +130,23 @@ def predict_fn(data, model_dict):
|
|
150 |
# Move predictions to CPU for numpy conversion
|
151 |
predictions = predictions.cpu().numpy()
|
152 |
|
153 |
-
# Reset default tensor type
|
154 |
-
torch.set_default_tensor_type('torch.FloatTensor')
|
155 |
-
|
156 |
return predictions
|
157 |
|
158 |
except Exception as e:
|
159 |
logger.error(f"Error during prediction: {str(e)}")
|
160 |
logger.error(f"Model device: {next(model.parameters()).device}")
|
161 |
logger.error(f"Input devices: {[f'{k}: {v.device}' for k, v in inputs.items()]}")
|
162 |
-
logger.error(f"Default tensor type: {torch.get_default_tensor_type()}")
|
163 |
raise
|
164 |
|
165 |
def input_fn(request_body, request_content_type):
|
166 |
"""Parse input request"""
|
167 |
if request_content_type == "application/json":
|
168 |
-
# Try to parse as JSON
|
169 |
try:
|
170 |
data = json.loads(request_body)
|
171 |
except:
|
172 |
-
# If JSON parsing fails, treat as raw text
|
173 |
data = request_body
|
174 |
return data
|
175 |
else:
|
176 |
-
# For non-JSON content, treat as raw text
|
177 |
return request_body
|
178 |
|
179 |
def output_fn(prediction, response_content_type):
|
|
|
1 |
import os
|
2 |
import json
|
3 |
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
import logging
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
|
|
12 |
# Can specify GPU device with:
|
13 |
# CUDA_VISIBLE_DEVICES="1" python script.py
|
14 |
|
15 |
+
class PhiForSequenceClassification(nn.Module):
|
16 |
+
def __init__(self, base_model, num_labels=2):
|
17 |
+
super().__init__()
|
18 |
+
self.phi = base_model
|
19 |
+
self.classifier = nn.Linear(self.phi.config.hidden_size, num_labels)
|
20 |
+
|
21 |
+
def forward(self, **inputs):
|
22 |
+
outputs = self.phi(**inputs, output_hidden_states=True)
|
23 |
+
# Use the last hidden state of the last token for classification
|
24 |
+
last_hidden_state = outputs.hidden_states[-1][:, -1, :]
|
25 |
+
logits = self.classifier(last_hidden_state)
|
26 |
+
return type('Outputs', (), {'logits': logits})()
|
27 |
+
|
28 |
+
def model_fn(model_dir, context=None):
|
29 |
"""Load the model for inference"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
try:
|
31 |
+
model_id = os.getenv("HF_MODEL_ID")
|
32 |
+
|
33 |
+
# Set specific GPU device if available
|
34 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
35 |
+
if device.type == 'cuda':
|
36 |
+
torch.cuda.set_device(device)
|
37 |
+
torch.cuda.empty_cache()
|
38 |
+
logger.info(f"Using device: {device}")
|
39 |
+
|
40 |
# Load tokenizer
|
41 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
42 |
|
43 |
+
# Load base model
|
44 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
45 |
model_id,
|
|
|
46 |
torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
|
47 |
+
trust_remote_code=True
|
|
|
48 |
)
|
49 |
|
50 |
+
# Create classification model
|
51 |
+
model = PhiForSequenceClassification(base_model, num_labels=2)
|
52 |
+
|
53 |
+
# Move model to device
|
54 |
model = model.to(device)
|
55 |
|
56 |
+
# Set memory optimizations
|
57 |
if device.type == 'cuda':
|
58 |
+
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
# Ensure model is in eval mode
|
61 |
model.eval()
|
|
|
|
|
|
|
|
|
62 |
|
63 |
logger.info(f"Model loaded successfully on {device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
return {
|
66 |
"model": model,
|
|
|
81 |
|
82 |
logger.info(f"Model is on device: {device}")
|
83 |
|
|
|
|
|
|
|
|
|
84 |
# Parse input
|
85 |
if isinstance(data, str):
|
86 |
input_text = data
|
|
|
100 |
return_tensors='pt'
|
101 |
)
|
102 |
|
103 |
+
# Move inputs to device
|
104 |
if device.type == 'cuda':
|
105 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
106 |
|
|
|
130 |
# Move predictions to CPU for numpy conversion
|
131 |
predictions = predictions.cpu().numpy()
|
132 |
|
|
|
|
|
|
|
133 |
return predictions
|
134 |
|
135 |
except Exception as e:
|
136 |
logger.error(f"Error during prediction: {str(e)}")
|
137 |
logger.error(f"Model device: {next(model.parameters()).device}")
|
138 |
logger.error(f"Input devices: {[f'{k}: {v.device}' for k, v in inputs.items()]}")
|
|
|
139 |
raise
|
140 |
|
141 |
def input_fn(request_body, request_content_type):
|
142 |
"""Parse input request"""
|
143 |
if request_content_type == "application/json":
|
|
|
144 |
try:
|
145 |
data = json.loads(request_body)
|
146 |
except:
|
|
|
147 |
data = request_body
|
148 |
return data
|
149 |
else:
|
|
|
150 |
return request_body
|
151 |
|
152 |
def output_fn(prediction, response_content_type):
|