# handler.py import io import base64 import numpy as np from PIL import Image import torch from transformers import SamModel, SamProcessor from typing import Dict, List, Any import torch.nn.functional as F # set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class EndpointHandler(): def __init__(self, path=""): """ Called once at startup. Load the SAM model using Hugging Face Transformers. """ try: # Load the model and processor from the local path self.model = SamModel.from_pretrained(path).to(device).eval() # Load processor with do_resize=False to avoid resizing self.processor = SamProcessor.from_pretrained(path) # Override the processor's image processor to disable resizing self.processor.image_processor.do_resize = False self.processor.image_processor.do_rescale = True self.processor.image_processor.do_normalize = True except Exception as e: # Fallback to loading from a known SAM model if local loading fails print(f"Failed to load from local path: {e}") print("Attempting to load from facebook/sam-vit-base") self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device).eval() self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base") # Override the processor's image processor to disable resizing self.processor.image_processor.do_resize = False self.processor.image_processor.do_rescale = True self.processor.image_processor.do_normalize = True def generate_grid_points(self, width, height, points_per_side=32): """Generate a grid of points across the image for comprehensive segmentation.""" points = [] labels = [] # Create a grid of points x_coords = np.linspace(0, width - 1, points_per_side, dtype=int) y_coords = np.linspace(0, height - 1, points_per_side, dtype=int) for y in y_coords: for x in x_coords: points.append([x, y]) labels.append(1) # foreground point return [points], [labels] def filter_masks(self, masks, iou_scores, score_threshold=0.88, stability_score_threshold=0.95): """Filter masks based on quality scores and remove duplicates.""" filtered_masks = [] filtered_scores = [] for i, (mask, score) in enumerate(zip(masks, iou_scores)): if score > score_threshold: # Calculate stability score (measure of mask quality) mask_binary = mask > 0.0 stability_score = self.calculate_stability_score(mask_binary) if stability_score > stability_score_threshold: filtered_masks.append(mask) filtered_scores.append(score.item()) return filtered_masks, filtered_scores def calculate_stability_score(self, mask): """Calculate stability score for a mask.""" # Simple stability score based on mask coherence mask_float = mask.float() # Calculate the ratio of the mask area to its bounding box area mask_area = torch.sum(mask_float) if mask_area == 0: return 0.0 # Find bounding box coords = torch.nonzero(mask_float) if len(coords) == 0: return 0.0 min_y, min_x = torch.min(coords, dim=0)[0] max_y, max_x = torch.max(coords, dim=0)[0] bbox_area = (max_y - min_y + 1) * (max_x - min_x + 1) stability = mask_area / bbox_area if bbox_area > 0 else 0.0 return stability.item() def __call__(self, data): """ Called on every HTTP request. Handles both base64-encoded images and PIL images. Returns a list of segment masks. """ # 1. Parse and decode the input image inputs = data.pop("inputs", None) if inputs is None: raise ValueError("Missing 'inputs' key in the payload.") # Check the type of inputs to handle both base64 strings and pre-processed PIL Images if isinstance(inputs, Image.Image): img = inputs.convert("RGB") elif isinstance(inputs, str): if inputs.startswith("data:"): inputs = inputs.split(",", 1)[1] image_bytes = base64.b64decode(inputs) img = Image.open(io.BytesIO(image_bytes)).convert("RGB") else: raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.") # 2. Get image dimensions width, height = img.size # 3. Generate grid points for comprehensive segmentation input_points, input_labels = self.generate_grid_points(width, height, points_per_side=16) # 4. Process the image and points inputs = self.processor( img, input_points=input_points, input_labels=input_labels, return_tensors="pt" ).to(device) # 5. Generate masks all_masks = [] all_scores = [] try: with torch.no_grad(): outputs = self.model(**inputs) # Get predicted masks and scores predicted_masks = outputs.pred_masks.cpu() # Shape: [batch, num_queries, num_masks_per_query, H, W] iou_scores = outputs.iou_scores.cpu() # Shape: [batch, num_queries, num_masks_per_query] # Process masks from all queries batch_size, num_queries, num_masks_per_query = predicted_masks.shape[:3] for query_idx in range(num_queries): query_masks = predicted_masks[0, query_idx] # [num_masks_per_query, H, W] query_scores = iou_scores[0, query_idx] # [num_masks_per_query] # Select best mask for this query best_mask_idx = torch.argmax(query_scores) if query_scores[best_mask_idx] > 0.5: # Only keep high-quality masks best_mask = query_masks[best_mask_idx] all_masks.append(best_mask) all_scores.append(query_scores[best_mask_idx]) # Filter and deduplicate masks if all_masks: filtered_masks, filtered_scores = self.filter_masks(all_masks, all_scores) else: filtered_masks, filtered_scores = [], [] except Exception as e: print(f"Error processing masks: {e}") # Fallback: create a simple center mask mask_binary = np.zeros((height, width), dtype=np.uint8) center_x, center_y = width // 2, height // 2 size = min(width, height) // 8 y_start, y_end = max(0, center_y-size), min(height, center_y+size) x_start, x_end = max(0, center_x-size), min(width, center_x+size) mask_binary[y_start:y_end, x_start:x_end] = 255 output_img = Image.fromarray(mask_binary) return [{'score': 0.5, 'label': 'fallback_segment', 'mask': output_img}] # 6. Convert masks to PIL Images and prepare results results = [] for i, (mask, score) in enumerate(zip(filtered_masks, filtered_scores)): # Convert to binary mask mask_binary = (mask > 0.0).numpy().astype(np.uint8) * 255 # Create PIL Image output_img = Image.fromarray(mask_binary) results.append({ 'score': float(score), 'label': f'segment_{i}', 'mask': output_img }) # If no segments found, return a fallback if not results: mask_binary = np.zeros((height, width), dtype=np.uint8) output_img = Image.fromarray(mask_binary) results.append({'score': 0.0, 'label': 'no_segments', 'mask': output_img}) return results def main(): # This main function shows how a client would call the endpoint locally. input_path = "/Users/rp7/Downloads/test.jpeg" output_dir = "output_masks" # Create output directory import os os.makedirs(output_dir, exist_ok=True) # 1. Prepare the payload with a base64-encoded image string with open(input_path, "rb") as f: img_bytes = f.read() img_b64 = base64.b64encode(img_bytes).decode("utf-8") payload = {"inputs": f"data:image/jpeg;base64,{img_b64}"} # 2. Instantiate handler and get the result handler = EndpointHandler(path=".") results = handler(payload) # 3. Save all masks if results and isinstance(results, list): print(f"Found {len(results)} segments") for i, result in enumerate(results): if 'mask' in result: output_path = os.path.join(output_dir, f"segment_{i}_score_{result['score']:.3f}.png") result['mask'].save(output_path) print(f"Saved {result['label']} (score: {result['score']:.3f}) to {output_path}") else: print("Failed to get valid masks from the handler.") if __name__ == "__main__": main()