MLP: Attention in a Trench Coat
In the realm of natural language processing, transformer models have revolutionized how we approach language tasks. At the heart of the transformer architecture lie two fundamental operations: attention and multilayer perceptron (MLP). While attention has received significant optimization efforts, the MLP has often been overlooked or considered as something fundamentally different. However, a closer examination reveals striking similarities between these two operations. In this article, I argue that MLPs can be interpreted as a specialized form of attention—essentially "attention in a trench coat"—and leverage this insight to develop more efficient kernels for MLP operations.
The Attention Mechanism: A Quick Refresher
Attention is the mechanism that allows transformers to focus on different parts of the input sequence when producing each output token. In its core implementation, attention involves computing similarities between queries (Q) and keys (K) to create an attention matrix, applying a softmax, and then using these weights to blend values (V).

The Flash Attention kernel, released by Tri Dao in 2022, revolutionized this process by avoiding the explicit creation of the full N×N attention matrix in GPU global memory. This optimization unlocked training and inference with much longer context lengths by removing the need to fully materialize the O(N²) attention matrix.
The MLP Module
The MLP module in transformers has traditionally been viewed as an operation that acts on tokens independently, in contrast to attention which operates across tokens. A typical MLP consists of:
- Projecting input vectors to a larger dimension
- Applying a non-linearity
- Projecting back to the original dimension
In older models like BERT and GPT-2, this often utilized the ReLU non-linearity:

Modern models like Llama and Qwen use the more sophisticated SwiGLU activation, which incorporates a data-dependent gating mechanism:

Unveiling the Disguise: MLP as Attention
Looking at these operations side by side reveals something fascinating: MLPs and attention mechanisms are structurally very similar. Let's draw the parallel:
- In attention, the K^T tensor transforms queries into attention scores
- In MLP, the up_proj matrix projects inputs into a higher dimension
- In attention, the V tensor maps attention weights to output values
- In MLP, the down_proj matrix maps activated values back to output space
- The softmax in attention and the non-linearity in MLP both serve to modulate the influence of different values
This suggests a reinterpretation of the MLP layer: MLPs are essentially attention mechanisms operating on a fixed KV cache. The key difference is that attention creates dynamic K and V tensors at runtime from the input, whereas MLP layers operate on learned parameters that remain fixed after training.
In other words, an MLP is like an attention mechanism that always attends to the same set of "tokens" (its weight matrices) regardless of the input. It's attention wearing a trench coat, pretending to be something else entirely!
Flash Optimization for MLPs
If MLPs are secretly just attention operations, why don't we apply the same fusion techniques that made Flash Attention so successful? The answer lies in the scaling properties and computational bottlenecks:
- Attention's computational complexity scales poorly with sequence length since the K and V tensors grow with input length
- MLP operations use fixed-size matrices regardless of sequence length
However, during inference, when we're often memory-bandwidth bound rather than compute-bound, kernel fusion becomes tremendously valuable for MLPs too. By fusing operations, we can avoid redundant memory transfers between GPU global memory and compute units.
Looking at the SwiGLU operation step by step, we can see significant opportunities to reduce memory movement through fusion:

Building a FlashMLP Kernel
With this reframing of MLPs as attention-like operations, we can adapt the principles from Flash Attention to optimize MLP operations. Just as Flash Attention avoids materializing the full attention matrix, our FlashMLP can avoid materializing large intermediate tensors.
Learning from Flash Attention
We'll use the Triton DSL, the same tool that powers many efficient attention implementations, to build our optimized MLP kernel. Our starting point is PyTorch's flex-attention kernel, which performs competitively with hand-tuned flash attention implementations.
First Attempt: Mimicking Flex-Attention
Our first approach attempts to copy the computation pattern from flex-attention, where query vectors from a single head are loaded and dot products are performed with K and V tensors:
.gif)
However, this approach falls short for MLPs because the inner dimensions are much larger (e.g., 8192 for Llama 8B's hidden dimension versus 64 for the head dimension). This means we cannot keep the entire input vector in fast shared memory during computation:

This results in a much slower kernel than the baseline torch implementation:

Improved Approach: Keeping Activations in Shared Memory
A more effective strategy is to keep the temporary tensors (post-non-linearity) in shared memory longer and perform a dot product across the entire row of the down_proj (V-equivalent) tensor:
.gif)
This modification significantly improves performance, matching the naive implementation at small batch size, but becoming slower at larger batch sizes.

Extending to SwiGLU
With this faster kernel in hand, we can tackle the SwiGLU operation, which requires loading and computing with the gate_proj matrix in addition to up_proj and down_proj:
.gif)
A critical optimization here is that we can reuse the same block of memory from the input tensor to compute both the gate and up tensor activations. This gives our fusion kernel an even greater advantage over naive implementations and the partially-fused SiLU+multiply kernels used in systems like vLLM.
Performance Results
Our FlashSwiGLU implementation consistently outperforms the native PyTorch SwiGLU implementation across various batch sizes:

What's Next for Flash MLP
This exercise demonstrates that reimagining MLPs as attention mechanisms can lead to meaningful performance improvements through specialized kernels. To make these kernels more broadly useful, several extensions are necessary:
- Support for more data types (FP16, BF16) beyond the current FP32 implementation
- Auto-tuning for different model architectures, GPU types, and tensor shapes
- Integration with quantized models (4-bit, 3-bit, 2-bit parameters)
- Compatibility with additional features like token skipping for mixture-of-experts models
Conclusion
By recognizing that MLPs are essentially "attention in a trench coat"—attention mechanisms operating on fixed parameters—we gain new insights into optimizing these crucial components of transformer models. The FlashMLP kernel demonstrates that the same principles that made Flash Attention so successful can be applied to MLP operations, resulting in faster inference and reduced memory usage. As large language models continue to grow in size and importance, these optimizations will help make AI more efficient and accessible across a wider range of hardware.
Deploying Enterprise-Grade AI in Your Environment?
Unlock unparalleled performance, security, and customization with the TitanML Enterprise Stack