CodeJackR commited on
Commit
d05bd8d
·
1 Parent(s): 2f7cfdc

Fix errors

Browse files
Files changed (1) hide show
  1. handler.py +46 -12
handler.py CHANGED
@@ -1,41 +1,75 @@
1
  # handler.py
2
 
3
  import io
 
4
  import numpy as np
5
  from PIL import Image
6
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
7
- from huggingface_inference_toolkit.handler import BaseHandler
8
 
9
- class EndpointHandler(BaseHandler):
10
- def __init__(self, model_dir):
11
  """
12
  Called once at startup.
13
- The model files are mounted under /mnt/models by the Inference Endpoint.
14
  """
15
- super().__init__(model_dir)
16
- checkpoint = "/mnt/models/pytorch_model.bin"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
18
  self.mask_generator = SamAutomaticMaskGenerator(sam)
19
 
20
- async def __call__(self, request):
21
  """
22
  Called on every HTTP request.
23
- Expecting multipart/form-data with an 'image' field.
24
  """
25
- form = await request.form()
26
- image_bytes = form["image"].file.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
 
28
  img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
29
  img_np = np.array(img)
30
 
 
31
  masks = self.mask_generator.generate(img_np)
32
  combined = np.zeros(img_np.shape[:2], dtype=np.uint8)
33
  for m in masks:
34
  combined[m["segmentation"]] = 255
35
 
 
36
  out = io.BytesIO()
37
  Image.fromarray(combined).save(out, format="PNG")
38
  out.seek(0)
 
39
 
40
- # Return a JSON-able dict; binary data will be base64-encoded by the toolkit
41
- return {"mask_png": out.getvalue()}
 
1
  # handler.py
2
 
3
  import io
4
+ import base64
5
  import numpy as np
6
  from PIL import Image
7
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
8
+ from typing import Dict, List, Any
9
 
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
  """
13
  Called once at startup.
14
+ The model files are mounted under /opt/ml/model by default in Inference Endpoints.
15
  """
16
+ # Try different possible checkpoint paths
17
+ import os
18
+ possible_paths = [
19
+ os.path.join(path, "pytorch_model.bin"),
20
+ os.path.join(path, "model.safetensors"),
21
+ "/opt/ml/model/pytorch_model.bin",
22
+ "/opt/ml/model/model.safetensors"
23
+ ]
24
+
25
+ checkpoint = None
26
+ for p in possible_paths:
27
+ if os.path.exists(p):
28
+ checkpoint = p
29
+ break
30
+
31
+ if checkpoint is None:
32
+ raise FileNotFoundError("Could not find model checkpoint in any of the expected locations")
33
+
34
  sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
35
  self.mask_generator = SamAutomaticMaskGenerator(sam)
36
 
37
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
38
  """
39
  Called on every HTTP request.
40
+ Expecting base64 encoded image in the 'inputs' field or 'image' field.
41
  """
42
+ # Handle different input formats
43
+ if "inputs" in data:
44
+ if isinstance(data["inputs"], str):
45
+ # Base64 encoded image
46
+ image_bytes = base64.b64decode(data["inputs"])
47
+ elif isinstance(data["inputs"], dict) and "image" in data["inputs"]:
48
+ # Nested structure with image field
49
+ image_bytes = base64.b64decode(data["inputs"]["image"])
50
+ else:
51
+ raise ValueError("Invalid input format. Expected base64 encoded image string.")
52
+ elif "image" in data:
53
+ # Direct image field
54
+ image_bytes = base64.b64decode(data["image"])
55
+ else:
56
+ raise ValueError("No image found in request. Expected 'inputs' or 'image' field with base64 encoded image.")
57
 
58
+ # Process the image
59
  img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
60
  img_np = np.array(img)
61
 
62
+ # Generate masks
63
  masks = self.mask_generator.generate(img_np)
64
  combined = np.zeros(img_np.shape[:2], dtype=np.uint8)
65
  for m in masks:
66
  combined[m["segmentation"]] = 255
67
 
68
+ # Convert result to base64
69
  out = io.BytesIO()
70
  Image.fromarray(combined).save(out, format="PNG")
71
  out.seek(0)
72
+ mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
73
 
74
+ # Return in the expected format
75
+ return [{"mask_png_base64": mask_base64, "num_masks": len(masks)}]