drbh
commited on
Commit
·
e52d1ec
0
Parent(s):
feat: mrope position id kernel and reference
Browse files- .gitignore +5 -0
- build.toml +19 -0
- ext-torch/get_position_ids/__init__.py +16 -0
- ext-torch/registration.h +27 -0
- ext-torch/torch_binding.cpp +11 -0
- ext-torch/torch_binding.h +6 -0
- flake.lock +97 -0
- flake.nix +21 -0
- get_position_ids/get_position_ids.cu +167 -0
- test/reference.py +91 -0
- test/test.py +101 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.pytest_cache
|
2 |
+
__pycache__
|
3 |
+
.bak
|
4 |
+
|
5 |
+
# result
|
build.toml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
version = "0.0.1"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
name = "get_position_ids"
|
6 |
+
src = [
|
7 |
+
"ext-torch/registration.h",
|
8 |
+
"ext-torch/torch_binding.cpp",
|
9 |
+
"ext-torch/torch_binding.h",
|
10 |
+
]
|
11 |
+
include = ["."]
|
12 |
+
pyroot = "ext-torch"
|
13 |
+
pyext = ["py", "json"]
|
14 |
+
|
15 |
+
[kernel.get_position_ids]
|
16 |
+
capabilities = ["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0"]
|
17 |
+
src = ["get_position_ids/get_position_ids.cu"]
|
18 |
+
include = ["."]
|
19 |
+
depends = ["torch"]
|
ext-torch/get_position_ids/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
try:
|
4 |
+
from ._ops import ops
|
5 |
+
except ImportError as e:
|
6 |
+
# Fallback for local development.
|
7 |
+
try:
|
8 |
+
import _get_position_ids
|
9 |
+
|
10 |
+
ops = torch.ops._get_position_ids
|
11 |
+
except ImportError:
|
12 |
+
raise e
|
13 |
+
|
14 |
+
def get_position_ids(out: torch.Tensor, input_ids: torch.Tensor, image_grid_thw: torch.Tensor) -> torch.Tensor:
|
15 |
+
ops.get_position_ids(out, input_ids, image_grid_thw)
|
16 |
+
return out
|
ext-torch/registration.h
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <Python.h>
|
4 |
+
|
5 |
+
#define _CONCAT(A, B) A##B
|
6 |
+
#define CONCAT(A, B) _CONCAT(A, B)
|
7 |
+
|
8 |
+
#define _STRINGIFY(A) #A
|
9 |
+
#define STRINGIFY(A) _STRINGIFY(A)
|
10 |
+
|
11 |
+
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
12 |
+
// could be a macro instead of a literal token.
|
13 |
+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
14 |
+
|
15 |
+
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
16 |
+
// could be a macro instead of a literal token.
|
17 |
+
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
18 |
+
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
19 |
+
|
20 |
+
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
21 |
+
// via python's import statement.
|
22 |
+
#define REGISTER_EXTENSION(NAME) \
|
23 |
+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
24 |
+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
25 |
+
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
26 |
+
return PyModule_Create(&module); \
|
27 |
+
}
|
ext-torch/torch_binding.cpp
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "torch_binding.h"
|
2 |
+
#include "registration.h"
|
3 |
+
#include <torch/library.h>
|
4 |
+
|
5 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
6 |
+
ops.def("get_position_ids(Tensor out, Tensor input_ids, Tensor "
|
7 |
+
"image_grid_thw) -> ()");
|
8 |
+
ops.impl("get_position_ids", torch::kCUDA, &get_position_ids);
|
9 |
+
}
|
10 |
+
|
11 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
ext-torch/torch_binding.h
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/torch.h>
|
4 |
+
|
5 |
+
void get_position_ids(torch::Tensor &out, torch::Tensor &input_ids,
|
6 |
+
torch::Tensor &image_grid_thw);
|
flake.lock
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nodes": {
|
3 |
+
"flake-compat": {
|
4 |
+
"locked": {
|
5 |
+
"lastModified": 1733328505,
|
6 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
7 |
+
"owner": "edolstra",
|
8 |
+
"repo": "flake-compat",
|
9 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
10 |
+
"type": "github"
|
11 |
+
},
|
12 |
+
"original": {
|
13 |
+
"owner": "edolstra",
|
14 |
+
"repo": "flake-compat",
|
15 |
+
"type": "github"
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"flake-utils": {
|
19 |
+
"inputs": {
|
20 |
+
"systems": "systems"
|
21 |
+
},
|
22 |
+
"locked": {
|
23 |
+
"lastModified": 1731533236,
|
24 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
25 |
+
"owner": "numtide",
|
26 |
+
"repo": "flake-utils",
|
27 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
28 |
+
"type": "github"
|
29 |
+
},
|
30 |
+
"original": {
|
31 |
+
"owner": "numtide",
|
32 |
+
"repo": "flake-utils",
|
33 |
+
"type": "github"
|
34 |
+
}
|
35 |
+
},
|
36 |
+
"kernel-builder": {
|
37 |
+
"inputs": {
|
38 |
+
"flake-compat": "flake-compat",
|
39 |
+
"flake-utils": "flake-utils",
|
40 |
+
"nixpkgs": "nixpkgs"
|
41 |
+
},
|
42 |
+
"locked": {
|
43 |
+
"lastModified": 1738830746,
|
44 |
+
"narHash": "sha256-WwMzQXiHnkgb+4xEn3mlTOLJ9/7rInn+SJdaC/rQr3M=",
|
45 |
+
"ref": "refs/heads/main",
|
46 |
+
"rev": "21c056ac3575e78d4228e9ed7924cfbe987398d6",
|
47 |
+
"revCount": 73,
|
48 |
+
"submodules": true,
|
49 |
+
"type": "git",
|
50 |
+
"url": "git+ssh://[email protected]/huggingface/kernel-builder"
|
51 |
+
},
|
52 |
+
"original": {
|
53 |
+
"submodules": true,
|
54 |
+
"type": "git",
|
55 |
+
"url": "git+ssh://[email protected]/huggingface/kernel-builder"
|
56 |
+
}
|
57 |
+
},
|
58 |
+
"nixpkgs": {
|
59 |
+
"locked": {
|
60 |
+
"lastModified": 1738247409,
|
61 |
+
"narHash": "sha256-F72dKl9Na6/2N+garOm9qCXPa92GzR8eYSuDra6kbjY=",
|
62 |
+
"owner": "danieldk",
|
63 |
+
"repo": "nixpkgs",
|
64 |
+
"rev": "358f57074b70e3ee9e1dc118151a4f6f81fcd3bb",
|
65 |
+
"type": "github"
|
66 |
+
},
|
67 |
+
"original": {
|
68 |
+
"owner": "danieldk",
|
69 |
+
"ref": "cuda-12.6-for-kernel-builder",
|
70 |
+
"repo": "nixpkgs",
|
71 |
+
"type": "github"
|
72 |
+
}
|
73 |
+
},
|
74 |
+
"root": {
|
75 |
+
"inputs": {
|
76 |
+
"kernel-builder": "kernel-builder"
|
77 |
+
}
|
78 |
+
},
|
79 |
+
"systems": {
|
80 |
+
"locked": {
|
81 |
+
"lastModified": 1681028828,
|
82 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
83 |
+
"owner": "nix-systems",
|
84 |
+
"repo": "default",
|
85 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
86 |
+
"type": "github"
|
87 |
+
},
|
88 |
+
"original": {
|
89 |
+
"owner": "nix-systems",
|
90 |
+
"repo": "default",
|
91 |
+
"type": "github"
|
92 |
+
}
|
93 |
+
}
|
94 |
+
},
|
95 |
+
"root": "root",
|
96 |
+
"version": 7
|
97 |
+
}
|
flake.nix
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for mrope_get_position_ids kernel";
|
3 |
+
inputs = {
|
4 |
+
kernel-builder = {
|
5 |
+
url = "git+ssh://[email protected]/huggingface/kernel-builder";
|
6 |
+
type = "git";
|
7 |
+
submodules = true;
|
8 |
+
};
|
9 |
+
};
|
10 |
+
outputs =
|
11 |
+
{
|
12 |
+
self,
|
13 |
+
kernel-builder,
|
14 |
+
}:
|
15 |
+
kernel-builder.lib.genFlakeOutputs ./.;
|
16 |
+
|
17 |
+
nixConfig = {
|
18 |
+
extra-substituters = [ "https://kernel-builder.cachix.org" ];
|
19 |
+
extra-trusted-public-keys = [ "kernel-builder.cachix.org-1:JCt71vSCqW9tnmOsUigxf7tVLztjYxQ198FI/j8LrFQ=" ];
|
20 |
+
};
|
21 |
+
}
|
get_position_ids/get_position_ids.cu
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda_runtime.h>
|
2 |
+
#include <torch/torch.h>
|
3 |
+
#include <vector>
|
4 |
+
#include <stdio.h>
|
5 |
+
|
6 |
+
#define SPATIAL_MERGE_SIZE 2
|
7 |
+
#define MAX_THREADS_PER_BLOCK 256
|
8 |
+
|
9 |
+
// Kernel: each block processes one vision segment.
|
10 |
+
// For a given segment, the kernel computes image positions by "unraveling" a 1D index
|
11 |
+
// into 3D coordinates (t_idx, h_idx, w_idx) and then adds a per‑segment offset.
|
12 |
+
__global__ void create_image_positions_kernel(
|
13 |
+
const int *image_grid_thw, // shape: [num_segments * 3]
|
14 |
+
const int *segment_offsets, // shape: [num_segments]
|
15 |
+
const int *vision_segment_lengths_cumsum, // shape: [num_segments]
|
16 |
+
int *image_positions) // output: shape [total_image_positions, 3]
|
17 |
+
{
|
18 |
+
int segment_idx = blockIdx.x;
|
19 |
+
|
20 |
+
// Load grid dims for this segment.
|
21 |
+
int t = image_grid_thw[segment_idx * 3];
|
22 |
+
int h = image_grid_thw[segment_idx * 3 + 1] / SPATIAL_MERGE_SIZE;
|
23 |
+
int w = image_grid_thw[segment_idx * 3 + 2] / SPATIAL_MERGE_SIZE;
|
24 |
+
int total_length = t * h * w;
|
25 |
+
|
26 |
+
// Get the starting output position for this segment.
|
27 |
+
int pos_offset = segment_offsets[segment_idx];
|
28 |
+
// The per‐segment offset to add to each coordinate.
|
29 |
+
int offset_add = vision_segment_lengths_cumsum[segment_idx];
|
30 |
+
|
31 |
+
// Process all positions in this segment using a grid–stride loop.
|
32 |
+
for (int pos_idx = threadIdx.x; pos_idx < total_length; pos_idx += blockDim.x)
|
33 |
+
{
|
34 |
+
// Compute the "unraveled" coordinates.
|
35 |
+
int t_idx = pos_idx / (h * w);
|
36 |
+
int h_idx = (pos_idx / w) % h;
|
37 |
+
int w_idx = pos_idx % w;
|
38 |
+
// Write out the 3 coordinates (each image token gets 3 ints).
|
39 |
+
int out_index = (pos_offset + pos_idx) * 3;
|
40 |
+
image_positions[out_index] = t_idx + offset_add;
|
41 |
+
image_positions[out_index + 1] = h_idx + offset_add;
|
42 |
+
image_positions[out_index + 2] = w_idx + offset_add;
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
// This function computes text and image position ids then interleaves them as:
|
47 |
+
// [text segment 0, image segment 0, text segment 1, image segment 1, ...].
|
48 |
+
// If extra text tokens exist after the last vision segment, they are appended at the end.
|
49 |
+
void get_position_ids(
|
50 |
+
torch::Tensor &out, // Final output tensor
|
51 |
+
torch::Tensor &input_ids, // tensor holding token ids
|
52 |
+
torch::Tensor &image_grid_thw) // tensor of shape [num_segments, 3]: each row is [t, h, w]
|
53 |
+
{
|
54 |
+
TORCH_CHECK(input_ids.device().is_cuda(), "input_ids must be a CUDA tensor");
|
55 |
+
TORCH_CHECK(image_grid_thw.device().is_cuda(), "image_grid_thw must be a CUDA tensor");
|
56 |
+
TORCH_CHECK(out.device().is_cuda(), "out must be a CUDA tensor");
|
57 |
+
|
58 |
+
const int input_len = input_ids.size(0);
|
59 |
+
auto options_int = torch::TensorOptions().device(input_ids.device()).dtype(torch::kInt);
|
60 |
+
auto options_long = torch::TensorOptions().device(input_ids.device()).dtype(torch::kLong);
|
61 |
+
|
62 |
+
const int VISION_START_TOKEN_ID = 151652;
|
63 |
+
const int VISION_END_TOKEN_ID = 151653;
|
64 |
+
|
65 |
+
// Find vision segments
|
66 |
+
auto vision_starts_mask = input_ids == VISION_START_TOKEN_ID;
|
67 |
+
auto vision_ends_mask = input_ids == VISION_END_TOKEN_ID;
|
68 |
+
|
69 |
+
auto starts = torch::where(vision_starts_mask)[0].to(torch::kInt);
|
70 |
+
auto ends = torch::where(vision_ends_mask)[0].to(torch::kInt);
|
71 |
+
|
72 |
+
int actual_segments = starts.size(0);
|
73 |
+
auto prev_end = torch::cat({torch::zeros({1}, options_long), ends.slice(0, 0, actual_segments - 1)});
|
74 |
+
|
75 |
+
// Compute text lengths between vision tokens.
|
76 |
+
auto text_lengths_between_vision = starts - prev_end + 1;
|
77 |
+
auto zeros = torch::zeros({1}, options_long);
|
78 |
+
auto widths = image_grid_thw.slice(0, 0, actual_segments).select(1, 2);
|
79 |
+
auto divided_widths = widths / SPATIAL_MERGE_SIZE;
|
80 |
+
auto vision_widths_max = torch::cat({zeros, divided_widths.slice(0, 0, actual_segments - 1)});
|
81 |
+
// The vision segment length is the sum of text tokens plus the (merged) image width.
|
82 |
+
auto vision_segment_lengths = text_lengths_between_vision + vision_widths_max;
|
83 |
+
auto vision_segment_lengths_cumsum = vision_segment_lengths.cumsum(0);
|
84 |
+
auto text_segment_lengths = vision_segment_lengths_cumsum - text_lengths_between_vision;
|
85 |
+
|
86 |
+
// Compute per‑segment starting indices for image positions.
|
87 |
+
std::vector<int> segment_offsets_vec(actual_segments);
|
88 |
+
int total_image_positions = 0;
|
89 |
+
// (Using a CPU copy because the number of segments is small.)
|
90 |
+
auto image_grid_cpu = image_grid_thw.to(torch::kCPU);
|
91 |
+
auto image_grid_accessor = image_grid_cpu.accessor<int, 2>(); // shape: [actual_segments, 3]
|
92 |
+
for (int i = 0; i < actual_segments; i++)
|
93 |
+
{
|
94 |
+
int t = image_grid_accessor[i][0];
|
95 |
+
int h = image_grid_accessor[i][1] / SPATIAL_MERGE_SIZE;
|
96 |
+
int w = image_grid_accessor[i][2] / SPATIAL_MERGE_SIZE;
|
97 |
+
segment_offsets_vec[i] = total_image_positions;
|
98 |
+
total_image_positions += t * h * w;
|
99 |
+
}
|
100 |
+
|
101 |
+
// IMPORTANT: Create the segment_offsets tensor directly so that its memory is on the device.
|
102 |
+
auto segment_offsets_tensor = torch::tensor(segment_offsets_vec, options_int);
|
103 |
+
|
104 |
+
// Make sure vision_segment_lengths_cumsum is int and on the correct device.
|
105 |
+
auto vision_segment_lengths_cumsum_int = vision_segment_lengths_cumsum.to(torch::kInt);
|
106 |
+
|
107 |
+
// Allocate one contiguous output tensor for all image positions.
|
108 |
+
// Each image token produces 3 ints.
|
109 |
+
auto image_positions_tensor = torch::empty({total_image_positions, 3}, options_int);
|
110 |
+
|
111 |
+
// Launch one block per vision segment.
|
112 |
+
int threads = MAX_THREADS_PER_BLOCK;
|
113 |
+
int blocks = actual_segments;
|
114 |
+
create_image_positions_kernel<<<blocks, threads>>>(
|
115 |
+
image_grid_thw.data_ptr<int>(),
|
116 |
+
segment_offsets_tensor.data_ptr<int>(),
|
117 |
+
vision_segment_lengths_cumsum_int.data_ptr<int>(),
|
118 |
+
image_positions_tensor.data_ptr<int>());
|
119 |
+
cudaDeviceSynchronize();
|
120 |
+
cudaError_t error = cudaGetLastError();
|
121 |
+
TORCH_CHECK(error == cudaSuccess, "CUDA error: ", cudaGetErrorString(error));
|
122 |
+
|
123 |
+
// Process text segments on host
|
124 |
+
// Each text segment is computed as a tensor of shape [3, seq_len] with all entries equal to text_segment_lengths[i].
|
125 |
+
std::vector<torch::Tensor> text_positions_list;
|
126 |
+
for (int i = 0; i < actual_segments; i++)
|
127 |
+
{
|
128 |
+
int seq_len = text_lengths_between_vision[i].item<int>();
|
129 |
+
auto text_range = torch::zeros({3, seq_len}, options_long) + text_segment_lengths[i];
|
130 |
+
text_positions_list.push_back(text_range);
|
131 |
+
}
|
132 |
+
|
133 |
+
// Interleave text and image segments
|
134 |
+
std::vector<torch::Tensor> full_positions_list;
|
135 |
+
// For each vision segment, first add its text positions then add its image positions.
|
136 |
+
for (int i = 0; i < actual_segments; i++)
|
137 |
+
{
|
138 |
+
// Append text segment for vision segment i.
|
139 |
+
full_positions_list.push_back(text_positions_list[i]);
|
140 |
+
// Determine the slice boundaries for this vision segment’s image positions.
|
141 |
+
int start = segment_offsets_vec[i];
|
142 |
+
int seg_length = 0;
|
143 |
+
if (i == actual_segments - 1)
|
144 |
+
seg_length = total_image_positions - segment_offsets_vec[i];
|
145 |
+
else
|
146 |
+
seg_length = segment_offsets_vec[i + 1] - segment_offsets_vec[i];
|
147 |
+
// Slice the image_positions_tensor for this segment.
|
148 |
+
// (Kernel output is [total_image_positions, 3]; we want to obtain a tensor of shape [3, seg_length] as in the Python reference.)
|
149 |
+
torch::Tensor image_segment = image_positions_tensor.slice(0, start, start + seg_length).t();
|
150 |
+
full_positions_list.push_back(image_segment);
|
151 |
+
}
|
152 |
+
// If there are extra text tokens after the last vision segment, add them.
|
153 |
+
int full_text_len = input_len - ends[actual_segments - 1].item<int>();
|
154 |
+
if (full_text_len > 0)
|
155 |
+
{
|
156 |
+
int max_s = full_positions_list.back().max().item<int>() + 1;
|
157 |
+
auto extra_text = torch::arange(full_text_len, options_long).view({1, -1}).expand({3, -1}) + max_s;
|
158 |
+
full_positions_list.push_back(extra_text);
|
159 |
+
}
|
160 |
+
|
161 |
+
// Concatenate along dimension 1 (the "position" dimension), then transpose so that the final tensor is [total_tokens, 3].
|
162 |
+
auto full_positions_concatenated = torch::cat(full_positions_list, 1);
|
163 |
+
auto full_positions_concatenated_transposed = full_positions_concatenated.t();
|
164 |
+
|
165 |
+
// Write final result to output tensor.
|
166 |
+
out.copy_(full_positions_concatenated_transposed);
|
167 |
+
}
|
test/reference.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
class DummyModel:
|
5 |
+
spatial_merge_size = 2
|
6 |
+
vision_start_token_id = 151652
|
7 |
+
vision_end_token_id = 151653
|
8 |
+
|
9 |
+
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
|
10 |
+
# modified to first find segments then initialize position ids for each segment
|
11 |
+
# Steps:
|
12 |
+
# locate all vision and text segments
|
13 |
+
# calculate `vision_segment_lengths` for each vision segment to be use as offset
|
14 |
+
# calculate `text_segment_lengths` for each text segment to be used as offset
|
15 |
+
# create position ids for each vision segment based on the image grid
|
16 |
+
# create position ids for each text segment
|
17 |
+
# combine all the position ids
|
18 |
+
# the final segment is the difference between the last vision segment and the end of the input
|
19 |
+
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
|
20 |
+
def get_position_ids(
|
21 |
+
self,
|
22 |
+
input_ids: torch.Tensor,
|
23 |
+
image_grid_thw: Optional[torch.Tensor] = None,
|
24 |
+
) -> torch.Tensor:
|
25 |
+
if image_grid_thw is None:
|
26 |
+
return (
|
27 |
+
torch.arange(input_ids.shape[0], device=input_ids.device)
|
28 |
+
.unsqueeze(1)
|
29 |
+
.repeat(1, 3)
|
30 |
+
)
|
31 |
+
|
32 |
+
spatial_merge_size = self.spatial_merge_size
|
33 |
+
vision_start_token_id = self.vision_start_token_id
|
34 |
+
vision_end_token_id = self.vision_end_token_id
|
35 |
+
device = input_ids.device
|
36 |
+
dtype = input_ids.dtype
|
37 |
+
input_ids_len = input_ids.shape[0]
|
38 |
+
|
39 |
+
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
|
40 |
+
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
|
41 |
+
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
|
42 |
+
prev_vision_end = torch.cat(
|
43 |
+
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
|
44 |
+
)
|
45 |
+
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
|
46 |
+
vision_widths_max = torch.cat(
|
47 |
+
[
|
48 |
+
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
|
49 |
+
image_grid_thw[:-1, 2] // spatial_merge_size,
|
50 |
+
]
|
51 |
+
)
|
52 |
+
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
|
53 |
+
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
|
54 |
+
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
|
55 |
+
|
56 |
+
# create position ids for each vision segment based on the image grid
|
57 |
+
llm_pos_ids_list = []
|
58 |
+
for i, _ in enumerate(vision_segments):
|
59 |
+
t, h, w = (
|
60 |
+
image_grid_thw[i][0],
|
61 |
+
image_grid_thw[i][1] // spatial_merge_size,
|
62 |
+
image_grid_thw[i][2] // spatial_merge_size,
|
63 |
+
)
|
64 |
+
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
|
65 |
+
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
|
66 |
+
w_indices = torch.arange(w, device=device).repeat(t * h)
|
67 |
+
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
|
68 |
+
|
69 |
+
# offset by the position of the last vision segment
|
70 |
+
im = image_position_ids + vision_segment_lengths[i]
|
71 |
+
llm_pos_ids_list.append(im)
|
72 |
+
|
73 |
+
# create position ids for each text segment
|
74 |
+
text_ranges = [
|
75 |
+
torch.zeros(3, seq_len, device=device) + text_segment_lengths[i]
|
76 |
+
for i, seq_len in enumerate(text_lengths_between_vision)
|
77 |
+
]
|
78 |
+
|
79 |
+
full_llm_pos_ids_list = [
|
80 |
+
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
81 |
+
]
|
82 |
+
max_s = full_llm_pos_ids_list[-1].max() + 1
|
83 |
+
final_text_len = input_ids_len - vision_ends[-1]
|
84 |
+
if final_text_len > 0:
|
85 |
+
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
86 |
+
full_llm_pos_ids_list.append(m + max_s)
|
87 |
+
|
88 |
+
position_ids = (
|
89 |
+
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
90 |
+
)
|
91 |
+
return position_ids
|
test/test.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import pytest
|
4 |
+
import get_position_ids # noqa: E402
|
5 |
+
from reference import DummyModel
|
6 |
+
|
7 |
+
# Each configuration includes:
|
8 |
+
# - name: A label for the test case.
|
9 |
+
# - input_ids: A list of token IDs (with vision start (151652) and vision end (151653) tokens embedded).
|
10 |
+
# - grid: A list of [t, h, w] values (one per vision segment).
|
11 |
+
#
|
12 |
+
# The cases below include:
|
13 |
+
# 1. one_segment: a single vision segment.
|
14 |
+
# 2. two_segments: two vision segments with extra text tokens afterward.
|
15 |
+
# 3. three_segments: three vision segments.
|
16 |
+
VISION_CONFIGS = [
|
17 |
+
{
|
18 |
+
"name": "one_segment",
|
19 |
+
"input_ids": (
|
20 |
+
[10] * 5 + # 5 text tokens before vision segment
|
21 |
+
[151652, 151653] + # vision tokens for segment 1
|
22 |
+
[20] * 5 # 5 extra text tokens after vision segment
|
23 |
+
),
|
24 |
+
"grid": [[2, 4, 6]] # one vision segment grid
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"name": "two_segments",
|
28 |
+
"input_ids": (
|
29 |
+
[100] * 5 + # 5 text tokens for segment 1
|
30 |
+
[151652, 151653] + # vision tokens for segment 1
|
31 |
+
[101] * 5 + # 5 text tokens for segment 2
|
32 |
+
[151652, 151653] + # vision tokens for segment 2
|
33 |
+
[102] * 5 # 5 extra text tokens after last vision segment
|
34 |
+
),
|
35 |
+
"grid": [
|
36 |
+
[2, 4, 6], # vision segment 1 grid
|
37 |
+
[3, 4, 6] # vision segment 2 grid
|
38 |
+
],
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"name": "three_segments",
|
42 |
+
"input_ids": (
|
43 |
+
[11] * 5 + # Segment 1: 5 text tokens
|
44 |
+
[151652, 151653] + # vision tokens for segment 1
|
45 |
+
[12] * 6 + # Segment 2: 6 text tokens
|
46 |
+
[151652, 151653] + # vision tokens for segment 2
|
47 |
+
[13] * 7 + # Segment 3: 7 text tokens
|
48 |
+
[151652, 151653] + # vision tokens for segment 3
|
49 |
+
[14] * 8 # 8 extra text tokens after the last vision segment
|
50 |
+
),
|
51 |
+
"grid": [
|
52 |
+
[2, 4, 6], # vision segment 1 grid
|
53 |
+
[3, 6, 6], # vision segment 2 grid
|
54 |
+
[4, 4, 8] # vision segment 3 grid
|
55 |
+
],
|
56 |
+
},
|
57 |
+
]
|
58 |
+
|
59 |
+
CUDA_DEVICES = ["cuda"] # List of CUDA devices; you can add more if needed.
|
60 |
+
SEEDS = [42] # Seeds for reproducibility.
|
61 |
+
DTYPES = [torch.int32] # In our test the tokens and grid are created with int32.
|
62 |
+
|
63 |
+
|
64 |
+
@pytest.mark.parametrize("vision_config",
|
65 |
+
VISION_CONFIGS,
|
66 |
+
ids=[cfg["name"] for cfg in VISION_CONFIGS])
|
67 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
68 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
69 |
+
@torch.inference_mode()
|
70 |
+
def test_get_position_ids(vision_config, seed, device):
|
71 |
+
torch.manual_seed(seed)
|
72 |
+
input_ids = torch.tensor(vision_config["input_ids"], dtype=torch.int32, device=device)
|
73 |
+
image_grid_thw = torch.tensor(vision_config["grid"], dtype=torch.int32, device=device)
|
74 |
+
|
75 |
+
# Create a DummyModel instance from the reference implementation.
|
76 |
+
dummy_model = DummyModel()
|
77 |
+
|
78 |
+
# reference implementation
|
79 |
+
torch.cuda.synchronize()
|
80 |
+
start_ref = time.perf_counter()
|
81 |
+
pos_ids_ref = dummy_model.get_position_ids(input_ids, image_grid_thw)
|
82 |
+
torch.cuda.synchronize()
|
83 |
+
end_ref = time.perf_counter()
|
84 |
+
ref_time = (end_ref - start_ref) * 1000 # ms
|
85 |
+
print(f"\nVision config {vision_config['name']} - Reference time: {ref_time:.2f} ms")
|
86 |
+
# Convert reference output to int32 for comparison (since its returned as a float tensor).
|
87 |
+
pos_ids_ref = pos_ids_ref.to(dtype=torch.int32)
|
88 |
+
|
89 |
+
# kernel implementation
|
90 |
+
torch.cuda.synchronize()
|
91 |
+
start_ext = time.perf_counter()
|
92 |
+
out = torch.empty(pos_ids_ref.shape, dtype=torch.int32, device=device)
|
93 |
+
get_position_ids.get_position_ids(out, input_ids, image_grid_thw)
|
94 |
+
torch.cuda.synchronize()
|
95 |
+
end_ext = time.perf_counter()
|
96 |
+
ext_time = (end_ext - start_ext) * 1000 # ms
|
97 |
+
print(f"Vision config {vision_config['name']} - Extension time: {ext_time:.2f} ms\n")
|
98 |
+
ext_out = out.clone()
|
99 |
+
|
100 |
+
# verify the results
|
101 |
+
torch.testing.assert_close(ext_out.cpu(), pos_ids_ref.cpu())
|