Understanding Muon

Chapter 1: Into the Matrix Speedy and fun Muon explainer
Chapter 2: Source Code Annotated PyTorch implementation
Chapter 3: Weight Regulation Our work, and future directions

Understanding Muon

Chapter 2: Source Code

Some of Muon's code can appear puzzling, like the magic numbers $(3.4445, -4.7750, 2.0315)$ in Newton-Schulz. No more being confused: we'll go through the code with line-by-line annotations, including pictures. Afterwards we'll discuss how you can use Muon to get ~1.5x speedups without headaches.

Mouse over the highlights to see explanations.

import torch

def set_all_singular_values_to_near_one(G, steps: int):
    assert G.ndim >= 2
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
        X = X.mT

    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A
        X = a * X + B @ X
    
    if G.size(-2) > G.size(-1):
        X = X.mT
    return X

def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
    momentum.lerp_(grad, 1 - beta)
    update = grad.lerp_(momentum, beta) if nesterov else momentum
    update = set_all_singular_values_to_near_one(update, steps=ns_steps)
    update *= (grad.size(-2) / grad.size(-1))**0.5
    return update

def adam_update(grad, buf1, buf2, step, betas, eps):
    buf1.lerp_(grad, 1 - betas[0])
    buf2.lerp_(grad.square(), 1 - betas[1])
    buf1c = buf1 / (1 - betas[0]**step)
    buf2c = buf2 / (1 - betas[1]**step)
    return buf1c / (buf2c.sqrt() + eps)

class Muon(torch.optim.Optimizer):
    def __init__(self, param_groups):
        for group in param_groups:
            assert "use_muon" in group
            if group["use_muon"]:
                # defaults
                group["lr"] = group.get("lr", 0.02)
                group["momentum"] = group.get("momentum", 0.95)
                group["weight_decay"] = group.get("weight_decay", 0)
                assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
            else:
                # defaults
                group["lr"] = group.get("lr", 3e-4)
                group["betas"] = group.get("betas", (0.9, 0.95))
                group["eps"] = group.get("eps", 1e-10)
                group["weight_decay"] = group.get("weight_decay", 0)
                assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
        super().__init__(param_groups, dict())

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if group["use_muon"]:
                for p in group["params"]:
                    if p.grad is None:
                        p.grad = torch.zeros_like(p)
                    state = self.state[p]
                    if len(state) == 0:
                        state["momentum_buffer"] = torch.zeros_like(p)
                    update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
                    p.mul_(1 - group["lr"] * group["weight_decay"])
                    p.add_(update, alpha=-group["lr"])
            else:
                for p in group["params"]:
                    if p.grad is None:
                        p.grad = torch.zeros_like(p)
                    state = self.state[p]
                    if len(state) == 0:
                        state["exp_avg"] = torch.zeros_like(p)
                        state["exp_avg_sq"] = torch.zeros_like(p)
                        state["step"] = 0
                    state["step"] += 1
                    update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
                                        state["step"], group["betas"], group["eps"])
                    p.mul_(1 - group["lr"] * group["weight_decay"])
                    p.add_(update, alpha=-group["lr"])

        return loss

(Implementation adapted from Keller Jordan's GitHub.)

Now that we have seen Muon's code, let's tackle some FAQs:

  1. Which parameters should Muon optimize, and which should be left to AdamW?
    Muon is designed for parameters that map dense activations to dense activations—such as the weight matrices in MLPs, attention $(Q, K, V, W)$, convolution kernels, and the final layer of a diffusion model. These are 2D matrices with dense vectors for both input and output. Muon is not designed for elementwise vector parameters such as bias and layer norm scales, or the embedding matrix of a language model. Muon optimizes these other parameters with an AdamW backup.

    Exercise 1: what's wrong with the parameter rule below?
    def is_muon_param(p):
        return p.ndim == 2  # incorrect
                    
    Solution 1: this rule could misinterpret 2D matrix parameters whose input/output norms aren't RMS, for example an LLM embedding matrix. Since Muon is designed for RMS norms, it may not perform well here.

    Exercise 2: does Muon make sense for the patch embedding matrix of a vision transformer?
    Solution 2: likely yes, since images are dense vectors with entries $\approx 1$.

  2. Should Muon and Adam share the same learning rate?
    Not necessarily. If you do share a learning rate, then while the best way to do it may actually be an open question, two approaches are discussed in Microsoft's Dion paper and Moonshot AI's Muon paper.
    • Dion's approach: follow the usual dimensional factors that transfer learning rate across width, $\sqrt{d_\text{out}/d_\text{in}}$. Then use the same learning rate across Adam and Muon, except scale down the unembedding layer's learning rate by $1/\sqrt{d_\text{in}}$. More discussion in their Appendix D. For more on "natural norms" that inspired this approach see Jeremy's spectral condition paper.
    • Moonshot's approach: match the RMS norm of Adam's updates with the "RMS norm" of Muon's updates. (Yes, we're stepping out the Matrix for a moment to view the weight update as a vector.) Matching the entrywise update scales may be good default if you want to use Adam-like learning rates (3e-4 anyone?) for both Adam and Muon. If $A$ is an orthogonal matrix, $$\| \text{flatten}(A) \|_{\text{RMS}} = 1/\sqrt{\text{max}(d_\text{out}, d_\text{in})}.$$ An experiment you can try at home is that an Adam update's average RMS norm over 10k steps with Gaussian random gradients is around $0.2$. So Moonshot recommends that Muon's update be $$W_{n+1} = W_n - 0.2 \sqrt{\text{max}(d_\text{out}, d_\text{in})} \cdot \eta \cdot \text{orthogonalize}(M_{n+1}).$$

  3. What are good default hyperparameters for Muon?
    Muon works well with learning rate around $0.02$, but sweep in logspace to find the best one. Using the $\text{RMS}$ to $\text{RMS}$ norm means Muon's optimal learning rate transfers as model dimension grows, so you can sweep on a smaller model. For momentum, $\beta = 0.95$ works well. Tuning it may not make much difference. Weight decay ($0.1$, $0.01$, or sweep in logspace) becomes important for larger models (>0.5B parameters). For the internal Adam optimizer, $\beta_1 = 0.9$ and $\beta_2 = 0.95$ work well. It may also be promising to replace the internal Adam with an internal Lion optimizer with $\beta_1 = 0.95$ and $\beta_2 = 0.98$.

What's next?

So far our focus has been on regulating the gradients. But if all we control are the gradients, then during training usually every weight matrix grows. Activations or attention logits deep in the model may spike. Layer norm, QK norm, and logit softcapping are very valuable but are indirect: is it possible to attack the problem at its source?

Chapter 3: Weight Regulation →