velocity-ai's picture
Update code/inference.py
c410e3a verified
raw
history blame
3.73 kB
import os
import json
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging
logger = logging.getLogger(__name__)
# Test CUDA device availability and names with:
# python -c "import torch; print('\n'.join([f'{i}: {torch.cuda.get_device_name(i)}' for i in range(torch.cuda.device_count())]))"
# Can specify GPU device with:
# CUDA_VISIBLE_DEVICES="1" python script.py
def model_fn(model_dir, context=None):
"""Load the model for inference"""
try:
model_id = os.getenv("HF_MODEL_ID")
# Set specific GPU device if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
torch.cuda.empty_cache()
logger.info(f"Using device: {device}")
# Load tokenizer and model directly using AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
num_labels=2,
torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
trust_remote_code=True
)
# Move model to device
model = model.to(device)
# Set memory optimizations
if device.type == 'cuda':
torch.backends.cudnn.benchmark = True
# Ensure model is in eval mode
model.eval()
logger.info(f"Model loaded successfully on {device}")
return {
"model": model,
"tokenizer": tokenizer,
"device": device
}
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
def predict_fn(data, model_dict):
"""Make a prediction"""
try:
logger.info("Starting prediction")
model = model_dict["model"]
tokenizer = model_dict["tokenizer"]
device = model_dict["device"]
# Parse input
if isinstance(data, str):
input_text = data
elif isinstance(data, dict):
input_text = data.get("inputs", data.get("text", str(data)))
else:
input_text = str(data)
# Tokenize input
inputs = tokenizer(
input_text,
add_special_tokens=True,
max_length=128,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# Move inputs to device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate prediction
with torch.no_grad():
if device.type == 'cuda':
torch.cuda.empty_cache()
outputs = model(**inputs)
predictions = torch.softmax(outputs.logits, dim=1)
# Move predictions to CPU and convert to numpy
predictions = predictions.cpu().numpy()
return predictions
except Exception as e:
logger.error(f"Error during prediction: {str(e)}")
raise
def input_fn(request_body, request_content_type):
"""Parse input request"""
if request_content_type == "application/json":
try:
data = json.loads(request_body)
except:
data = request_body
return data
else:
return request_body
def output_fn(prediction, response_content_type):
"""Format the output"""
if response_content_type == "application/json":
return json.dumps(prediction.tolist())
else:
raise ValueError(f"Unsupported content type: {response_content_type}")