drbh
commited on
Commit
·
6324f5a
1
Parent(s):
62efba7
feat: add quick start and readme example
Browse files- README.md +58 -0
- readme_example.py +47 -0
README.md
CHANGED
@@ -6,6 +6,64 @@ license: apache-2.0
|
|
6 |
|
7 |
This is an implementation of Flash Attention 3 CUDA kernels with support for attention sinks. The attention sinks implementation was contributed to Flash Attention by the [vLLM team](https://huggingface.co/vllm-project). The [transformers team](https://huggingface.co/transformers-community) packaged the implementation and pre-built it for use with the [kernels library](https://github.com/huggingface/kernels).
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
## How to Use
|
10 |
|
11 |
When loading your model with transformers, provide this repository id as the source of the attention implementation:
|
|
|
6 |
|
7 |
This is an implementation of Flash Attention 3 CUDA kernels with support for attention sinks. The attention sinks implementation was contributed to Flash Attention by the [vLLM team](https://huggingface.co/vllm-project). The [transformers team](https://huggingface.co/transformers-community) packaged the implementation and pre-built it for use with the [kernels library](https://github.com/huggingface/kernels).
|
8 |
|
9 |
+
|
10 |
+
## Quickstart
|
11 |
+
|
12 |
+
```bash
|
13 |
+
uv run https://huggingface.co/kernels-community/vllm-flash-attn3/raw/main/readme_example.py
|
14 |
+
```
|
15 |
+
|
16 |
+
```python
|
17 |
+
# /// script
|
18 |
+
# requires-python = ">=3.10"
|
19 |
+
# dependencies = [
|
20 |
+
# "torch",
|
21 |
+
# "triton",
|
22 |
+
# "numpy",
|
23 |
+
# "kernels",
|
24 |
+
# ]
|
25 |
+
# ///
|
26 |
+
|
27 |
+
import torch
|
28 |
+
from kernels import get_kernel
|
29 |
+
|
30 |
+
# Load vllm-flash-attn3 via kernels library
|
31 |
+
vllm_flash_attn3 = get_kernel("kernels-community/vllm-flash-attn3")
|
32 |
+
|
33 |
+
# Access Flash Attention function
|
34 |
+
flash_attn_func = vllm_flash_attn3.flash_attn_func
|
35 |
+
|
36 |
+
# Set device and seed for reproducibility
|
37 |
+
device = "cuda"
|
38 |
+
torch.manual_seed(42)
|
39 |
+
torch.cuda.manual_seed(42)
|
40 |
+
|
41 |
+
# Parameters
|
42 |
+
batch_size = 2
|
43 |
+
seqlen_q = 128 # Query sequence length
|
44 |
+
seqlen_k = 256 # Key sequence length
|
45 |
+
nheads = 8 # Number of attention heads
|
46 |
+
d = 64 # Head dimension
|
47 |
+
|
48 |
+
# Create input tensors (Q, K, V)
|
49 |
+
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=torch.bfloat16)
|
50 |
+
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=torch.bfloat16)
|
51 |
+
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=torch.bfloat16)
|
52 |
+
|
53 |
+
print(f"Query shape: {q.shape}")
|
54 |
+
print(f"Key shape: {k.shape}")
|
55 |
+
print(f"Value shape: {v.shape}")
|
56 |
+
|
57 |
+
# Run Flash Attention 3
|
58 |
+
output, lse = flash_attn_func(q, k, v, causal=True)
|
59 |
+
|
60 |
+
print(f"\nOutput shape: {output.shape}")
|
61 |
+
print(f"LSE (log-sum-exp) shape: {lse.shape}")
|
62 |
+
print(f"\nAttention computation successful!")
|
63 |
+
print(f"Output tensor stats - Mean: {output.mean().item():.4f}, Std: {output.std().item():.4f}")
|
64 |
+
```
|
65 |
+
|
66 |
+
|
67 |
## How to Use
|
68 |
|
69 |
When loading your model with transformers, provide this repository id as the source of the attention implementation:
|
readme_example.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "torch",
|
5 |
+
# "triton",
|
6 |
+
# "numpy",
|
7 |
+
# "kernels",
|
8 |
+
# ]
|
9 |
+
# ///
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from kernels import get_kernel
|
13 |
+
|
14 |
+
# Load vllm-flash-attn3 via kernels library
|
15 |
+
vllm_flash_attn3 = get_kernel("kernels-community/vllm-flash-attn3")
|
16 |
+
|
17 |
+
# Access Flash Attention function
|
18 |
+
flash_attn_func = vllm_flash_attn3.flash_attn_func
|
19 |
+
|
20 |
+
# Set device and seed for reproducibility
|
21 |
+
device = "cuda"
|
22 |
+
torch.manual_seed(42)
|
23 |
+
torch.cuda.manual_seed(42)
|
24 |
+
|
25 |
+
# Parameters
|
26 |
+
batch_size = 2
|
27 |
+
seqlen_q = 128 # Query sequence length
|
28 |
+
seqlen_k = 256 # Key sequence length
|
29 |
+
nheads = 8 # Number of attention heads
|
30 |
+
d = 64 # Head dimension
|
31 |
+
|
32 |
+
# Create input tensors (Q, K, V)
|
33 |
+
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=torch.bfloat16)
|
34 |
+
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=torch.bfloat16)
|
35 |
+
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=torch.bfloat16)
|
36 |
+
|
37 |
+
print(f"Query shape: {q.shape}")
|
38 |
+
print(f"Key shape: {k.shape}")
|
39 |
+
print(f"Value shape: {v.shape}")
|
40 |
+
|
41 |
+
# Run Flash Attention 3
|
42 |
+
output, lse = flash_attn_func(q, k, v, causal=True)
|
43 |
+
|
44 |
+
print(f"\nOutput shape: {output.shape}")
|
45 |
+
print(f"LSE (log-sum-exp) shape: {lse.shape}")
|
46 |
+
print(f"\nAttention computation successful!")
|
47 |
+
print(f"Output tensor stats - Mean: {output.mean().item():.4f}, Std: {output.std().item():.4f}")
|