File size: 4,733 Bytes
592adee
2f7cfdc
592adee
d05bd8d
592adee
 
d816a26
0e71822
d05bd8d
f9b3f94
592adee
c78d04e
 
 
d05bd8d
 
2f7cfdc
 
d816a26
2f7cfdc
d816a26
 
c78d04e
0e71822
d816a26
 
e0fb0e6
d816a26
c78d04e
0e71822
592adee
e52ad65
2f7cfdc
 
2f4ef92
e52ad65
2f7cfdc
e0fb0e6
2f4ef92
 
 
c78d04e
2f4ef92
 
 
 
 
e52ad65
2f4ef92
 
 
 
16a5f8c
e0fb0e6
0e71822
 
 
16a5f8c
0e71822
d816a26
e0fb0e6
d816a26
 
 
e0fb0e6
f9b3f94
233c56f
 
 
38a30a4
233c56f
 
 
40b9d26
233c56f
 
 
38a30a4
b0c81ad
 
d816a26
f9b3f94
e0fb0e6
b0c81ad
 
16a5f8c
 
 
 
d816a26
e52ad65
 
06bd1fa
f9b3f94
 
e52ad65
f9b3f94
e0fb0e6
f9b3f94
e52ad65
f9b3f94
 
 
e0fb0e6
f9b3f94
e52ad65
f9b3f94
d29fd9e
e0fb0e6
d29fd9e
 
 
 
 
 
 
f9b3f94
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# handler.py

import io
import base64
import numpy as np
from PIL import Image
import torch
from transformers import SamModel, SamImageProcessor
from typing import Dict, List, Any
import torch.nn.functional as F

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class EndpointHandler():
    def __init__(self, path=""):
        """
        Called once at startup. 
        Load the SAM model using Hugging Face Transformers.
        """
        try:
            # Load the model and processor from the local path
            self.model = SamModel.from_pretrained(path).to(device)
            self.processor = SamImageProcessor.from_pretrained(path, do_resize=False)
        except Exception as e:
            # Fallback to loading from a known SAM model if local loading fails
            print("Failed to load from local path: {}".format(e))
            print("Attempting to load from facebook/sam-vit-base")
            self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
            self.processor = SamImageProcessor.from_pretrained("facebook/sam-vit-base", do_resize=False)

    def __call__(self, data):
        """
        Called on every HTTP request.
        Handles both base64-encoded images and PIL images.
        Returns a PIL Image object.
        """
        # 1. Parse and decode the input image
        inputs = data.pop("inputs", None)
        if inputs is None:
            raise ValueError("Missing 'inputs' key in the payload.")

        # Check the type of inputs to handle both base64 strings and pre-processed PIL Images
        if isinstance(inputs, Image.Image):
            img = inputs.convert("RGB")
        elif isinstance(inputs, str):
            if inputs.startswith("data:"):
                inputs = inputs.split(",", 1)[1]
            image_bytes = base64.b64decode(inputs)
            img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        else:
            raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
        
        # 2. Prepare prompts and process the image
        # height, width = img.size[1], img.size[0]
        # input_points = [[[width // 2, height // 2]]]
        # input_labels = [[1]]
        
        inputs = self.processor(img, return_tensors="pt").to(device)
        
        # 3. Generate masks
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        # 4. Process and select the best mask
        try:
            # Get predicted masks and scores
            predicted_masks = outputs.pred_masks.cpu()
            iou_scores = outputs.iou_scores.cpu()[0]
            
            # Handle different tensor dimensions
            if predicted_masks.ndim == 5:
                predicted_masks = predicted_masks.squeeze(1)
            
            # Select the best mask
            best_mask_idx = torch.argmax(iou_scores)
            best_mask = predicted_masks[0, best_mask_idx, :, :]
            
            # Convert to binary mask (no resizing needed since processor doesn't resize)
            mask_binary = (best_mask > 0.0).numpy().astype(np.uint8) * 255
        
        except Exception as e:
            print("Error processing masks: {}".format(e))
            # Fallback: create a simple mask
            height, width = img.size[1], img.size[0]
            mask_binary = np.zeros((height, width), dtype=np.uint8)
            center_x, center_y = width // 2, height // 2
            size = min(width, height) // 8
            mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
        
        # 5. Create and return the output PIL Image
        output_img = Image.fromarray(mask_binary)
        return [{'score': None, 'label': 'everything', 'mask': output_img}]

def main():
    # This main function shows how a client would call the endpoint locally.
    input_path = "/Users/rp7/Downloads/test.jpeg"
    output_path = "output.png"

    # 1. Prepare the payload with a base64-encoded image string
    with open(input_path, "rb") as f:
        img_bytes = f.read()
    img_b64 = base64.b64encode(img_bytes).decode("utf-8")
    payload = {"inputs": "data:image/jpeg;base64,{}".format(img_b64)}

    # 2. Instantiate handler and get the PIL Image result
    handler = EndpointHandler(path=".")
    result = handler(payload)

    # 3. Extract the image from the result and save it
    if result and isinstance(result, list) and 'mask' in result[0]:
        result_img = result[0]['mask']
        result_img.save(output_path)
        print("Wrote mask to {}".format(output_path))
    else:
        print("Failed to get a valid mask from the handler.")

if __name__ == "__main__":
    main()