|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import io |
|
import sys |
|
import base64 |
|
import subprocess |
|
from typing import Any, Dict, Optional |
|
|
|
import torch |
|
from PIL import Image |
|
import requests |
|
|
|
|
|
|
|
def _ensure_llava(tag: str = "v1.2.0"): |
|
try: |
|
import llava |
|
print("[DEBUG] LLaVA already available.") |
|
return |
|
except ImportError: |
|
print(f"[DEBUG] LLaVA not found; installing (tag={tag}) ...") |
|
subprocess.check_call([ |
|
sys.executable, "-m", "pip", "install", |
|
f"git+https://github.com/haotian-liu/LLaVA@{tag}#egg=llava" |
|
]) |
|
print("[DEBUG] LLaVA installed.") |
|
|
|
_ensure_llava("v1.2.0") |
|
|
|
|
|
from llava.conversation import conv_templates |
|
from llava.constants import ( |
|
DEFAULT_IMAGE_TOKEN, |
|
DEFAULT_IM_START_TOKEN, |
|
DEFAULT_IM_END_TOKEN, |
|
IMAGE_TOKEN_INDEX, |
|
) |
|
from llava.model.builder import load_pretrained_model |
|
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path |
|
|
|
|
|
|
|
def _get_env(name: str, default: Optional[str] = None) -> Optional[str]: |
|
v = os.getenv(name) |
|
return v if v not in (None, "") else default |
|
|
|
def _pick_device() -> torch.device: |
|
if torch.cuda.is_available(): |
|
dev = torch.device("cuda") |
|
elif torch.backends.mps.is_available(): |
|
dev = torch.device("mps") |
|
else: |
|
dev = torch.device("cpu") |
|
print(f"[DEBUG] pick_device -> {dev}") |
|
return dev |
|
|
|
def _pick_dtype(device: torch.device): |
|
if device.type == "cuda": |
|
dt = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
|
else: |
|
dt = torch.float32 |
|
print(f"[DEBUG] pick_dtype({device}) -> {dt}") |
|
return dt |
|
|
|
def _is_probably_base64(s: str) -> bool: |
|
s = s.strip() |
|
if s.startswith("data:image"): |
|
return True |
|
allowed = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=\n\r") |
|
return len(s) % 4 == 0 and all(c in allowed for c in s) |
|
|
|
def _load_image_from_any(image_input: Any) -> Image.Image: |
|
print(f"[DEBUG] _load_image_from_any type={type(image_input)}") |
|
if isinstance(image_input, Image.Image): |
|
return image_input.convert("RGB") |
|
if isinstance(image_input, (bytes, bytearray)): |
|
return Image.open(io.BytesIO(image_input)).convert("RGB") |
|
if hasattr(image_input, "read"): |
|
return Image.open(image_input).convert("RGB") |
|
if isinstance(image_input, str): |
|
s = image_input.strip() |
|
if s.startswith("data:image"): |
|
try: |
|
_, b64 = s.split(",", 1) |
|
data = base64.b64decode(b64) |
|
return Image.open(io.BytesIO(data)).convert("RGB") |
|
except Exception as e: |
|
raise ValueError(f"Bad data URL: {e}") |
|
if _is_probably_base64(s) and not s.startswith(("http://", "https://")): |
|
try: |
|
data = base64.b64decode(s) |
|
return Image.open(io.BytesIO(data)).convert("RGB") |
|
except Exception as e: |
|
raise ValueError(f"Bad base64 image: {e}") |
|
if s.startswith(("http://", "https://")): |
|
resp = requests.get(s, timeout=20) |
|
resp.raise_for_status() |
|
return Image.open(io.BytesIO(resp.content)).convert("RGB") |
|
|
|
return Image.open(s).convert("RGB") |
|
raise ValueError(f"Unsupported image input type: {type(image_input)}") |
|
|
|
def _get_conv_mode(model_name: str) -> str: |
|
name = (model_name or "").lower() |
|
if "llama-2" in name: |
|
return "llava_llama_2" |
|
if "mistral" in name: |
|
return "mistral_instruct" |
|
if "v1.6-34b" in name: |
|
return "chatml_direct" |
|
if "v1" in name or "pulse" in name: |
|
return "llava_v1" |
|
if "mpt" in name: |
|
return "mpt" |
|
return "llava_v0" |
|
|
|
def _build_prompt_with_image(prompt: str, model_cfg) -> str: |
|
|
|
if DEFAULT_IMAGE_TOKEN in prompt or DEFAULT_IM_START_TOKEN in prompt: |
|
return prompt |
|
if getattr(model_cfg, "mm_use_im_start_end", False): |
|
token = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN |
|
return f"{token}\n{prompt}" |
|
return f"{DEFAULT_IMAGE_TOKEN}\n{prompt}" |
|
|
|
def _resolve_model_path(model_dir_hint: Optional[str], default_dir: str = "/repository") -> str: |
|
|
|
p = _get_env("HF_MODEL_DIR") or model_dir_hint or default_dir |
|
p = os.path.abspath(p) |
|
print(f"[DEBUG] resolved model path: {p}") |
|
return p |
|
|
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir: Optional[str] = None): |
|
|
|
print("🚀 Starting up PULSE-7B handler (startup load)...") |
|
print("📝 Enhanced by Ubden® Team") |
|
print(f"🔧 Python: {sys.version}") |
|
print(f"🔧 PyTorch: {torch.__version__}") |
|
try: |
|
import transformers |
|
print(f"🔧 Transformers: {transformers.__version__}") |
|
except Exception as e: |
|
print(f"[DEBUG] transformers import failed: {e}") |
|
|
|
self.model_dir = model_dir |
|
self.device = _pick_device() |
|
self.dtype = _pick_dtype(self.device) |
|
|
|
|
|
os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2") |
|
os.environ.setdefault("FLASH_ATTENTION", "1") |
|
print(f"[DEBUG] ATTN_IMPLEMENTATION={os.getenv('ATTN_IMPLEMENTATION')} FLASH_ATTENTION={os.getenv('FLASH_ATTENTION')}") |
|
|
|
|
|
self.model = None |
|
self.tokenizer = None |
|
self.image_processor = None |
|
self.context_len = None |
|
self.model_name = None |
|
|
|
|
|
try: |
|
self._startup_load_model() |
|
print("✅ Model loaded & ready in __init__") |
|
except Exception as e: |
|
print(f"💥 CRITICAL: model startup load failed: {e}") |
|
raise |
|
|
|
def _startup_load_model(self): |
|
|
|
local_path = _resolve_model_path(self.model_dir) |
|
use_local = os.path.isdir(local_path) and any( |
|
os.path.exists(os.path.join(local_path, f)) |
|
for f in ("config.json", "tokenizer_config.json") |
|
) |
|
model_base = _get_env("HF_MODEL_BASE", None) |
|
|
|
if use_local: |
|
model_path = local_path |
|
print(f"[DEBUG] loading model LOCALLY from: {model_path}") |
|
else: |
|
model_path = _get_env("HF_MODEL_ID", "PULSE-ECG/PULSE-7B") |
|
print(f"[DEBUG] loading model from HUB: {model_path} (HF_MODEL_BASE={model_base})") |
|
|
|
|
|
model_name = get_model_name_from_path(model_path) |
|
print(f"[DEBUG] resolved model_name: {model_name}") |
|
|
|
print("[DEBUG] calling load_pretrained_model ...") |
|
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( |
|
model_path=model_path, |
|
model_base=model_base, |
|
model_name=model_name, |
|
load_8bit=False, |
|
load_4bit=False, |
|
device_map="auto", |
|
device=self.device, |
|
) |
|
self.model_name = getattr(self.model.config, "name_or_path", str(model_path)) |
|
print(f"[DEBUG] model loaded: name={self.model_name}") |
|
|
|
|
|
vt = ( |
|
getattr(self.model.config, "mm_vision_tower", None) |
|
or getattr(self.model.config, "vision_tower", None) |
|
) |
|
print(f"[DEBUG] vision tower: {vt}") |
|
if self.image_processor is None or vt is None: |
|
raise RuntimeError( |
|
"[ERROR] Vision tower not loaded (mm_vision_tower/vision_tower None). " |
|
"Yerel yükleme için HF_MODEL_DIR doğru klasörü göstermeli; " |
|
"Hub için HF_MODEL_ID PULSE/LLaVA tabanlı olmalı (örn: 'PULSE-ECG/PULSE-7B')." |
|
) |
|
|
|
|
|
try: |
|
self.tokenizer.padding_side = "left" |
|
if getattr(self.tokenizer, "pad_token_id", None) is None: |
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
|
except Exception as e: |
|
print(f"[DEBUG] tokenizer safety patch failed: {e}") |
|
|
|
self.model.eval() |
|
|
|
|
|
def load(self): |
|
print("[DEBUG] load(): model is already initialized in __init__") |
|
return True |
|
|
|
@torch.inference_mode() |
|
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
print(f"[DEBUG] __call__ inputs keys={list(inputs.keys()) if hasattr(inputs,'keys') else 'N/A'}") |
|
|
|
if "inputs" in inputs and isinstance(inputs["inputs"], dict): |
|
inputs = inputs["inputs"] |
|
|
|
prompt = inputs.get("query") or inputs.get("prompt") or inputs.get("istem") or "" |
|
image_in = inputs.get("image") or inputs.get("image_url") or inputs.get("img") |
|
if not image_in: |
|
return {"error": "Missing 'image' in payload"} |
|
if not isinstance(prompt, str) or not prompt.strip(): |
|
return {"error": "Missing 'query'/'prompt' text"} |
|
|
|
|
|
temperature = float(inputs.get("temperature", 0.0)) |
|
top_p = float(inputs.get("top_p", 0.9)) |
|
max_new = int(inputs.get("max_new_tokens", inputs.get("max_tokens", 512))) |
|
repetition_penalty = float(inputs.get("repetition_penalty", 1.0)) |
|
conv_mode_override = inputs.get("conv_mode") or _get_env("CONV_MODE", None) |
|
|
|
|
|
try: |
|
image = _load_image_from_any(image_in) |
|
print(f"[DEBUG] loaded image size={image.size}") |
|
except Exception as e: |
|
return {"error": f"Failed to load image: {e}"} |
|
|
|
if self.image_processor is None: |
|
return {"error": "image_processor is None; model not initialized properly (no vision tower)"} |
|
|
|
try: |
|
out = self.image_processor.preprocess(image, return_tensors="pt") |
|
images_tensor = out["pixel_values"].to(self.device, dtype=self.dtype) |
|
image_sizes = [image.size] |
|
print(f"[DEBUG] preprocess OK; images_tensor.shape={images_tensor.shape}") |
|
except Exception as e: |
|
return {"error": f"Image preprocessing failed: {e}"} |
|
|
|
|
|
mode = conv_mode_override or _get_conv_mode(self.model_name) |
|
conv = (conv_templates.get(mode) or conv_templates[list(conv_templates.keys())[0]]).copy() |
|
conv.append_message(conv.roles[0], _build_prompt_with_image(prompt.strip(), self.model.config)) |
|
conv.append_message(conv.roles[1], None) |
|
full_prompt = conv.get_prompt() |
|
print(f"[DEBUG] conv_mode={mode}; full_prompt_len={len(full_prompt)}") |
|
|
|
|
|
try: |
|
input_ids = tokenizer_image_token( |
|
full_prompt, self.tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors="pt" |
|
).unsqueeze(0).to(self.device) |
|
print(f"[DEBUG] tokenizer_image_token OK; input_ids.shape={input_ids.shape}") |
|
except Exception as e: |
|
print(f"[DEBUG] tokenizer_image_token failed: {e}; fallback to plain tokenizer") |
|
try: |
|
toks = self.tokenizer([full_prompt], return_tensors="pt", padding=True, truncation=True) |
|
input_ids = toks["input_ids"].to(self.device) |
|
print(f"[DEBUG] plain tokenizer OK; input_ids.shape={input_ids.shape}") |
|
except Exception as e2: |
|
return {"error": f"Tokenization failed: {e} / {e2}"} |
|
|
|
attention_mask = torch.ones_like(input_ids, device=self.device) |
|
|
|
|
|
try: |
|
print(f"[DEBUG] generate(max_new_tokens={max_new}, temp={temperature}, top_p={top_p}, rep={repetition_penalty})") |
|
gen_ids = self.model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
images=images_tensor, |
|
image_sizes=image_sizes, |
|
do_sample=(temperature > 0), |
|
temperature=temperature, |
|
top_p=top_p, |
|
max_new_tokens=max_new, |
|
repetition_penalty=repetition_penalty, |
|
use_cache=True, |
|
) |
|
print(f"[DEBUG] generate OK; gen_ids.shape={gen_ids.shape}") |
|
except Exception as e: |
|
return {"error": f"Generation failed: {e}"} |
|
|
|
|
|
try: |
|
new_tokens = gen_ids[0, input_ids.shape[1]:] |
|
text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip() |
|
print(f"[DEBUG] decoded_text_len={len(text)}") |
|
except Exception as e: |
|
return {"error": f"Decode failed: {e}"} |
|
|
|
return { |
|
"generated_text": text, |
|
"model": self.model_name, |
|
"conv_mode": mode, |
|
} |
|
|