import transformers | |
def replace_llama_rmsnorm_with_fused_rmsnorm(): | |
try: | |
from apex.normalization import FusedRMSNorm | |
from functools import partial | |
LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa | |
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm | |
print("Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm") | |
except ImportError: | |
# using the normal LlamaRMSNorm | |
pass | |
except Exception: | |
print("discovered apex but it failed to load, falling back to LlamaRMSNorm") | |
pass | |