File size: 2,892 Bytes
b4cad21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#pragma once

#include <torch/torch.h>

bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);                                                                                                                                                 
                                                                                                                                                                                                                     
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,                                                                                                                                                   
                       torch::Tensor const& b, torch::Tensor const& a_scales,                                                                                                                                        
                       torch::Tensor const& b_scales,                                                                                                                                                                
                       c10::optional<torch::Tensor> const& bias);                                                                                                                                                    
                                                                                                                                                                                                                     
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,                                                                                                                                               
                           torch::Tensor const& b,                                                                                                                                                                   
                           torch::Tensor const& a_scales,                                                                                                                                                            
                           torch::Tensor const& b_scales,                                                                                                                                                            
                           torch::Tensor const& azp_adj,                                                                                                                                                             
                           c10::optional<torch::Tensor> const& azp,                                                                                                                                                  
                           c10::optional<torch::Tensor> const& bias);