How do Mixture of Experts Layers Work? Part 1¶
Mixture of Experts layers have become increasingly popular in newer large language models. However, the idea has been around for a while in different flavours. It can be traced to a 1989 paper by Hampshire and Waibel[1] which used a gating network to allocate cases to one or more expert networks. A 1990 paper by Jacobs, Jordan, Nowlan and Hinton[2] added the idea of making the gating network trainable using a supervised learning approach.
By themselves, mixture of experts models are not very useful, since deep neural networks can easily learn to discriminate between training cases even without the explicit inductive bias provided by the mixture of experts architecture. However, when combined with the idea of sparse activations, they make it possible to train extremely large networks using comparatively fewer computations. This was demonstrated in the 2017 paper by Shazeer et al., [3] which used sparse mixture of experts layers to grow the size of the model while maintaining sub-linear growth in the activations.
In this post, we will train a simple mixture of experts layer that learns to discriminate between different training cases. Each expert can be modeled using a feed-forward network, while the gating layer can be modeled using a softmax classifier. Sparsity is introduced using a Top-K function, which routes examples to only the K highest ranked experts for each example. However, this makes the model difficult to train, because the Top-K function is not differentiable.
For now, we will avoid this problem by not using the Top-K function and smoothly allocating examples to experts based on the softmax weights. Mathematically,
where \(W_g\) is the gating network.
Let’s start by creating a simple dataset.¶
This dataset consists of two different sets of square matrices, each set having undergone a different linear transformation.
import jax
import jax.numpy as jnp
D, B, C = 10, 1000, 3
x = jax.random.normal(
jax.random.key(0),
(D * B, C)
)
expert_ids = (x[:, 0] > 0 ).astype(jnp.int32)
t = [
jax.random.normal(
jax.random.key(1000), (C, C)
),
jax.random.normal(
jax.random.key(2000), (C, C)
)
]
def transform(xi, ei):
return jnp.where(ei == 0, xi @ t[0], xi @ t[1])
y = jax.vmap(lambda xi, ei: transform(xi, ei))( x, expert_ids)
Next, let’s build our model.¶
Let’s define the gating function, also called a router
import flax.nnx as nnx
class Router(nnx.Module):
def __init__(self, dim: int, num_experts: int, *, rngs: nnx.Rngs):
self.w1 = nnx.Linear(dim, num_experts, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
return self.w1(x)
Next, let’s define our expert as a simple single layer linear network
class Expert(nnx.Module):
def __init__(self, dim: int, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(dim, dim, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
return self.linear(x)
Finally, let’s define the model.
class SimpleMoE(nnx.Module):
def __init__(self, dim: int, *, rngs: nnx.Rngs):
num_experts = 2
self.router = Router(dim, num_experts=num_experts, rngs=rngs)
self.experts = [
Expert(dim, rngs=rngs)
for _ in range(num_experts)
]
def __call__(self, x: jax.Array) -> jax.Array:
gate_logits = self.router(x)
expert_weights = jax.nn.softmax(gate_logits, axis=-1)
outputs = [ e(x) for e in self.experts ]
result = jnp.zeros_like(x)
for i, o in enumerate(outputs):
result += (o * expert_weights[:, i:i+1])
return result
Now let’s train our model¶
Since this is a regression problem, we can use the mean squared error to train our model.
import optax
def loss_fn(model, x, y):
y_pred = model(x)
loss = jnp.mean((y - y_pred)**2) # + lb_loss
return loss
model = SimpleMoE(dim=C, rngs=nnx.Rngs(0))
tx = optax.adam(1e-3)
state = nnx.Optimizer(model, tx)
@nnx.jit
def step(state, x, y):
loss, grads = nnx.value_and_grad(loss_fn)(state.model, x, y)
state.update(grads)
return loss
x = x.reshape(D, B, C)
y = y.reshape(D, B, C)
for e in range(2000):
for i in range(D):
loss = step(state, x[i], y[i])
if e % 200 == 0:
print(e, loss)
0 3.5661597
200 0.90568215
400 0.13584565
600 0.0723149
800 0.052140094
1000 0.039651874
1200 0.031048419
1400 0.024952801
1600 0.020527959
1800 0.017231295
It works! In the next post, we’ll build on this simple model and introduce the notion of sparse activations.
References¶
S. R. Hampshire and A. Waibel. “The use of a priori knowledge in the design of a speech recognition system based on neural networks.” In Proceedings of the International Joint Conference on Neural Networks (IJCNN), 1989.
Jacobs, R. A., Jordan, M. I., Nowlan, S. J., & Hinton, G. E. “Adaptive Mixtures of Local Experts.” Neural Computation, 3(1), 79–87, 1991.
N. Shazeer, A. Mirhoseini, K. Maziarz, A. Davis, Q. Le, G. Hinton, and J. Dean. “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer.” In Proceedings of the International Conference on Learning Representations (ICLR), 2017.