File size: 4,733 Bytes
592adee 2f7cfdc 592adee d05bd8d 592adee d816a26 0e71822 d05bd8d f9b3f94 592adee c78d04e d05bd8d 2f7cfdc d816a26 2f7cfdc d816a26 c78d04e 0e71822 d816a26 e0fb0e6 d816a26 c78d04e 0e71822 592adee e52ad65 2f7cfdc 2f4ef92 e52ad65 2f7cfdc e0fb0e6 2f4ef92 c78d04e 2f4ef92 e52ad65 2f4ef92 16a5f8c e0fb0e6 0e71822 16a5f8c 0e71822 d816a26 e0fb0e6 d816a26 e0fb0e6 f9b3f94 233c56f 38a30a4 233c56f 40b9d26 233c56f 38a30a4 b0c81ad d816a26 f9b3f94 e0fb0e6 b0c81ad 16a5f8c d816a26 e52ad65 06bd1fa f9b3f94 e52ad65 f9b3f94 e0fb0e6 f9b3f94 e52ad65 f9b3f94 e0fb0e6 f9b3f94 e52ad65 f9b3f94 d29fd9e e0fb0e6 d29fd9e f9b3f94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
# handler.py
import io
import base64
import numpy as np
from PIL import Image
import torch
from transformers import SamModel, SamImageProcessor
from typing import Dict, List, Any
import torch.nn.functional as F
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
"""
Called once at startup.
Load the SAM model using Hugging Face Transformers.
"""
try:
# Load the model and processor from the local path
self.model = SamModel.from_pretrained(path).to(device)
self.processor = SamImageProcessor.from_pretrained(path, do_resize=False)
except Exception as e:
# Fallback to loading from a known SAM model if local loading fails
print("Failed to load from local path: {}".format(e))
print("Attempting to load from facebook/sam-vit-base")
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
self.processor = SamImageProcessor.from_pretrained("facebook/sam-vit-base", do_resize=False)
def __call__(self, data):
"""
Called on every HTTP request.
Handles both base64-encoded images and PIL images.
Returns a PIL Image object.
"""
# 1. Parse and decode the input image
inputs = data.pop("inputs", None)
if inputs is None:
raise ValueError("Missing 'inputs' key in the payload.")
# Check the type of inputs to handle both base64 strings and pre-processed PIL Images
if isinstance(inputs, Image.Image):
img = inputs.convert("RGB")
elif isinstance(inputs, str):
if inputs.startswith("data:"):
inputs = inputs.split(",", 1)[1]
image_bytes = base64.b64decode(inputs)
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
else:
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
# 2. Prepare prompts and process the image
# height, width = img.size[1], img.size[0]
# input_points = [[[width // 2, height // 2]]]
# input_labels = [[1]]
inputs = self.processor(img, return_tensors="pt").to(device)
# 3. Generate masks
with torch.no_grad():
outputs = self.model(**inputs)
# 4. Process and select the best mask
try:
# Get predicted masks and scores
predicted_masks = outputs.pred_masks.cpu()
iou_scores = outputs.iou_scores.cpu()[0]
# Handle different tensor dimensions
if predicted_masks.ndim == 5:
predicted_masks = predicted_masks.squeeze(1)
# Select the best mask
best_mask_idx = torch.argmax(iou_scores)
best_mask = predicted_masks[0, best_mask_idx, :, :]
# Convert to binary mask (no resizing needed since processor doesn't resize)
mask_binary = (best_mask > 0.0).numpy().astype(np.uint8) * 255
except Exception as e:
print("Error processing masks: {}".format(e))
# Fallback: create a simple mask
height, width = img.size[1], img.size[0]
mask_binary = np.zeros((height, width), dtype=np.uint8)
center_x, center_y = width // 2, height // 2
size = min(width, height) // 8
mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
# 5. Create and return the output PIL Image
output_img = Image.fromarray(mask_binary)
return [{'score': None, 'label': 'everything', 'mask': output_img}]
def main():
# This main function shows how a client would call the endpoint locally.
input_path = "/Users/rp7/Downloads/test.jpeg"
output_path = "output.png"
# 1. Prepare the payload with a base64-encoded image string
with open(input_path, "rb") as f:
img_bytes = f.read()
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
payload = {"inputs": "data:image/jpeg;base64,{}".format(img_b64)}
# 2. Instantiate handler and get the PIL Image result
handler = EndpointHandler(path=".")
result = handler(payload)
# 3. Extract the image from the result and save it
if result and isinstance(result, list) and 'mask' in result[0]:
result_img = result[0]['mask']
result_img.save(output_path)
print("Wrote mask to {}".format(output_path))
else:
print("Failed to get a valid mask from the handler.")
if __name__ == "__main__":
main()
|