Brian Tang commited on
Commit
8f0a794
·
1 Parent(s): 49ebb9c

Adds flash attention check with the device type

Browse files
Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +2 -1
modeling_jina_embeddings_v4.py CHANGED
@@ -569,7 +569,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
569
  kwargs["torch_dtype"] = "auto"
570
 
571
  kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
572
- if not is_flash_attn_2_available():
 
573
  kwargs["attn_implementation"] = "sdpa"
574
 
575
  base_model = super().from_pretrained(
 
569
  kwargs["torch_dtype"] = "auto"
570
 
571
  kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
572
+ device = kwargs.get("device", "auto")
573
+ if not is_flash_attn_2_available() or device == "cpu":
574
  kwargs["attn_implementation"] = "sdpa"
575
 
576
  base_model = super().from_pretrained(