|
import torch |
|
import importlib |
|
from triton_kernels.specialize import cacheable, specialize |
|
import triton |
|
import triton.language as tl |
|
|
|
|
|
@triton.jit |
|
def template_kernel(o): |
|
cst = 1.0 |
|
tl.store(o, cst) |
|
|
|
|
|
def retrieve_fn(module, name): |
|
module = importlib.import_module(module) |
|
fn = getattr(module, name) |
|
return fn |
|
|
|
|
|
_specialized_kernel = None |
|
|
|
|
|
def get_specialized_kernel(): |
|
global _specialized_kernel |
|
if _specialized_kernel is not None: |
|
return _specialized_kernel |
|
import types |
|
spec_constants = {} |
|
spec_tuples = {} |
|
module = types.ModuleType("specialized_kernel") |
|
module.specialized = specialize(template_kernel, module, spec_constants, spec_tuples) |
|
_specialized_kernel = module.specialized |
|
return _specialized_kernel |
|
|
|
|
|
@cacheable |
|
def cacheable_kernel(): |
|
return get_specialized_kernel() |
|
|
|
|
|
def test_cacheable(device, fresh_knobs): |
|
specialized_kernel = get_specialized_kernel() |
|
|
|
specialization_data = None |
|
fn_name = None |
|
module_name = None |
|
|
|
def cache_hook(*args, **kwargs): |
|
nonlocal specialization_data |
|
nonlocal fn_name |
|
nonlocal module_name |
|
specialization_data = kwargs["compile"]["specialization_data"] |
|
fn_name = kwargs["fn"].name |
|
module_name = kwargs["fn"].module |
|
|
|
triton.knobs.runtime.jit_cache_hook = cache_hook |
|
o = torch.empty((1, ), dtype=torch.float32, device=device) |
|
k = specialized_kernel[(1, )](o, ) |
|
hash = k.hash |
|
assert o.item() == 1.0 |
|
assert module_name == "tests.test_specialize" |
|
assert fn_name == "cacheable_kernel" |
|
|
|
compile_count = 0 |
|
|
|
def count_hook(*args, **kwargs): |
|
nonlocal compile_count |
|
compile_count += 1 |
|
|
|
triton.knobs.runtime.jit_cache_hook = count_hook |
|
|
|
specialized_kernel.device_caches.clear() |
|
|
|
|
|
fn = retrieve_fn(module_name, fn_name) |
|
assert fn == specialized_kernel |
|
preload = fn.preload(specialization_data) |
|
assert compile_count == 1 |
|
assert preload.hash == hash |
|
|
|
|
|
compile_count = 0 |
|
specialized_kernel[(1, )](o, ) |
|
assert compile_count == 0 |
|
|