Add explicit error for flash + alibi
Browse files- mosaic_gpt.py +3 -0
mosaic_gpt.py
CHANGED
|
@@ -31,6 +31,9 @@ class MosaicGPT(PreTrainedModel):
|
|
| 31 |
def __init__(self, config: MosaicGPTConfig):
|
| 32 |
super().__init__(config)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
self.attn_impl = config.attn_impl
|
| 35 |
self.prefix_lm = config.prefix_lm
|
| 36 |
self.attn_uses_sequence_id = config.attn_uses_sequence_id
|
|
|
|
| 31 |
def __init__(self, config: MosaicGPTConfig):
|
| 32 |
super().__init__(config)
|
| 33 |
|
| 34 |
+
if config.attn_impl == 'flash' and config.alibi:
|
| 35 |
+
raise RuntimeError("ALiBi is not supported with flash attention. Please use triton or torch.")
|
| 36 |
+
|
| 37 |
self.attn_impl = config.attn_impl
|
| 38 |
self.prefix_lm = config.prefix_lm
|
| 39 |
self.attn_uses_sequence_id = config.attn_uses_sequence_id
|