|
|
|
|
|
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer |
|
from PIL import Image |
|
import requests |
|
import torch |
|
from threading import Thread |
|
import logging |
|
import time |
|
import pynvml |
|
|
|
class Gemma: |
|
def __init__(self, model_id): |
|
self.model_id = model_id |
|
self.model = Gemma3ForConditionalGeneration.from_pretrained( |
|
model_id, device_map="auto", torch_dtype=torch.bfloat16 |
|
).eval() |
|
self.processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
self.handle = None |
|
if torch.cuda.is_available(): |
|
try: |
|
pynvml.nvmlInit() |
|
device_id = next(self.model.parameters()).device.index |
|
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) |
|
except Exception as e: |
|
logging.error(f"Failed to initialize NVML: {e}") |
|
|
|
def __del__(self): |
|
if self.handle: |
|
try: |
|
pynvml.nvmlShutdown() |
|
except: |
|
pass |
|
|
|
def generate(self, video, prompt): |
|
start_time = time.time() |
|
|
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": [{"type": "text", "text": "You are a helpful assistant."}] |
|
}, |
|
|
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": prompt}] |
|
} |
|
] |
|
|
|
for image in video: |
|
messages[1]["content"].append({"type": "image", "image": image}) |
|
|
|
print(messages) |
|
inputs = self.processor.apply_chat_template( |
|
messages, add_generation_prompt=True, tokenize=True, |
|
return_dict=True, return_tensors="pt" |
|
).to(self.model.device) |
|
|
|
logging.info(f"Prompt token length: {len(inputs.input_ids[0])}") |
|
|
|
streamer = TextIteratorStreamer(self.processor, skip_prompt=True, skip_special_tokens=True) |
|
|
|
generation_kwargs = dict( |
|
**inputs, |
|
streamer=streamer, |
|
max_new_tokens=512 |
|
) |
|
|
|
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
full_response = "" |
|
print("Response: ", end="") |
|
first_token_time = None |
|
for new_text in streamer: |
|
if first_token_time is None: |
|
first_token_time = time.time() |
|
full_response += new_text |
|
print(new_text, end="", flush=True) |
|
print() |
|
thread.join() |
|
|
|
end_time = time.time() |
|
|
|
if first_token_time is not None: |
|
generation_time = end_time - first_token_time |
|
else: |
|
generation_time = 0 |
|
|
|
num_generated_tokens = len(self.processor.tokenizer(full_response).input_ids) |
|
tokens_per_second = num_generated_tokens / generation_time if generation_time > 0 else 0 |
|
|
|
peak_memory_mb = 0 |
|
if self.handle: |
|
mem_info = pynvml.nvmlDeviceGetMemoryInfo(self.handle) |
|
peak_memory_mb = mem_info.used / (1024 * 1024) |
|
|
|
return { |
|
"response": full_response, |
|
"tokens_per_second": tokens_per_second, |
|
"peak_gpu_memory_mb": peak_memory_mb, |
|
"num_generated_tokens": num_generated_tokens, |
|
} |
|
|
|
|