File size: 604 Bytes
cf932d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import transformers

def replace_llama_rmsnorm_with_fused_rmsnorm():
    try:
        from apex.normalization import FusedRMSNorm
        from functools import partial
        LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6)   # noqa
        transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
        print("Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm")
    except ImportError:
        # using the normal LlamaRMSNorm
        pass
    except Exception:
        print("discovered apex but it failed to load, falling back to LlamaRMSNorm")
        pass