File size: 9,476 Bytes
592adee
2f7cfdc
592adee
d05bd8d
592adee
 
d816a26
69e7d30
d05bd8d
f9b3f94
592adee
c78d04e
 
 
d05bd8d
 
2f7cfdc
 
d816a26
2f7cfdc
d816a26
 
69e7d30
 
 
 
 
 
 
 
d816a26
 
69e7d30
d816a26
69e7d30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592adee
e52ad65
2f7cfdc
 
2f4ef92
69e7d30
2f7cfdc
e0fb0e6
2f4ef92
 
 
c78d04e
2f4ef92
 
 
 
 
e52ad65
2f4ef92
 
 
 
16a5f8c
69e7d30
 
 
 
 
16a5f8c
69e7d30
 
 
 
 
 
 
d816a26
69e7d30
 
 
d816a26
f9b3f94
69e7d30
 
 
233c56f
69e7d30
 
38a30a4
69e7d30
 
40b9d26
69e7d30
 
 
 
 
 
 
 
 
 
38a30a4
69e7d30
 
 
 
 
 
f9b3f94
69e7d30
 
16a5f8c
 
 
69e7d30
 
 
 
 
 
d816a26
69e7d30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9b3f94
 
e52ad65
f9b3f94
69e7d30
 
 
 
 
f9b3f94
e52ad65
f9b3f94
 
 
69e7d30
f9b3f94
69e7d30
f9b3f94
69e7d30
e0fb0e6
69e7d30
 
 
 
 
 
 
 
d29fd9e
69e7d30
f9b3f94
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# handler.py

import io
import base64
import numpy as np
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from typing import Dict, List, Any
import torch.nn.functional as F

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class EndpointHandler():
    def __init__(self, path=""):
        """
        Called once at startup. 
        Load the SAM model using Hugging Face Transformers.
        """
        try:
            # Load the model and processor from the local path
            self.model = SamModel.from_pretrained(path).to(device).eval()
            # Load processor with do_resize=False to avoid resizing
            self.processor = SamProcessor.from_pretrained(path)
            # Override the processor's image processor to disable resizing
            self.processor.image_processor.do_resize = False
            self.processor.image_processor.do_rescale = True
            self.processor.image_processor.do_normalize = True
            
        except Exception as e:
            # Fallback to loading from a known SAM model if local loading fails
            print(f"Failed to load from local path: {e}")
            print("Attempting to load from facebook/sam-vit-base")
            self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device).eval()
            self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
            # Override the processor's image processor to disable resizing
            self.processor.image_processor.do_resize = False
            self.processor.image_processor.do_rescale = True
            self.processor.image_processor.do_normalize = True

    def generate_grid_points(self, width, height, points_per_side=32):
        """Generate a grid of points across the image for comprehensive segmentation."""
        points = []
        labels = []
        
        # Create a grid of points
        x_coords = np.linspace(0, width - 1, points_per_side, dtype=int)
        y_coords = np.linspace(0, height - 1, points_per_side, dtype=int)
        
        for y in y_coords:
            for x in x_coords:
                points.append([x, y])
                labels.append(1)  # foreground point
        
        return [points], [labels]

    def filter_masks(self, masks, iou_scores, score_threshold=0.88, stability_score_threshold=0.95):
        """Filter masks based on quality scores and remove duplicates."""
        filtered_masks = []
        filtered_scores = []
        
        for i, (mask, score) in enumerate(zip(masks, iou_scores)):
            if score > score_threshold:
                # Calculate stability score (measure of mask quality)
                mask_binary = mask > 0.0
                stability_score = self.calculate_stability_score(mask_binary)
                
                if stability_score > stability_score_threshold:
                    filtered_masks.append(mask)
                    filtered_scores.append(score.item())
        
        return filtered_masks, filtered_scores

    def calculate_stability_score(self, mask):
        """Calculate stability score for a mask."""
        # Simple stability score based on mask coherence
        mask_float = mask.float()
        # Calculate the ratio of the mask area to its bounding box area
        mask_area = torch.sum(mask_float)
        if mask_area == 0:
            return 0.0
        
        # Find bounding box
        coords = torch.nonzero(mask_float)
        if len(coords) == 0:
            return 0.0
            
        min_y, min_x = torch.min(coords, dim=0)[0]
        max_y, max_x = torch.max(coords, dim=0)[0]
        bbox_area = (max_y - min_y + 1) * (max_x - min_x + 1)
        
        stability = mask_area / bbox_area if bbox_area > 0 else 0.0
        return stability.item()

    def __call__(self, data):
        """
        Called on every HTTP request.
        Handles both base64-encoded images and PIL images.
        Returns a list of segment masks.
        """
        # 1. Parse and decode the input image
        inputs = data.pop("inputs", None)
        if inputs is None:
            raise ValueError("Missing 'inputs' key in the payload.")

        # Check the type of inputs to handle both base64 strings and pre-processed PIL Images
        if isinstance(inputs, Image.Image):
            img = inputs.convert("RGB")
        elif isinstance(inputs, str):
            if inputs.startswith("data:"):
                inputs = inputs.split(",", 1)[1]
            image_bytes = base64.b64decode(inputs)
            img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        else:
            raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
        
        # 2. Get image dimensions
        width, height = img.size
        
        # 3. Generate grid points for comprehensive segmentation
        input_points, input_labels = self.generate_grid_points(width, height, points_per_side=16)
        
        # 4. Process the image and points
        inputs = self.processor(
            img, 
            input_points=input_points, 
            input_labels=input_labels, 
            return_tensors="pt"
        ).to(device)
        
        # 5. Generate masks
        all_masks = []
        all_scores = []
        
        try:
            with torch.no_grad():
                outputs = self.model(**inputs)
            
            # Get predicted masks and scores
            predicted_masks = outputs.pred_masks.cpu()  # Shape: [batch, num_queries, num_masks_per_query, H, W]
            iou_scores = outputs.iou_scores.cpu()       # Shape: [batch, num_queries, num_masks_per_query]
            
            # Process masks from all queries
            batch_size, num_queries, num_masks_per_query = predicted_masks.shape[:3]
            
            for query_idx in range(num_queries):
                query_masks = predicted_masks[0, query_idx]  # [num_masks_per_query, H, W]
                query_scores = iou_scores[0, query_idx]      # [num_masks_per_query]
                
                # Select best mask for this query
                best_mask_idx = torch.argmax(query_scores)
                if query_scores[best_mask_idx] > 0.5:  # Only keep high-quality masks
                    best_mask = query_masks[best_mask_idx]
                    all_masks.append(best_mask)
                    all_scores.append(query_scores[best_mask_idx])
            
            # Filter and deduplicate masks
            if all_masks:
                filtered_masks, filtered_scores = self.filter_masks(all_masks, all_scores)
            else:
                filtered_masks, filtered_scores = [], []
                
        except Exception as e:
            print(f"Error processing masks: {e}")
            # Fallback: create a simple center mask
            mask_binary = np.zeros((height, width), dtype=np.uint8)
            center_x, center_y = width // 2, height // 2
            size = min(width, height) // 8
            y_start, y_end = max(0, center_y-size), min(height, center_y+size)
            x_start, x_end = max(0, center_x-size), min(width, center_x+size)
            mask_binary[y_start:y_end, x_start:x_end] = 255
            
            output_img = Image.fromarray(mask_binary)
            return [{'score': 0.5, 'label': 'fallback_segment', 'mask': output_img}]
        
        # 6. Convert masks to PIL Images and prepare results
        results = []
        for i, (mask, score) in enumerate(zip(filtered_masks, filtered_scores)):
            # Convert to binary mask
            mask_binary = (mask > 0.0).numpy().astype(np.uint8) * 255
            
            # Create PIL Image
            output_img = Image.fromarray(mask_binary)
            
            results.append({
                'score': float(score),
                'label': f'segment_{i}',
                'mask': output_img
            })
        
        # If no segments found, return a fallback
        if not results:
            mask_binary = np.zeros((height, width), dtype=np.uint8)
            output_img = Image.fromarray(mask_binary)
            results.append({'score': 0.0, 'label': 'no_segments', 'mask': output_img})
        
        return results

def main():
    # This main function shows how a client would call the endpoint locally.
    input_path = "/Users/rp7/Downloads/test.jpeg"
    output_dir = "output_masks"

    # Create output directory
    import os
    os.makedirs(output_dir, exist_ok=True)

    # 1. Prepare the payload with a base64-encoded image string
    with open(input_path, "rb") as f:
        img_bytes = f.read()
    img_b64 = base64.b64encode(img_bytes).decode("utf-8")
    payload = {"inputs": f"data:image/jpeg;base64,{img_b64}"}

    # 2. Instantiate handler and get the result
    handler = EndpointHandler(path=".")
    results = handler(payload)

    # 3. Save all masks
    if results and isinstance(results, list):
        print(f"Found {len(results)} segments")
        for i, result in enumerate(results):
            if 'mask' in result:
                output_path = os.path.join(output_dir, f"segment_{i}_score_{result['score']:.3f}.png")
                result['mask'].save(output_path)
                print(f"Saved {result['label']} (score: {result['score']:.3f}) to {output_path}")
    else:
        print("Failed to get valid masks from the handler.")

if __name__ == "__main__":
    main()