File size: 3,730 Bytes
88e320f
 
 
c410e3a
88e320f
 
 
 
 
 
 
 
 
c410e3a
88e320f
 
4f560b4
 
 
 
 
 
 
 
c410e3a
88e320f
4f1bef3
88e320f
c410e3a
88e320f
4f560b4
88e320f
 
4f560b4
88e320f
 
4f560b4
88e320f
4f560b4
88e320f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c410e3a
88e320f
 
 
c410e3a
88e320f
 
4f1bef3
c410e3a
88e320f
 
 
 
 
 
 
 
 
4f560b4
c410e3a
88e320f
 
 
 
 
 
c410e3a
 
88e320f
c410e3a
 
88e320f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import json
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging

logger = logging.getLogger(__name__)

# Test CUDA device availability and names with:
# python -c "import torch; print('\n'.join([f'{i}: {torch.cuda.get_device_name(i)}' for i in range(torch.cuda.device_count())]))"
# Can specify GPU device with:
# CUDA_VISIBLE_DEVICES="1" python script.py

def model_fn(model_dir, context=None):
    """Load the model for inference"""
    try:
        model_id = os.getenv("HF_MODEL_ID")
        
        # Set specific GPU device if available
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        if device.type == 'cuda':
            torch.cuda.empty_cache()
        logger.info(f"Using device: {device}")
        
        # Load tokenizer and model directly using AutoModelForSequenceClassification
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
        model = AutoModelForSequenceClassification.from_pretrained(
            model_id,
            num_labels=2,
            torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
            trust_remote_code=True
        )
        
        # Move model to device
        model = model.to(device)
        
        # Set memory optimizations
        if device.type == 'cuda':
            torch.backends.cudnn.benchmark = True
            
        # Ensure model is in eval mode
        model.eval()
            
        logger.info(f"Model loaded successfully on {device}")
        
        return {
            "model": model,
            "tokenizer": tokenizer,
            "device": device
        }
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        raise

def predict_fn(data, model_dict):
    """Make a prediction"""
    try:
        logger.info("Starting prediction")
        model = model_dict["model"]
        tokenizer = model_dict["tokenizer"]
        device = model_dict["device"]
        
        # Parse input
        if isinstance(data, str):
            input_text = data
        elif isinstance(data, dict):
            input_text = data.get("inputs", data.get("text", str(data)))
        else:
            input_text = str(data)
            
        # Tokenize input
        inputs = tokenizer(
            input_text,
            add_special_tokens=True,
            max_length=128,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Move inputs to device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Generate prediction
        with torch.no_grad():
            if device.type == 'cuda':
                torch.cuda.empty_cache()
            
            outputs = model(**inputs)
            predictions = torch.softmax(outputs.logits, dim=1)
        
        # Move predictions to CPU and convert to numpy
        predictions = predictions.cpu().numpy()
        
        return predictions
        
    except Exception as e:
        logger.error(f"Error during prediction: {str(e)}")
        raise

def input_fn(request_body, request_content_type):
    """Parse input request"""
    if request_content_type == "application/json":
        try:
            data = json.loads(request_body)
        except:
            data = request_body
        return data
    else:
        return request_body

def output_fn(prediction, response_content_type):
    """Format the output"""
    if response_content_type == "application/json":
        return json.dumps(prediction.tolist())
    else:
        raise ValueError(f"Unsupported content type: {response_content_type}")