CodeJackR commited on
Commit
69e7d30
·
1 Parent(s): b0c81ad

Update to new way of handling model

Browse files
Files changed (1) hide show
  1. handler.py +156 -44
handler.py CHANGED
@@ -5,7 +5,7 @@ import base64
5
  import numpy as np
6
  from PIL import Image
7
  import torch
8
- from transformers import SamModel, SamImageProcessor
9
  from typing import Dict, List, Any
10
  import torch.nn.functional as F
11
 
@@ -20,20 +20,84 @@ class EndpointHandler():
20
  """
21
  try:
22
  # Load the model and processor from the local path
23
- self.model = SamModel.from_pretrained(path).to(device)
24
- self.processor = SamImageProcessor.from_pretrained(path, do_resize=False)
 
 
 
 
 
 
25
  except Exception as e:
26
  # Fallback to loading from a known SAM model if local loading fails
27
- print("Failed to load from local path: {}".format(e))
28
  print("Attempting to load from facebook/sam-vit-base")
29
- self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
30
- self.processor = SamImageProcessor.from_pretrained("facebook/sam-vit-base", do_resize=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def __call__(self, data):
33
  """
34
  Called on every HTTP request.
35
  Handles both base64-encoded images and PIL images.
36
- Returns a PIL Image object.
37
  """
38
  # 1. Parse and decode the input image
39
  inputs = data.pop("inputs", None)
@@ -51,69 +115,117 @@ class EndpointHandler():
51
  else:
52
  raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
53
 
54
- # 2. Prepare prompts and process the image
55
- # height, width = img.size[1], img.size[0]
56
- # input_points = [[[width // 2, height // 2]]]
57
- # input_labels = [[1]]
 
58
 
59
- inputs = self.processor(img, return_tensors="pt").to(device)
 
 
 
 
 
 
60
 
61
- # 3. Generate masks
62
- with torch.no_grad():
63
- outputs = self.model(**inputs)
64
 
65
- # 4. Process and select the best mask
66
  try:
 
 
 
67
  # Get predicted masks and scores
68
- predicted_masks = outputs.pred_masks.cpu()
69
- iou_scores = outputs.iou_scores.cpu()[0]
70
 
71
- # Handle different tensor dimensions
72
- if predicted_masks.ndim == 5:
73
- predicted_masks = predicted_masks.squeeze(1)
74
 
75
- # Select the best mask
76
- best_mask_idx = torch.argmax(iou_scores)
77
- best_mask = predicted_masks[0, best_mask_idx, :, :]
 
 
 
 
 
 
 
78
 
79
- # Convert to binary mask (no resizing needed since processor doesn't resize)
80
- mask_binary = (best_mask > 0.0).numpy().astype(np.uint8) * 255
81
-
 
 
 
82
  except Exception as e:
83
- print("Error processing masks: {}".format(e))
84
- # Fallback: create a simple mask
85
- height, width = img.size[1], img.size[0]
86
  mask_binary = np.zeros((height, width), dtype=np.uint8)
87
  center_x, center_y = width // 2, height // 2
88
  size = min(width, height) // 8
89
- mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
 
 
 
 
 
90
 
91
- # 5. Create and return the output PIL Image
92
- output_img = Image.fromarray(mask_binary)
93
- return [{'score': None, 'label': 'everything', 'mask': output_img}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  def main():
96
  # This main function shows how a client would call the endpoint locally.
97
  input_path = "/Users/rp7/Downloads/test.jpeg"
98
- output_path = "output.png"
 
 
 
 
99
 
100
  # 1. Prepare the payload with a base64-encoded image string
101
  with open(input_path, "rb") as f:
102
  img_bytes = f.read()
103
  img_b64 = base64.b64encode(img_bytes).decode("utf-8")
104
- payload = {"inputs": "data:image/jpeg;base64,{}".format(img_b64)}
105
 
106
- # 2. Instantiate handler and get the PIL Image result
107
  handler = EndpointHandler(path=".")
108
- result = handler(payload)
109
 
110
- # 3. Extract the image from the result and save it
111
- if result and isinstance(result, list) and 'mask' in result[0]:
112
- result_img = result[0]['mask']
113
- result_img.save(output_path)
114
- print("Wrote mask to {}".format(output_path))
 
 
 
115
  else:
116
- print("Failed to get a valid mask from the handler.")
117
 
118
  if __name__ == "__main__":
119
  main()
 
5
  import numpy as np
6
  from PIL import Image
7
  import torch
8
+ from transformers import SamModel, SamProcessor
9
  from typing import Dict, List, Any
10
  import torch.nn.functional as F
11
 
 
20
  """
21
  try:
22
  # Load the model and processor from the local path
23
+ self.model = SamModel.from_pretrained(path).to(device).eval()
24
+ # Load processor with do_resize=False to avoid resizing
25
+ self.processor = SamProcessor.from_pretrained(path)
26
+ # Override the processor's image processor to disable resizing
27
+ self.processor.image_processor.do_resize = False
28
+ self.processor.image_processor.do_rescale = True
29
+ self.processor.image_processor.do_normalize = True
30
+
31
  except Exception as e:
32
  # Fallback to loading from a known SAM model if local loading fails
33
+ print(f"Failed to load from local path: {e}")
34
  print("Attempting to load from facebook/sam-vit-base")
35
+ self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device).eval()
36
+ self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
37
+ # Override the processor's image processor to disable resizing
38
+ self.processor.image_processor.do_resize = False
39
+ self.processor.image_processor.do_rescale = True
40
+ self.processor.image_processor.do_normalize = True
41
+
42
+ def generate_grid_points(self, width, height, points_per_side=32):
43
+ """Generate a grid of points across the image for comprehensive segmentation."""
44
+ points = []
45
+ labels = []
46
+
47
+ # Create a grid of points
48
+ x_coords = np.linspace(0, width - 1, points_per_side, dtype=int)
49
+ y_coords = np.linspace(0, height - 1, points_per_side, dtype=int)
50
+
51
+ for y in y_coords:
52
+ for x in x_coords:
53
+ points.append([x, y])
54
+ labels.append(1) # foreground point
55
+
56
+ return [points], [labels]
57
+
58
+ def filter_masks(self, masks, iou_scores, score_threshold=0.88, stability_score_threshold=0.95):
59
+ """Filter masks based on quality scores and remove duplicates."""
60
+ filtered_masks = []
61
+ filtered_scores = []
62
+
63
+ for i, (mask, score) in enumerate(zip(masks, iou_scores)):
64
+ if score > score_threshold:
65
+ # Calculate stability score (measure of mask quality)
66
+ mask_binary = mask > 0.0
67
+ stability_score = self.calculate_stability_score(mask_binary)
68
+
69
+ if stability_score > stability_score_threshold:
70
+ filtered_masks.append(mask)
71
+ filtered_scores.append(score.item())
72
+
73
+ return filtered_masks, filtered_scores
74
+
75
+ def calculate_stability_score(self, mask):
76
+ """Calculate stability score for a mask."""
77
+ # Simple stability score based on mask coherence
78
+ mask_float = mask.float()
79
+ # Calculate the ratio of the mask area to its bounding box area
80
+ mask_area = torch.sum(mask_float)
81
+ if mask_area == 0:
82
+ return 0.0
83
+
84
+ # Find bounding box
85
+ coords = torch.nonzero(mask_float)
86
+ if len(coords) == 0:
87
+ return 0.0
88
+
89
+ min_y, min_x = torch.min(coords, dim=0)[0]
90
+ max_y, max_x = torch.max(coords, dim=0)[0]
91
+ bbox_area = (max_y - min_y + 1) * (max_x - min_x + 1)
92
+
93
+ stability = mask_area / bbox_area if bbox_area > 0 else 0.0
94
+ return stability.item()
95
 
96
  def __call__(self, data):
97
  """
98
  Called on every HTTP request.
99
  Handles both base64-encoded images and PIL images.
100
+ Returns a list of segment masks.
101
  """
102
  # 1. Parse and decode the input image
103
  inputs = data.pop("inputs", None)
 
115
  else:
116
  raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
117
 
118
+ # 2. Get image dimensions
119
+ width, height = img.size
120
+
121
+ # 3. Generate grid points for comprehensive segmentation
122
+ input_points, input_labels = self.generate_grid_points(width, height, points_per_side=16)
123
 
124
+ # 4. Process the image and points
125
+ inputs = self.processor(
126
+ img,
127
+ input_points=input_points,
128
+ input_labels=input_labels,
129
+ return_tensors="pt"
130
+ ).to(device)
131
 
132
+ # 5. Generate masks
133
+ all_masks = []
134
+ all_scores = []
135
 
 
136
  try:
137
+ with torch.no_grad():
138
+ outputs = self.model(**inputs)
139
+
140
  # Get predicted masks and scores
141
+ predicted_masks = outputs.pred_masks.cpu() # Shape: [batch, num_queries, num_masks_per_query, H, W]
142
+ iou_scores = outputs.iou_scores.cpu() # Shape: [batch, num_queries, num_masks_per_query]
143
 
144
+ # Process masks from all queries
145
+ batch_size, num_queries, num_masks_per_query = predicted_masks.shape[:3]
 
146
 
147
+ for query_idx in range(num_queries):
148
+ query_masks = predicted_masks[0, query_idx] # [num_masks_per_query, H, W]
149
+ query_scores = iou_scores[0, query_idx] # [num_masks_per_query]
150
+
151
+ # Select best mask for this query
152
+ best_mask_idx = torch.argmax(query_scores)
153
+ if query_scores[best_mask_idx] > 0.5: # Only keep high-quality masks
154
+ best_mask = query_masks[best_mask_idx]
155
+ all_masks.append(best_mask)
156
+ all_scores.append(query_scores[best_mask_idx])
157
 
158
+ # Filter and deduplicate masks
159
+ if all_masks:
160
+ filtered_masks, filtered_scores = self.filter_masks(all_masks, all_scores)
161
+ else:
162
+ filtered_masks, filtered_scores = [], []
163
+
164
  except Exception as e:
165
+ print(f"Error processing masks: {e}")
166
+ # Fallback: create a simple center mask
 
167
  mask_binary = np.zeros((height, width), dtype=np.uint8)
168
  center_x, center_y = width // 2, height // 2
169
  size = min(width, height) // 8
170
+ y_start, y_end = max(0, center_y-size), min(height, center_y+size)
171
+ x_start, x_end = max(0, center_x-size), min(width, center_x+size)
172
+ mask_binary[y_start:y_end, x_start:x_end] = 255
173
+
174
+ output_img = Image.fromarray(mask_binary)
175
+ return [{'score': 0.5, 'label': 'fallback_segment', 'mask': output_img}]
176
 
177
+ # 6. Convert masks to PIL Images and prepare results
178
+ results = []
179
+ for i, (mask, score) in enumerate(zip(filtered_masks, filtered_scores)):
180
+ # Convert to binary mask
181
+ mask_binary = (mask > 0.0).numpy().astype(np.uint8) * 255
182
+
183
+ # Create PIL Image
184
+ output_img = Image.fromarray(mask_binary)
185
+
186
+ results.append({
187
+ 'score': float(score),
188
+ 'label': f'segment_{i}',
189
+ 'mask': output_img
190
+ })
191
+
192
+ # If no segments found, return a fallback
193
+ if not results:
194
+ mask_binary = np.zeros((height, width), dtype=np.uint8)
195
+ output_img = Image.fromarray(mask_binary)
196
+ results.append({'score': 0.0, 'label': 'no_segments', 'mask': output_img})
197
+
198
+ return results
199
 
200
  def main():
201
  # This main function shows how a client would call the endpoint locally.
202
  input_path = "/Users/rp7/Downloads/test.jpeg"
203
+ output_dir = "output_masks"
204
+
205
+ # Create output directory
206
+ import os
207
+ os.makedirs(output_dir, exist_ok=True)
208
 
209
  # 1. Prepare the payload with a base64-encoded image string
210
  with open(input_path, "rb") as f:
211
  img_bytes = f.read()
212
  img_b64 = base64.b64encode(img_bytes).decode("utf-8")
213
+ payload = {"inputs": f"data:image/jpeg;base64,{img_b64}"}
214
 
215
+ # 2. Instantiate handler and get the result
216
  handler = EndpointHandler(path=".")
217
+ results = handler(payload)
218
 
219
+ # 3. Save all masks
220
+ if results and isinstance(results, list):
221
+ print(f"Found {len(results)} segments")
222
+ for i, result in enumerate(results):
223
+ if 'mask' in result:
224
+ output_path = os.path.join(output_dir, f"segment_{i}_score_{result['score']:.3f}.png")
225
+ result['mask'].save(output_path)
226
+ print(f"Saved {result['label']} (score: {result['score']:.3f}) to {output_path}")
227
  else:
228
+ print("Failed to get valid masks from the handler.")
229
 
230
  if __name__ == "__main__":
231
  main()