File size: 2,190 Bytes
567c8ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import importlib
from triton_kernels.specialize import cacheable, specialize
import triton
import triton.language as tl


@triton.jit
def template_kernel(o):
    cst = 1.0
    tl.store(o, cst)


def retrieve_fn(module, name):
    module = importlib.import_module(module)
    fn = getattr(module, name)
    return fn


_specialized_kernel = None


def get_specialized_kernel():
    global _specialized_kernel
    if _specialized_kernel is not None:
        return _specialized_kernel
    import types
    spec_constants = {}
    spec_tuples = {}
    module = types.ModuleType("specialized_kernel")
    module.specialized = specialize(template_kernel, module, spec_constants, spec_tuples)
    _specialized_kernel = module.specialized
    return _specialized_kernel


@cacheable
def cacheable_kernel():
    return get_specialized_kernel()


def test_cacheable(device, fresh_knobs):
    specialized_kernel = get_specialized_kernel()

    specialization_data = None
    fn_name = None
    module_name = None

    def cache_hook(*args, **kwargs):
        nonlocal specialization_data
        nonlocal fn_name
        nonlocal module_name
        specialization_data = kwargs["compile"]["specialization_data"]
        fn_name = kwargs["fn"].name
        module_name = kwargs["fn"].module

    triton.knobs.runtime.jit_cache_hook = cache_hook
    o = torch.empty((1, ), dtype=torch.float32, device=device)
    k = specialized_kernel[(1, )](o, )
    hash = k.hash
    assert o.item() == 1.0
    assert module_name == "tests.test_specialize"
    assert fn_name == "cacheable_kernel"

    compile_count = 0

    def count_hook(*args, **kwargs):
        nonlocal compile_count
        compile_count += 1

    triton.knobs.runtime.jit_cache_hook = count_hook
    # clear the cache
    specialized_kernel.device_caches.clear()

    # retrieve the kernel from name and preload it.
    fn = retrieve_fn(module_name, fn_name)
    assert fn == specialized_kernel
    preload = fn.preload(specialization_data)
    assert compile_count == 1
    assert preload.hash == hash

    # verify that we hit the cache.
    compile_count = 0
    specialized_kernel[(1, )](o, )
    assert compile_count == 0