|
|
|
|
|
import io |
|
import base64 |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
from transformers import SamModel, SamImageProcessor |
|
from typing import Dict, List, Any |
|
import torch.nn.functional as F |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
""" |
|
Called once at startup. |
|
Load the SAM model using Hugging Face Transformers. |
|
""" |
|
try: |
|
|
|
self.model = SamModel.from_pretrained(path).to(device) |
|
self.processor = SamImageProcessor.from_pretrained(path, do_resize=False) |
|
except Exception as e: |
|
|
|
print("Failed to load from local path: {}".format(e)) |
|
print("Attempting to load from facebook/sam-vit-base") |
|
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) |
|
self.processor = SamImageProcessor.from_pretrained("facebook/sam-vit-base", do_resize=False) |
|
|
|
def __call__(self, data): |
|
""" |
|
Called on every HTTP request. |
|
Handles both base64-encoded images and PIL images. |
|
Returns a PIL Image object. |
|
""" |
|
|
|
inputs = data.pop("inputs", None) |
|
if inputs is None: |
|
raise ValueError("Missing 'inputs' key in the payload.") |
|
|
|
|
|
if isinstance(inputs, Image.Image): |
|
img = inputs.convert("RGB") |
|
elif isinstance(inputs, str): |
|
if inputs.startswith("data:"): |
|
inputs = inputs.split(",", 1)[1] |
|
image_bytes = base64.b64decode(inputs) |
|
img = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
else: |
|
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = self.processor(img, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
|
|
|
|
try: |
|
|
|
predicted_masks = outputs.pred_masks.cpu() |
|
iou_scores = outputs.iou_scores.cpu()[0] |
|
|
|
|
|
if predicted_masks.ndim == 5: |
|
predicted_masks = predicted_masks.squeeze(1) |
|
|
|
|
|
best_mask_idx = torch.argmax(iou_scores) |
|
best_mask = predicted_masks[0, best_mask_idx, :, :] |
|
|
|
|
|
mask_binary = (best_mask > 0.0).numpy().astype(np.uint8) * 255 |
|
|
|
except Exception as e: |
|
print("Error processing masks: {}".format(e)) |
|
|
|
height, width = img.size[1], img.size[0] |
|
mask_binary = np.zeros((height, width), dtype=np.uint8) |
|
center_x, center_y = width // 2, height // 2 |
|
size = min(width, height) // 8 |
|
mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255 |
|
|
|
|
|
output_img = Image.fromarray(mask_binary) |
|
return [{'score': None, 'label': 'everything', 'mask': output_img}] |
|
|
|
def main(): |
|
|
|
input_path = "/Users/rp7/Downloads/test.jpeg" |
|
output_path = "output.png" |
|
|
|
|
|
with open(input_path, "rb") as f: |
|
img_bytes = f.read() |
|
img_b64 = base64.b64encode(img_bytes).decode("utf-8") |
|
payload = {"inputs": "data:image/jpeg;base64,{}".format(img_b64)} |
|
|
|
|
|
handler = EndpointHandler(path=".") |
|
result = handler(payload) |
|
|
|
|
|
if result and isinstance(result, list) and 'mask' in result[0]: |
|
result_img = result[0]['mask'] |
|
result_img.save(output_path) |
|
print("Wrote mask to {}".format(output_path)) |
|
else: |
|
print("Failed to get a valid mask from the handler.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|