CodeJackR
Remove image resizing
b0c81ad
raw
history blame
4.73 kB
# handler.py
import io
import base64
import numpy as np
from PIL import Image
import torch
from transformers import SamModel, SamImageProcessor
from typing import Dict, List, Any
import torch.nn.functional as F
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
"""
Called once at startup.
Load the SAM model using Hugging Face Transformers.
"""
try:
# Load the model and processor from the local path
self.model = SamModel.from_pretrained(path).to(device)
self.processor = SamImageProcessor.from_pretrained(path, do_resize=False)
except Exception as e:
# Fallback to loading from a known SAM model if local loading fails
print("Failed to load from local path: {}".format(e))
print("Attempting to load from facebook/sam-vit-base")
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
self.processor = SamImageProcessor.from_pretrained("facebook/sam-vit-base", do_resize=False)
def __call__(self, data):
"""
Called on every HTTP request.
Handles both base64-encoded images and PIL images.
Returns a PIL Image object.
"""
# 1. Parse and decode the input image
inputs = data.pop("inputs", None)
if inputs is None:
raise ValueError("Missing 'inputs' key in the payload.")
# Check the type of inputs to handle both base64 strings and pre-processed PIL Images
if isinstance(inputs, Image.Image):
img = inputs.convert("RGB")
elif isinstance(inputs, str):
if inputs.startswith("data:"):
inputs = inputs.split(",", 1)[1]
image_bytes = base64.b64decode(inputs)
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
else:
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
# 2. Prepare prompts and process the image
# height, width = img.size[1], img.size[0]
# input_points = [[[width // 2, height // 2]]]
# input_labels = [[1]]
inputs = self.processor(img, return_tensors="pt").to(device)
# 3. Generate masks
with torch.no_grad():
outputs = self.model(**inputs)
# 4. Process and select the best mask
try:
# Get predicted masks and scores
predicted_masks = outputs.pred_masks.cpu()
iou_scores = outputs.iou_scores.cpu()[0]
# Handle different tensor dimensions
if predicted_masks.ndim == 5:
predicted_masks = predicted_masks.squeeze(1)
# Select the best mask
best_mask_idx = torch.argmax(iou_scores)
best_mask = predicted_masks[0, best_mask_idx, :, :]
# Convert to binary mask (no resizing needed since processor doesn't resize)
mask_binary = (best_mask > 0.0).numpy().astype(np.uint8) * 255
except Exception as e:
print("Error processing masks: {}".format(e))
# Fallback: create a simple mask
height, width = img.size[1], img.size[0]
mask_binary = np.zeros((height, width), dtype=np.uint8)
center_x, center_y = width // 2, height // 2
size = min(width, height) // 8
mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
# 5. Create and return the output PIL Image
output_img = Image.fromarray(mask_binary)
return [{'score': None, 'label': 'everything', 'mask': output_img}]
def main():
# This main function shows how a client would call the endpoint locally.
input_path = "/Users/rp7/Downloads/test.jpeg"
output_path = "output.png"
# 1. Prepare the payload with a base64-encoded image string
with open(input_path, "rb") as f:
img_bytes = f.read()
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
payload = {"inputs": "data:image/jpeg;base64,{}".format(img_b64)}
# 2. Instantiate handler and get the PIL Image result
handler = EndpointHandler(path=".")
result = handler(payload)
# 3. Extract the image from the result and save it
if result and isinstance(result, list) and 'mask' in result[0]:
result_img = result[0]['mask']
result_img.save(output_path)
print("Wrote mask to {}".format(output_path))
else:
print("Failed to get a valid mask from the handler.")
if __name__ == "__main__":
main()