drbh commited on
Commit
6324f5a
·
1 Parent(s): 62efba7

feat: add quick start and readme example

Browse files
Files changed (2) hide show
  1. README.md +58 -0
  2. readme_example.py +47 -0
README.md CHANGED
@@ -6,6 +6,64 @@ license: apache-2.0
6
 
7
  This is an implementation of Flash Attention 3 CUDA kernels with support for attention sinks. The attention sinks implementation was contributed to Flash Attention by the [vLLM team](https://huggingface.co/vllm-project). The [transformers team](https://huggingface.co/transformers-community) packaged the implementation and pre-built it for use with the [kernels library](https://github.com/huggingface/kernels).
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ## How to Use
10
 
11
  When loading your model with transformers, provide this repository id as the source of the attention implementation:
 
6
 
7
  This is an implementation of Flash Attention 3 CUDA kernels with support for attention sinks. The attention sinks implementation was contributed to Flash Attention by the [vLLM team](https://huggingface.co/vllm-project). The [transformers team](https://huggingface.co/transformers-community) packaged the implementation and pre-built it for use with the [kernels library](https://github.com/huggingface/kernels).
8
 
9
+
10
+ ## Quickstart
11
+
12
+ ```bash
13
+ uv run https://huggingface.co/kernels-community/vllm-flash-attn3/raw/main/readme_example.py
14
+ ```
15
+
16
+ ```python
17
+ # /// script
18
+ # requires-python = ">=3.10"
19
+ # dependencies = [
20
+ # "torch",
21
+ # "triton",
22
+ # "numpy",
23
+ # "kernels",
24
+ # ]
25
+ # ///
26
+
27
+ import torch
28
+ from kernels import get_kernel
29
+
30
+ # Load vllm-flash-attn3 via kernels library
31
+ vllm_flash_attn3 = get_kernel("kernels-community/vllm-flash-attn3")
32
+
33
+ # Access Flash Attention function
34
+ flash_attn_func = vllm_flash_attn3.flash_attn_func
35
+
36
+ # Set device and seed for reproducibility
37
+ device = "cuda"
38
+ torch.manual_seed(42)
39
+ torch.cuda.manual_seed(42)
40
+
41
+ # Parameters
42
+ batch_size = 2
43
+ seqlen_q = 128 # Query sequence length
44
+ seqlen_k = 256 # Key sequence length
45
+ nheads = 8 # Number of attention heads
46
+ d = 64 # Head dimension
47
+
48
+ # Create input tensors (Q, K, V)
49
+ q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=torch.bfloat16)
50
+ k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=torch.bfloat16)
51
+ v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=torch.bfloat16)
52
+
53
+ print(f"Query shape: {q.shape}")
54
+ print(f"Key shape: {k.shape}")
55
+ print(f"Value shape: {v.shape}")
56
+
57
+ # Run Flash Attention 3
58
+ output, lse = flash_attn_func(q, k, v, causal=True)
59
+
60
+ print(f"\nOutput shape: {output.shape}")
61
+ print(f"LSE (log-sum-exp) shape: {lse.shape}")
62
+ print(f"\nAttention computation successful!")
63
+ print(f"Output tensor stats - Mean: {output.mean().item():.4f}, Std: {output.std().item():.4f}")
64
+ ```
65
+
66
+
67
  ## How to Use
68
 
69
  When loading your model with transformers, provide this repository id as the source of the attention implementation:
readme_example.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch",
5
+ # "triton",
6
+ # "numpy",
7
+ # "kernels",
8
+ # ]
9
+ # ///
10
+
11
+ import torch
12
+ from kernels import get_kernel
13
+
14
+ # Load vllm-flash-attn3 via kernels library
15
+ vllm_flash_attn3 = get_kernel("kernels-community/vllm-flash-attn3")
16
+
17
+ # Access Flash Attention function
18
+ flash_attn_func = vllm_flash_attn3.flash_attn_func
19
+
20
+ # Set device and seed for reproducibility
21
+ device = "cuda"
22
+ torch.manual_seed(42)
23
+ torch.cuda.manual_seed(42)
24
+
25
+ # Parameters
26
+ batch_size = 2
27
+ seqlen_q = 128 # Query sequence length
28
+ seqlen_k = 256 # Key sequence length
29
+ nheads = 8 # Number of attention heads
30
+ d = 64 # Head dimension
31
+
32
+ # Create input tensors (Q, K, V)
33
+ q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=torch.bfloat16)
34
+ k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=torch.bfloat16)
35
+ v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=torch.bfloat16)
36
+
37
+ print(f"Query shape: {q.shape}")
38
+ print(f"Key shape: {k.shape}")
39
+ print(f"Value shape: {v.shape}")
40
+
41
+ # Run Flash Attention 3
42
+ output, lse = flash_attn_func(q, k, v, causal=True)
43
+
44
+ print(f"\nOutput shape: {output.shape}")
45
+ print(f"LSE (log-sum-exp) shape: {lse.shape}")
46
+ print(f"\nAttention computation successful!")
47
+ print(f"Output tensor stats - Mean: {output.mean().item():.4f}, Std: {output.std().item():.4f}")