File size: 9,476 Bytes
592adee 2f7cfdc 592adee d05bd8d 592adee d816a26 69e7d30 d05bd8d f9b3f94 592adee c78d04e d05bd8d 2f7cfdc d816a26 2f7cfdc d816a26 69e7d30 d816a26 69e7d30 d816a26 69e7d30 592adee e52ad65 2f7cfdc 2f4ef92 69e7d30 2f7cfdc e0fb0e6 2f4ef92 c78d04e 2f4ef92 e52ad65 2f4ef92 16a5f8c 69e7d30 16a5f8c 69e7d30 d816a26 69e7d30 d816a26 f9b3f94 69e7d30 233c56f 69e7d30 38a30a4 69e7d30 40b9d26 69e7d30 38a30a4 69e7d30 f9b3f94 69e7d30 16a5f8c 69e7d30 d816a26 69e7d30 f9b3f94 e52ad65 f9b3f94 69e7d30 f9b3f94 e52ad65 f9b3f94 69e7d30 f9b3f94 69e7d30 f9b3f94 69e7d30 e0fb0e6 69e7d30 d29fd9e 69e7d30 f9b3f94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
# handler.py
import io
import base64
import numpy as np
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from typing import Dict, List, Any
import torch.nn.functional as F
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
"""
Called once at startup.
Load the SAM model using Hugging Face Transformers.
"""
try:
# Load the model and processor from the local path
self.model = SamModel.from_pretrained(path).to(device).eval()
# Load processor with do_resize=False to avoid resizing
self.processor = SamProcessor.from_pretrained(path)
# Override the processor's image processor to disable resizing
self.processor.image_processor.do_resize = False
self.processor.image_processor.do_rescale = True
self.processor.image_processor.do_normalize = True
except Exception as e:
# Fallback to loading from a known SAM model if local loading fails
print(f"Failed to load from local path: {e}")
print("Attempting to load from facebook/sam-vit-base")
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device).eval()
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# Override the processor's image processor to disable resizing
self.processor.image_processor.do_resize = False
self.processor.image_processor.do_rescale = True
self.processor.image_processor.do_normalize = True
def generate_grid_points(self, width, height, points_per_side=32):
"""Generate a grid of points across the image for comprehensive segmentation."""
points = []
labels = []
# Create a grid of points
x_coords = np.linspace(0, width - 1, points_per_side, dtype=int)
y_coords = np.linspace(0, height - 1, points_per_side, dtype=int)
for y in y_coords:
for x in x_coords:
points.append([x, y])
labels.append(1) # foreground point
return [points], [labels]
def filter_masks(self, masks, iou_scores, score_threshold=0.88, stability_score_threshold=0.95):
"""Filter masks based on quality scores and remove duplicates."""
filtered_masks = []
filtered_scores = []
for i, (mask, score) in enumerate(zip(masks, iou_scores)):
if score > score_threshold:
# Calculate stability score (measure of mask quality)
mask_binary = mask > 0.0
stability_score = self.calculate_stability_score(mask_binary)
if stability_score > stability_score_threshold:
filtered_masks.append(mask)
filtered_scores.append(score.item())
return filtered_masks, filtered_scores
def calculate_stability_score(self, mask):
"""Calculate stability score for a mask."""
# Simple stability score based on mask coherence
mask_float = mask.float()
# Calculate the ratio of the mask area to its bounding box area
mask_area = torch.sum(mask_float)
if mask_area == 0:
return 0.0
# Find bounding box
coords = torch.nonzero(mask_float)
if len(coords) == 0:
return 0.0
min_y, min_x = torch.min(coords, dim=0)[0]
max_y, max_x = torch.max(coords, dim=0)[0]
bbox_area = (max_y - min_y + 1) * (max_x - min_x + 1)
stability = mask_area / bbox_area if bbox_area > 0 else 0.0
return stability.item()
def __call__(self, data):
"""
Called on every HTTP request.
Handles both base64-encoded images and PIL images.
Returns a list of segment masks.
"""
# 1. Parse and decode the input image
inputs = data.pop("inputs", None)
if inputs is None:
raise ValueError("Missing 'inputs' key in the payload.")
# Check the type of inputs to handle both base64 strings and pre-processed PIL Images
if isinstance(inputs, Image.Image):
img = inputs.convert("RGB")
elif isinstance(inputs, str):
if inputs.startswith("data:"):
inputs = inputs.split(",", 1)[1]
image_bytes = base64.b64decode(inputs)
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
else:
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
# 2. Get image dimensions
width, height = img.size
# 3. Generate grid points for comprehensive segmentation
input_points, input_labels = self.generate_grid_points(width, height, points_per_side=16)
# 4. Process the image and points
inputs = self.processor(
img,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt"
).to(device)
# 5. Generate masks
all_masks = []
all_scores = []
try:
with torch.no_grad():
outputs = self.model(**inputs)
# Get predicted masks and scores
predicted_masks = outputs.pred_masks.cpu() # Shape: [batch, num_queries, num_masks_per_query, H, W]
iou_scores = outputs.iou_scores.cpu() # Shape: [batch, num_queries, num_masks_per_query]
# Process masks from all queries
batch_size, num_queries, num_masks_per_query = predicted_masks.shape[:3]
for query_idx in range(num_queries):
query_masks = predicted_masks[0, query_idx] # [num_masks_per_query, H, W]
query_scores = iou_scores[0, query_idx] # [num_masks_per_query]
# Select best mask for this query
best_mask_idx = torch.argmax(query_scores)
if query_scores[best_mask_idx] > 0.5: # Only keep high-quality masks
best_mask = query_masks[best_mask_idx]
all_masks.append(best_mask)
all_scores.append(query_scores[best_mask_idx])
# Filter and deduplicate masks
if all_masks:
filtered_masks, filtered_scores = self.filter_masks(all_masks, all_scores)
else:
filtered_masks, filtered_scores = [], []
except Exception as e:
print(f"Error processing masks: {e}")
# Fallback: create a simple center mask
mask_binary = np.zeros((height, width), dtype=np.uint8)
center_x, center_y = width // 2, height // 2
size = min(width, height) // 8
y_start, y_end = max(0, center_y-size), min(height, center_y+size)
x_start, x_end = max(0, center_x-size), min(width, center_x+size)
mask_binary[y_start:y_end, x_start:x_end] = 255
output_img = Image.fromarray(mask_binary)
return [{'score': 0.5, 'label': 'fallback_segment', 'mask': output_img}]
# 6. Convert masks to PIL Images and prepare results
results = []
for i, (mask, score) in enumerate(zip(filtered_masks, filtered_scores)):
# Convert to binary mask
mask_binary = (mask > 0.0).numpy().astype(np.uint8) * 255
# Create PIL Image
output_img = Image.fromarray(mask_binary)
results.append({
'score': float(score),
'label': f'segment_{i}',
'mask': output_img
})
# If no segments found, return a fallback
if not results:
mask_binary = np.zeros((height, width), dtype=np.uint8)
output_img = Image.fromarray(mask_binary)
results.append({'score': 0.0, 'label': 'no_segments', 'mask': output_img})
return results
def main():
# This main function shows how a client would call the endpoint locally.
input_path = "/Users/rp7/Downloads/test.jpeg"
output_dir = "output_masks"
# Create output directory
import os
os.makedirs(output_dir, exist_ok=True)
# 1. Prepare the payload with a base64-encoded image string
with open(input_path, "rb") as f:
img_bytes = f.read()
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
payload = {"inputs": f"data:image/jpeg;base64,{img_b64}"}
# 2. Instantiate handler and get the result
handler = EndpointHandler(path=".")
results = handler(payload)
# 3. Save all masks
if results and isinstance(results, list):
print(f"Found {len(results)} segments")
for i, result in enumerate(results):
if 'mask' in result:
output_path = os.path.join(output_dir, f"segment_{i}_score_{result['score']:.3f}.png")
result['mask'].save(output_path)
print(f"Saved {result['label']} (score: {result['score']:.3f}) to {output_path}")
else:
print("Failed to get valid masks from the handler.")
if __name__ == "__main__":
main()
|