File size: 3,438 Bytes
2991961
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
TensorRT Inference Example for WayraPPL
Requires A100 GPU with TensorRT 8.6+
"""

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
from transformers import AutoTokenizer
import torch

class WayraPPLTensorRT:
    def __init__(self, engine_path: str):
        # Load TensorRT engine
        trt_logger = trt.Logger(trt.Logger.INFO)
        runtime = trt.Runtime(trt_logger)
        
        with open(engine_path, 'rb') as f:
            engine_data = f.read()
        
        self.engine = runtime.deserialize_cuda_engine(engine_data)
        self.context = self.engine.create_execution_context()
        self.stream = cuda.Stream()
    
    def infer(self, input_ids: np.ndarray, attention_mask: np.ndarray):
        batch_size, seq_len = input_ids.shape
        
        # Set dynamic shapes
        self.context.set_input_shape("input_ids", input_ids.shape)
        self.context.set_input_shape("attention_mask", attention_mask.shape)
        
        # Allocate memory
        d_input_ids = cuda.mem_alloc(input_ids.nbytes)
        d_attention_mask = cuda.mem_alloc(attention_mask.nbytes)
        
        # Copy inputs
        cuda.memcpy_htod_async(d_input_ids, input_ids.astype(np.int64), self.stream)
        cuda.memcpy_htod_async(d_attention_mask, attention_mask.astype(np.int64), self.stream)
        
        # Setup outputs
        outputs = {}
        device_outputs = {}
        
        for i in range(self.engine.num_io_tensors):
            tensor_name = self.engine.get_tensor_name(i)
            if self.engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.OUTPUT:
                output_shape = self.context.get_tensor_shape(tensor_name)
                if output_shape[0] == -1:
                    output_shape = (batch_size,) + output_shape[1:]
                
                host_output = np.empty(output_shape, dtype=np.float32)
                device_output = cuda.mem_alloc(host_output.nbytes)
                
                outputs[tensor_name] = host_output
                device_outputs[tensor_name] = device_output
                self.context.set_tensor_address(tensor_name, int(device_output))
        
        # Set input addresses
        self.context.set_tensor_address("input_ids", int(d_input_ids))
        self.context.set_tensor_address("attention_mask", int(d_attention_mask))
        
        # Execute
        self.context.execute_async_v3(stream_handle=self.stream.handle)
        
        # Copy outputs
        for tensor_name, host_output in outputs.items():
            cuda.memcpy_dtoh_async(host_output, device_outputs[tensor_name], self.stream)
        
        self.stream.synchronize()
        
        # Cleanup
        d_input_ids.free()
        d_attention_mask.free()
        for device_output in device_outputs.values():
            device_output.free()
        
        return outputs

# Usage example
if __name__ == "__main__":
    # Load model
    model = WayraPPLTensorRT("wayrappl_fp16_bs2048.engine")
    tokenizer = AutoTokenizer.from_pretrained(".")
    
    # Prepare input
    texts = ["Esta es una muestra de texto en español."]
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    
    # Run inference
    outputs = model.infer(
        inputs['input_ids'].numpy(),
        inputs['attention_mask'].numpy()
    )
    
    print(f"PPL Score: {outputs['ppl']}")