aimedlab-pulse-hf / handler.py
ubden's picture
Update handler.py (#5)
15837c2 verified
# -*- coding: utf-8 -*-
# handler.py — Rapid_ECG / PULSE-7B — Startup-load, Stabil ve DEBUG'li sürüm
# - Sunucu açılır açılmaz model yüklenir (cold start only once)
# - HF Endpoint sözleşmesi (EndpointHandler.load().__call__)
# - Yerel (HF_MODEL_DIR) → Hub (HF_MODEL_ID) yükleme sırası
# - Görsel sadece .preprocess() ile işlenir (process_images yok)
# - Vision tower kontrolü: mm_vision_tower veya vision_tower
# - IMAGE_TOKEN_INDEX kullanımı ve kapsamlı [DEBUG] logları
import os
import io
import sys
import base64
import subprocess
from typing import Any, Dict, Optional
import torch
from PIL import Image
import requests
# ===== LLaVA kütüphanesini garantiye al =====
def _ensure_llava(tag: str = "v1.2.0"):
try:
import llava # noqa
print("[DEBUG] LLaVA already available.")
return
except ImportError:
print(f"[DEBUG] LLaVA not found; installing (tag={tag}) ...")
subprocess.check_call([
sys.executable, "-m", "pip", "install",
f"git+https://github.com/haotian-liu/LLaVA@{tag}#egg=llava"
])
print("[DEBUG] LLaVA installed.")
_ensure_llava("v1.2.0")
# ===== LLaVA importları =====
from llava.conversation import conv_templates
from llava.constants import (
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
IMAGE_TOKEN_INDEX,
)
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path
# ---------- yardımcılar ----------
def _get_env(name: str, default: Optional[str] = None) -> Optional[str]:
v = os.getenv(name)
return v if v not in (None, "") else default
def _pick_device() -> torch.device:
if torch.cuda.is_available():
dev = torch.device("cuda")
elif torch.backends.mps.is_available():
dev = torch.device("mps")
else:
dev = torch.device("cpu")
print(f"[DEBUG] pick_device -> {dev}")
return dev
def _pick_dtype(device: torch.device):
if device.type == "cuda":
dt = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
dt = torch.float32
print(f"[DEBUG] pick_dtype({device}) -> {dt}")
return dt
def _is_probably_base64(s: str) -> bool:
s = s.strip()
if s.startswith("data:image"):
return True
allowed = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=\n\r")
return len(s) % 4 == 0 and all(c in allowed for c in s)
def _load_image_from_any(image_input: Any) -> Image.Image:
print(f"[DEBUG] _load_image_from_any type={type(image_input)}")
if isinstance(image_input, Image.Image):
return image_input.convert("RGB")
if isinstance(image_input, (bytes, bytearray)):
return Image.open(io.BytesIO(image_input)).convert("RGB")
if hasattr(image_input, "read"):
return Image.open(image_input).convert("RGB")
if isinstance(image_input, str):
s = image_input.strip()
if s.startswith("data:image"):
try:
_, b64 = s.split(",", 1)
data = base64.b64decode(b64)
return Image.open(io.BytesIO(data)).convert("RGB")
except Exception as e:
raise ValueError(f"Bad data URL: {e}")
if _is_probably_base64(s) and not s.startswith(("http://", "https://")):
try:
data = base64.b64decode(s)
return Image.open(io.BytesIO(data)).convert("RGB")
except Exception as e:
raise ValueError(f"Bad base64 image: {e}")
if s.startswith(("http://", "https://")):
resp = requests.get(s, timeout=20)
resp.raise_for_status()
return Image.open(io.BytesIO(resp.content)).convert("RGB")
# local path
return Image.open(s).convert("RGB")
raise ValueError(f"Unsupported image input type: {type(image_input)}")
def _get_conv_mode(model_name: str) -> str:
name = (model_name or "").lower()
if "llama-2" in name:
return "llava_llama_2"
if "mistral" in name:
return "mistral_instruct"
if "v1.6-34b" in name:
return "chatml_direct"
if "v1" in name or "pulse" in name:
return "llava_v1"
if "mpt" in name:
return "mpt"
return "llava_v0"
def _build_prompt_with_image(prompt: str, model_cfg) -> str:
# Kullanıcı image token eklediyse yeniden eklemeyelim
if DEFAULT_IMAGE_TOKEN in prompt or DEFAULT_IM_START_TOKEN in prompt:
return prompt
if getattr(model_cfg, "mm_use_im_start_end", False):
token = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
return f"{token}\n{prompt}"
return f"{DEFAULT_IMAGE_TOKEN}\n{prompt}"
def _resolve_model_path(model_dir_hint: Optional[str], default_dir: str = "/repository") -> str:
# Öncelik: HF_MODEL_DIR (yerel) -> ctor'dan gelen model_dir_hint -> default_dir
p = _get_env("HF_MODEL_DIR") or model_dir_hint or default_dir
p = os.path.abspath(p)
print(f"[DEBUG] resolved model path: {p}")
return p
# ---------- Endpoint Handler ----------
class EndpointHandler:
def __init__(self, model_dir: Optional[str] = None):
# DEBUG banner
print("🚀 Starting up PULSE-7B handler (startup load)...")
print("📝 Enhanced by Ubden® Team")
print(f"🔧 Python: {sys.version}")
print(f"🔧 PyTorch: {torch.__version__}")
try:
import transformers
print(f"🔧 Transformers: {transformers.__version__}")
except Exception as e:
print(f"[DEBUG] transformers import failed: {e}")
self.model_dir = model_dir
self.device = _pick_device()
self.dtype = _pick_dtype(self.device)
# Ortam ayarları (flash attn ipucu, zarar vermez)
os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
os.environ.setdefault("FLASH_ATTENTION", "1")
print(f"[DEBUG] ATTN_IMPLEMENTATION={os.getenv('ATTN_IMPLEMENTATION')} FLASH_ATTENTION={os.getenv('FLASH_ATTENTION')}")
# Model/Tokenizer/ImageProcessor konteynerleri
self.model = None
self.tokenizer = None
self.image_processor = None
self.context_len = None
self.model_name = None
# ---- Modeli burada (startup’ta) yükle ----
try:
self._startup_load_model()
print("✅ Model loaded & ready in __init__")
except Exception as e:
print(f"💥 CRITICAL: model startup load failed: {e}")
raise
def _startup_load_model(self):
# Yerel dizin varsa onu kullan, yoksa hub
local_path = _resolve_model_path(self.model_dir)
use_local = os.path.isdir(local_path) and any(
os.path.exists(os.path.join(local_path, f))
for f in ("config.json", "tokenizer_config.json")
)
model_base = _get_env("HF_MODEL_BASE", None)
if use_local:
model_path = local_path
print(f"[DEBUG] loading model LOCALLY from: {model_path}")
else:
model_path = _get_env("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
print(f"[DEBUG] loading model from HUB: {model_path} (HF_MODEL_BASE={model_base})")
# ⬇️ FIX: LLaVA v1.2.0 imzası model_name parametresi istiyor
model_name = get_model_name_from_path(model_path)
print(f"[DEBUG] resolved model_name: {model_name}")
print("[DEBUG] calling load_pretrained_model ...")
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path=model_path,
model_base=model_base,
model_name=model_name, # <-- gerekli parametre
load_8bit=False,
load_4bit=False,
device_map="auto",
device=self.device,
)
self.model_name = getattr(self.model.config, "name_or_path", str(model_path))
print(f"[DEBUG] model loaded: name={self.model_name}")
# Vision tower kontrolü (yeni/eskı alan adları)
vt = (
getattr(self.model.config, "mm_vision_tower", None)
or getattr(self.model.config, "vision_tower", None)
)
print(f"[DEBUG] vision tower: {vt}")
if self.image_processor is None or vt is None:
raise RuntimeError(
"[ERROR] Vision tower not loaded (mm_vision_tower/vision_tower None). "
"Yerel yükleme için HF_MODEL_DIR doğru klasörü göstermeli; "
"Hub için HF_MODEL_ID PULSE/LLaVA tabanlı olmalı (örn: 'PULSE-ECG/PULSE-7B')."
)
# Tokenizer güvenliği
try:
self.tokenizer.padding_side = "left"
if getattr(self.tokenizer, "pad_token_id", None) is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
except Exception as e:
print(f"[DEBUG] tokenizer safety patch failed: {e}")
self.model.eval()
# HF inference toolkit load() yine çağıracağı için no-op
def load(self):
print("[DEBUG] load(): model is already initialized in __init__")
return True
@torch.inference_mode()
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
print(f"[DEBUG] __call__ inputs keys={list(inputs.keys()) if hasattr(inputs,'keys') else 'N/A'}")
# HF {"inputs": {...}} sarmasını aç
if "inputs" in inputs and isinstance(inputs["inputs"], dict):
inputs = inputs["inputs"]
prompt = inputs.get("query") or inputs.get("prompt") or inputs.get("istem") or ""
image_in = inputs.get("image") or inputs.get("image_url") or inputs.get("img")
if not image_in:
return {"error": "Missing 'image' in payload"}
if not isinstance(prompt, str) or not prompt.strip():
return {"error": "Missing 'query'/'prompt' text"}
# Üretim parametreleri
temperature = float(inputs.get("temperature", 0.0))
top_p = float(inputs.get("top_p", 0.9))
max_new = int(inputs.get("max_new_tokens", inputs.get("max_tokens", 512)))
repetition_penalty = float(inputs.get("repetition_penalty", 1.0))
conv_mode_override = inputs.get("conv_mode") or _get_env("CONV_MODE", None)
# ---- Görsel yükle + preprocess
try:
image = _load_image_from_any(image_in)
print(f"[DEBUG] loaded image size={image.size}")
except Exception as e:
return {"error": f"Failed to load image: {e}"}
if self.image_processor is None:
return {"error": "image_processor is None; model not initialized properly (no vision tower)"}
try:
out = self.image_processor.preprocess(image, return_tensors="pt")
images_tensor = out["pixel_values"].to(self.device, dtype=self.dtype)
image_sizes = [image.size]
print(f"[DEBUG] preprocess OK; images_tensor.shape={images_tensor.shape}")
except Exception as e:
return {"error": f"Image preprocessing failed: {e}"}
# ---- Konuşma + prompt
mode = conv_mode_override or _get_conv_mode(self.model_name)
conv = (conv_templates.get(mode) or conv_templates[list(conv_templates.keys())[0]]).copy()
conv.append_message(conv.roles[0], _build_prompt_with_image(prompt.strip(), self.model.config))
conv.append_message(conv.roles[1], None)
full_prompt = conv.get_prompt()
print(f"[DEBUG] conv_mode={mode}; full_prompt_len={len(full_prompt)}")
# ---- Tokenization (IMAGE_TOKEN_INDEX ile)
try:
input_ids = tokenizer_image_token(
full_prompt, self.tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors="pt"
).unsqueeze(0).to(self.device)
print(f"[DEBUG] tokenizer_image_token OK; input_ids.shape={input_ids.shape}")
except Exception as e:
print(f"[DEBUG] tokenizer_image_token failed: {e}; fallback to plain tokenizer")
try:
toks = self.tokenizer([full_prompt], return_tensors="pt", padding=True, truncation=True)
input_ids = toks["input_ids"].to(self.device)
print(f"[DEBUG] plain tokenizer OK; input_ids.shape={input_ids.shape}")
except Exception as e2:
return {"error": f"Tokenization failed: {e} / {e2}"}
attention_mask = torch.ones_like(input_ids, device=self.device)
# ---- Generate
try:
print(f"[DEBUG] generate(max_new_tokens={max_new}, temp={temperature}, top_p={top_p}, rep={repetition_penalty})")
gen_ids = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
images=images_tensor,
image_sizes=image_sizes,
do_sample=(temperature > 0),
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new,
repetition_penalty=repetition_penalty,
use_cache=True,
)
print(f"[DEBUG] generate OK; gen_ids.shape={gen_ids.shape}")
except Exception as e:
return {"error": f"Generation failed: {e}"}
# ---- Decode (sadece yeni tokenlar)
try:
new_tokens = gen_ids[0, input_ids.shape[1]:]
text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
print(f"[DEBUG] decoded_text_len={len(text)}")
except Exception as e:
return {"error": f"Decode failed: {e}"}
return {
"generated_text": text,
"model": self.model_name,
"conv_mode": mode,
}