How to Write a Softmax Kernel in Pallas

softmax-2d

The softmax function is fundamental to neural network training because it converts raw model outputs (logits) into valid probability distributions. This is essential for classification tasks where we need to interpret network predictions as probabilities over discrete classes. Additionally, softmax is differentiable, allowing gradients to flow effectively through the network during training, which is why it’s the standard choice for the output layer in multi-class classification models.

In the previous post, we wrote a GPU kernel in Pallas for performing efficient matrix multiplication. In this post, we’ll build on this by writing a GPU kernel for the softmax function. We will also write the backward pass and test it with a neural network training run.

Softmax Operation

Given an input vector \(z = (z_1, ..., z_n) \in R^n\), the softmax function \(σ : R^n → (0,1)^n\) produces a probability distribution over the n entries:

\[ \sigma(z)_i = \frac{\exp(z_i)}{\sum_{j=1}^n \exp(z_j)} \quad\text{for } i=1,\dots,n. \]

Softmax is invariant to shifts: \(σ(z) = σ(z + c)\) for any scalar c. For numerical stability one commonly uses

\[ \sigma(z)_i = \frac{\exp(z_i - \max_j z_j)}{\sum_{k=1}^n \exp(z_k - \max_j z_j)}. \]

let’s start with a naive implementation below:

import jax.numpy as jnp

def manual_softmax(logits):
    m = jnp.max(logits, axis=-1)         # 1
    s = jnp.exp(logits - m[..., None])   # 2
    l = jnp.sum(s, axis=-1)              # 3
    return s / l[..., None]              # 4

Now imagine trying to implement this on a GPU. For computing line #1, the entire logits tensor will have to be loaded into the GPU cache. This will be extremely slow for large matrices.

Online Softmax

Since the max and sum operations are across columns, it’s easy to tile across the rows. However, the column axis might still cause a bottleneck. Since we cannot parallelize the row operations, we can calculate the max (m) and normalizing factor (l) across each block of columns in a loop and use a trick to correct both as we process each block.

new_m = max(old_m, block_max)
l = l * exp(old_m - new_m) + sum(exp(x_block - new_m))

Let’s update our implementation to reflect our new algorithm:

BLOCK_M = 64 # These can be tuned for your GPU
BLOCK_N = 64

# Online softmax
def online_softmax(logits):
    out = jnp.zeros_like(logits)
    m = jnp.full((logits.shape[0],), -jnp.inf)
    l = jnp.zeros((logits.shape[0],))
    for i in range(0, logits.shape[0], BLOCK_M):  # This axis can be tiled in parallel blocks.
        for j in range(0, logits.shape[1], BLOCK_N):  # This axis cannot be tiled in parallel, so it is tiled sequentially
            block = logits[i:i+BLOCK_M, j:j+BLOCK_N] # Load a block
            block_max = jnp.max(block, axis=-1) # Get the max across the block
            curr_max = m[i:i+BLOCK_M] # Retrieve the previous computed max for the rows
            new_max = jnp.maximum(curr_max, block_max) # Update the max for all the rows
            m = m.at[i:i+BLOCK_M].set(new_max)  
            l_block = l[i:i+BLOCK_M] # Get the denominator for the rows in the block
            l_block = l_block * jnp.exp(curr_max - new_max) + jnp.sum( # Correct and update the denominator based on the current block
                jnp.exp(block - new_max[:, None]), axis=-1
            )
            l = l.at[i:i+BLOCK_M].set(l_block)
        for j in range(0, logits.shape[1], BLOCK_N):  # Loop over the column blocks and generate the output values 
            out_block = jnp.exp(logits[i:i+BLOCK_M, j:j+BLOCK_N] - m[i:i+BLOCK_M][:, None]) / l[i:i+BLOCK_M][:, None]
            out = out.at[i:i+BLOCK_M, j:j+BLOCK_N].set(out_block)
    
    return out

The next step is to convert this into an efficient GPU kernel.

Let’s Implement the Forward Pass Kernel

from functools import partial

import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl
from jax.experimental.pallas import triton as plgpu

INTERPRET_MODE = False # Set to False on GPU
NUM_WARPS = 4
NUM_STAGES = 3

def softmax_kernel(x_ref, out_ref, *, n_col_blocks, n_rows, n_cols):
    max_reg = jnp.full((BLOCK_M,), -jnp.inf, dtype=jnp.float32) 
    l_reg = jnp.zeros((BLOCK_M,), dtype=jnp.float32) 
    row_ids = pl.program_id(0) * BLOCK_M + jnp.arange(BLOCK_M)
    row_mask = row_ids < n_rows

    def stats_body(t, args):
        max_reg, l_reg = args
        idx = pl.dslice(t * BLOCK_N, BLOCK_N)
        col_ids = t * BLOCK_N + jnp.arange(BLOCK_N)
        cols_mask = col_ids < n_cols
        mask = row_mask[:, None] & cols_mask[None, :]

        x_tile = plgpu.load(
            x_ref.at[:, idx],
            mask=mask,
            other=-jnp.inf,
        ).astype(jnp.float32)
        max_tile = jnp.max(x_tile, axis=-1)
        max_new = jnp.maximum(max_reg, max_tile)
        l_update = l_reg * jnp.exp(max_reg - max_new) + jnp.sum(
            jnp.exp(x_tile - max_new[:, None]), axis=-1
        )
        return max_new, l_update
        
    max_reg, l_reg = jax.lax.fori_loop(0, n_col_blocks, stats_body, (max_reg, l_reg))

    def out_body(t, _):
        idx = pl.dslice(t * BLOCK_N, BLOCK_N)
        col_ids = t * BLOCK_N + jnp.arange(BLOCK_N)
        cols_mask = col_ids < n_cols
        mask = row_mask[:, None] & cols_mask[None, :]

        x_tile = plgpu.load(
            x_ref.at[:, idx],
            mask=mask,
            other=-jnp.inf,
        ).astype(jnp.float32)
        out_tile = jnp.exp(x_tile - max_reg[:, None]) / l_reg[:, None]
        plgpu.store(out_ref.at[:, idx], out_tile.astype(jnp.float32), mask=mask)

    _ = jax.lax.fori_loop(0, n_col_blocks, out_body, None)


@jax.jit
def softmax(logits):
    n_row_blocks = pl.cdiv(logits.shape[0], BLOCK_M)
    n_col_blocks = pl.cdiv(logits.shape[1], BLOCK_N)
    return pl.pallas_call(
        partial(softmax_kernel, n_col_blocks=n_col_blocks, n_rows=logits.shape[0], n_cols=logits.shape[1]),
        out_shape=jax.ShapeDtypeStruct(logits.shape, jnp.float32),
        grid=(n_row_blocks,),
        in_specs=[pl.BlockSpec((BLOCK_M, logits.shape[1]), lambda i: (i, 0))],
        out_specs=pl.BlockSpec((BLOCK_M, logits.shape[1]), lambda i: (i, 0)),
        interpret=INTERPRET_MODE,
        compiler_params=plgpu.CompilerParams(
            num_warps=NUM_WARPS,
            num_stages=NUM_STAGES,
        ),
    )(logits)

Performance

Let’s compare our performance with the out-of-the-box softmax implementation provided by Jax.

import time

def bench(fn, *args, iters=10):
    times = []
    for _ in range(iters):
        t0 = time.perf_counter()
        out = fn(*args)
        out.block_until_ready()
        t1 = time.perf_counter()
        times.append(t1 - t0)
    times.sort()
    return times[len(times) // 2]


d = 1024
key = jax.random.key(0)
logits = jax.random.normal(shape=(d, d), key=key)

out_jax = jax.nn.softmax(logits)
out_online = online_softmax(logits)
out_pallas = softmax(logits)

assert jnp.allclose(jnp.squeeze(out_jax), out_pallas)
assert jnp.allclose(jnp.squeeze(out_jax), out_online)

softmax_jit = jax.jit(jax.nn.softmax)

_ = softmax_jit(logits).block_until_ready()
_ = softmax(logits).block_until_ready()

t_jax = bench(softmax_jit, logits)
t_pallas = bench(softmax, logits)

print(f"Jax Softmax: {t_jax*1e3:.2f} ms")
print(f"Pallas Softmax: {t_pallas*1e3:.2f} ms")
print(f"Speedup (jax / pallas): {t_jax / t_pallas:.2f}x")
W0107 10:44:50.336112    3714 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0107 10:44:50.340067    3436 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
Jax Softmax: 0.30 ms
Pallas Softmax: 0.33 ms
Speedup (jax / pallas): 0.93x

Our softmax kernel is around 7% slower than the Jax implementation. This is not surprising considering that Jax uses Nvidia’s CUDA kernels under the hood, which are highly tuned. Now, let’s implement the backward pass.

Let’s Implement the Backward Pass Kernel

Let’s first derive the expression for the backward pass.

Let \(x\in\mathbb{R}^n, y=\mathrm{softmax}(x)\) with \(y_i=\frac{e^{x_i}}{\sum_{k=1}^n e^{x_k}}.\) Assume an upstream gradient \(g=\frac{\partial L}{\partial y}\) is given, and we want \(\frac{\partial L}{\partial x}\).

First compute the Jacobian of softmax. Let \(S=\sum_k e^{x_k}\), so \(y_i=e^{x_i}/S\).

Differentiate \(y_i\) w.r.t. \(x_j\):

\[ \frac{\partial y_i}{\partial x_j} =\frac{\partial}{\partial x_j}\left(\frac{e^{x_i}}{S}\right) =\frac{\delta_{ij}e^{x_i}\,S - e^{x_i}\,\frac{\partial S}{\partial x_j}}{S^2}\]

But \(\frac{\partial S}{\partial x_j}=e^{x_j}\) Substitute:

\[ \frac{\partial y_i}{\partial x_j} =\frac{\delta_{ij}e^{x_i}S - e^{x_i}e^{x_j}}{S^2} =\delta_{ij}\frac{e^{x_i}}{S} - \frac{e^{x_i}}{S}\frac{e^{x_j}}{S} =\delta_{ij}y_i - y_i y_j \]

Now apply the chain rule:

\[ \frac{\partial L}{\partial x_j} =\sum_{i=1}^n \frac{\partial L}{\partial y_i}\frac{\partial y_i}{\partial x_j} =\sum_i g_i\left(\delta_{ij}y_i - y_i y_j\right) \]
\[ = g_j y_j - y_j\sum_i g_i y_i = y_j\left(g_j - \sum_i g_i y_i\right) \]

Finally, in vector form:

\[ \;\frac{\partial L}{\partial x} = y \odot (g - \langle g, y\rangle)\; \]

The kernel for the backward pass can be implemented in two steps. First, we can compute the inner product \( \langle g, y\rangle \), then an elementwise operation to compute the final expression. Since this is a binary classifier, both the upstream gradient \(g\) and the output \(y\) will be of shape (B, C) where B is the batch size and C is the number of classes. Since C = 2, we only need to tile our kernel along the B axis, simplifying our implementation greatly.

def softmax_backward_kernel(y_ref, dy_ref, dx_ref):
    # compute the inner product <g_ref, y_ref>
    dy_reg = plgpu.load(dy_ref)
    y_reg = plgpu.load(y_ref)
    g_dot_y = jnp.sum(dy_reg * y_reg, axis=1)

    # Compute the output block
    output_reg = y_reg * ( dy_reg - g_dot_y[:, None] )
    plgpu.store(dx_ref, output_reg)


@jax.jit
def softmax_backward(y, dy):
    M, N = y.shape

    grid = (pl.cdiv(M, BLOCK_M),)
    out_shape = jax.ShapeDtypeStruct((M, N), y.dtype)

    return pl.pallas_call(
        softmax_backward_kernel,
        out_shape=out_shape,
        grid=grid,
        in_specs=[
            pl.BlockSpec((BLOCK_M, N), lambda i: (i, 0)),  # y
            pl.BlockSpec((BLOCK_M, N), lambda i: (i, 0)),  # dy 
        ],
        out_specs=pl.BlockSpec((BLOCK_M, N), lambda i: (i, 0)),  # dx
        interpret=INTERPRET_MODE,
        compiler_params=plgpu.CompilerParams(
            num_warps=NUM_WARPS,
            num_stages=NUM_STAGES,
        ),
    )(y, dy)


@jax.custom_vjp
def softmax_pallas(x):
    return softmax(x)


def softmax_fwd(x):
    y = softmax(x)
    return y, y


def softmax_bwd(saved_y, dy):
    (y,) = (saved_y,)
    dx = softmax_backward(y, dy)
    return (dx,)


softmax_pallas.defvjp(softmax_fwd, softmax_bwd)

Let’s Evaluate our Kernel

We will attempt to train a binary classifier model on some synthetic data. Let’s start by generating a toy dataset.

import jax

B, E = 256, 24 # (batch size, number of features)
x = jax.random.normal(jax.random.key(1000), (B, E))
class_ids = (x[:, 0] > 0).astype(jnp.int32)
y = class_ids

Next, let’s define our binary classifier and loss function.

from dataclasses import dataclass
import flax.nnx as nnx

@dataclass
class ModelConfig:
    in_dim: int
    hidden_dim: int
    out_dim: int

class Model(nnx.Module):
    def __init__(self, config: ModelConfig, rngs: nnx.Rngs):
        self.config = config
        self.l1 = nnx.Linear(config.in_dim, config.hidden_dim, rngs=rngs)
        self.l2 = nnx.Linear(config.hidden_dim, config.out_dim, rngs=rngs)

    def __call__(self, x):
        x = x.reshape(-1, x.shape[-1])
        x = self.l1(x)
        x = jax.nn.relu(x)
        x = self.l2(x)
        return x


def loss_fn(model, x, y):
    logits = model(x)
    probs = softmax_pallas(logits)
    labels = y.reshape(-1)
    one_hot = jax.nn.one_hot(labels, probs.shape[-1], dtype=probs.dtype)
    loss = -jnp.mean(jnp.sum(one_hot * jnp.log(probs + 1e-9), axis=-1))
    return loss

Before we train the model, let’s first test if our backward pass kernel is correct.

def loss_from_logits_pallas(logits, y):
    probs = softmax_pallas(logits)
    labels = y.reshape(-1)
    one_hot = jax.nn.one_hot(labels, probs.shape[-1], dtype=probs.dtype)
    loss = -jnp.mean(jnp.sum(one_hot * jnp.log(probs + 1e-9), axis=-1))
    return loss

def loss_from_logits_gt(logits, y):
    probs = jax.nn.softmax(logits)
    labels = y.reshape(-1)
    one_hot = jax.nn.one_hot(labels, probs.shape[-1], dtype=probs.dtype)
    loss = -jnp.mean(jnp.sum(one_hot * jnp.log(probs + 1e-9), axis=-1))
    return loss

@nnx.jit
def verify(model, x, y):
    logits = model(x)
    d_logits_pallas = jax.grad(loss_from_logits_pallas)(logits, y)
    d_logits_gt = jax.grad(loss_from_logits_gt)(logits, y)
    return jnp.allclose(d_logits_pallas, d_logits_gt)


default = jax.random.key(69)
rngs = nnx.Rngs(default=default) 

config = ModelConfig(in_dim=E, hidden_dim=E * 4, out_dim=2)
model = Model(config, rngs)
model.train()

print(verify(model, x, y))
True

Excellent! Looks like our backward pass kernel works correctly. Finally, let’s overfit the model on our toy dataset using our softmax kernel.

import optax 

@nnx.jit
def step(model, state, x, y):
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    state.update(model, grads)
    return loss

tx = optax.adam(1e-1)
state = nnx.Optimizer(model, tx, wrt=nnx.Param)

iters = 15
for i in range(iters):
    loss = step(model, state, x, y)
    print(f"iter {i}: loss={loss}")
iter 0: loss=0.7163699865341187
iter 1: loss=0.578166127204895
iter 2: loss=1.0504496097564697
iter 3: loss=0.13622123003005981
iter 4: loss=0.3550783097743988
iter 5: loss=0.1289454996585846
iter 6: loss=0.06486515700817108
iter 7: loss=0.056444257497787476
iter 8: loss=0.05389563366770744
iter 9: loss=0.03236997872591019
iter 10: loss=0.016692688688635826
iter 11: loss=0.0100052859634161
iter 12: loss=0.003909544087946415
iter 13: loss=0.0019604917615652084
iter 14: loss=0.0014198491116985679

We were able to successfully overfit our toy dataset using our softmax implementation. In the next one, we’ll build on this and implement a custom Pallas kernel for computing self-attention.