A simple way to save memory is activation "checkpointing". Instead of storing ALL intermediate results of the forward pass, you just store them once at the end of each transformer layer. You then recompute all missing parts during the backward pass.
Matthias Seeger
AI & ML interests
Recent Activity
Organizations
mseeger's activity
@lvwerra I checked this code. Did you realize this is really inefficient?
In essence, this code computes the K, Q, V tensors explicitly, as q_nope
, k_nope
of shape (bs, num_heads, q_len, qk_nope_head_dim)
, and value_states
. Then, it appends additional RoPE encoded vectors.
If you do it this way, you can just position-encode K and Q directly -- why appended anything additional? In the paper, the authors say they don't want to do this, just because they do not want to compute K and Q explicitly, because that is wasteful.
This is in fact quite possible, but the code here does not do it. For example, from the assumptions, one can write code where the K tensor going into the inner products does not have a num_heads dimension at all, just like in multi-query attention.
And worse, this code here does not even use torch.nn.functional.scaled_dot_product_attention
, and therefore FlashAttention etc. Fair enough, for inference this is not so important, but if this code here is used for training (e.g., fine-tuning), it will be very slow.
With this code, you succeeded to make computations less efficient than for MHA, even though you have a low rank assumption!
Hello guys, I am even just trying to find code for DeepSeek-V2, in order to fully understand multi-head latent attention. You do not seem to have code in Hugging Face even for that. Or am I missing something? Don't see anything in src/transformers/models. MLA is not properly described in their paper, so it would be important to have code for this.