|
from PIL import Image |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer |
|
from transformers.image_utils import load_image |
|
from threading import Thread |
|
import logging |
|
import time |
|
import pynvml |
|
|
|
class MiniCPM: |
|
def __init__(self, model_id): |
|
self.model_id = model_id |
|
self.model = AutoModel.from_pretrained( |
|
model_id, |
|
trust_remote_code=True, |
|
attn_implementation='sdpa', |
|
torch_dtype=torch.bfloat16 |
|
) |
|
self.model = self.model.eval().cuda() |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_id, trust_remote_code=True |
|
) |
|
|
|
self.handle = None |
|
if torch.cuda.is_available(): |
|
try: |
|
pynvml.nvmlInit() |
|
device_id = next(self.model.parameters()).device.index |
|
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) |
|
except Exception as e: |
|
logging.error(f"Failed to initialize NVML: {e}") |
|
|
|
def __del__(self): |
|
if self.handle: |
|
try: |
|
pynvml.nvmlShutdown() |
|
except: |
|
pass |
|
|
|
def generate(self, video, prompt): |
|
start_time = time.time() |
|
|
|
images = [Image.open(frame).convert('RGB') for frame in video] |
|
content = images + [prompt] |
|
msgs = [{'role': 'user', 'content': content}] |
|
|
|
|
|
res = self.model.chat( |
|
image=None, |
|
msgs=msgs, |
|
tokenizer=self.tokenizer, |
|
stream=True |
|
) |
|
|
|
full_response = "" |
|
print("Response: ", end="") |
|
first_token_time = None |
|
for new_text in res: |
|
if first_token_time is None: |
|
first_token_time = time.time() |
|
full_response += new_text |
|
print(new_text, end="", flush=True) |
|
print() |
|
|
|
end_time = time.time() |
|
|
|
if first_token_time is not None: |
|
generation_time = end_time - first_token_time |
|
else: |
|
generation_time = 0 |
|
|
|
num_generated_tokens = len(self.tokenizer(full_response).input_ids) |
|
tokens_per_second = num_generated_tokens / generation_time if generation_time > 0 else 0 |
|
|
|
peak_memory_mb = 0 |
|
if self.handle: |
|
mem_info = pynvml.nvmlDeviceGetMemoryInfo(self.handle) |
|
peak_memory_mb = mem_info.used / (1024 * 1024) |
|
|
|
return { |
|
"response": full_response, |
|
"tokens_per_second": tokens_per_second, |
|
"peak_gpu_memory_mb": peak_memory_mb, |
|
"num_generated_tokens": num_generated_tokens, |
|
} |
|
|