# pip install accelerate from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer from PIL import Image import requests import torch from threading import Thread import logging import time import pynvml class Gemma: def __init__(self, model_id): self.model_id = model_id self.model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16 ).eval() self.processor = AutoProcessor.from_pretrained(model_id) self.handle = None if torch.cuda.is_available(): try: pynvml.nvmlInit() device_id = next(self.model.parameters()).device.index self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) except Exception as e: logging.error(f"Failed to initialize NVML: {e}") def __del__(self): if self.handle: try: pynvml.nvmlShutdown() except: pass def generate(self, video, prompt): start_time = time.time() messages = [ { "role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}] }, { "role": "user", "content": [ {"type": "text", "text": prompt}] } ] for image in video: messages[1]["content"].append({"type": "image", "image": image}) print(messages) inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(self.model.device) logging.info(f"Prompt token length: {len(inputs.input_ids[0])}") streamer = TextIteratorStreamer(self.processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=512 ) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() full_response = "" print("Response: ", end="") first_token_time = None for new_text in streamer: if first_token_time is None: first_token_time = time.time() full_response += new_text print(new_text, end="", flush=True) print() thread.join() end_time = time.time() if first_token_time is not None: generation_time = end_time - first_token_time else: generation_time = 0 num_generated_tokens = len(self.processor.tokenizer(full_response).input_ids) tokens_per_second = num_generated_tokens / generation_time if generation_time > 0 else 0 peak_memory_mb = 0 if self.handle: mem_info = pynvml.nvmlDeviceGetMemoryInfo(self.handle) peak_memory_mb = mem_info.used / (1024 * 1024) return { "response": full_response, "tokens_per_second": tokens_per_second, "peak_gpu_memory_mb": peak_memory_mb, "num_generated_tokens": num_generated_tokens, }