TryingHard commited on
Commit
fd846ba
·
verified ·
1 Parent(s): 736d924

Update modeling_ovis2_5.py

Browse files
Files changed (1) hide show
  1. modeling_ovis2_5.py +48 -1
modeling_ovis2_5.py CHANGED
@@ -894,7 +894,54 @@ class Ovis2_5(OvisPreTrainedModel):
894
  pixel_values=kwargs.pop('pixel_values', None),
895
  grid_thws=kwargs.pop('grid_thws', None)
896
  )
897
- return self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
898
 
899
 
900
  AutoConfig.register('siglip2_navit', Siglip2NavitConfig)
 
894
  pixel_values=kwargs.pop('pixel_values', None),
895
  grid_thws=kwargs.pop('grid_thws', None)
896
  )
897
+ enable_thinking = kwargs.pop('enable_thinking', False)
898
+ enable_thinking_budget = kwargs.pop('enable_thinking_budget', False)
899
+ thinking_budget = kwargs.pop('thinking_budget', 1024)
900
+
901
+ if enable_thinking and enable_thinking_budget:
902
+ actual_max_new_tokens = kwargs['max_new_tokens']
903
+ kwargs['max_new_tokens'] = thinking_budget
904
+ generated_ids = self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
905
+ output_ids = generated_ids
906
+ output_ids_list = generated_ids[0]
907
+
908
+ # check if the generation has already finished (151645 is <|im_end|>)
909
+ if 151645 not in output_ids_list:
910
+ # check if the thinking process has finished (151668 is </think>)
911
+ # and prepare the second model input
912
+ if 151668 not in output_ids_list:
913
+ print("thinking budget is reached")
914
+ early_stopping_text = "\n\nConsidering the limited time by the user, I have to give the solution based on the thinking directly now.\n</think>\n\n"
915
+ early_stopping_ids = self.text_tokenizer(early_stopping_text, return_tensors="pt", return_attention_mask=False).input_ids.to(inputs.device)
916
+ input_ids_appendent = torch.cat([output_ids, early_stopping_ids], dim=-1)
917
+ kwargs['streamer'].put(early_stopping_ids) if 'streamer' in kwargs else None
918
+ else:
919
+ input_ids_appendent = output_ids
920
+
921
+
922
+ # second generation
923
+ new_inputs = torch.cat([inputs, input_ids_appendent], dim=-1)
924
+ attention_mask = torch.ne(new_inputs, self.text_tokenizer.pad_token_id).to(device=inputs.device)
925
+ inputs_embeds_appendent = self.merge_multimodal(
926
+ input_ids=input_ids_appendent,
927
+ pixel_values=None,
928
+ grid_thws=None
929
+ )
930
+ new_inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_appendent], dim=-2)
931
+
932
+ kwargs['max_new_tokens'] = inputs_embeds.size(-2) + actual_max_new_tokens - new_inputs_embeds.size(-2)
933
+ generated_ids2 = self.llm.generate(inputs=None, inputs_embeds=new_inputs_embeds, attention_mask=attention_mask, **kwargs)
934
+ kwargs['streamer'].manual_end() if 'streamer' in kwargs else None
935
+ return torch.cat([input_ids_appendent, generated_ids2], dim=-1)
936
+
937
+ else:
938
+ kwargs['streamer'].manual_end() if 'streamer' in kwargs else None
939
+ return generated_ids
940
+
941
+ else:
942
+ generated_ids = self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
943
+ kwargs['streamer'].manual_end() if 'streamer' in kwargs else None
944
+ return generated_ids
945
 
946
 
947
  AutoConfig.register('siglip2_navit', Siglip2NavitConfig)