File size: 604 Bytes
cf932d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import transformers
def replace_llama_rmsnorm_with_fused_rmsnorm():
try:
from apex.normalization import FusedRMSNorm
from functools import partial
LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
print("Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm")
except ImportError:
# using the normal LlamaRMSNorm
pass
except Exception:
print("discovered apex but it failed to load, falling back to LlamaRMSNorm")
pass
|