import torch from torchvision import models, transforms from PIL import Image import json # Load model class CustomResNet: def __init__(self, model_path, num_classes): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = models.resnet152(pretrained=False) self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes) self.model.load_state_dict(torch.load(model_path, map_location=self.device)) self.model.to(self.device) self.model.eval() # Preprocessing self.preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def predict(self, image_bytes): # Load and preprocess image image = Image.open(image_bytes).convert("RGB") tensor = self.preprocess(image).unsqueeze(0).to(self.device) # Make prediction with torch.no_grad(): outputs = self.model(tensor) _, predicted = torch.max(outputs, 1) return predicted.item() # API function def load_model(): with open("config.json", "r") as f: config = json.load(f) return CustomResNet("trained_model.pth", config["num_labels"])