llava-next-inference / handler.py
eBoreal's picture
gpu support
f0e2eeb
raw
history blame
2.26 kB
from typing import Dict, List, Any
from tempfile import TemporaryDirectory
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
import torch
import requests
class EndpointHandler:
def __init__(self, path=""):
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
device = 'gpu' if torch.cuda.is_available() else 'cpu'
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
torch_dtype=torch.float32 if device == 'cpu' else torch.float16,
low_cpu_mem_usage=True
)
model.to(device)
self.model = model
self.device = device
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
text (:obj: `str`)
files (:obj: `list`) - List of URLs to images
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
prompt = data.pop("prompt", data)
# get additional date field0
image_url = data.pop("files", None)[-1]['path']
print(image_url)
print(prompt)
if image_url is None:
return "You need to upload an image URL for LLaVA to work."
# Create a temporary directory
with TemporaryDirectory() as tmpdirname:
# Download the image
response = requests.get(image_url)
if response.status_code != 200:
return "Failed to download the image."
# Define the path for the downloaded image
image_path = f"{tmpdirname}/image.jpg"
with open(image_path, "wb") as f:
f.write(response.content)
# Open the downloaded image
with Image.open(image_path).convert("RGB") as image:
prompt = f"[INST] <image>\n{prompt} [/INST]"
inputs = self.processor(prompt, image, return_tensors="pt").to(self.device)
output = self.model.generate(**inputs, max_new_tokens=100)
clean = self.processor.decode(output[0], skip_special_tokens=True)
return clean