drbh commited on
Commit
22b535b
·
1 Parent(s): 1d2e955

feat: add quick start and readme example

Browse files
Files changed (2) hide show
  1. README.md +53 -1
  2. readme_example.py +42 -0
README.md CHANGED
@@ -7,4 +7,56 @@ triton-kernels is a set of kernels that enable fast moe on different architectur
7
 
8
  Original code here https://github.com/triton-lang/triton/tree/main/python/triton_kernels
9
 
10
- The current version is the following commit 7d0efaa7231661299284a603512fce4fa255e62c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  Original code here https://github.com/triton-lang/triton/tree/main/python/triton_kernels
9
 
10
+ The current version is the following commit 7d0efaa7231661299284a603512fce4fa255e62c
11
+
12
+
13
+ ## Quickstart
14
+
15
+ ```bash
16
+ uv run https://huggingface.co/kernels-community/triton_kernels/raw/main/readme_example.py
17
+ ```
18
+
19
+ ```python
20
+ # /// script
21
+ # requires-python = ">=3.10"
22
+ # dependencies = [
23
+ # "torch",
24
+ # "triton",
25
+ # "numpy",
26
+ # "kernels",
27
+ # ]
28
+ # ///
29
+
30
+ import torch
31
+ import sys
32
+ from kernels import get_kernel
33
+
34
+ torch.manual_seed(42)
35
+ torch.cuda.manual_seed(42)
36
+
37
+ # Load triton_kernels module via kernels library
38
+ triton_kernels = get_kernel("kernels-community/triton_kernels")
39
+
40
+ # Access modules directly from the loaded kernel
41
+ swiglu = triton_kernels.swiglu
42
+ routing = triton_kernels.routing
43
+
44
+ # Setup
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+
47
+ # SwiGLU example
48
+ x = torch.randn(512, 1024, device=device, dtype=torch.bfloat16)
49
+ y = swiglu.swiglu_torch(x, 0.5, swiglu.PrecisionConfig(limit=1.0))
50
+ print(f"SwiGLU: {x.shape} -> {y.shape}")
51
+
52
+ # Routing example
53
+ logits = torch.randn(128, 8, device=device, dtype=torch.float16)
54
+ routing_data, gather_idx, scatter_idx = routing.routing_torch(logits, n_expts_act=2)
55
+ print(f"Routing: {routing_data.expt_hist.sum()} tokens routed")
56
+
57
+ # MoE integrated
58
+ n_tokens = routing_data.expt_hist.sum().item()
59
+ x_moe = torch.randn(n_tokens, 512, device=device, dtype=torch.bfloat16)
60
+ y_moe = swiglu.swiglu_torch(x_moe, 0.5, swiglu.PrecisionConfig(limit=1.0))
61
+ print(f"MoE SwiGLU: {x_moe.shape} -> {y_moe.shape}")
62
+ ```
readme_example.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch",
5
+ # "triton",
6
+ # "numpy",
7
+ # "kernels",
8
+ # ]
9
+ # ///
10
+
11
+ import torch
12
+ import sys
13
+ from kernels import get_kernel
14
+
15
+ torch.manual_seed(42)
16
+ torch.cuda.manual_seed(42)
17
+
18
+ # Load triton_kernels module via kernels library
19
+ triton_kernels = get_kernel("kernels-community/triton_kernels")
20
+
21
+ # Access modules directly from the loaded kernel
22
+ swiglu = triton_kernels.swiglu
23
+ routing = triton_kernels.routing
24
+
25
+ # Setup
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ # SwiGLU example
29
+ x = torch.randn(512, 1024, device=device, dtype=torch.bfloat16)
30
+ y = swiglu.swiglu_torch(x, 0.5, swiglu.PrecisionConfig(limit=1.0))
31
+ print(f"SwiGLU: {x.shape} -> {y.shape}")
32
+
33
+ # Routing example
34
+ logits = torch.randn(128, 8, device=device, dtype=torch.float16)
35
+ routing_data, gather_idx, scatter_idx = routing.routing_torch(logits, n_expts_act=2)
36
+ print(f"Routing: {routing_data.expt_hist.sum()} tokens routed")
37
+
38
+ # MoE integrated
39
+ n_tokens = routing_data.expt_hist.sum().item()
40
+ x_moe = torch.randn(n_tokens, 512, device=device, dtype=torch.bfloat16)
41
+ y_moe = swiglu.swiglu_torch(x_moe, 0.5, swiglu.PrecisionConfig(limit=1.0))
42
+ print(f"MoE SwiGLU: {x_moe.shape} -> {y_moe.shape}")