velocity-ai commited on
Commit
88e320f
·
verified ·
1 Parent(s): b86d830

Create code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +184 -0
code/inference.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ # Test CUDA device availability and names with:
10
+ # python -c "import torch; print('\n'.join([f'{i}: {torch.cuda.get_device_name(i)}' for i in range(torch.cuda.device_count())]))"
11
+ # Can specify GPU device with:
12
+ # CUDA_VISIBLE_DEVICES="1" python script.py
13
+
14
+ def model_fn(model_dir):
15
+ """Load the model for inference"""
16
+ model_id = os.getenv("HF_MODEL_ID")
17
+
18
+ # Set specific GPU device if available
19
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
20
+ if device.type == 'cuda':
21
+ torch.cuda.set_device(device)
22
+ torch.cuda.empty_cache()
23
+ logger.info(f"Using device: {device}")
24
+
25
+ try:
26
+ # Load tokenizer
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
28
+
29
+ # Load model with specific configuration
30
+ model = AutoModelForSequenceClassification.from_pretrained(
31
+ model_id,
32
+ num_labels=2,
33
+ torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
34
+ trust_remote_code=True,
35
+ device_map=None
36
+ )
37
+
38
+ # Move model to device explicitly
39
+ model = model.to(device)
40
+
41
+ # Force all existing tensors to device and set default tensor type
42
+ if device.type == 'cuda':
43
+ torch.set_default_tensor_type('torch.cuda.FloatTensor')
44
+
45
+ for param in model.parameters():
46
+ param.data = param.data.to(device)
47
+ for buffer in model.buffers():
48
+ buffer.data = buffer.data.to(device)
49
+
50
+ # Ensure model is in eval mode
51
+ model.eval()
52
+
53
+ # Set memory optimizations
54
+ if device.type == 'cuda':
55
+ torch.backends.cudnn.benchmark = True
56
+
57
+ logger.info(f"Model loaded successfully on {device}")
58
+ logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'N/A'}")
59
+ logger.info(f"Default tensor type: {torch.get_default_tensor_type()}")
60
+
61
+ # Verify all model components are on correct device
62
+ def verify_module_devices(module, prefix=''):
63
+ issues = []
64
+ for name, child in module.named_children():
65
+ child_prefix = f"{prefix}.{name}" if prefix else name
66
+ if hasattr(child, 'device'):
67
+ if child.device != device:
68
+ issues.append(f"{child_prefix} on {child.device}")
69
+ for param_name, param in child.named_parameters(recurse=False):
70
+ if param.device != device:
71
+ issues.append(f"{child_prefix}.{param_name} on {param.device}")
72
+ issues.extend(verify_module_devices(child, child_prefix))
73
+ return issues
74
+
75
+ device_issues = verify_module_devices(model)
76
+ if device_issues:
77
+ logger.warning("Found model components on wrong device:")
78
+ for issue in device_issues:
79
+ logger.warning(issue)
80
+
81
+ return {
82
+ "model": model,
83
+ "tokenizer": tokenizer,
84
+ "device": device
85
+ }
86
+ except Exception as e:
87
+ logger.error(f"Error loading model: {str(e)}")
88
+ raise
89
+
90
+ def predict_fn(data, model_dict):
91
+ """Make a prediction"""
92
+ try:
93
+ logger.info("Starting prediction")
94
+ model = model_dict["model"]
95
+ tokenizer = model_dict["tokenizer"]
96
+ device = model_dict["device"]
97
+
98
+ logger.info(f"Model is on device: {device}")
99
+
100
+ # Set default tensor type for any new tensors
101
+ if device.type == 'cuda':
102
+ torch.set_default_tensor_type('torch.cuda.FloatTensor')
103
+
104
+ # Parse input
105
+ if isinstance(data, str):
106
+ input_text = data
107
+ elif isinstance(data, dict):
108
+ input_text = data.get("inputs", data.get("text", str(data)))
109
+ else:
110
+ input_text = str(data)
111
+ logger.debug(f"Parsed input text: {input_text}")
112
+
113
+ # Create tensors directly on target device
114
+ inputs = tokenizer(
115
+ input_text,
116
+ add_special_tokens=True,
117
+ max_length=128,
118
+ padding='max_length',
119
+ truncation=True,
120
+ return_tensors='pt'
121
+ )
122
+
123
+ # Move inputs to CUDA directly
124
+ if device.type == 'cuda':
125
+ inputs = {k: v.cuda() for k, v in inputs.items()}
126
+
127
+ logger.debug(f"Inputs moved to device: {device}")
128
+
129
+ # Log tensor devices and dtypes
130
+ for k, v in inputs.items():
131
+ logger.debug(f"Input '{k}' - Device: {v.device}, Shape: {v.shape}, Dtype: {v.dtype}")
132
+
133
+ # Generate prediction
134
+ logger.info("Generating prediction")
135
+ with torch.no_grad():
136
+ if device.type == 'cuda':
137
+ torch.cuda.empty_cache()
138
+
139
+ try:
140
+ # Run inference
141
+ outputs = model(**inputs)
142
+ predictions = torch.softmax(outputs.logits, dim=1)
143
+
144
+ except RuntimeError as e:
145
+ logger.error("Error during inference:")
146
+ logger.error(f"Model device: {next(model.parameters()).device}")
147
+ logger.error(f"Input devices: {[f'{k}: {v.device}' for k, v in inputs.items()]}")
148
+ raise
149
+
150
+ # Move predictions to CPU for numpy conversion
151
+ predictions = predictions.cpu().numpy()
152
+
153
+ # Reset default tensor type
154
+ torch.set_default_tensor_type('torch.FloatTensor')
155
+
156
+ return predictions
157
+
158
+ except Exception as e:
159
+ logger.error(f"Error during prediction: {str(e)}")
160
+ logger.error(f"Model device: {next(model.parameters()).device}")
161
+ logger.error(f"Input devices: {[f'{k}: {v.device}' for k, v in inputs.items()]}")
162
+ logger.error(f"Default tensor type: {torch.get_default_tensor_type()}")
163
+ raise
164
+
165
+ def input_fn(request_body, request_content_type):
166
+ """Parse input request"""
167
+ if request_content_type == "application/json":
168
+ # Try to parse as JSON
169
+ try:
170
+ data = json.loads(request_body)
171
+ except:
172
+ # If JSON parsing fails, treat as raw text
173
+ data = request_body
174
+ return data
175
+ else:
176
+ # For non-JSON content, treat as raw text
177
+ return request_body
178
+
179
+ def output_fn(prediction, response_content_type):
180
+ """Format the output"""
181
+ if response_content_type == "application/json":
182
+ return json.dumps(prediction.tolist())
183
+ else:
184
+ raise ValueError(f"Unsupported content type: {response_content_type}")