megiddo commited on
Commit
313d602
·
verified ·
1 Parent(s): 38778af

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +36 -62
model.py CHANGED
@@ -1,66 +1,40 @@
1
- from flask import Flask, request, jsonify
2
- from flask_cors import CORS
3
- from PIL import Image
4
  import torch
5
- from torchvision import transforms, models
6
-
7
- # Initialize Flask app
8
- app = Flask(__name__)
9
-
10
- # Enable CORS
11
- CORS(app)
12
-
13
- # Load the trained model
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
-
16
- # Define the model architecture
17
- model = models.resnet152()
18
- model.fc = torch.nn.Linear(model.fc.in_features, 26) # Adjust for the number of classes
19
- model.load_state_dict(torch.load("trained_model.pth", map_location=device))
20
- model = model.to(device)
21
- model.eval()
22
-
23
- # Define preprocessing for the input image
24
- preprocess = transforms.Compose([
25
- transforms.Resize((224, 224)),
26
- transforms.ToTensor(),
27
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
28
- ])
29
-
30
- # Class labels (replace with your dataset's classes)
31
- CLASS_LABELS = [
32
- "bluebell", "buttercup", "colts_foot", "corn_poppy", "cowslip",
33
- "crocus", "daffodil", "daisy", "dandelion", "foxglove",
34
- "fritillary", "geranium", "hibiscus", "iris", "lily_valley",
35
- "pansy", "petunia", "rose", "snowdrop", "sunflower",
36
- "tigerlily", "tulip", "wallflower", "water_lily", "wild_tulip",
37
- "windflower"
38
- ]
39
-
40
- @app.route("/predict", methods=["POST"])
41
- def predict():
42
- if "file" not in request.files:
43
- return jsonify({"error": "No file uploaded"}), 400
44
-
45
- file = request.files["file"]
46
-
47
- try:
48
- # Load and preprocess the image
49
- image = Image.open(file.stream).convert("RGB")
50
- input_tensor = preprocess(image).unsqueeze(0).to(device)
51
-
52
- # Perform inference
53
  with torch.no_grad():
54
- outputs = model(input_tensor)
55
- _, predicted_class = torch.max(outputs, 1)
56
-
57
- predicted_label = CLASS_LABELS[predicted_class.item()]
58
- print(f"Predicted class: {predicted_label}")
59
- return jsonify({"predicted_class": predicted_label})
60
 
61
- except Exception as e:
62
- return jsonify({"error": f"Error during prediction: {str(e)}"}), 500
63
 
64
- # Run the app (Hugging Face Spaces requires `app.run()` here)
65
- if __name__ == "__main__":
66
- app.run(host="0.0.0.0", port=8080)
 
 
 
 
 
 
1
  import torch
2
+ from torchvision import models, transforms
3
+ from PIL import Image
4
+ import json
5
+
6
+ # Load model
7
+ class CustomResNet:
8
+ def __init__(self, model_path, num_classes):
9
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ self.model = models.resnet152(pretrained=False)
11
+ self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
12
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
13
+ self.model.to(self.device)
14
+ self.model.eval()
15
+
16
+ # Preprocessing
17
+ self.preprocess = transforms.Compose([
18
+ transforms.Resize((224, 224)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
21
+ ])
22
+
23
+ def predict(self, image_bytes):
24
+ # Load and preprocess image
25
+ image = Image.open(image_bytes).convert("RGB")
26
+ tensor = self.preprocess(image).unsqueeze(0).to(self.device)
27
+
28
+ # Make prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  with torch.no_grad():
30
+ outputs = self.model(tensor)
31
+ _, predicted = torch.max(outputs, 1)
32
+
33
+ return predicted.item()
 
 
34
 
 
 
35
 
36
+ # API function
37
+ def load_model():
38
+ with open("config.json", "r") as f:
39
+ config = json.load(f)
40
+ return CustomResNet("trained_model.pth", config["num_labels"])