CodeJackR
commited on
Commit
·
69e7d30
1
Parent(s):
b0c81ad
Update to new way of handling model
Browse files- handler.py +156 -44
handler.py
CHANGED
@@ -5,7 +5,7 @@ import base64
|
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
import torch
|
8 |
-
from transformers import SamModel,
|
9 |
from typing import Dict, List, Any
|
10 |
import torch.nn.functional as F
|
11 |
|
@@ -20,20 +20,84 @@ class EndpointHandler():
|
|
20 |
"""
|
21 |
try:
|
22 |
# Load the model and processor from the local path
|
23 |
-
self.model = SamModel.from_pretrained(path).to(device)
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
except Exception as e:
|
26 |
# Fallback to loading from a known SAM model if local loading fails
|
27 |
-
print("Failed to load from local path: {}"
|
28 |
print("Attempting to load from facebook/sam-vit-base")
|
29 |
-
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
|
30 |
-
self.processor =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def __call__(self, data):
|
33 |
"""
|
34 |
Called on every HTTP request.
|
35 |
Handles both base64-encoded images and PIL images.
|
36 |
-
Returns a
|
37 |
"""
|
38 |
# 1. Parse and decode the input image
|
39 |
inputs = data.pop("inputs", None)
|
@@ -51,69 +115,117 @@ class EndpointHandler():
|
|
51 |
else:
|
52 |
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
|
53 |
|
54 |
-
# 2.
|
55 |
-
|
56 |
-
|
57 |
-
#
|
|
|
58 |
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
-
#
|
62 |
-
|
63 |
-
|
64 |
|
65 |
-
# 4. Process and select the best mask
|
66 |
try:
|
|
|
|
|
|
|
67 |
# Get predicted masks and scores
|
68 |
-
predicted_masks = outputs.pred_masks.cpu()
|
69 |
-
iou_scores = outputs.iou_scores.cpu()[
|
70 |
|
71 |
-
#
|
72 |
-
|
73 |
-
predicted_masks = predicted_masks.squeeze(1)
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
#
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
82 |
except Exception as e:
|
83 |
-
print("Error processing masks: {}"
|
84 |
-
# Fallback: create a simple mask
|
85 |
-
height, width = img.size[1], img.size[0]
|
86 |
mask_binary = np.zeros((height, width), dtype=np.uint8)
|
87 |
center_x, center_y = width // 2, height // 2
|
88 |
size = min(width, height) // 8
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
#
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
def main():
|
96 |
# This main function shows how a client would call the endpoint locally.
|
97 |
input_path = "/Users/rp7/Downloads/test.jpeg"
|
98 |
-
|
|
|
|
|
|
|
|
|
99 |
|
100 |
# 1. Prepare the payload with a base64-encoded image string
|
101 |
with open(input_path, "rb") as f:
|
102 |
img_bytes = f.read()
|
103 |
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
104 |
-
payload = {"inputs": "data:image/jpeg;base64,{}"
|
105 |
|
106 |
-
# 2. Instantiate handler and get the
|
107 |
handler = EndpointHandler(path=".")
|
108 |
-
|
109 |
|
110 |
-
# 3.
|
111 |
-
if
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
115 |
else:
|
116 |
-
print("Failed to get
|
117 |
|
118 |
if __name__ == "__main__":
|
119 |
main()
|
|
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
import torch
|
8 |
+
from transformers import SamModel, SamProcessor
|
9 |
from typing import Dict, List, Any
|
10 |
import torch.nn.functional as F
|
11 |
|
|
|
20 |
"""
|
21 |
try:
|
22 |
# Load the model and processor from the local path
|
23 |
+
self.model = SamModel.from_pretrained(path).to(device).eval()
|
24 |
+
# Load processor with do_resize=False to avoid resizing
|
25 |
+
self.processor = SamProcessor.from_pretrained(path)
|
26 |
+
# Override the processor's image processor to disable resizing
|
27 |
+
self.processor.image_processor.do_resize = False
|
28 |
+
self.processor.image_processor.do_rescale = True
|
29 |
+
self.processor.image_processor.do_normalize = True
|
30 |
+
|
31 |
except Exception as e:
|
32 |
# Fallback to loading from a known SAM model if local loading fails
|
33 |
+
print(f"Failed to load from local path: {e}")
|
34 |
print("Attempting to load from facebook/sam-vit-base")
|
35 |
+
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device).eval()
|
36 |
+
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
37 |
+
# Override the processor's image processor to disable resizing
|
38 |
+
self.processor.image_processor.do_resize = False
|
39 |
+
self.processor.image_processor.do_rescale = True
|
40 |
+
self.processor.image_processor.do_normalize = True
|
41 |
+
|
42 |
+
def generate_grid_points(self, width, height, points_per_side=32):
|
43 |
+
"""Generate a grid of points across the image for comprehensive segmentation."""
|
44 |
+
points = []
|
45 |
+
labels = []
|
46 |
+
|
47 |
+
# Create a grid of points
|
48 |
+
x_coords = np.linspace(0, width - 1, points_per_side, dtype=int)
|
49 |
+
y_coords = np.linspace(0, height - 1, points_per_side, dtype=int)
|
50 |
+
|
51 |
+
for y in y_coords:
|
52 |
+
for x in x_coords:
|
53 |
+
points.append([x, y])
|
54 |
+
labels.append(1) # foreground point
|
55 |
+
|
56 |
+
return [points], [labels]
|
57 |
+
|
58 |
+
def filter_masks(self, masks, iou_scores, score_threshold=0.88, stability_score_threshold=0.95):
|
59 |
+
"""Filter masks based on quality scores and remove duplicates."""
|
60 |
+
filtered_masks = []
|
61 |
+
filtered_scores = []
|
62 |
+
|
63 |
+
for i, (mask, score) in enumerate(zip(masks, iou_scores)):
|
64 |
+
if score > score_threshold:
|
65 |
+
# Calculate stability score (measure of mask quality)
|
66 |
+
mask_binary = mask > 0.0
|
67 |
+
stability_score = self.calculate_stability_score(mask_binary)
|
68 |
+
|
69 |
+
if stability_score > stability_score_threshold:
|
70 |
+
filtered_masks.append(mask)
|
71 |
+
filtered_scores.append(score.item())
|
72 |
+
|
73 |
+
return filtered_masks, filtered_scores
|
74 |
+
|
75 |
+
def calculate_stability_score(self, mask):
|
76 |
+
"""Calculate stability score for a mask."""
|
77 |
+
# Simple stability score based on mask coherence
|
78 |
+
mask_float = mask.float()
|
79 |
+
# Calculate the ratio of the mask area to its bounding box area
|
80 |
+
mask_area = torch.sum(mask_float)
|
81 |
+
if mask_area == 0:
|
82 |
+
return 0.0
|
83 |
+
|
84 |
+
# Find bounding box
|
85 |
+
coords = torch.nonzero(mask_float)
|
86 |
+
if len(coords) == 0:
|
87 |
+
return 0.0
|
88 |
+
|
89 |
+
min_y, min_x = torch.min(coords, dim=0)[0]
|
90 |
+
max_y, max_x = torch.max(coords, dim=0)[0]
|
91 |
+
bbox_area = (max_y - min_y + 1) * (max_x - min_x + 1)
|
92 |
+
|
93 |
+
stability = mask_area / bbox_area if bbox_area > 0 else 0.0
|
94 |
+
return stability.item()
|
95 |
|
96 |
def __call__(self, data):
|
97 |
"""
|
98 |
Called on every HTTP request.
|
99 |
Handles both base64-encoded images and PIL images.
|
100 |
+
Returns a list of segment masks.
|
101 |
"""
|
102 |
# 1. Parse and decode the input image
|
103 |
inputs = data.pop("inputs", None)
|
|
|
115 |
else:
|
116 |
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
|
117 |
|
118 |
+
# 2. Get image dimensions
|
119 |
+
width, height = img.size
|
120 |
+
|
121 |
+
# 3. Generate grid points for comprehensive segmentation
|
122 |
+
input_points, input_labels = self.generate_grid_points(width, height, points_per_side=16)
|
123 |
|
124 |
+
# 4. Process the image and points
|
125 |
+
inputs = self.processor(
|
126 |
+
img,
|
127 |
+
input_points=input_points,
|
128 |
+
input_labels=input_labels,
|
129 |
+
return_tensors="pt"
|
130 |
+
).to(device)
|
131 |
|
132 |
+
# 5. Generate masks
|
133 |
+
all_masks = []
|
134 |
+
all_scores = []
|
135 |
|
|
|
136 |
try:
|
137 |
+
with torch.no_grad():
|
138 |
+
outputs = self.model(**inputs)
|
139 |
+
|
140 |
# Get predicted masks and scores
|
141 |
+
predicted_masks = outputs.pred_masks.cpu() # Shape: [batch, num_queries, num_masks_per_query, H, W]
|
142 |
+
iou_scores = outputs.iou_scores.cpu() # Shape: [batch, num_queries, num_masks_per_query]
|
143 |
|
144 |
+
# Process masks from all queries
|
145 |
+
batch_size, num_queries, num_masks_per_query = predicted_masks.shape[:3]
|
|
|
146 |
|
147 |
+
for query_idx in range(num_queries):
|
148 |
+
query_masks = predicted_masks[0, query_idx] # [num_masks_per_query, H, W]
|
149 |
+
query_scores = iou_scores[0, query_idx] # [num_masks_per_query]
|
150 |
+
|
151 |
+
# Select best mask for this query
|
152 |
+
best_mask_idx = torch.argmax(query_scores)
|
153 |
+
if query_scores[best_mask_idx] > 0.5: # Only keep high-quality masks
|
154 |
+
best_mask = query_masks[best_mask_idx]
|
155 |
+
all_masks.append(best_mask)
|
156 |
+
all_scores.append(query_scores[best_mask_idx])
|
157 |
|
158 |
+
# Filter and deduplicate masks
|
159 |
+
if all_masks:
|
160 |
+
filtered_masks, filtered_scores = self.filter_masks(all_masks, all_scores)
|
161 |
+
else:
|
162 |
+
filtered_masks, filtered_scores = [], []
|
163 |
+
|
164 |
except Exception as e:
|
165 |
+
print(f"Error processing masks: {e}")
|
166 |
+
# Fallback: create a simple center mask
|
|
|
167 |
mask_binary = np.zeros((height, width), dtype=np.uint8)
|
168 |
center_x, center_y = width // 2, height // 2
|
169 |
size = min(width, height) // 8
|
170 |
+
y_start, y_end = max(0, center_y-size), min(height, center_y+size)
|
171 |
+
x_start, x_end = max(0, center_x-size), min(width, center_x+size)
|
172 |
+
mask_binary[y_start:y_end, x_start:x_end] = 255
|
173 |
+
|
174 |
+
output_img = Image.fromarray(mask_binary)
|
175 |
+
return [{'score': 0.5, 'label': 'fallback_segment', 'mask': output_img}]
|
176 |
|
177 |
+
# 6. Convert masks to PIL Images and prepare results
|
178 |
+
results = []
|
179 |
+
for i, (mask, score) in enumerate(zip(filtered_masks, filtered_scores)):
|
180 |
+
# Convert to binary mask
|
181 |
+
mask_binary = (mask > 0.0).numpy().astype(np.uint8) * 255
|
182 |
+
|
183 |
+
# Create PIL Image
|
184 |
+
output_img = Image.fromarray(mask_binary)
|
185 |
+
|
186 |
+
results.append({
|
187 |
+
'score': float(score),
|
188 |
+
'label': f'segment_{i}',
|
189 |
+
'mask': output_img
|
190 |
+
})
|
191 |
+
|
192 |
+
# If no segments found, return a fallback
|
193 |
+
if not results:
|
194 |
+
mask_binary = np.zeros((height, width), dtype=np.uint8)
|
195 |
+
output_img = Image.fromarray(mask_binary)
|
196 |
+
results.append({'score': 0.0, 'label': 'no_segments', 'mask': output_img})
|
197 |
+
|
198 |
+
return results
|
199 |
|
200 |
def main():
|
201 |
# This main function shows how a client would call the endpoint locally.
|
202 |
input_path = "/Users/rp7/Downloads/test.jpeg"
|
203 |
+
output_dir = "output_masks"
|
204 |
+
|
205 |
+
# Create output directory
|
206 |
+
import os
|
207 |
+
os.makedirs(output_dir, exist_ok=True)
|
208 |
|
209 |
# 1. Prepare the payload with a base64-encoded image string
|
210 |
with open(input_path, "rb") as f:
|
211 |
img_bytes = f.read()
|
212 |
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
213 |
+
payload = {"inputs": f"data:image/jpeg;base64,{img_b64}"}
|
214 |
|
215 |
+
# 2. Instantiate handler and get the result
|
216 |
handler = EndpointHandler(path=".")
|
217 |
+
results = handler(payload)
|
218 |
|
219 |
+
# 3. Save all masks
|
220 |
+
if results and isinstance(results, list):
|
221 |
+
print(f"Found {len(results)} segments")
|
222 |
+
for i, result in enumerate(results):
|
223 |
+
if 'mask' in result:
|
224 |
+
output_path = os.path.join(output_dir, f"segment_{i}_score_{result['score']:.3f}.png")
|
225 |
+
result['mask'].save(output_path)
|
226 |
+
print(f"Saved {result['label']} (score: {result['score']:.3f}) to {output_path}")
|
227 |
else:
|
228 |
+
print("Failed to get valid masks from the handler.")
|
229 |
|
230 |
if __name__ == "__main__":
|
231 |
main()
|