From Set Transformer to Perceiver Sampler

On multi-modal LLM Flamingo’s vision encoder

5 min read

7 hours ago

Designing Multi-modal LLM is hard.

The state-of-the-art multi-modal LLMs are primarily based on existing LLM architectures, with modifications specifically addressing different sources of input, and that’s where the difficulty comes from. The latest Nvidia paper divides the commonly used multi-modal architectures into two categories:

  • decoder-based;
  • cross-attention-based.

One of my previous medium articles discussed the latest paper from Meta, using decoder-based architecture, which converts an input image into a latent vector using a VAE encoder to address the issue that the image space is continuous and different from the discrete text space.

However, the problem with cross-attention-based architecture is different. For example, in the multi-modal LLM model Flamingo, the critical issue is converting the vision embedding from a generic vision model of varying temporal and spatial dimensions into the cross-attention layer to match the language input dimension.

In this post, I will dive deep into Flamingo’s unique design on top of the vision encoder, the Perceiver Resampler, to explain how this issue was solved. Furthermore, I will explore the Perceiver Resampler’s origin — the Induced Set Attention Block from Set Transformer, which further inspired DeepMind’s Perceiver model for learning fixed-length latent embeddings from generic input data.

Image source: https://pxhere.com/en/photo/1399240

Set Transformer

Published in 2019, the Set Transformer work extended the original Transformer model on sets to solve permutation-invariant problems like Set Anomaly Detection, Point Cloud Classification, etc. Inspired by the sparse Gaussian process where a small set of inducing variables could adequately approximate the posterior of an input, the Set Transformer uses the Induced Set Attention Blocks (ISAB) defined below:

Induced Set Attention Blocks (ISAB). Equantion source: https://arxiv.org/pdf/1810.00825

MAB(X, Y) is the transformers’ original multi-head attention block, where query = X, key/value = Y. The ISAB block is almost identical to two stacked multi-head attention blocks, except that the input key/value is replaced by the inducing matrix I. The original set X is of dimension N*D, and I is of dimension M*D, representing M 1*D inducing points. A visualization is shown below.

A visualization of multi-head attention block and induced set attention block. Image source: https://arxiv.org/pdf/1810.00825

Note that the design of the ISAB is to save computational cost. The reason is that the M could be much smaller than the original N dimension, which makes the time complexity of ISAB O(N*d) much smaller than the original self-attention complexity O(N**2*d).

Perceiver

Inspired by the use of inducing points as query matrix from Set Transformer, the Perceiver model, proposed by DeepMind, separated the query matrix as a short sequence of learnable latent embeddings (e.g., N=512) while the key and value pair to be a byte array that is an ultra-long sequence input (e.g., M=224*224 pixels).

Perceiver model architecture. Image source: https://arxiv.org/abs/2103.03206

The cross attention is borrowed from the decoder part of the original transformer, where the query and key/value come from different sources, and in this case, unlearnable representations:

Multi-head attention and cross attention. Image by author.

Since K and V are input “constants,” the Perceiver transformer layer computational complexity becomes only relative to the latent space, which is O(N**2), and is also called a latent transformer. Decoupled from the input size, the latent transformers could quickly scale up to 48 layers, which is a great advantage over traditional transformer designs.

Flamingo’s Vision Encoder and Perceiver Resampler

Instead of applying the Perceiver directly, Flamingo first uses a pre-trained, CNN-based, weight-frozen Normalizer-Free ResNet (NFNet) to extract image/video features, then adds a learnable temporal positional embedding and flattens them to the 1D sequence. The Perceiver Resampler is attached to the vision encoder to learn a fixed-size latent embedding before being passed into the cross-attention layer of the leading architecture.

Flamingo architecture. Image source: https://arxiv.org/pdf/2204.14198

Like DeepMind’s Preceiver model, the Percerver Resampler uses constant input embeddings as keys/values and the learnable latent vectors as queries. Note that no spatial encoding is used here, and the rationale is that the previous vision encoder, NFNet, is a convolution-based model with spatial information embedded in the channel information. To increase performance, the learnable vectors are concatenated to the key/value vectors in the cross-attention computation.

Preceiver Resampler architecture. Image source: https://arxiv.org/abs/2204.14198

The detailed algorithm is given below:

Perceiver Resampler algorithm. Algorithm source: https://arxiv.org/abs/2204.14198

Summary

This article gives a detailed walk-through of the vision encoder part of the Flamingo architecture. The vision encoder has a unique design, the Perceiver Resampler, which originated from the Set Transformer and the Perceiver model and could minimize the cross-attention computation cost while leveraging information from both the spatial and temporal domains.

References

  • Dai et al., NVLM: Open Frontier-Class Multimodal LLMs. arXiv 2024.
  • Zhou et al., Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model. arXiv 2024.
  • Alayrac et al., Flamingo: a Visual Language Model for Few-Shot Learning. NeurIPS 2022.
  • Jaegle et al., Perceiver: General Perception with Iterative Attention. ICML 2021.
  • Brock at al., High-Performance Large-Scale Image Recognition Without Normalization. arXiv 2021.
  • Lee et al., Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks. ICML 2019. Slides
  • Vaswani et al., Attention Is All You Need. NeurIPS 2017.
  • Stanford CS25: V1 I DeepMind’s Perceiver and Perceiver IO: new data family architecture, https://www.youtube.com/watch?v=wTZ3o36lXoQ
  • HuggingFace, Perceiver Model Doc. https://huggingface.co/docs/transformers/v4.34.0/en/model_doc/perceiver