drbh commited on
Commit
e52d1ec
·
0 Parent(s):

feat: mrope position id kernel and reference

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .pytest_cache
2
+ __pycache__
3
+ .bak
4
+
5
+ # result
build.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ version = "0.0.1"
3
+
4
+ [torch]
5
+ name = "get_position_ids"
6
+ src = [
7
+ "ext-torch/registration.h",
8
+ "ext-torch/torch_binding.cpp",
9
+ "ext-torch/torch_binding.h",
10
+ ]
11
+ include = ["."]
12
+ pyroot = "ext-torch"
13
+ pyext = ["py", "json"]
14
+
15
+ [kernel.get_position_ids]
16
+ capabilities = ["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0"]
17
+ src = ["get_position_ids/get_position_ids.cu"]
18
+ include = ["."]
19
+ depends = ["torch"]
ext-torch/get_position_ids/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ try:
4
+ from ._ops import ops
5
+ except ImportError as e:
6
+ # Fallback for local development.
7
+ try:
8
+ import _get_position_ids
9
+
10
+ ops = torch.ops._get_position_ids
11
+ except ImportError:
12
+ raise e
13
+
14
+ def get_position_ids(out: torch.Tensor, input_ids: torch.Tensor, image_grid_thw: torch.Tensor) -> torch.Tensor:
15
+ ops.get_position_ids(out, input_ids, image_grid_thw)
16
+ return out
ext-torch/registration.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <Python.h>
4
+
5
+ #define _CONCAT(A, B) A##B
6
+ #define CONCAT(A, B) _CONCAT(A, B)
7
+
8
+ #define _STRINGIFY(A) #A
9
+ #define STRINGIFY(A) _STRINGIFY(A)
10
+
11
+ // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
12
+ // could be a macro instead of a literal token.
13
+ #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
14
+
15
+ // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
16
+ // could be a macro instead of a literal token.
17
+ #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
18
+ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
19
+
20
+ // REGISTER_EXTENSION allows the shared library to be loaded and initialized
21
+ // via python's import statement.
22
+ #define REGISTER_EXTENSION(NAME) \
23
+ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
24
+ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
25
+ STRINGIFY(NAME), nullptr, 0, nullptr}; \
26
+ return PyModule_Create(&module); \
27
+ }
ext-torch/torch_binding.cpp ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "torch_binding.h"
2
+ #include "registration.h"
3
+ #include <torch/library.h>
4
+
5
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
6
+ ops.def("get_position_ids(Tensor out, Tensor input_ids, Tensor "
7
+ "image_grid_thw) -> ()");
8
+ ops.impl("get_position_ids", torch::kCUDA, &get_position_ids);
9
+ }
10
+
11
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
ext-torch/torch_binding.h ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void get_position_ids(torch::Tensor &out, torch::Tensor &input_ids,
6
+ torch::Tensor &image_grid_thw);
flake.lock ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1733328505,
6
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-utils": {
19
+ "inputs": {
20
+ "systems": "systems"
21
+ },
22
+ "locked": {
23
+ "lastModified": 1731533236,
24
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
25
+ "owner": "numtide",
26
+ "repo": "flake-utils",
27
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
28
+ "type": "github"
29
+ },
30
+ "original": {
31
+ "owner": "numtide",
32
+ "repo": "flake-utils",
33
+ "type": "github"
34
+ }
35
+ },
36
+ "kernel-builder": {
37
+ "inputs": {
38
+ "flake-compat": "flake-compat",
39
+ "flake-utils": "flake-utils",
40
+ "nixpkgs": "nixpkgs"
41
+ },
42
+ "locked": {
43
+ "lastModified": 1738830746,
44
+ "narHash": "sha256-WwMzQXiHnkgb+4xEn3mlTOLJ9/7rInn+SJdaC/rQr3M=",
45
+ "ref": "refs/heads/main",
46
+ "rev": "21c056ac3575e78d4228e9ed7924cfbe987398d6",
47
+ "revCount": 73,
48
+ "submodules": true,
49
+ "type": "git",
50
+ "url": "git+ssh://[email protected]/huggingface/kernel-builder"
51
+ },
52
+ "original": {
53
+ "submodules": true,
54
+ "type": "git",
55
+ "url": "git+ssh://[email protected]/huggingface/kernel-builder"
56
+ }
57
+ },
58
+ "nixpkgs": {
59
+ "locked": {
60
+ "lastModified": 1738247409,
61
+ "narHash": "sha256-F72dKl9Na6/2N+garOm9qCXPa92GzR8eYSuDra6kbjY=",
62
+ "owner": "danieldk",
63
+ "repo": "nixpkgs",
64
+ "rev": "358f57074b70e3ee9e1dc118151a4f6f81fcd3bb",
65
+ "type": "github"
66
+ },
67
+ "original": {
68
+ "owner": "danieldk",
69
+ "ref": "cuda-12.6-for-kernel-builder",
70
+ "repo": "nixpkgs",
71
+ "type": "github"
72
+ }
73
+ },
74
+ "root": {
75
+ "inputs": {
76
+ "kernel-builder": "kernel-builder"
77
+ }
78
+ },
79
+ "systems": {
80
+ "locked": {
81
+ "lastModified": 1681028828,
82
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
83
+ "owner": "nix-systems",
84
+ "repo": "default",
85
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
86
+ "type": "github"
87
+ },
88
+ "original": {
89
+ "owner": "nix-systems",
90
+ "repo": "default",
91
+ "type": "github"
92
+ }
93
+ }
94
+ },
95
+ "root": "root",
96
+ "version": 7
97
+ }
flake.nix ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for mrope_get_position_ids kernel";
3
+ inputs = {
4
+ kernel-builder = {
5
+ url = "git+ssh://[email protected]/huggingface/kernel-builder";
6
+ type = "git";
7
+ submodules = true;
8
+ };
9
+ };
10
+ outputs =
11
+ {
12
+ self,
13
+ kernel-builder,
14
+ }:
15
+ kernel-builder.lib.genFlakeOutputs ./.;
16
+
17
+ nixConfig = {
18
+ extra-substituters = [ "https://kernel-builder.cachix.org" ];
19
+ extra-trusted-public-keys = [ "kernel-builder.cachix.org-1:JCt71vSCqW9tnmOsUigxf7tVLztjYxQ198FI/j8LrFQ=" ];
20
+ };
21
+ }
get_position_ids/get_position_ids.cu ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda_runtime.h>
2
+ #include <torch/torch.h>
3
+ #include <vector>
4
+ #include <stdio.h>
5
+
6
+ #define SPATIAL_MERGE_SIZE 2
7
+ #define MAX_THREADS_PER_BLOCK 256
8
+
9
+ // Kernel: each block processes one vision segment.
10
+ // For a given segment, the kernel computes image positions by "unraveling" a 1D index
11
+ // into 3D coordinates (t_idx, h_idx, w_idx) and then adds a per‑segment offset.
12
+ __global__ void create_image_positions_kernel(
13
+ const int *image_grid_thw, // shape: [num_segments * 3]
14
+ const int *segment_offsets, // shape: [num_segments]
15
+ const int *vision_segment_lengths_cumsum, // shape: [num_segments]
16
+ int *image_positions) // output: shape [total_image_positions, 3]
17
+ {
18
+ int segment_idx = blockIdx.x;
19
+
20
+ // Load grid dims for this segment.
21
+ int t = image_grid_thw[segment_idx * 3];
22
+ int h = image_grid_thw[segment_idx * 3 + 1] / SPATIAL_MERGE_SIZE;
23
+ int w = image_grid_thw[segment_idx * 3 + 2] / SPATIAL_MERGE_SIZE;
24
+ int total_length = t * h * w;
25
+
26
+ // Get the starting output position for this segment.
27
+ int pos_offset = segment_offsets[segment_idx];
28
+ // The per‐segment offset to add to each coordinate.
29
+ int offset_add = vision_segment_lengths_cumsum[segment_idx];
30
+
31
+ // Process all positions in this segment using a grid–stride loop.
32
+ for (int pos_idx = threadIdx.x; pos_idx < total_length; pos_idx += blockDim.x)
33
+ {
34
+ // Compute the "unraveled" coordinates.
35
+ int t_idx = pos_idx / (h * w);
36
+ int h_idx = (pos_idx / w) % h;
37
+ int w_idx = pos_idx % w;
38
+ // Write out the 3 coordinates (each image token gets 3 ints).
39
+ int out_index = (pos_offset + pos_idx) * 3;
40
+ image_positions[out_index] = t_idx + offset_add;
41
+ image_positions[out_index + 1] = h_idx + offset_add;
42
+ image_positions[out_index + 2] = w_idx + offset_add;
43
+ }
44
+ }
45
+
46
+ // This function computes text and image position ids then interleaves them as:
47
+ // [text segment 0, image segment 0, text segment 1, image segment 1, ...].
48
+ // If extra text tokens exist after the last vision segment, they are appended at the end.
49
+ void get_position_ids(
50
+ torch::Tensor &out, // Final output tensor
51
+ torch::Tensor &input_ids, // tensor holding token ids
52
+ torch::Tensor &image_grid_thw) // tensor of shape [num_segments, 3]: each row is [t, h, w]
53
+ {
54
+ TORCH_CHECK(input_ids.device().is_cuda(), "input_ids must be a CUDA tensor");
55
+ TORCH_CHECK(image_grid_thw.device().is_cuda(), "image_grid_thw must be a CUDA tensor");
56
+ TORCH_CHECK(out.device().is_cuda(), "out must be a CUDA tensor");
57
+
58
+ const int input_len = input_ids.size(0);
59
+ auto options_int = torch::TensorOptions().device(input_ids.device()).dtype(torch::kInt);
60
+ auto options_long = torch::TensorOptions().device(input_ids.device()).dtype(torch::kLong);
61
+
62
+ const int VISION_START_TOKEN_ID = 151652;
63
+ const int VISION_END_TOKEN_ID = 151653;
64
+
65
+ // Find vision segments
66
+ auto vision_starts_mask = input_ids == VISION_START_TOKEN_ID;
67
+ auto vision_ends_mask = input_ids == VISION_END_TOKEN_ID;
68
+
69
+ auto starts = torch::where(vision_starts_mask)[0].to(torch::kInt);
70
+ auto ends = torch::where(vision_ends_mask)[0].to(torch::kInt);
71
+
72
+ int actual_segments = starts.size(0);
73
+ auto prev_end = torch::cat({torch::zeros({1}, options_long), ends.slice(0, 0, actual_segments - 1)});
74
+
75
+ // Compute text lengths between vision tokens.
76
+ auto text_lengths_between_vision = starts - prev_end + 1;
77
+ auto zeros = torch::zeros({1}, options_long);
78
+ auto widths = image_grid_thw.slice(0, 0, actual_segments).select(1, 2);
79
+ auto divided_widths = widths / SPATIAL_MERGE_SIZE;
80
+ auto vision_widths_max = torch::cat({zeros, divided_widths.slice(0, 0, actual_segments - 1)});
81
+ // The vision segment length is the sum of text tokens plus the (merged) image width.
82
+ auto vision_segment_lengths = text_lengths_between_vision + vision_widths_max;
83
+ auto vision_segment_lengths_cumsum = vision_segment_lengths.cumsum(0);
84
+ auto text_segment_lengths = vision_segment_lengths_cumsum - text_lengths_between_vision;
85
+
86
+ // Compute per‑segment starting indices for image positions.
87
+ std::vector<int> segment_offsets_vec(actual_segments);
88
+ int total_image_positions = 0;
89
+ // (Using a CPU copy because the number of segments is small.)
90
+ auto image_grid_cpu = image_grid_thw.to(torch::kCPU);
91
+ auto image_grid_accessor = image_grid_cpu.accessor<int, 2>(); // shape: [actual_segments, 3]
92
+ for (int i = 0; i < actual_segments; i++)
93
+ {
94
+ int t = image_grid_accessor[i][0];
95
+ int h = image_grid_accessor[i][1] / SPATIAL_MERGE_SIZE;
96
+ int w = image_grid_accessor[i][2] / SPATIAL_MERGE_SIZE;
97
+ segment_offsets_vec[i] = total_image_positions;
98
+ total_image_positions += t * h * w;
99
+ }
100
+
101
+ // IMPORTANT: Create the segment_offsets tensor directly so that its memory is on the device.
102
+ auto segment_offsets_tensor = torch::tensor(segment_offsets_vec, options_int);
103
+
104
+ // Make sure vision_segment_lengths_cumsum is int and on the correct device.
105
+ auto vision_segment_lengths_cumsum_int = vision_segment_lengths_cumsum.to(torch::kInt);
106
+
107
+ // Allocate one contiguous output tensor for all image positions.
108
+ // Each image token produces 3 ints.
109
+ auto image_positions_tensor = torch::empty({total_image_positions, 3}, options_int);
110
+
111
+ // Launch one block per vision segment.
112
+ int threads = MAX_THREADS_PER_BLOCK;
113
+ int blocks = actual_segments;
114
+ create_image_positions_kernel<<<blocks, threads>>>(
115
+ image_grid_thw.data_ptr<int>(),
116
+ segment_offsets_tensor.data_ptr<int>(),
117
+ vision_segment_lengths_cumsum_int.data_ptr<int>(),
118
+ image_positions_tensor.data_ptr<int>());
119
+ cudaDeviceSynchronize();
120
+ cudaError_t error = cudaGetLastError();
121
+ TORCH_CHECK(error == cudaSuccess, "CUDA error: ", cudaGetErrorString(error));
122
+
123
+ // Process text segments on host
124
+ // Each text segment is computed as a tensor of shape [3, seq_len] with all entries equal to text_segment_lengths[i].
125
+ std::vector<torch::Tensor> text_positions_list;
126
+ for (int i = 0; i < actual_segments; i++)
127
+ {
128
+ int seq_len = text_lengths_between_vision[i].item<int>();
129
+ auto text_range = torch::zeros({3, seq_len}, options_long) + text_segment_lengths[i];
130
+ text_positions_list.push_back(text_range);
131
+ }
132
+
133
+ // Interleave text and image segments
134
+ std::vector<torch::Tensor> full_positions_list;
135
+ // For each vision segment, first add its text positions then add its image positions.
136
+ for (int i = 0; i < actual_segments; i++)
137
+ {
138
+ // Append text segment for vision segment i.
139
+ full_positions_list.push_back(text_positions_list[i]);
140
+ // Determine the slice boundaries for this vision segment’s image positions.
141
+ int start = segment_offsets_vec[i];
142
+ int seg_length = 0;
143
+ if (i == actual_segments - 1)
144
+ seg_length = total_image_positions - segment_offsets_vec[i];
145
+ else
146
+ seg_length = segment_offsets_vec[i + 1] - segment_offsets_vec[i];
147
+ // Slice the image_positions_tensor for this segment.
148
+ // (Kernel output is [total_image_positions, 3]; we want to obtain a tensor of shape [3, seg_length] as in the Python reference.)
149
+ torch::Tensor image_segment = image_positions_tensor.slice(0, start, start + seg_length).t();
150
+ full_positions_list.push_back(image_segment);
151
+ }
152
+ // If there are extra text tokens after the last vision segment, add them.
153
+ int full_text_len = input_len - ends[actual_segments - 1].item<int>();
154
+ if (full_text_len > 0)
155
+ {
156
+ int max_s = full_positions_list.back().max().item<int>() + 1;
157
+ auto extra_text = torch::arange(full_text_len, options_long).view({1, -1}).expand({3, -1}) + max_s;
158
+ full_positions_list.push_back(extra_text);
159
+ }
160
+
161
+ // Concatenate along dimension 1 (the "position" dimension), then transpose so that the final tensor is [total_tokens, 3].
162
+ auto full_positions_concatenated = torch::cat(full_positions_list, 1);
163
+ auto full_positions_concatenated_transposed = full_positions_concatenated.t();
164
+
165
+ // Write final result to output tensor.
166
+ out.copy_(full_positions_concatenated_transposed);
167
+ }
test/reference.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+
4
+ class DummyModel:
5
+ spatial_merge_size = 2
6
+ vision_start_token_id = 151652
7
+ vision_end_token_id = 151653
8
+
9
+ # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
10
+ # modified to first find segments then initialize position ids for each segment
11
+ # Steps:
12
+ # locate all vision and text segments
13
+ # calculate `vision_segment_lengths` for each vision segment to be use as offset
14
+ # calculate `text_segment_lengths` for each text segment to be used as offset
15
+ # create position ids for each vision segment based on the image grid
16
+ # create position ids for each text segment
17
+ # combine all the position ids
18
+ # the final segment is the difference between the last vision segment and the end of the input
19
+ # combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
20
+ def get_position_ids(
21
+ self,
22
+ input_ids: torch.Tensor,
23
+ image_grid_thw: Optional[torch.Tensor] = None,
24
+ ) -> torch.Tensor:
25
+ if image_grid_thw is None:
26
+ return (
27
+ torch.arange(input_ids.shape[0], device=input_ids.device)
28
+ .unsqueeze(1)
29
+ .repeat(1, 3)
30
+ )
31
+
32
+ spatial_merge_size = self.spatial_merge_size
33
+ vision_start_token_id = self.vision_start_token_id
34
+ vision_end_token_id = self.vision_end_token_id
35
+ device = input_ids.device
36
+ dtype = input_ids.dtype
37
+ input_ids_len = input_ids.shape[0]
38
+
39
+ vision_starts = torch.where(input_ids == vision_start_token_id)[0]
40
+ vision_ends = torch.where(input_ids == vision_end_token_id)[0]
41
+ vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
42
+ prev_vision_end = torch.cat(
43
+ [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
44
+ )
45
+ text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
46
+ vision_widths_max = torch.cat(
47
+ [
48
+ torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
49
+ image_grid_thw[:-1, 2] // spatial_merge_size,
50
+ ]
51
+ )
52
+ vision_segment_lengths = vision_widths_max + text_lengths_between_vision
53
+ vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
54
+ text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
55
+
56
+ # create position ids for each vision segment based on the image grid
57
+ llm_pos_ids_list = []
58
+ for i, _ in enumerate(vision_segments):
59
+ t, h, w = (
60
+ image_grid_thw[i][0],
61
+ image_grid_thw[i][1] // spatial_merge_size,
62
+ image_grid_thw[i][2] // spatial_merge_size,
63
+ )
64
+ t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
65
+ h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
66
+ w_indices = torch.arange(w, device=device).repeat(t * h)
67
+ image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
68
+
69
+ # offset by the position of the last vision segment
70
+ im = image_position_ids + vision_segment_lengths[i]
71
+ llm_pos_ids_list.append(im)
72
+
73
+ # create position ids for each text segment
74
+ text_ranges = [
75
+ torch.zeros(3, seq_len, device=device) + text_segment_lengths[i]
76
+ for i, seq_len in enumerate(text_lengths_between_vision)
77
+ ]
78
+
79
+ full_llm_pos_ids_list = [
80
+ item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
81
+ ]
82
+ max_s = full_llm_pos_ids_list[-1].max() + 1
83
+ final_text_len = input_ids_len - vision_ends[-1]
84
+ if final_text_len > 0:
85
+ m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
86
+ full_llm_pos_ids_list.append(m + max_s)
87
+
88
+ position_ids = (
89
+ torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
90
+ )
91
+ return position_ids
test/test.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import pytest
4
+ import get_position_ids # noqa: E402
5
+ from reference import DummyModel
6
+
7
+ # Each configuration includes:
8
+ # - name: A label for the test case.
9
+ # - input_ids: A list of token IDs (with vision start (151652) and vision end (151653) tokens embedded).
10
+ # - grid: A list of [t, h, w] values (one per vision segment).
11
+ #
12
+ # The cases below include:
13
+ # 1. one_segment: a single vision segment.
14
+ # 2. two_segments: two vision segments with extra text tokens afterward.
15
+ # 3. three_segments: three vision segments.
16
+ VISION_CONFIGS = [
17
+ {
18
+ "name": "one_segment",
19
+ "input_ids": (
20
+ [10] * 5 + # 5 text tokens before vision segment
21
+ [151652, 151653] + # vision tokens for segment 1
22
+ [20] * 5 # 5 extra text tokens after vision segment
23
+ ),
24
+ "grid": [[2, 4, 6]] # one vision segment grid
25
+ },
26
+ {
27
+ "name": "two_segments",
28
+ "input_ids": (
29
+ [100] * 5 + # 5 text tokens for segment 1
30
+ [151652, 151653] + # vision tokens for segment 1
31
+ [101] * 5 + # 5 text tokens for segment 2
32
+ [151652, 151653] + # vision tokens for segment 2
33
+ [102] * 5 # 5 extra text tokens after last vision segment
34
+ ),
35
+ "grid": [
36
+ [2, 4, 6], # vision segment 1 grid
37
+ [3, 4, 6] # vision segment 2 grid
38
+ ],
39
+ },
40
+ {
41
+ "name": "three_segments",
42
+ "input_ids": (
43
+ [11] * 5 + # Segment 1: 5 text tokens
44
+ [151652, 151653] + # vision tokens for segment 1
45
+ [12] * 6 + # Segment 2: 6 text tokens
46
+ [151652, 151653] + # vision tokens for segment 2
47
+ [13] * 7 + # Segment 3: 7 text tokens
48
+ [151652, 151653] + # vision tokens for segment 3
49
+ [14] * 8 # 8 extra text tokens after the last vision segment
50
+ ),
51
+ "grid": [
52
+ [2, 4, 6], # vision segment 1 grid
53
+ [3, 6, 6], # vision segment 2 grid
54
+ [4, 4, 8] # vision segment 3 grid
55
+ ],
56
+ },
57
+ ]
58
+
59
+ CUDA_DEVICES = ["cuda"] # List of CUDA devices; you can add more if needed.
60
+ SEEDS = [42] # Seeds for reproducibility.
61
+ DTYPES = [torch.int32] # In our test the tokens and grid are created with int32.
62
+
63
+
64
+ @pytest.mark.parametrize("vision_config",
65
+ VISION_CONFIGS,
66
+ ids=[cfg["name"] for cfg in VISION_CONFIGS])
67
+ @pytest.mark.parametrize("seed", SEEDS)
68
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
69
+ @torch.inference_mode()
70
+ def test_get_position_ids(vision_config, seed, device):
71
+ torch.manual_seed(seed)
72
+ input_ids = torch.tensor(vision_config["input_ids"], dtype=torch.int32, device=device)
73
+ image_grid_thw = torch.tensor(vision_config["grid"], dtype=torch.int32, device=device)
74
+
75
+ # Create a DummyModel instance from the reference implementation.
76
+ dummy_model = DummyModel()
77
+
78
+ # reference implementation
79
+ torch.cuda.synchronize()
80
+ start_ref = time.perf_counter()
81
+ pos_ids_ref = dummy_model.get_position_ids(input_ids, image_grid_thw)
82
+ torch.cuda.synchronize()
83
+ end_ref = time.perf_counter()
84
+ ref_time = (end_ref - start_ref) * 1000 # ms
85
+ print(f"\nVision config {vision_config['name']} - Reference time: {ref_time:.2f} ms")
86
+ # Convert reference output to int32 for comparison (since its returned as a float tensor).
87
+ pos_ids_ref = pos_ids_ref.to(dtype=torch.int32)
88
+
89
+ # kernel implementation
90
+ torch.cuda.synchronize()
91
+ start_ext = time.perf_counter()
92
+ out = torch.empty(pos_ids_ref.shape, dtype=torch.int32, device=device)
93
+ get_position_ids.get_position_ids(out, input_ids, image_grid_thw)
94
+ torch.cuda.synchronize()
95
+ end_ext = time.perf_counter()
96
+ ext_time = (end_ext - start_ext) * 1000 # ms
97
+ print(f"Vision config {vision_config['name']} - Extension time: {ext_time:.2f} ms\n")
98
+ ext_out = out.clone()
99
+
100
+ # verify the results
101
+ torch.testing.assert_close(ext_out.cpu(), pos_ids_ref.cpu())