|
|
|
|
|
import io |
|
import base64 |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
from transformers import SamModel, SamProcessor |
|
from typing import Dict, List, Any |
|
import torch.nn.functional as F |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
""" |
|
Called once at startup. |
|
Load the SAM model using Hugging Face Transformers. |
|
""" |
|
try: |
|
|
|
self.model = SamModel.from_pretrained(path).to(device).eval() |
|
|
|
self.processor = SamProcessor.from_pretrained(path) |
|
|
|
self.processor.image_processor.do_resize = False |
|
self.processor.image_processor.do_rescale = True |
|
self.processor.image_processor.do_normalize = True |
|
|
|
except Exception as e: |
|
|
|
print(f"Failed to load from local path: {e}") |
|
print("Attempting to load from facebook/sam-vit-base") |
|
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device).eval() |
|
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base") |
|
|
|
self.processor.image_processor.do_resize = False |
|
self.processor.image_processor.do_rescale = True |
|
self.processor.image_processor.do_normalize = True |
|
|
|
def generate_grid_points(self, width, height, points_per_side=32): |
|
"""Generate a grid of points across the image for comprehensive segmentation.""" |
|
points = [] |
|
labels = [] |
|
|
|
|
|
x_coords = np.linspace(0, width - 1, points_per_side, dtype=int) |
|
y_coords = np.linspace(0, height - 1, points_per_side, dtype=int) |
|
|
|
for y in y_coords: |
|
for x in x_coords: |
|
points.append([x, y]) |
|
labels.append(1) |
|
|
|
return [points], [labels] |
|
|
|
def filter_masks(self, masks, iou_scores, score_threshold=0.88, stability_score_threshold=0.95): |
|
"""Filter masks based on quality scores and remove duplicates.""" |
|
filtered_masks = [] |
|
filtered_scores = [] |
|
|
|
for i, (mask, score) in enumerate(zip(masks, iou_scores)): |
|
if score > score_threshold: |
|
|
|
mask_binary = mask > 0.0 |
|
stability_score = self.calculate_stability_score(mask_binary) |
|
|
|
if stability_score > stability_score_threshold: |
|
filtered_masks.append(mask) |
|
filtered_scores.append(score.item()) |
|
|
|
return filtered_masks, filtered_scores |
|
|
|
def calculate_stability_score(self, mask): |
|
"""Calculate stability score for a mask.""" |
|
|
|
mask_float = mask.float() |
|
|
|
mask_area = torch.sum(mask_float) |
|
if mask_area == 0: |
|
return 0.0 |
|
|
|
|
|
coords = torch.nonzero(mask_float) |
|
if len(coords) == 0: |
|
return 0.0 |
|
|
|
min_y, min_x = torch.min(coords, dim=0)[0] |
|
max_y, max_x = torch.max(coords, dim=0)[0] |
|
bbox_area = (max_y - min_y + 1) * (max_x - min_x + 1) |
|
|
|
stability = mask_area / bbox_area if bbox_area > 0 else 0.0 |
|
return stability.item() |
|
|
|
def __call__(self, data): |
|
""" |
|
Called on every HTTP request. |
|
Handles both base64-encoded images and PIL images. |
|
Returns a list of segment masks. |
|
""" |
|
|
|
inputs = data.pop("inputs", None) |
|
if inputs is None: |
|
raise ValueError("Missing 'inputs' key in the payload.") |
|
|
|
|
|
if isinstance(inputs, Image.Image): |
|
img = inputs.convert("RGB") |
|
elif isinstance(inputs, str): |
|
if inputs.startswith("data:"): |
|
inputs = inputs.split(",", 1)[1] |
|
image_bytes = base64.b64decode(inputs) |
|
img = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
else: |
|
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.") |
|
|
|
|
|
width, height = img.size |
|
|
|
|
|
input_points, input_labels = self.generate_grid_points(width, height, points_per_side=16) |
|
|
|
|
|
inputs = self.processor( |
|
img, |
|
input_points=input_points, |
|
input_labels=input_labels, |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
|
|
all_masks = [] |
|
all_scores = [] |
|
|
|
try: |
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
|
|
|
|
predicted_masks = outputs.pred_masks.cpu() |
|
iou_scores = outputs.iou_scores.cpu() |
|
|
|
|
|
batch_size, num_queries, num_masks_per_query = predicted_masks.shape[:3] |
|
|
|
for query_idx in range(num_queries): |
|
query_masks = predicted_masks[0, query_idx] |
|
query_scores = iou_scores[0, query_idx] |
|
|
|
|
|
best_mask_idx = torch.argmax(query_scores) |
|
if query_scores[best_mask_idx] > 0.5: |
|
best_mask = query_masks[best_mask_idx] |
|
all_masks.append(best_mask) |
|
all_scores.append(query_scores[best_mask_idx]) |
|
|
|
|
|
if all_masks: |
|
filtered_masks, filtered_scores = self.filter_masks(all_masks, all_scores) |
|
else: |
|
filtered_masks, filtered_scores = [], [] |
|
|
|
except Exception as e: |
|
print(f"Error processing masks: {e}") |
|
|
|
mask_binary = np.zeros((height, width), dtype=np.uint8) |
|
center_x, center_y = width // 2, height // 2 |
|
size = min(width, height) // 8 |
|
y_start, y_end = max(0, center_y-size), min(height, center_y+size) |
|
x_start, x_end = max(0, center_x-size), min(width, center_x+size) |
|
mask_binary[y_start:y_end, x_start:x_end] = 255 |
|
|
|
output_img = Image.fromarray(mask_binary) |
|
return [{'score': 0.5, 'label': 'fallback_segment', 'mask': output_img}] |
|
|
|
|
|
results = [] |
|
for i, (mask, score) in enumerate(zip(filtered_masks, filtered_scores)): |
|
|
|
mask_binary = (mask > 0.0).numpy().astype(np.uint8) * 255 |
|
|
|
|
|
output_img = Image.fromarray(mask_binary) |
|
|
|
results.append({ |
|
'score': float(score), |
|
'label': f'segment_{i}', |
|
'mask': output_img |
|
}) |
|
|
|
|
|
if not results: |
|
mask_binary = np.zeros((height, width), dtype=np.uint8) |
|
output_img = Image.fromarray(mask_binary) |
|
results.append({'score': 0.0, 'label': 'no_segments', 'mask': output_img}) |
|
|
|
return results |
|
|
|
def main(): |
|
|
|
input_path = "/Users/rp7/Downloads/test.jpeg" |
|
output_dir = "output_masks" |
|
|
|
|
|
import os |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
with open(input_path, "rb") as f: |
|
img_bytes = f.read() |
|
img_b64 = base64.b64encode(img_bytes).decode("utf-8") |
|
payload = {"inputs": f"data:image/jpeg;base64,{img_b64}"} |
|
|
|
|
|
handler = EndpointHandler(path=".") |
|
results = handler(payload) |
|
|
|
|
|
if results and isinstance(results, list): |
|
print(f"Found {len(results)} segments") |
|
for i, result in enumerate(results): |
|
if 'mask' in result: |
|
output_path = os.path.join(output_dir, f"segment_{i}_score_{result['score']:.3f}.png") |
|
result['mask'].save(output_path) |
|
print(f"Saved {result['label']} (score: {result['score']:.3f}) to {output_path}") |
|
else: |
|
print("Failed to get valid masks from the handler.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|