triton error when convert to onnx

#114
by hoailebads - opened

I have trouble when convert model that i have finetuned to onnx format. The errors show that

Traceback (most recent call last):
  File "/validation/convert_module/main.py", line 20, in <module>
    onnx_converter.convert(onnx_path)
  File "/validation/convert_module/onnx_converter.py", line 41, in convert
    torch.onnx.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/__init__.py", line 377, in export
    export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 502, in export
    _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1564, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 997, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 904, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 1500, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 139, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 130, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/validation/convert_module/xlm_roberta_lora/modeling_lora.py", line 374, in forward
    return self.roberta(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/validation/convert_module/xlm_roberta_lora/modeling_xlm_roberta.py", line 736, in forward
    sequence_output = self.encoder(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/validation/convert_module/xlm_roberta_lora/modeling_xlm_roberta.py", line 230, in forward
    hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/validation/convert_module/xlm_roberta_lora/block.py", line 201, in forward
    mixer_out = self.mixer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/validation/convert_module/xlm_roberta_lora/mha.py", line 732, in forward
    qkv = self.rotary_emb(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/validation/convert_module/xlm_roberta_lora/rotary.py", line 604, in forward
    return apply_rotary_emb_qkv_(
  File "/validation/convert_module/xlm_roberta_lora/rotary.py", line 327, in apply_rotary_emb_qkv_
    return ApplyRotaryEmbQKV_.apply(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/validation/convert_module/xlm_roberta_lora/rotary.py", line 186, in forward
    apply_rotary(
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/ops/triton/rotary.py", line 213, in apply_rotary
    rotary_kernel[grid](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 662, in run
    kernel = self.compile(
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 276, in compile
    module = src.make_ir(options, codegen_fns, context)
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 113, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
triton.compiler.errors.CompilationError: at 34:22:
    # Meta-parameters
    BLOCK_K: tl.constexpr,
    IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
    IS_VARLEN: tl.constexpr,
    INTERLEAVED: tl.constexpr,
    CONJUGATE: tl.constexpr,
    BLOCK_M: tl.constexpr,
):
    pid_m = tl.program_id(axis=0)
    pid_batch = tl.program_id(axis=1)
    pid_head = tl.program_id(axis=2)
    rotary_dim_half = rotary_dim // 2
                      ^
IncompatibleTypeErrorImpl('invalid operands of type pointer<int64> and triton.language.int32')

my package

pytorch-triton             3.0.0+dedb7bdf3
torch                      2.5.0a0+e000cf0ad9.nv24.10
torch_tensorrt             2.5.0a0
torchprofile               0.0.4
torchvision                0.20.0a0
sentence-transformers      3.4.1
transformers               4.48.3
flash_attn                 2.4.2
onnx                       1.16.2

Please let me know if you guy have fixed this issue before

Sign up or log in to comment