Mixture of Experts (MoE) From Scratch in PyTorch — Building Sparse Transformers

Understand MoE architecture by building a sparse model from scratch, and learn Top-K routing, noisy gating, and expert load balancing.
MoE From Scratch: Thumbnail

1. Introduction

Mixture of Experts (MoE) is an architectural idea that is becoming increasingly popular in modern machine learning, especially in transformer models. Although the concept originated in the early 1990s, it has gained renewed attention because it allows models to scale to massive sizes without a proportional increase in computation.

In this article, we'll be discussing the core idea behind Mixture of Experts and how sparsity in transformers allows conditional computation. Along with that, we will create a sparse MoE model from scratch in PyTorch for text generation. Let's begin

Before diving into the details of Mixture of Experts, it is important to understand the scaling problem it aims to solve. As Transformer models grow to billions or even trillions of parameters, standard dense architectures become increasingly inefficient. In a standard Transformer, every parameter is activated for every token, which leads to high computational cost and, in many cases, diminishing performance returns as models continue to scale.

Mixture of Experts addresses this bottleneck by introducing sparsity through conditional computation. Instead of activating the entire network, only a subset of parameters is used for a given input, while the rest remain inactive. Since a large fraction of a Transformer’s parameters and computation reside in the feed-forward network (FFN), MoE is typically applied by replacing this layer with a set of expert networks controlled by a learned gating mechanism.

3. The Idea of Sparsity in MoE

The idea of sparsity in machine learning is often associated with sparse weights or parameter pruning, but in the case of Mixture of Experts, sparsity refers to conditional activation. All parameters remain part of the model and are trained; however, only a small subset is activated for a given input token. In this sense, every parameter participates in learning across the dataset, but not all parameters are used for any single token.

Dense vs Sparse MoE
Dense vs Sparse MoE

There are two main reasons why feed-forward networks are a natural target for introducing sparsity. First, they account for the majority of a Transformer’s parameters and computational cost. Second, unlike the attention mechanism, which explicitly models interactions between tokens, the feed-forward network processes each token independently. Even though attention enriches token representations using contextual information, the FFN applies the same transformation to each token in isolation, making it well suited for conditional and token-wise sparse computation.

4. The Architecture of a Mixture of Experts Layer

At a high level, an MoE layer replaces the standard dense feed-forward block of a Transformer. It consists of two primary components: the Experts themselves and a Router (or Gating Network) that determines which experts are activated for a given input.

4.1 Experts

Experts are simply parallel Feed-Forward neural networks that are identical in dimension with separate independent weights. This allows them to evolve distinct behaviors, which leads to emergent specialization over the course of training, such as one expert will pick the syntactic structure of language while another focuses on semantic naunce or factual knowledge.

Modern state-of-the-art LLMs employ several experts to better learn complex language and reasoning patterns. However, scaling the number of experts requires careful balance. Adding too many can lead to diminishing returns. If the expert count is too high relative to the dataset size, individual experts may suffer from data starvation, receiving too few tokens to converge effectively, resulting in an undertrained model.

4.2 The Router (Gating Network)

The router can be thought of as the controller that sends tokens to a specific expert whom it thinks can do the job. The Router knows the capabilities of each of the experts and dispatches tokens accordingly. Formally, it is typically a simple linear projection layer with a softmax function. It takes the input token representation and produces a probability distribution over all available experts. These probabilities represent the confidence that a specific expert is best suited to process the given token. You can think of a router as a classification network that models the relation between tokens and experts.

Routing Mechanism - Overview
Routing Mechanism Overview

This visualization captures the essence of conditional computation. While the diagram shows a simplified `Top-1` scenario where only the single highest-probability expert is chosen, practical implementations often employ a Top-k strategy. By routing a token to multiple experts, the model gains better representation and stability. The actual logic behind this selection involves managing differentiability and load balancing, which we will dissect in the upcoming sections.

4.3 The MoE Layer

As we discussed, the MoE layer has two main components, the Router and the Experts. First, the Router receives the input token representation from the preceding attention layer and generates raw logits for each available expert. These logits are then normalized using a Softmax function to produce a probability distribution across all experts. Renormalization is necessary after top-k selection because it just distorts the probability distribution; we'll discuss more about this in the next section.

Next, a Top-k gating mechanism uses these probabilities to select only the k most suitable experts (top-2 in the diagram below) to process the token. Finally, the outputs from these selected experts are aggregated via a weighted sum, scaled by their respective router probabilities, to produce the overall output of the MoE layer. A visual representation of this entire pipeline is shown below.

MoE Layer (Top-K Gating)
MoE Layer (Top-K Gating)

This is a typical MoE layer seen in most Transformer implementations. What changes is the placement of experts during parallel training. In parallel training, each expert is distributed across different devices, and all communicate via all-to-all connections. This allows tokens to send and receive data between devices through the routing mechanism. We'll discuss this idea in another article.

4.3.1 The Routing Mechanism (Top-K Gating)

Alright, let's now discuss how the routing mechanism works in detail. The specific type of routing we are going to discuss is `Top-K`. However, there are other types of routing, for example, the "Switch Transformer" routing only activates one expert per token; it is a specific case of `Top-K` gating where `k = 1`, and the combination is hard with no mixture property. It is used when the models are really big, like trillions of parameters with thousands of experts. But most of the implementation uses `Top-K`.

For better understanding, we'll compare the routing and expert selection formally with Dense.

In a Standard Transformer FFN, every token goes through the same transformation. We can represent it by

\[y_{\text{dense}} = f(x), \quad x \in \mathbb{R}^d\]

Typically, the forward pass can be written as,

\[f(x) = W_2 \, \sigma(W_1 x)\]

Where

\[\begin{aligned} & \; W_1 \in \mathbb{R}^{d_{\text{ff}} \times d} \\ & \; W_2 \in \mathbb{R}^{d \times d_{\text{ff}}} \\ & \; \sigma(\cdot)\ \text{is a nonlinearity (GELU/ReLU)} \end{aligned}\]

Mathematically, the definition of an MoE layer is a direct extension of the standard Transformer Feed-Forward Network (FFN). We replace the single dense FFN with a set of `E` independent networks. For any specific expert `e`, the computation is defined as,

$$f_e(x) = W_2^{(e)} \, \sigma\!\big(W_1^{(e)} x\big), \quad e \in \{1, \dots, E\}$$

Here, `f_e(x)` represents the output of the `e`-th expert.

Now, the Router, which is a linear layer parameterized by the weight matrix `W_r`, projects the input token `x` to produce raw logits,

\[r = W_r x \in \mathbb{R}^{E}\]

Next, we'll do the softmax over these logits,

\[p = \text{softmax}(r)\]

Then we do the `Top-K` selection,

\[S = \operatorname{TopK}(p, k),  \quad S \subset \{1, \dots, E\},  \quad |S| = k\]

We'll select those experts which has the highest top-k probabilities from `p`,

Note that when `k` experts are selected, the probabilities are not renormalized to sum up to one. For example, consider these probabilities,

\[p = \begin{bmatrix} 0.25 \\ 0.10 \\ 0.50 \\ 0.15 \end{bmatrix} \in \mathbb{R}^4\]

Now we select `Top-2` from this, and we'll get,

\[v = \begin{bmatrix} 0.25 \\ 0.50 \end{bmatrix} \in \mathbb{R}^2\]

The sum of the selected values will be `0.75`, so we need to renormalize the probabilities to sum up to `1`. If we renormalize the `v`, we'll get.

\[g(x) = \begin{bmatrix} 0.3333 \\ 0 \\ 0.6666 \\ 0 \end{bmatrix} \in \mathbb{R}^4\]

The renormalization can be generally written as,

\[g_e = \frac{p_e}{\sum_{j \in S} p_j}, \quad e \in S\]

Where:

\[\begin{aligned} g_e &:\ \text{renormalized routing weight of expert } e \\ p_e &:\ \text{router probability assigned to expert } e \\ S &:\ \text{set of selected (Top-}k\text{) experts} \\ j &:\ \text{index over selected experts} \\ \sum_{j \in S} p_j &:\ \text{total probability mass of selected experts} \end{aligned}\]

Now the gating works by doing a weighted sum of gate weights and the expert outputs,

\[y = g_1(x)f_1(x) + g_2(x)f_2(x) + g_3(x)f_3(x) + \cdots + g_E(x)f_E(x)\]

In general, the computation can be written as,

\[y_{\text{MoE}} = \sum_{e=1}^{E} g_e(x)\, f_e(x),  \quad g_e(x) = 0 \ \text{for non-selected experts}\]

The conditional computation happens where the non-selected experts' gate weights become zero, and they will not contribute to the weighted sum, but a question can arise, we are computing `f_{e}(x)` anyway eventhough the corresponding `g_{e}(x)` is 0, if you ask this question, you are getting the idea clearly, but the fact is that we are not computing `f_{e}(x)` which have `g_{e}(x) = 0`, modern libraries like PyTorch will simply ignore the whole computation if this case happens.

Another question you probably ask, since experts are selected using a `Top-K` operation, why is a weighted sum with gate values still required? The key point is that `Top-K` is a hard selection operator; it makes discrete choices based on maximum values and is therefore non-differentiable

For effective end-to-end training, gradients must propagate smoothly through the entire MoE block. The weighted combination of expert outputs using continuous gate values provides this differentiable path. It allows the loss to flow not only into the selected experts but also into the routing scores that produced the selection, making the overall mechanism trainable despite the hard selection step.

Gating through Top-K
Top-K Gating (Formulas included)

4.3.2 Noisy Top-K Gating

There is a inherent problem exists in Mixture of Experts models, there is a very high chance that during training of the model, the router can quickly favour some experts and others will sit idle forever, in this case some of the experts parameters will not be updated, this is a classic problem in MoE called Router Collapse which happens due to the lack of exploration, to prevent this from happening, it is better to add some randomness to logits which encourages the router to explore than exploit early on.

This is done by adding a small noise to the router logits. There are various ways to add noise; the most common way is to add noise from a standard normal distribution, which is often simpler and works well. 

\[\varepsilon \sim \mathcal{N}(0, I)\]

\[r = W_r x + 0.01 \, \varepsilon \quad \text{(jitter noise)}\]

\[p = \operatorname{softmax}(r)\]

Notice that we multiply the noise by a small constant `0.01` before adding it to the logits. This acts as a simple jitter noise term that scales the magnitude of the noise and controls how much randomness is injected into the router. 

A common practice during training is to gradually decay this jitter coefficient over time. In the early stages of training, the router explores more due to higher noise, and toward the end of training, the noise becomes small, so the router primarily exploits the learned routing behavior.

The noise term is particularly used when training and is disabled when running the inference.

5. Expert Load Balancing

Load balancing is a critical component when training MoE models. As discussed, the router can become “lazy” and collapse to sending most tokens to the same few experts. Although the noise term helps mitigate this behavior, it is not sufficient on its own. We therefore need an explicit regularization term that encourages the optimizer to keep the router balanced, preventing it from becoming overly discriminative toward only a small subset of experts.

The idea of load balancing here is to make sure that every expert receives a roughly equal share of tokens. For this, we introduce an auxiliary load balancing loss, which is added to the primary loss objective. Let's try to understand how this is done,

Consider a batch of `N` tokens `{x_1, ....., x_N}`

We need to compute how much importance the router is giving to each of the experts for all the tokens in the batch. 

5.1 Importance

\[\mathrm{Importance}_e = \sum_{i=1}^{N} p_e(x_i)\]

The `\text{Importance}_e` simply represents how much probability mass the router assigns to expert `e` for all the tokens in the batch. In other words, it is the summed probability mass that the router assigns to the expert `e` across all the tokens in the batch. This is a soft usage, meaning that it captures the router's raw confidence or intent for that expert. This does not represent what is actually executed after `Top-K` which is represented by something called `Load_e.`

5.2 Load

\[\mathrm{Load}_e = \sum_{i=1}^{N} \mathbf{1}\!\left[e \in S(x_i)\right]\]

The `\text{Load}_e` quantifies the actual discrete decisions made by the router. It will be `1` if the importance is higher for expert `e`, else `0`. It represents the fraction of tokens in the batch that were definitively assigned to expert `e` after the `Top-K` selection process (a "hard" binary choice), regardless of the original probability strength.

5.3 Auxiliary Load Balancing Loss

The objective of MoE encourages alignment between importance (soft routing) and load (hard routing). 

Currently, the `\text{Importance}_e` and `\text{Load}_e` represent the raw accumulated totals for each expert. To compute the auxiliary loss, we must normalize these values across all the experts,

\[\tilde{I}_e = \frac{\mathrm{Importance}_e}{\sum_{j=1}^{E} \mathrm{Importance}_j}, \qquad \tilde{L}_e = \frac{\mathrm{Load}_e}{\sum_{j=1}^{E} \mathrm{Load}_j}\]

And the auxiliary loss is simply the dot product between `\tilde{I}_e` and `\tilde{L}_e`,

\[\mathcal{L}_{\text{aux}} = E \sum_{e=1}^{E} \tilde{I}_e \, \tilde{L}_e\]

We multiply by `E` (the number of experts) to normalize the loss. Without this factor, the loss value would naturally shrink towards zero as you add more experts, because the individual probabilities become smaller. This multiplication counteracts that shrinking, ensuring that a perfectly balanced router always results in a loss of `1`, regardless of whether you have `4` experts or `1,000`.

Note: The auxiliary loss is minimized by the optimizer during training. Since the normalized load `\tilde{L}_e` is a direct consequence of the normalized importance​ `\tilde{I}_e`, the two are strongly correlated across a batch, and in practice `\tilde{I} \approx \tilde{L}` for each expert. As a result, their dot product behaves similarly to a sum of squares of the importance terms. 
Minimizing this dot product therefore encourages the importance vector to become uniform. When the router assigns importance uniformly across experts, the resulting load also becomes uniform, meaning that each expert receives roughly the same number of tokens. This is the core intuition behind the auxiliary load-balancing loss: by reducing it, we push the router toward a uniform distribution over experts, preventing expert collapse and promoting balanced expert utilization.  

 5.4 Overall Training Objective

The Base Task Loss for language modelling is to use a cross-entropy loss, which penalizes confidence in wrong next word predictions. Given model output logits `z(x) \in \mathbb{R}^{d_{\text{vocab}}}` and true target `y`:

\[P(y \mid x) = \operatorname{softmax}(z(x))\]

\[\mathcal{L}_{\text{CE}}(x, y) = - \sum_{c=1}^{C} y_c \log P(c \mid x)\]

Finally, we combine everything into a single objective function. While our primary goal is to predict the next token accurately, we must simultaneously ensure that the router remains balanced. We achieve this by adding the auxiliary loss to the primary loss, scaled by a hyperparameter `\lambda` (often set to a small value like `0.01` or `0.1`)

$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{CE}} + \lambda \, \mathcal{L}_{\text{aux}}$$

The training process minimizes this joint loss function. This forces the model to strike a balance, it learns to generate high-quality text (`\mathcal{L}_{\text{CE}}`) while ensuring that all experts are utilized more evenly (`\mathcal{L}_{\text{aux}}`), preventing the router from collapsing into a dense-like state.

6. Gradient Flow: Dense vs MoE

I would also like to discuss the gradient flow between Dense and MoE. This part is not strictly necessary to understand because modern libraries will use automatic differentiation methods, but this short section will help you to understand what is happening under the hood when training and how the gradient flow differs between both.

In a standard dense model, the output is given by

\[y = f(x),\]

and the gradient of the loss with respect to the model parameters is

\[\frac{\partial L}{\partial \theta} = \frac{\partial L}{\partial y} \cdot \frac{\partial f(x)}{\partial \theta}.\]

In this case, every input token propagates gradients through the entire network, meaning that all parameters are updated for every token.

In contrast, for a Mixture-of-Experts (MoE) layer, the output is

\[y = \sum_{e=1}^{E} g_e(x)\, f_e(x), \quad \text{where only } k \text{ of the } g_e(x) \text{ are nonzero}.\]

The gradient of the loss with respect to the parameters of expert \(e\) is then

\[\frac{\partial L}{\partial \theta_e} = g_e(x)\, \frac{\partial L}{\partial y} \cdot \frac{\partial f_e(x)}{\partial \theta_e}.\]

This expression makes the sparsity of MoE explicit. If \(g_e(x) = 0\), the gradient for expert \(e\) becomes zero, and no update is applied to its parameters. Consequently, only the selected experts receive gradients and are updated for a given token, while all other experts remain inactive. This is the key mechanism that enables conditional computation back and forth.

7. Implementation In PyTorch

Alright, now we are getting into the practical side. We'll implement each module one by one. First, let's define a class to store the hyperparameters used for this model.
import math
from dataclasses import dataclass
from typing import Any, Dict, Tuple

import torch
import torch.nn as nn

@dataclass(frozen=True)
class MoEConfig:
    vocab_size: int = 50257
    d_model: int = 768
    num_layers: int = 12
    num_heads: int = 16
    max_seq_len: int = 512
    num_experts: int = 8
    top_k: int = 2
    ffn_dim: int = 2048

7.1 RMS Norm Layer

For this model, we use Root Mean Square Normalization (RMSNorm), which is particularly effective for stabilizing residual learning. RMSNorm can be written in mathematical form as:

\[\operatorname{RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2 + \varepsilon}\]

\[\operatorname{RMSNorm}(x) = \gamma \odot \frac{x}{\operatorname{RMS}(x)}\]

The first equation computes the root mean square of the input vector. Expressions of this form are commonly seen in linear regression and statistical normalization. The second equation applies the actual normalization, where \( \odot \) denotes element-wise multiplication and \( \gamma \) is a learnable scale vector.
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(d_model, dtype=torch.float32))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        return (x / rms) * self.scale

7.2 Multi-head Attention Layer

Most readers are already familiar with Multi-Head Attention (MHA), which is one of the core building blocks of the Transformer architecture. The main components of MHA are the `Query (Q)`, `Key (K)`, and `Value (V)` representations, along with their corresponding projection matrices \( W_Q \), \( W_K \), and \( W_V \).

Given an input matrix \( X \in \mathbb{R}^{n \times d} \), the projections are defined as:

\[Q = X W_Q, \quad K = X W_K, \quad V = X W_V\]

The attention operation is then computed as:

\[\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^\top}{\sqrt{d_k}} \right) V\]

For multi-head attention with \( h \) heads:

\[\text{MHA}(X) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W_O\]

where each head is given by:

\[\text{head}_i = \text{Attention}(Q_i, K_i, V_i)\]

with independent projection matrices for each head.

At a high level, attention allows each token to dynamically weight and aggregate information from other tokens based on their relevance. If you'd like to refresh your understanding of attention or need a reference, check out this article: Transformers From Scratch PyTorch
class DenseNoBias(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.kernel = nn.Parameter(torch.empty(in_features, out_features, dtype=torch.float32))
        nn.init.normal_(self.kernel, std=0.02)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x @ self.kernel


# causal mask (lower triangular matrix)
def causal_mask(t: int, *, device: torch.device) -> torch.Tensor:
    return torch.tril(torch.ones((t, t), dtype=torch.bool, device=device))

# standard multi-head attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.q_proj = DenseNoBias(d_model, d_model)
        self.k_proj = DenseNoBias(d_model, d_model)
        self.v_proj = DenseNoBias(d_model, d_model)
        self.out_proj = DenseNoBias(d_model, d_model)

    def forward(self, x: torch.Tensor, *, attn_mask: torch.Tensor) -> torch.Tensor:
        b, t, d = x.shape
        q = self.q_proj(x).view(b, t, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(b, t, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(b, t, self.num_heads, self.head_dim)

        scale = 1.0 / math.sqrt(self.head_dim)
        att = torch.einsum("bthd,bshd->bhts", q, k) * scale
        att = att.masked_fill(~attn_mask.view(1, 1, t, t), -1e30)
        att = torch.softmax(att, dim=-1)
        out = torch.einsum("bhts,bshd->bthd", att, v).contiguous()
        out = out.view(b, t, d)
        return self.out_proj(out)
We have added some helper modules, such as Dense with no bias and a causal mask, which are necessary for attention.

7.3 Router (Noisy Topk)

Since we have discussed enough about the router, here is how to implement those ideas in code
class Router(nn.Module):
    def __init__(self, d_model: int, num_experts: int, top_k: int, jitter_noise: float = 0.0):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.jitter_noise = jitter_noise
        
        # simple linear layer for gating
        self.gate = DenseNoBias(d_model, num_experts)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # x shape: [batch_size, seq_len, d_model]
        
        # 1. Calculate raw logits (routing scores)
        logits = self.gate(x) # Shape: [batch, seq, num_experts]

        # Add Jitter Noise to encourage exploration (only when training,,)
        # We add noise to logits before softmax so it affects the probabilities, it must.
        if self.training and self.jitter_noise > 0:
            noise = torch.randn_like(logits) * self.jitter_noise
            logits_noisy = logits + noise
        else:
            logits_noisy = logits

        # Calculate Probabilities
        probs = torch.softmax(logits_noisy, dim=-1) # Shape: [batch, seq, num_experts]

        # Select Top-K Experts
        # topk_vals: The probabilities of the chosen experts
        # topk_idx: The indices of the chosen experts
        topk_vals, topk_idx = torch.topk(probs, k=self.top_k, dim=-1)

        # Renormalize the Top-K probabilities
        # We need the selected weights to sum to 1.0
        denom = topk_vals.sum(dim=-1, keepdim=True).clamp_min(1e-6)
        gates = topk_vals / denom

        # topk_idx: For routing tokens to experts
        # gates: For the weighted sum of outputs
        # probs: The full probability distribution (Needed for Aux Load Balancing Loss)
        return topk_idx, gates

7.4 Expert MLP

In a standard Transformer implementation, we use a simple two-layer MLP. However, in the current MoE model, the configuration specifies 8 experts. Therefore, it is more appropriate to create an MLP bank that stores the parameters of all experts in a single contiguous block, which can then be efficiently accessed using the indices of the selected experts.
class ExpertMLPBank(nn.Module):
    def __init__(self, d_model: int, hidden_dim: int, num_experts: int):
        super().__init__()
        # Weights and biases with expert dimension inclusive, this reduce need for loops
        self.w1 = nn.Parameter(torch.empty(num_experts, d_model, hidden_dim, dtype=torch.float32))
        self.b1 = nn.Parameter(torch.zeros(num_experts, hidden_dim, dtype=torch.float32))
        self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, d_model, dtype=torch.float32))
        self.b2 = nn.Parameter(torch.zeros(num_experts, d_model, dtype=torch.float32))

        nn.init.normal_(self.w1, std=0.02)
        nn.init.normal_(self.w2, std=0.02)

    def forward(self, x: torch.Tensor, expert_idx: torch.Tensor) -> torch.Tensor:
        w1 = self.w1.index_select(0, expert_idx)
        b1 = self.b1.index_select(0, expert_idx)
        w2 = self.w2.index_select(0, expert_idx)
        b2 = self.b2.index_select(0, expert_idx)

        h = torch.einsum("nd,ndh->nh", x, w1) + b1
        h = torch.nn.functional.silu(h)
        y = torch.einsum("nh,nhd->nd", h, w2) + b2
        return y

7.5 MoE Layer

We group the router and the experts together inside a single MoE layer. The router produces the routing probabilities and the indices of the top-k selected experts, and the input is then forwarded to the corresponding experts for computation. Much of those are same as we discussed in the above section.
class MoELayer(nn.Module):
    def __init__(self, d_model: int, hidden_dim: int, num_experts: int, top_k: int):
        super().__init__()
        self.router = Router(d_model=d_model, num_experts=num_experts, top_k=top_k)
        self.experts = ExpertMLPBank(d_model=d_model, hidden_dim=hidden_dim, num_experts=num_experts)
        self.top_k = top_k

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, t, d = x.shape
        topk_idx, gates = self.router(x)

        x_flat = x.reshape(b * t, d)
        idx_flat = topk_idx.reshape(b * t, self.top_k)
        gates_flat = gates.reshape(b * t, self.top_k)

        y = torch.zeros_like(x_flat)
        for j in range(self.top_k):
            e_idx = idx_flat[:, j] # Shape: [batch * seq]
            y_j = self.experts(x_flat, e_idx) # Shape: [batch * seq, d_model]
            y = y + y_j * gates_flat[:, j : j + 1] # Shape: [batch * seq, d_model]
        return y.reshape(b, t, d) # Shape: [batch, seq, d_model]

7.6 Single Decoder MoE Block

Now we can club all the components together to create a single Decoder MoE block, this includes RMS Norm, MHA, & MoE Layer.
class Block(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        hidden_dim: int,
        num_experts: int,
        top_k: int,
    ):
        super().__init__()
        self.rmsnorm_0 = RMSNorm(d_model)
        self.attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.rmsnorm_1 = RMSNorm(d_model)
        self.moe = MoELayer(
            d_model=d_model,
            hidden_dim=hidden_dim,
            num_experts=num_experts,
            top_k=top_k,
        )

    def forward(self, x: torch.Tensor, *, attn_mask: torch.Tensor) -> torch.Tensor:
        h = self.rmsnorm_0(x)
        x = x + self.attn(h, attn_mask=attn_mask)
        h = self.rmsnorm_1(x)
        x = x + self.moe(h)
        return x

7.7 MoE Model

Let's pack everything to build the model, this module contains the Token and Positional Embedding (We are using learned positional embedding, not sin-cosine positional encoding), Multiple MoE blocks, and finally attached with a language modelling classification head.
class MoE(nn.Module):
    def __init__(self, cfg: MoEConfig):
        super().__init__()
        self.cfg = cfg

        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
        self.blocks = nn.ModuleList(
            [
                Block(
                    d_model=cfg.d_model,
                    num_heads=cfg.num_heads,
                    hidden_dim=cfg.ffn_dim,
                    num_experts=cfg.num_experts,
                    top_k=cfg.top_k,
                )
                for _ in range(cfg.num_layers)
            ]
        )
        self.rmsnorm_f = RMSNorm(cfg.d_model)
        self.lm_head = DenseNoBias(cfg.d_model, cfg.vocab_size)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        b, t = input_ids.shape
        device = input_ids.device

        if t > self.cfg.max_seq_len:
            raise ValueError("input length exceeds max_seq_len")

        tok = self.tok_emb(input_ids)
        pos_idx = torch.arange(t, device=device).unsqueeze(0)
        pos = self.pos_emb(pos_idx)
        x = tok + pos

        attn_mask = causal_mask(t, device=device)
        for blk in self.blocks:
            x = blk(x, attn_mask=attn_mask)
        x = self.rmsnorm_f(x)
        logits = self.lm_head(x)
        return logits
All these are the stuff that we encounter when building standard transformer architecture, except the MoE block.

7.8 Loading Checkpoint

It comes to you as a little surprise that this model is already trained; the checkpoint for this model is available here: QMoE-400-checkpoints. You can download the torch checkpoints from the HuggingFace repo, place them in the working directory, and load them using the code below.
def load_model_from_checkpoint(
    checkpoint_path: str,
    *,
    device: str | torch.device = "cpu",
    dtype: torch.dtype = torch.float32,
) -> MoE:
    ckpt = torch.load(checkpoint_path, map_location="cpu")
    cfg_dict: Dict[str, Any] = ckpt.get("config", {})
    cfg = MoEConfig(**cfg_dict)
    model = MoE(cfg).to(torch.device(device))
    model.load_state_dict(ckpt["state_dict"], strict=True)
    model.to(dtype=dtype)
    model.eval()
    return model

If you have a GPU, you can set the device to "cuda".

You can check this GitHub repo on how to generate text with this model Q-MoE-400

Huggingface Repo Link: QMoE-400

8. Conclusion

In this article, we explored the core ideas behind MoE, including routing, Top-K selection, and load balancing. We then built a working sparse MoE Transformer from scratch in PyTorch and saw how all the pieces fit together. This approach makes it possible to train very large models without proportional compute costs. With this foundation, you are now ready to experiment, modify the architecture, and build your own scalable expert-based models.

References:

https://www.linkedin.com/in/sidharth-gn-4ab311208/

Post a Comment