CodeJackR
Update to new way of handling model
69e7d30
raw
history blame
9.48 kB
# handler.py
import io
import base64
import numpy as np
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from typing import Dict, List, Any
import torch.nn.functional as F
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
"""
Called once at startup.
Load the SAM model using Hugging Face Transformers.
"""
try:
# Load the model and processor from the local path
self.model = SamModel.from_pretrained(path).to(device).eval()
# Load processor with do_resize=False to avoid resizing
self.processor = SamProcessor.from_pretrained(path)
# Override the processor's image processor to disable resizing
self.processor.image_processor.do_resize = False
self.processor.image_processor.do_rescale = True
self.processor.image_processor.do_normalize = True
except Exception as e:
# Fallback to loading from a known SAM model if local loading fails
print(f"Failed to load from local path: {e}")
print("Attempting to load from facebook/sam-vit-base")
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device).eval()
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# Override the processor's image processor to disable resizing
self.processor.image_processor.do_resize = False
self.processor.image_processor.do_rescale = True
self.processor.image_processor.do_normalize = True
def generate_grid_points(self, width, height, points_per_side=32):
"""Generate a grid of points across the image for comprehensive segmentation."""
points = []
labels = []
# Create a grid of points
x_coords = np.linspace(0, width - 1, points_per_side, dtype=int)
y_coords = np.linspace(0, height - 1, points_per_side, dtype=int)
for y in y_coords:
for x in x_coords:
points.append([x, y])
labels.append(1) # foreground point
return [points], [labels]
def filter_masks(self, masks, iou_scores, score_threshold=0.88, stability_score_threshold=0.95):
"""Filter masks based on quality scores and remove duplicates."""
filtered_masks = []
filtered_scores = []
for i, (mask, score) in enumerate(zip(masks, iou_scores)):
if score > score_threshold:
# Calculate stability score (measure of mask quality)
mask_binary = mask > 0.0
stability_score = self.calculate_stability_score(mask_binary)
if stability_score > stability_score_threshold:
filtered_masks.append(mask)
filtered_scores.append(score.item())
return filtered_masks, filtered_scores
def calculate_stability_score(self, mask):
"""Calculate stability score for a mask."""
# Simple stability score based on mask coherence
mask_float = mask.float()
# Calculate the ratio of the mask area to its bounding box area
mask_area = torch.sum(mask_float)
if mask_area == 0:
return 0.0
# Find bounding box
coords = torch.nonzero(mask_float)
if len(coords) == 0:
return 0.0
min_y, min_x = torch.min(coords, dim=0)[0]
max_y, max_x = torch.max(coords, dim=0)[0]
bbox_area = (max_y - min_y + 1) * (max_x - min_x + 1)
stability = mask_area / bbox_area if bbox_area > 0 else 0.0
return stability.item()
def __call__(self, data):
"""
Called on every HTTP request.
Handles both base64-encoded images and PIL images.
Returns a list of segment masks.
"""
# 1. Parse and decode the input image
inputs = data.pop("inputs", None)
if inputs is None:
raise ValueError("Missing 'inputs' key in the payload.")
# Check the type of inputs to handle both base64 strings and pre-processed PIL Images
if isinstance(inputs, Image.Image):
img = inputs.convert("RGB")
elif isinstance(inputs, str):
if inputs.startswith("data:"):
inputs = inputs.split(",", 1)[1]
image_bytes = base64.b64decode(inputs)
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
else:
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
# 2. Get image dimensions
width, height = img.size
# 3. Generate grid points for comprehensive segmentation
input_points, input_labels = self.generate_grid_points(width, height, points_per_side=16)
# 4. Process the image and points
inputs = self.processor(
img,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt"
).to(device)
# 5. Generate masks
all_masks = []
all_scores = []
try:
with torch.no_grad():
outputs = self.model(**inputs)
# Get predicted masks and scores
predicted_masks = outputs.pred_masks.cpu() # Shape: [batch, num_queries, num_masks_per_query, H, W]
iou_scores = outputs.iou_scores.cpu() # Shape: [batch, num_queries, num_masks_per_query]
# Process masks from all queries
batch_size, num_queries, num_masks_per_query = predicted_masks.shape[:3]
for query_idx in range(num_queries):
query_masks = predicted_masks[0, query_idx] # [num_masks_per_query, H, W]
query_scores = iou_scores[0, query_idx] # [num_masks_per_query]
# Select best mask for this query
best_mask_idx = torch.argmax(query_scores)
if query_scores[best_mask_idx] > 0.5: # Only keep high-quality masks
best_mask = query_masks[best_mask_idx]
all_masks.append(best_mask)
all_scores.append(query_scores[best_mask_idx])
# Filter and deduplicate masks
if all_masks:
filtered_masks, filtered_scores = self.filter_masks(all_masks, all_scores)
else:
filtered_masks, filtered_scores = [], []
except Exception as e:
print(f"Error processing masks: {e}")
# Fallback: create a simple center mask
mask_binary = np.zeros((height, width), dtype=np.uint8)
center_x, center_y = width // 2, height // 2
size = min(width, height) // 8
y_start, y_end = max(0, center_y-size), min(height, center_y+size)
x_start, x_end = max(0, center_x-size), min(width, center_x+size)
mask_binary[y_start:y_end, x_start:x_end] = 255
output_img = Image.fromarray(mask_binary)
return [{'score': 0.5, 'label': 'fallback_segment', 'mask': output_img}]
# 6. Convert masks to PIL Images and prepare results
results = []
for i, (mask, score) in enumerate(zip(filtered_masks, filtered_scores)):
# Convert to binary mask
mask_binary = (mask > 0.0).numpy().astype(np.uint8) * 255
# Create PIL Image
output_img = Image.fromarray(mask_binary)
results.append({
'score': float(score),
'label': f'segment_{i}',
'mask': output_img
})
# If no segments found, return a fallback
if not results:
mask_binary = np.zeros((height, width), dtype=np.uint8)
output_img = Image.fromarray(mask_binary)
results.append({'score': 0.0, 'label': 'no_segments', 'mask': output_img})
return results
def main():
# This main function shows how a client would call the endpoint locally.
input_path = "/Users/rp7/Downloads/test.jpeg"
output_dir = "output_masks"
# Create output directory
import os
os.makedirs(output_dir, exist_ok=True)
# 1. Prepare the payload with a base64-encoded image string
with open(input_path, "rb") as f:
img_bytes = f.read()
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
payload = {"inputs": f"data:image/jpeg;base64,{img_b64}"}
# 2. Instantiate handler and get the result
handler = EndpointHandler(path=".")
results = handler(payload)
# 3. Save all masks
if results and isinstance(results, list):
print(f"Found {len(results)} segments")
for i, result in enumerate(results):
if 'mask' in result:
output_path = os.path.join(output_dir, f"segment_{i}_score_{result['score']:.3f}.png")
result['mask'].save(output_path)
print(f"Saved {result['label']} (score: {result['score']:.3f}) to {output_path}")
else:
print("Failed to get valid masks from the handler.")
if __name__ == "__main__":
main()