""" TensorRT Inference Example for WayraPPL Requires A100 GPU with TensorRT 8.6+ """ import tensorrt as trt import pycuda.driver as cuda import pycuda.autoinit import numpy as np from transformers import AutoTokenizer import torch class WayraPPLTensorRT: def __init__(self, engine_path: str): # Load TensorRT engine trt_logger = trt.Logger(trt.Logger.INFO) runtime = trt.Runtime(trt_logger) with open(engine_path, 'rb') as f: engine_data = f.read() self.engine = runtime.deserialize_cuda_engine(engine_data) self.context = self.engine.create_execution_context() self.stream = cuda.Stream() def infer(self, input_ids: np.ndarray, attention_mask: np.ndarray): batch_size, seq_len = input_ids.shape # Set dynamic shapes self.context.set_input_shape("input_ids", input_ids.shape) self.context.set_input_shape("attention_mask", attention_mask.shape) # Allocate memory d_input_ids = cuda.mem_alloc(input_ids.nbytes) d_attention_mask = cuda.mem_alloc(attention_mask.nbytes) # Copy inputs cuda.memcpy_htod_async(d_input_ids, input_ids.astype(np.int64), self.stream) cuda.memcpy_htod_async(d_attention_mask, attention_mask.astype(np.int64), self.stream) # Setup outputs outputs = {} device_outputs = {} for i in range(self.engine.num_io_tensors): tensor_name = self.engine.get_tensor_name(i) if self.engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.OUTPUT: output_shape = self.context.get_tensor_shape(tensor_name) if output_shape[0] == -1: output_shape = (batch_size,) + output_shape[1:] host_output = np.empty(output_shape, dtype=np.float32) device_output = cuda.mem_alloc(host_output.nbytes) outputs[tensor_name] = host_output device_outputs[tensor_name] = device_output self.context.set_tensor_address(tensor_name, int(device_output)) # Set input addresses self.context.set_tensor_address("input_ids", int(d_input_ids)) self.context.set_tensor_address("attention_mask", int(d_attention_mask)) # Execute self.context.execute_async_v3(stream_handle=self.stream.handle) # Copy outputs for tensor_name, host_output in outputs.items(): cuda.memcpy_dtoh_async(host_output, device_outputs[tensor_name], self.stream) self.stream.synchronize() # Cleanup d_input_ids.free() d_attention_mask.free() for device_output in device_outputs.values(): device_output.free() return outputs # Usage example if __name__ == "__main__": # Load model model = WayraPPLTensorRT("wayrappl_fp16_bs2048.engine") tokenizer = AutoTokenizer.from_pretrained(".") # Prepare input texts = ["Esta es una muestra de texto en espaƱol."] inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512) # Run inference outputs = model.infer( inputs['input_ids'].numpy(), inputs['attention_mask'].numpy() ) print(f"PPL Score: {outputs['ppl']}")