|
from typing import Dict, List, Any |
|
from tempfile import TemporaryDirectory |
|
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration |
|
from PIL import Image |
|
import torch |
|
import requests |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") |
|
|
|
|
|
device = 'gpu' if torch.cuda.is_available() else 'cpu' |
|
|
|
model = LlavaNextForConditionalGeneration.from_pretrained( |
|
"llava-hf/llava-v1.6-mistral-7b-hf", |
|
torch_dtype=torch.float32 if device == 'cpu' else torch.float16, |
|
low_cpu_mem_usage=True |
|
) |
|
model.to(device) |
|
|
|
self.model = model |
|
self.device = device |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
text (:obj: `str`) |
|
files (:obj: `list`) - List of URLs to images |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
prompt = data.pop("prompt", data) |
|
|
|
image_url = data.pop("files", None)[-1]['path'] |
|
|
|
print(image_url) |
|
print(prompt) |
|
|
|
if image_url is None: |
|
return "You need to upload an image URL for LLaVA to work." |
|
|
|
|
|
with TemporaryDirectory() as tmpdirname: |
|
|
|
response = requests.get(image_url) |
|
if response.status_code != 200: |
|
return "Failed to download the image." |
|
|
|
|
|
image_path = f"{tmpdirname}/image.jpg" |
|
with open(image_path, "wb") as f: |
|
f.write(response.content) |
|
|
|
|
|
with Image.open(image_path).convert("RGB") as image: |
|
prompt = f"[INST] <image>\n{prompt} [/INST]" |
|
|
|
inputs = self.processor(prompt, image, return_tensors="pt").to(self.device) |
|
|
|
output = self.model.generate(**inputs, max_new_tokens=100) |
|
|
|
clean = self.processor.decode(output[0], skip_special_tokens=True) |
|
|
|
return clean |