File size: 13,827 Bytes
15837c2
 
 
 
 
 
 
 
775dded
15837c2
 
 
775dded
15837c2
 
 
 
775dded
 
3a15aeb
 
15837c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a15aeb
15837c2
 
 
 
 
 
3504078
 
15837c2
3504078
15837c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a15aeb
15837c2
 
3a15aeb
15837c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a15aeb
15837c2
 
 
b13fdfd
15837c2
 
 
 
 
775dded
15837c2
 
 
 
 
 
 
 
 
 
 
 
 
 
775dded
15837c2
 
 
3a15aeb
15837c2
 
 
3a15aeb
15837c2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
# -*- coding: utf-8 -*-
# handler.py — Rapid_ECG / PULSE-7B — Startup-load, Stabil ve DEBUG'li sürüm
# - Sunucu açılır açılmaz model yüklenir (cold start only once)
# - HF Endpoint sözleşmesi (EndpointHandler.load().__call__)
# - Yerel (HF_MODEL_DIR) → Hub (HF_MODEL_ID) yükleme sırası
# - Görsel sadece .preprocess() ile işlenir (process_images yok)
# - Vision tower kontrolü: mm_vision_tower veya vision_tower
# - IMAGE_TOKEN_INDEX kullanımı ve kapsamlı [DEBUG] logları

import os
import io
import sys
import base64
import subprocess
from typing import Any, Dict, Optional

import torch
from PIL import Image
import requests


# ===== LLaVA kütüphanesini garantiye al =====
def _ensure_llava(tag: str = "v1.2.0"):
    try:
        import llava  # noqa
        print("[DEBUG] LLaVA already available.")
        return
    except ImportError:
        print(f"[DEBUG] LLaVA not found; installing (tag={tag}) ...")
        subprocess.check_call([
            sys.executable, "-m", "pip", "install",
            f"git+https://github.com/haotian-liu/LLaVA@{tag}#egg=llava"
        ])
        print("[DEBUG] LLaVA installed.")

_ensure_llava("v1.2.0")

# ===== LLaVA importları =====
from llava.conversation import conv_templates
from llava.constants import (
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_TOKEN_INDEX,
)
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path


# ---------- yardımcılar ----------
def _get_env(name: str, default: Optional[str] = None) -> Optional[str]:
    v = os.getenv(name)
    return v if v not in (None, "") else default

def _pick_device() -> torch.device:
    if torch.cuda.is_available():
        dev = torch.device("cuda")
    elif torch.backends.mps.is_available():
        dev = torch.device("mps")
    else:
        dev = torch.device("cpu")
    print(f"[DEBUG] pick_device -> {dev}")
    return dev

def _pick_dtype(device: torch.device):
    if device.type == "cuda":
        dt = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    else:
        dt = torch.float32
    print(f"[DEBUG] pick_dtype({device}) -> {dt}")
    return dt

def _is_probably_base64(s: str) -> bool:
    s = s.strip()
    if s.startswith("data:image"):
        return True
    allowed = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=\n\r")
    return len(s) % 4 == 0 and all(c in allowed for c in s)

def _load_image_from_any(image_input: Any) -> Image.Image:
    print(f"[DEBUG] _load_image_from_any type={type(image_input)}")
    if isinstance(image_input, Image.Image):
        return image_input.convert("RGB")
    if isinstance(image_input, (bytes, bytearray)):
        return Image.open(io.BytesIO(image_input)).convert("RGB")
    if hasattr(image_input, "read"):
        return Image.open(image_input).convert("RGB")
    if isinstance(image_input, str):
        s = image_input.strip()
        if s.startswith("data:image"):
            try:
                _, b64 = s.split(",", 1)
                data = base64.b64decode(b64)
                return Image.open(io.BytesIO(data)).convert("RGB")
            except Exception as e:
                raise ValueError(f"Bad data URL: {e}")
        if _is_probably_base64(s) and not s.startswith(("http://", "https://")):
            try:
                data = base64.b64decode(s)
                return Image.open(io.BytesIO(data)).convert("RGB")
            except Exception as e:
                raise ValueError(f"Bad base64 image: {e}")
        if s.startswith(("http://", "https://")):
            resp = requests.get(s, timeout=20)
            resp.raise_for_status()
            return Image.open(io.BytesIO(resp.content)).convert("RGB")
        # local path
        return Image.open(s).convert("RGB")
    raise ValueError(f"Unsupported image input type: {type(image_input)}")

def _get_conv_mode(model_name: str) -> str:
    name = (model_name or "").lower()
    if "llama-2" in name:
        return "llava_llama_2"
    if "mistral" in name:
        return "mistral_instruct"
    if "v1.6-34b" in name:
        return "chatml_direct"
    if "v1" in name or "pulse" in name:
        return "llava_v1"
    if "mpt" in name:
        return "mpt"
    return "llava_v0"

def _build_prompt_with_image(prompt: str, model_cfg) -> str:
    # Kullanıcı image token eklediyse yeniden eklemeyelim
    if DEFAULT_IMAGE_TOKEN in prompt or DEFAULT_IM_START_TOKEN in prompt:
        return prompt
    if getattr(model_cfg, "mm_use_im_start_end", False):
        token = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
        return f"{token}\n{prompt}"
    return f"{DEFAULT_IMAGE_TOKEN}\n{prompt}"

def _resolve_model_path(model_dir_hint: Optional[str], default_dir: str = "/repository") -> str:
    # Öncelik: HF_MODEL_DIR (yerel) -> ctor'dan gelen model_dir_hint -> default_dir
    p = _get_env("HF_MODEL_DIR") or model_dir_hint or default_dir
    p = os.path.abspath(p)
    print(f"[DEBUG] resolved model path: {p}")
    return p


# ---------- Endpoint Handler ----------
class EndpointHandler:
    def __init__(self, model_dir: Optional[str] = None):
        # DEBUG banner
        print("🚀 Starting up PULSE-7B handler (startup load)...")
        print("📝 Enhanced by Ubden® Team")
        print(f"🔧 Python: {sys.version}")
        print(f"🔧 PyTorch: {torch.__version__}")
        try:
            import transformers
            print(f"🔧 Transformers: {transformers.__version__}")
        except Exception as e:
            print(f"[DEBUG] transformers import failed: {e}")

        self.model_dir = model_dir
        self.device = _pick_device()
        self.dtype = _pick_dtype(self.device)

        # Ortam ayarları (flash attn ipucu, zarar vermez)
        os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
        os.environ.setdefault("FLASH_ATTENTION", "1")
        print(f"[DEBUG] ATTN_IMPLEMENTATION={os.getenv('ATTN_IMPLEMENTATION')} FLASH_ATTENTION={os.getenv('FLASH_ATTENTION')}")

        # Model/Tokenizer/ImageProcessor konteynerleri
        self.model = None
        self.tokenizer = None
        self.image_processor = None
        self.context_len = None
        self.model_name = None

        # ---- Modeli burada (startup’ta) yükle ----
        try:
            self._startup_load_model()
            print("✅ Model loaded & ready in __init__")
        except Exception as e:
            print(f"💥 CRITICAL: model startup load failed: {e}")
            raise

    def _startup_load_model(self):
        # Yerel dizin varsa onu kullan, yoksa hub
        local_path = _resolve_model_path(self.model_dir)
        use_local = os.path.isdir(local_path) and any(
            os.path.exists(os.path.join(local_path, f))
            for f in ("config.json", "tokenizer_config.json")
        )
        model_base = _get_env("HF_MODEL_BASE", None)

        if use_local:
            model_path = local_path
            print(f"[DEBUG] loading model LOCALLY from: {model_path}")
        else:
            model_path = _get_env("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
            print(f"[DEBUG] loading model from HUB: {model_path} (HF_MODEL_BASE={model_base})")

        # ⬇️ FIX: LLaVA v1.2.0 imzası model_name parametresi istiyor
        model_name = get_model_name_from_path(model_path)
        print(f"[DEBUG] resolved model_name: {model_name}")

        print("[DEBUG] calling load_pretrained_model ...")
        self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
            model_path=model_path,
            model_base=model_base,
            model_name=model_name,         # <-- gerekli parametre
            load_8bit=False,
            load_4bit=False,
            device_map="auto",
            device=self.device,
        )
        self.model_name = getattr(self.model.config, "name_or_path", str(model_path))
        print(f"[DEBUG] model loaded: name={self.model_name}")

        # Vision tower kontrolü (yeni/eskı alan adları)
        vt = (
            getattr(self.model.config, "mm_vision_tower", None)
            or getattr(self.model.config, "vision_tower", None)
        )
        print(f"[DEBUG] vision tower: {vt}")
        if self.image_processor is None or vt is None:
            raise RuntimeError(
                "[ERROR] Vision tower not loaded (mm_vision_tower/vision_tower None). "
                "Yerel yükleme için HF_MODEL_DIR doğru klasörü göstermeli; "
                "Hub için HF_MODEL_ID PULSE/LLaVA tabanlı olmalı (örn: 'PULSE-ECG/PULSE-7B')."
            )

        # Tokenizer güvenliği
        try:
            self.tokenizer.padding_side = "left"
            if getattr(self.tokenizer, "pad_token_id", None) is None:
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        except Exception as e:
            print(f"[DEBUG] tokenizer safety patch failed: {e}")

        self.model.eval()

    # HF inference toolkit load() yine çağıracağı için no-op
    def load(self):
        print("[DEBUG] load(): model is already initialized in __init__")
        return True

    @torch.inference_mode()
    def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        print(f"[DEBUG] __call__ inputs keys={list(inputs.keys()) if hasattr(inputs,'keys') else 'N/A'}")
        # HF {"inputs": {...}} sarmasını aç
        if "inputs" in inputs and isinstance(inputs["inputs"], dict):
            inputs = inputs["inputs"]

        prompt = inputs.get("query") or inputs.get("prompt") or inputs.get("istem") or ""
        image_in = inputs.get("image") or inputs.get("image_url") or inputs.get("img")
        if not image_in:
            return {"error": "Missing 'image' in payload"}
        if not isinstance(prompt, str) or not prompt.strip():
            return {"error": "Missing 'query'/'prompt' text"}

        # Üretim parametreleri
        temperature = float(inputs.get("temperature", 0.0))
        top_p = float(inputs.get("top_p", 0.9))
        max_new = int(inputs.get("max_new_tokens", inputs.get("max_tokens", 512)))
        repetition_penalty = float(inputs.get("repetition_penalty", 1.0))
        conv_mode_override = inputs.get("conv_mode") or _get_env("CONV_MODE", None)

        # ---- Görsel yükle + preprocess
        try:
            image = _load_image_from_any(image_in)
            print(f"[DEBUG] loaded image size={image.size}")
        except Exception as e:
            return {"error": f"Failed to load image: {e}"}

        if self.image_processor is None:
            return {"error": "image_processor is None; model not initialized properly (no vision tower)"}

        try:
            out = self.image_processor.preprocess(image, return_tensors="pt")
            images_tensor = out["pixel_values"].to(self.device, dtype=self.dtype)
            image_sizes = [image.size]
            print(f"[DEBUG] preprocess OK; images_tensor.shape={images_tensor.shape}")
        except Exception as e:
            return {"error": f"Image preprocessing failed: {e}"}

        # ---- Konuşma + prompt
        mode = conv_mode_override or _get_conv_mode(self.model_name)
        conv = (conv_templates.get(mode) or conv_templates[list(conv_templates.keys())[0]]).copy()
        conv.append_message(conv.roles[0], _build_prompt_with_image(prompt.strip(), self.model.config))
        conv.append_message(conv.roles[1], None)
        full_prompt = conv.get_prompt()
        print(f"[DEBUG] conv_mode={mode}; full_prompt_len={len(full_prompt)}")

        # ---- Tokenization (IMAGE_TOKEN_INDEX ile)
        try:
            input_ids = tokenizer_image_token(
                full_prompt, self.tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors="pt"
            ).unsqueeze(0).to(self.device)
            print(f"[DEBUG] tokenizer_image_token OK; input_ids.shape={input_ids.shape}")
        except Exception as e:
            print(f"[DEBUG] tokenizer_image_token failed: {e}; fallback to plain tokenizer")
            try:
                toks = self.tokenizer([full_prompt], return_tensors="pt", padding=True, truncation=True)
                input_ids = toks["input_ids"].to(self.device)
                print(f"[DEBUG] plain tokenizer OK; input_ids.shape={input_ids.shape}")
            except Exception as e2:
                return {"error": f"Tokenization failed: {e} / {e2}"}

        attention_mask = torch.ones_like(input_ids, device=self.device)

        # ---- Generate
        try:
            print(f"[DEBUG] generate(max_new_tokens={max_new}, temp={temperature}, top_p={top_p}, rep={repetition_penalty})")
            gen_ids = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                images=images_tensor,
                image_sizes=image_sizes,
                do_sample=(temperature > 0),
                temperature=temperature,
                top_p=top_p,
                max_new_tokens=max_new,
                repetition_penalty=repetition_penalty,
                use_cache=True,
            )
            print(f"[DEBUG] generate OK; gen_ids.shape={gen_ids.shape}")
        except Exception as e:
            return {"error": f"Generation failed: {e}"}

        # ---- Decode (sadece yeni tokenlar)
        try:
            new_tokens = gen_ids[0, input_ids.shape[1]:]
            text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
            print(f"[DEBUG] decoded_text_len={len(text)}")
        except Exception as e:
            return {"error": f"Decode failed: {e}"}

        return {
            "generated_text": text,
            "model": self.model_name,
            "conv_mode": mode,
        }