from typing import Dict, List, Any from tempfile import TemporaryDirectory from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration from PIL import Image import torch import requests class EndpointHandler: def __init__(self, path=""): self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") device = 'gpu' if torch.cuda.is_available() else 'cpu' model = LlavaNextForConditionalGeneration.from_pretrained( "llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float32 if device == 'cpu' else torch.float16, low_cpu_mem_usage=True ) model.to(device) self.model = model self.device = device def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: text (:obj: `str`) files (:obj: `list`) - List of URLs to images Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs prompt = data.pop("prompt", data) # get additional date field0 image_url = data.pop("files", None)[-1]['path'] print(image_url) print(prompt) if image_url is None: return "You need to upload an image URL for LLaVA to work." # Create a temporary directory with TemporaryDirectory() as tmpdirname: # Download the image response = requests.get(image_url) if response.status_code != 200: return "Failed to download the image." # Define the path for the downloaded image image_path = f"{tmpdirname}/image.jpg" with open(image_path, "wb") as f: f.write(response.content) # Open the downloaded image with Image.open(image_path).convert("RGB") as image: prompt = f"[INST] \n{prompt} [/INST]" inputs = self.processor(prompt, image, return_tensors="pt").to(self.device) output = self.model.generate(**inputs, max_new_tokens=100) clean = self.processor.decode(output[0], skip_special_tokens=True) return clean