CodeJackR commited on
Commit
16a5f8c
·
1 Parent(s): 2ea60f3

handle boxes and points for SAM input

Browse files
Files changed (1) hide show
  1. handler.py +20 -5
handler.py CHANGED
@@ -62,8 +62,15 @@ class EndpointHandler():
62
  # Process the image
63
  img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
64
 
65
- # Prepare inputs for the model
66
- inputs = self.processor(img, return_tensors="pt")
 
 
 
 
 
 
 
67
 
68
  # Generate masks using the model
69
  with torch.no_grad():
@@ -76,9 +83,17 @@ class EndpointHandler():
76
  inputs["reshaped_input_sizes"].cpu()
77
  )[0]
78
 
79
- # Convert the first mask to a binary mask
80
- mask = masks[0].squeeze().numpy()
81
- mask_binary = (mask > 0.0).astype(np.uint8) * 255
 
 
 
 
 
 
 
 
82
 
83
  # Convert result to base64
84
  out = io.BytesIO()
 
62
  # Process the image
63
  img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
64
 
65
+ # SAM requires input prompts, so we'll generate a center point prompt
66
+ height, width = img.size[1], img.size[0] # PIL returns (width, height)
67
+
68
+ # Create a center point prompt for automatic segmentation
69
+ input_points = [[[width // 2, height // 2]]] # Center point
70
+ input_labels = [[1]] # Positive prompt
71
+
72
+ # Prepare inputs for the model with prompts
73
+ inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt")
74
 
75
  # Generate masks using the model
76
  with torch.no_grad():
 
83
  inputs["reshaped_input_sizes"].cpu()
84
  )[0]
85
 
86
+ # Convert the best mask to a binary mask
87
+ # SAM returns multiple masks, take the first one
88
+ if len(masks) > 0:
89
+ mask = masks[0].squeeze().numpy()
90
+ mask_binary = (mask > 0.5).astype(np.uint8) * 255
91
+ else:
92
+ # Fallback: create a simple center mask
93
+ mask_binary = np.zeros((height, width), dtype=np.uint8)
94
+ center_x, center_y = width // 2, height // 2
95
+ size = min(width, height) // 8
96
+ mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
97
 
98
  # Convert result to base64
99
  out = io.BytesIO()