Some of Muon's code can appear puzzling, like the magic numbers $(3.4445, -4.7750, 2.0315)$ in Newton-Schulz. No more being confused: we'll go through the code with line-by-line annotations, including pictures. Afterwards we'll discuss how you can use Muon to get ~1.5x speedups without headaches.
Mouse over the highlights to see explanations.
import torch
def set_all_singular_values_to_near_one(G, steps: int):
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
momentum.lerp_(grad, 1 - beta)
update = grad.lerp_(momentum, beta) if nesterov else momentum
update = set_all_singular_values_to_near_one(update, steps=ns_steps)
update *= (grad.size(-2) / grad.size(-1))**0.5
return update
def adam_update(grad, buf1, buf2, step, betas, eps):
buf1.lerp_(grad, 1 - betas[0])
buf2.lerp_(grad.square(), 1 - betas[1])
buf1c = buf1 / (1 - betas[0]**step)
buf2c = buf2 / (1 - betas[1]**step)
return buf1c / (buf2c.sqrt() + eps)
class Muon(torch.optim.Optimizer):
def __init__(self, param_groups):
for group in param_groups:
assert "use_muon" in group
if group["use_muon"]:
# defaults
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
group["betas"] = group.get("betas", (0.9, 0.95))
group["eps"] = group.get("eps", 1e-10)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
super().__init__(param_groups, dict())
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
if group["use_muon"]:
for p in group["params"]:
if p.grad is None:
p.grad = torch.zeros_like(p)
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])
else:
for p in group["params"]:
if p.grad is None:
p.grad = torch.zeros_like(p)
state = self.state[p]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
state["step"], group["betas"], group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])
return loss
(Implementation adapted from Keller Jordan's GitHub.)
Now that we have seen Muon's code, let's tackle some FAQs:
def is_muon_param(p):
return p.ndim == 2 # incorrect
Solution 1: this rule could misinterpret 2D matrix parameters whose input/output
norms aren't RMS, for example an LLM embedding matrix. Since Muon is
designed for RMS norms, it may not perform well here.
So far our focus has been on regulating the gradients. But if all we control are the gradients, then during training usually every weight matrix grows. Activations or attention logits deep in the model may spike. Layer norm, QK norm, and logit softcapping are very valuable but are indirect: is it possible to attack the problem at its source?
Chapter 3: Weight Regulation →