qianyuchen commited on
Commit
c331656
·
verified ·
1 Parent(s): d7bbb09

Update resampler.py

Browse files
Files changed (1) hide show
  1. resampler.py +2 -2
resampler.py CHANGED
@@ -246,7 +246,7 @@ class Resampler(nn.Module):
246
 
247
 
248
  def batch_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
249
- bs = k.shape[0]
250
 
251
  if pos_embed_temporal:
252
  # temporal 维度折叠
@@ -281,7 +281,7 @@ class Resampler(nn.Module):
281
 
282
 
283
  def foreach_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
284
- bs = k.shape[0]
285
 
286
  if pos_embed_temporal:
287
  k += torch.stack(pos_embed_temporal, dim=0)
 
246
 
247
 
248
  def batch_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
249
+ bs = k.shape[1]
250
 
251
  if pos_embed_temporal:
252
  # temporal 维度折叠
 
281
 
282
 
283
  def foreach_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
284
+ bs = k.shape[1]
285
 
286
  if pos_embed_temporal:
287
  k += torch.stack(pos_embed_temporal, dim=0)