Update resampler.py
Browse files- resampler.py +2 -2
resampler.py
CHANGED
|
@@ -246,7 +246,7 @@ class Resampler(nn.Module):
|
|
| 246 |
|
| 247 |
|
| 248 |
def batch_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
|
| 249 |
-
bs = k.shape[
|
| 250 |
|
| 251 |
if pos_embed_temporal:
|
| 252 |
# temporal 维度折叠
|
|
@@ -281,7 +281,7 @@ class Resampler(nn.Module):
|
|
| 281 |
|
| 282 |
|
| 283 |
def foreach_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
|
| 284 |
-
bs = k.shape[
|
| 285 |
|
| 286 |
if pos_embed_temporal:
|
| 287 |
k += torch.stack(pos_embed_temporal, dim=0)
|
|
|
|
| 246 |
|
| 247 |
|
| 248 |
def batch_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
|
| 249 |
+
bs = k.shape[1]
|
| 250 |
|
| 251 |
if pos_embed_temporal:
|
| 252 |
# temporal 维度折叠
|
|
|
|
| 281 |
|
| 282 |
|
| 283 |
def foreach_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
|
| 284 |
+
bs = k.shape[1]
|
| 285 |
|
| 286 |
if pos_embed_temporal:
|
| 287 |
k += torch.stack(pos_embed_temporal, dim=0)
|