File size: 2,190 Bytes
567c8ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import torch
import importlib
from triton_kernels.specialize import cacheable, specialize
import triton
import triton.language as tl
@triton.jit
def template_kernel(o):
cst = 1.0
tl.store(o, cst)
def retrieve_fn(module, name):
module = importlib.import_module(module)
fn = getattr(module, name)
return fn
_specialized_kernel = None
def get_specialized_kernel():
global _specialized_kernel
if _specialized_kernel is not None:
return _specialized_kernel
import types
spec_constants = {}
spec_tuples = {}
module = types.ModuleType("specialized_kernel")
module.specialized = specialize(template_kernel, module, spec_constants, spec_tuples)
_specialized_kernel = module.specialized
return _specialized_kernel
@cacheable
def cacheable_kernel():
return get_specialized_kernel()
def test_cacheable(device, fresh_knobs):
specialized_kernel = get_specialized_kernel()
specialization_data = None
fn_name = None
module_name = None
def cache_hook(*args, **kwargs):
nonlocal specialization_data
nonlocal fn_name
nonlocal module_name
specialization_data = kwargs["compile"]["specialization_data"]
fn_name = kwargs["fn"].name
module_name = kwargs["fn"].module
triton.knobs.runtime.jit_cache_hook = cache_hook
o = torch.empty((1, ), dtype=torch.float32, device=device)
k = specialized_kernel[(1, )](o, )
hash = k.hash
assert o.item() == 1.0
assert module_name == "tests.test_specialize"
assert fn_name == "cacheable_kernel"
compile_count = 0
def count_hook(*args, **kwargs):
nonlocal compile_count
compile_count += 1
triton.knobs.runtime.jit_cache_hook = count_hook
# clear the cache
specialized_kernel.device_caches.clear()
# retrieve the kernel from name and preload it.
fn = retrieve_fn(module_name, fn_name)
assert fn == specialized_kernel
preload = fn.preload(specialization_data)
assert compile_count == 1
assert preload.hash == hash
# verify that we hit the cache.
compile_count = 0
specialized_kernel[(1, )](o, )
assert compile_count == 0
|