CodeJackR
commited on
Commit
·
16a5f8c
1
Parent(s):
2ea60f3
handle boxes and points for SAM input
Browse files- handler.py +20 -5
handler.py
CHANGED
@@ -62,8 +62,15 @@ class EndpointHandler():
|
|
62 |
# Process the image
|
63 |
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
64 |
|
65 |
-
#
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
# Generate masks using the model
|
69 |
with torch.no_grad():
|
@@ -76,9 +83,17 @@ class EndpointHandler():
|
|
76 |
inputs["reshaped_input_sizes"].cpu()
|
77 |
)[0]
|
78 |
|
79 |
-
# Convert the
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
# Convert result to base64
|
84 |
out = io.BytesIO()
|
|
|
62 |
# Process the image
|
63 |
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
64 |
|
65 |
+
# SAM requires input prompts, so we'll generate a center point prompt
|
66 |
+
height, width = img.size[1], img.size[0] # PIL returns (width, height)
|
67 |
+
|
68 |
+
# Create a center point prompt for automatic segmentation
|
69 |
+
input_points = [[[width // 2, height // 2]]] # Center point
|
70 |
+
input_labels = [[1]] # Positive prompt
|
71 |
+
|
72 |
+
# Prepare inputs for the model with prompts
|
73 |
+
inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt")
|
74 |
|
75 |
# Generate masks using the model
|
76 |
with torch.no_grad():
|
|
|
83 |
inputs["reshaped_input_sizes"].cpu()
|
84 |
)[0]
|
85 |
|
86 |
+
# Convert the best mask to a binary mask
|
87 |
+
# SAM returns multiple masks, take the first one
|
88 |
+
if len(masks) > 0:
|
89 |
+
mask = masks[0].squeeze().numpy()
|
90 |
+
mask_binary = (mask > 0.5).astype(np.uint8) * 255
|
91 |
+
else:
|
92 |
+
# Fallback: create a simple center mask
|
93 |
+
mask_binary = np.zeros((height, width), dtype=np.uint8)
|
94 |
+
center_x, center_y = width // 2, height // 2
|
95 |
+
size = min(width, height) // 8
|
96 |
+
mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
|
97 |
|
98 |
# Convert result to base64
|
99 |
out = io.BytesIO()
|