Brian Tang
commited on
Commit
·
8f0a794
1
Parent(s):
49ebb9c
Adds flash attention check with the device type
Browse files
modeling_jina_embeddings_v4.py
CHANGED
@@ -569,7 +569,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
569 |
kwargs["torch_dtype"] = "auto"
|
570 |
|
571 |
kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
|
572 |
-
|
|
|
573 |
kwargs["attn_implementation"] = "sdpa"
|
574 |
|
575 |
base_model = super().from_pretrained(
|
|
|
569 |
kwargs["torch_dtype"] = "auto"
|
570 |
|
571 |
kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
|
572 |
+
device = kwargs.get("device", "auto")
|
573 |
+
if not is_flash_attn_2_available() or device == "cpu":
|
574 |
kwargs["attn_implementation"] = "sdpa"
|
575 |
|
576 |
base_model = super().from_pretrained(
|