Mixture of Experts (MoE) From Scratch in PyTorch — Building Sparse Transformers
Understand MoE architecture by building a sparse model from scratch, and learn Top-K routing, noisy gating, and expert load balancing.
1. Introduction
Mixture of Experts (MoE) is an architectural idea that is becoming increasingly popular in modern machine learning, especially in transformer models. Although the concept originated in the early 1990s, it has gained renewed attention because it allows models to scale to massive sizes without a proportional increase in computation.
In this article, we'll be discussing the core idea behind Mixture of Experts and how sparsity in transformers allows conditional computation. Along with that, we will create a sparse MoE model from scratch in PyTorch for text generation. Let's begin
2. The Scaling Bottleneck in Dense Transformers
Before diving into the details of Mixture of Experts, it is important to understand the scaling problem it aims to solve. As Transformer models grow to billions or even trillions of parameters, standard dense architectures become increasingly inefficient. In a standard Transformer, every parameter is activated for every token, which leads to high computational cost and, in many cases, diminishing performance returns as models continue to scale.
Mixture of Experts addresses this bottleneck by introducing sparsity through conditional computation. Instead of activating the entire network, only a subset of parameters is used for a given input, while the rest remain inactive. Since a large fraction of a Transformer’s parameters and computation reside in the feed-forward network (FFN), MoE is typically applied by replacing this layer with a set of expert networks controlled by a learned gating mechanism.
3. The Idea of Sparsity in MoE
The idea of sparsity in machine learning is often associated with sparse weights or parameter pruning, but in the case of Mixture of Experts, sparsity refers toconditional activation. All parameters remain part of the model and are trained; however, only a small subset is activated for a given input token. In this sense, every parameter participates in learning across the dataset, but not all parameters are used for any single token.
Dense vs Sparse MoE
There are two main reasons why feed-forward networks are a natural target for introducing sparsity. First, they account for the majority of a Transformer’s parameters and computational cost. Second, unlike the attention mechanism, which explicitly models interactions between tokens, the feed-forward network processes each token independently. Even though attention enriches token representations using contextual information, the FFN applies the same transformation to each token in isolation, making it well suited for conditional and token-wise sparse computation.
4. The Architecture of a Mixture of Experts Layer
At a high level, an MoE layer replaces the standard dense feed-forward block of a Transformer. It consists of two primary components: the Experts themselves and a Router (or Gating Network) that determines which experts are activated for a given input.
4.1 Experts
Experts are simply parallel Feed-Forward neural networks that are identical in dimension with separate independent weights. This allows them to evolve distinct behaviors, which leads to emergent specialization over the course of training, such as one expert will pick the syntactic structure of language while another focuses on semantic naunce or factual knowledge.
Modern state-of-the-art LLMs employ several experts to better learn complex language and reasoning patterns. However, scaling the number of experts requires careful balance. Adding too many can lead to diminishing returns. If the expert count is too high relative to the dataset size, individual experts may suffer from data starvation, receiving too few tokens to converge effectively, resulting in an undertrained model.
4.2 The Router (Gating Network)
The router can be thought of as the controller that sends tokens to a specific expert whom it thinks can do the job. The Router knows the capabilities of each of the experts and dispatches tokens accordingly. Formally, it is typically a simple linear projection layer with a softmax function. It takes the input token representation and produces a probability distribution over all available experts. These probabilities represent the confidence that a specific expert is best suited to process the given token. You can think of a router as a classification network that models the relation between tokens and experts.
Routing Mechanism Overview
This visualization captures the essence of conditional computation. While the diagram shows a simplified `Top-1` scenario where only the single highest-probability expert is chosen, practical implementations often employ a Top-k strategy. By routing a token to multiple experts, the model gains better representation and stability. The actual logic behind this selection involves managing differentiability and load balancing, which we will dissect in the upcoming sections.
4.3 The MoE Layer
As we discussed, the MoE layer has two main components, the Router and the Experts. First, the Router receives the input token representation from the preceding attention layer and generates raw logits for each available expert. These logits are then normalized using a Softmax function to produce a probability distribution across all experts. Renormalization is necessary after top-k selection because it just distorts the probability distribution; we'll discuss more about this in the next section.
Next, a Top-k gating mechanism uses these probabilities to select only the k most suitable experts (top-2 in the diagram below) to process the token. Finally, the outputs from these selected experts are aggregated via a weighted sum, scaled by their respective router probabilities, to produce the overall output of the MoE layer. A visual representation of this entire pipeline is shown below.
MoE Layer (Top-K Gating)
This is a typical MoE layer seen in most Transformer implementations. What changes is the placement of experts during parallel training. In parallel training, each expert is distributed across different devices, and all communicate via all-to-all connections. This allows tokens to send and receive data between devices through the routing mechanism. We'll discuss this idea in another article.
4.3.1 The Routing Mechanism (Top-K Gating)
Alright, let's now discuss how the routing mechanism works in detail. The specific type of routing we are going to discuss is `Top-K`. However, there are other types of routing, for example, the "Switch Transformer" routing only activates one expert per token; it is a specific case of `Top-K` gating where `k = 1`, and the combination is hard with no mixture property. It is used when the models are really big, like trillions of parameters with thousands of experts. But most of the implementation uses `Top-K`.
For better understanding, we'll compare the routing and expert selection formally with Dense.
In a Standard Transformer FFN, every token goes through the same transformation. We can represent it by
\[y_{\text{dense}} = f(x), \quad x \in \mathbb{R}^d\]
Mathematically, the definition of an MoE layer is a direct extension of the standard Transformer Feed-Forward Network (FFN). We replace the single dense FFN with a set of `E` independent networks. For any specific expert `e`, the computation is defined as,
\[g_e = \frac{p_e}{\sum_{j \in S} p_j}, \quad e \in S\]
Where:
\[\begin{aligned} g_e &:\ \text{renormalized routing weight of expert } e \\ p_e &:\ \text{router probability assigned to expert } e \\ S &:\ \text{set of selected (Top-}k\text{) experts} \\ j &:\ \text{index over selected experts} \\ \sum_{j \in S} p_j &:\ \text{total probability mass of selected experts} \end{aligned}\]
Now the gating works by doing a weighted sum of gate weights and the expert outputs,
The conditional computation happens where the non-selected experts' gate weights become zero, and they will not contribute to the weighted sum, but a question can arise, we are computing `f_{e}(x)` anyway eventhough the corresponding `g_{e}(x)` is 0, if you ask this question, you are getting the idea clearly, but the fact is that we are not computing `f_{e}(x)` which have `g_{e}(x) = 0`, modern libraries like PyTorch will simply ignore the whole computation if this case happens.
Another question you probably ask, since experts are selected using a `Top-K` operation, why is a weighted sum with gate values still required? The key point is that `Top-K` is a hard selection operator; it makes discrete choices based on maximum values and is therefore non-differentiable.
For effective end-to-end training, gradients must propagate smoothly through the entire MoE block. The weighted combination of expert outputs using continuous gate values provides this differentiable path. It allows the loss to flow not only into the selected experts but also into the routing scores that produced the selection, making the overall mechanism trainable despite the hard selection step.
Top-K Gating (Formulas included)
4.3.2 Noisy Top-K Gating
There is a inherent problem exists in Mixture of Experts models, there is a very high chance that during training of the model, the router can quickly favour some experts and others will sit idle forever, in this case some of the experts parameters will not be updated, this is a classic problem in MoE called Router Collapse which happens due to the lack of exploration, to prevent this from happening, it is better to add some randomness to logits which encourages the router to explore than exploit early on.
This is done by adding a small noise to the router logits. There are various ways to add noise; the most common way is to add noise from a standard normal distribution, which is often simpler and works well.
Notice that we multiply the noise by a small constant `0.01` before adding it to the logits. This acts as a simple jitter noise term that scales the magnitude of the noise and controls how much randomness is injected into the router.
A common practice during training is to gradually decay this jitter coefficient over time. In the early stages of training, the router explores more due to higher noise, and toward the end of training, the noise becomes small, so the router primarily exploits the learned routing behavior.
The noise term is particularly used when training and is disabled when running the inference.
5. Expert Load Balancing
Load balancing is a critical component when training MoE models. As discussed, the router can become “lazy” and collapse to sending most tokens to the same few experts. Although the noise term helps mitigate this behavior, it is not sufficient on its own. We therefore need an explicit regularization term that encourages the optimizer to keep the router balanced, preventing it from becoming overly discriminative toward only a small subset of experts.
The idea of load balancing here is to make sure that every expert receives a roughly equal share of tokens. For this, we introduce an auxiliary load balancing loss, which is added to the primary loss objective. Let's try to understand how this is done,
Consider a batch of `N` tokens `{x_1, ....., x_N}`
We need to compute how much importance the router is giving to each of the experts for all the tokens in the batch.
The `\text{Importance}_e` simply represents how much probability mass the router assigns to expert `e` for all the tokens in the batch. In other words, it is the summed probability mass that the router assigns to the expert `e` across all the tokens in the batch. This is a soft usage, meaning that it captures the router's raw confidence or intent for that expert. This does not represent what is actually executed after `Top-K` which is represented by something called `Load_e.`
The `\text{Load}_e` quantifies the actual discrete decisions made by the router. It will be `1` if the importance is higher for expert `e`, else `0`. It represents the fraction of tokens in the batch that were definitively assigned to expert `e` after the `Top-K` selection process (a "hard" binary choice), regardless of the original probability strength.
5.3 Auxiliary Load Balancing Loss
The objective of MoE encourages alignment between importance (soft routing) and load (hard routing).
Currently, the `\text{Importance}_e` and `\text{Load}_e` represent the raw accumulated totals for each expert. To compute the auxiliary loss, we must normalize these values across all the experts,
And the auxiliary loss is simply the dot product between `\tilde{I}_e` and `\tilde{L}_e`,
\[\mathcal{L}_{\text{aux}} = E \sum_{e=1}^{E} \tilde{I}_e \, \tilde{L}_e\]
We multiply by `E` (the number of experts) to normalize the loss. Without this factor, the loss value would naturally shrink towards zero as you add more experts, because the individual probabilities become smaller. This multiplication counteracts that shrinking, ensuring that a perfectly balanced router always results in a loss of `1`, regardless of whether you have `4` experts or `1,000`.
Note: The auxiliary loss is minimized by the optimizer during training. Since the normalized load `\tilde{L}_e` is a direct consequence of the normalized importance `\tilde{I}_e`, the two are strongly correlated across a batch, and in practice `\tilde{I} \approx \tilde{L}` for each expert. As a result, their dot product behaves similarly to a sum of squares of the importance terms.
Minimizing this dot product therefore encourages the importance vector to become uniform. When the router assigns importance uniformly across experts, the resulting load also becomes uniform, meaning that each expert receives roughly the same number of tokens. This is the core intuition behind the auxiliary load-balancing loss: by reducing it, we push the router toward a uniform distribution over experts, preventing expert collapse and promoting balanced expert utilization.
5.4 Overall Training Objective
The Base Task Loss for language modelling is to use a cross-entropy loss, which penalizes confidence in wrong next word predictions. Given model output logits `z(x) \in \mathbb{R}^{d_{\text{vocab}}}` and true target `y`:
Finally, we combine everything into a single objective function. While our primary goal is to predict the next token accurately, we must simultaneously ensure that the router remains balanced. We achieve this by adding the auxiliary loss to the primary loss, scaled by a hyperparameter `\lambda` (often set to a small value like `0.01` or `0.1`)
The training process minimizes this joint loss function. This forces the model to strike a balance, it learns to generate high-quality text (`\mathcal{L}_{\text{CE}}`) while ensuring that all experts are utilized more evenly (`\mathcal{L}_{\text{aux}}`), preventing the router from collapsing into a dense-like state.
6. Gradient Flow: Dense vs MoE
I would also like to discuss the gradient flow between Dense and MoE. This part is not strictly necessary to understand because modern libraries will use automatic differentiation methods, but this short section will help you to understand what is happening under the hood when training and how the gradient flow differs between both.
In a standard dense model, the output is given by
\[y = f(x),\]
and the gradient of the loss with respect to the model parameters is
This expression makes the sparsity of MoE explicit. If \(g_e(x) = 0\), the gradient for expert \(e\) becomes zero, and no update is applied to its parameters. Consequently, only the selected experts receive gradients and are updated for a given token, while all other experts remain inactive. This is the key mechanism that enables conditional computation back and forth.
7. Implementation In PyTorch
Alright, now we are getting into the practical side. We'll implement each module one by one. First, let's define a class to store the hyperparameters used for this model.
import math
from dataclasses import dataclass
from typing import Any, Dict, Tuple
import torch
import torch.nn as nn
@dataclass(frozen=True)
class MoEConfig:
vocab_size: int = 50257
d_model: int = 768
num_layers: int = 12
num_heads: int = 16
max_seq_len: int = 512
num_experts: int = 8
top_k: int = 2
ffn_dim: int = 2048
7.1 RMS Norm Layer
For this model, we use Root Mean Square Normalization (RMSNorm), which is particularly effective for stabilizing residual learning. RMSNorm can be written in mathematical form as:
The first equation computes the root mean square of the input vector. Expressions of this form are commonly seen in linear regression and statistical normalization. The second equation applies the actual normalization, where \( \odot \) denotes element-wise multiplication and \( \gamma \) is a learnable scale vector.
Most readers are already familiar with Multi-Head Attention (MHA), which is one of the core building blocks of the Transformer architecture. The main components of MHA are the `Query (Q)`, `Key (K)`, and `Value (V)` representations, along with their corresponding projection matrices \( W_Q \), \( W_K \), and \( W_V \).
Given an input matrix \( X \in \mathbb{R}^{n \times d} \), the projections are defined as:
with independent projection matrices for each head.
At a high level, attention allows each token to dynamically weight and aggregate information from other tokens based on their relevance. If you'd like to refresh your understanding of attention or need a reference, check out this article: Transformers From Scratch PyTorch
class DenseNoBias(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.kernel = nn.Parameter(torch.empty(in_features, out_features, dtype=torch.float32))
nn.init.normal_(self.kernel, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x @ self.kernel
# causal mask (lower triangular matrix)
def causal_mask(t: int, *, device: torch.device) -> torch.Tensor:
return torch.tril(torch.ones((t, t), dtype=torch.bool, device=device))
# standard multi-head attention
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
if d_model % num_heads != 0:
raise ValueError("d_model must be divisible by num_heads")
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_proj = DenseNoBias(d_model, d_model)
self.k_proj = DenseNoBias(d_model, d_model)
self.v_proj = DenseNoBias(d_model, d_model)
self.out_proj = DenseNoBias(d_model, d_model)
def forward(self, x: torch.Tensor, *, attn_mask: torch.Tensor) -> torch.Tensor:
b, t, d = x.shape
q = self.q_proj(x).view(b, t, self.num_heads, self.head_dim)
k = self.k_proj(x).view(b, t, self.num_heads, self.head_dim)
v = self.v_proj(x).view(b, t, self.num_heads, self.head_dim)
scale = 1.0 / math.sqrt(self.head_dim)
att = torch.einsum("bthd,bshd->bhts", q, k) * scale
att = att.masked_fill(~attn_mask.view(1, 1, t, t), -1e30)
att = torch.softmax(att, dim=-1)
out = torch.einsum("bhts,bshd->bthd", att, v).contiguous()
out = out.view(b, t, d)
return self.out_proj(out)
We have added some helper modules, such as Dense with no bias and a causal mask, which are necessary for attention.
7.3 Router (Noisy Topk)
Since we have discussed enough about the router, here is how to implement those ideas in code
class Router(nn.Module):
def __init__(self, d_model: int, num_experts: int, top_k: int, jitter_noise: float = 0.0):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.jitter_noise = jitter_noise
# simple linear layer for gating
self.gate = DenseNoBias(d_model, num_experts)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# x shape: [batch_size, seq_len, d_model]
# 1. Calculate raw logits (routing scores)
logits = self.gate(x) # Shape: [batch, seq, num_experts]
# Add Jitter Noise to encourage exploration (only when training,,)
# We add noise to logits before softmax so it affects the probabilities, it must.
if self.training and self.jitter_noise > 0:
noise = torch.randn_like(logits) * self.jitter_noise
logits_noisy = logits + noise
else:
logits_noisy = logits
# Calculate Probabilities
probs = torch.softmax(logits_noisy, dim=-1) # Shape: [batch, seq, num_experts]
# Select Top-K Experts
# topk_vals: The probabilities of the chosen experts
# topk_idx: The indices of the chosen experts
topk_vals, topk_idx = torch.topk(probs, k=self.top_k, dim=-1)
# Renormalize the Top-K probabilities
# We need the selected weights to sum to 1.0
denom = topk_vals.sum(dim=-1, keepdim=True).clamp_min(1e-6)
gates = topk_vals / denom
# topk_idx: For routing tokens to experts
# gates: For the weighted sum of outputs
# probs: The full probability distribution (Needed for Aux Load Balancing Loss)
return topk_idx, gates
7.4 Expert MLP
In a standard Transformer implementation, we use a simple two-layer MLP. However, in the current MoE model, the configuration specifies 8 experts. Therefore, it is more appropriate to create an MLP bank that stores the parameters of all experts in a single contiguous block, which can then be efficiently accessed using the indices of the selected experts.
class ExpertMLPBank(nn.Module):
def __init__(self, d_model: int, hidden_dim: int, num_experts: int):
super().__init__()
# Weights and biases with expert dimension inclusive, this reduce need for loops
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, hidden_dim, dtype=torch.float32))
self.b1 = nn.Parameter(torch.zeros(num_experts, hidden_dim, dtype=torch.float32))
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, d_model, dtype=torch.float32))
self.b2 = nn.Parameter(torch.zeros(num_experts, d_model, dtype=torch.float32))
nn.init.normal_(self.w1, std=0.02)
nn.init.normal_(self.w2, std=0.02)
def forward(self, x: torch.Tensor, expert_idx: torch.Tensor) -> torch.Tensor:
w1 = self.w1.index_select(0, expert_idx)
b1 = self.b1.index_select(0, expert_idx)
w2 = self.w2.index_select(0, expert_idx)
b2 = self.b2.index_select(0, expert_idx)
h = torch.einsum("nd,ndh->nh", x, w1) + b1
h = torch.nn.functional.silu(h)
y = torch.einsum("nh,nhd->nd", h, w2) + b2
return y
7.5 MoE Layer
We group the router and the experts together inside a single MoE layer. The router produces the routing probabilities and the indices of the top-k selected experts, and the input is then forwarded to the corresponding experts for computation. Much of those are same as we discussed in the above section.
Now we can club all the components together to create a single Decoder MoE block, this includes RMS Norm, MHA, & MoE Layer.
class Block(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
hidden_dim: int,
num_experts: int,
top_k: int,
):
super().__init__()
self.rmsnorm_0 = RMSNorm(d_model)
self.attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
self.rmsnorm_1 = RMSNorm(d_model)
self.moe = MoELayer(
d_model=d_model,
hidden_dim=hidden_dim,
num_experts=num_experts,
top_k=top_k,
)
def forward(self, x: torch.Tensor, *, attn_mask: torch.Tensor) -> torch.Tensor:
h = self.rmsnorm_0(x)
x = x + self.attn(h, attn_mask=attn_mask)
h = self.rmsnorm_1(x)
x = x + self.moe(h)
return x
7.7 MoE Model
Let's pack everything to build the model, this module contains the Token and Positional Embedding (We are using learned positional embedding, not sin-cosine positional encoding), Multiple MoE blocks, and finally attached with a language modelling classification head.
class MoE(nn.Module):
def __init__(self, cfg: MoEConfig):
super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
self.blocks = nn.ModuleList(
[
Block(
d_model=cfg.d_model,
num_heads=cfg.num_heads,
hidden_dim=cfg.ffn_dim,
num_experts=cfg.num_experts,
top_k=cfg.top_k,
)
for _ in range(cfg.num_layers)
]
)
self.rmsnorm_f = RMSNorm(cfg.d_model)
self.lm_head = DenseNoBias(cfg.d_model, cfg.vocab_size)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
b, t = input_ids.shape
device = input_ids.device
if t > self.cfg.max_seq_len:
raise ValueError("input length exceeds max_seq_len")
tok = self.tok_emb(input_ids)
pos_idx = torch.arange(t, device=device).unsqueeze(0)
pos = self.pos_emb(pos_idx)
x = tok + pos
attn_mask = causal_mask(t, device=device)
for blk in self.blocks:
x = blk(x, attn_mask=attn_mask)
x = self.rmsnorm_f(x)
logits = self.lm_head(x)
return logits
All these are the stuff that we encounter when building standard transformer architecture, except the MoE block.
7.8 Loading Checkpoint
It comes to you as a little surprise that this model is already trained; the checkpoint for this model is available here: QMoE-400-checkpoints. You can download the torch checkpoints from the HuggingFace repo, place them in the working directory, and load them using the code below.
In this article, we explored the core ideas behind MoE, including routing, Top-K selection, and load balancing. We then built a working sparse MoE Transformer from scratch in PyTorch and saw how all the pieces fit together. This approach makes it possible to train very large models without proportional compute costs. With this foundation, you are now ready to experiment, modify the architecture, and build your own scalable expert-based models.