Understanding MLP-Mixer as a wide and sparse MLP through random permutation matrix

2023-11-09 explainer

📎 Migration note: 7 inline media block(s) (image/file) were not migrated by scripts/copy-post-bodies-to-notes.ts — Notion's API can only re-attach external URLs, not re-upload internal files. The originals remain on the legacy Posts Public page until Phase 8.

Understanding MLP-Mixer as wide and sparse MLP through Random Permutation Matices

Tomohiro Hayase Talk at Non-Commutative Probability Theory, Random Matrix Theory and their Applications(NPRM2023) 2023/11/08--09 Preprint:

Table of Contents

  1. Effective Expression of MLP-Mixer
    1. Preliminaries
    2. Symmirarity between MLP and MLP-Mixer
    3. Effective Width
    4. Monarch Matrices
  2. Alternative to static sparse weight MLP
    1. Random Permuted Mixers
    2. Revisit the similarity in wider cases
    3. Conclusion and Future Works

1. Effective Expression of MLP-Mixer

Introduction

Research Question: Why MLP-Mixer has higher performance than usual MLP? Our Answer: Because layers in MLP-Mixer are extremely wide MLP.

Preliminaries

MLP

MLP (multilayer-perceptron) is a composition of transforms in the form of

y=ϕ(Wx),xRmy = \phi( Wx), x \in \R^m

where WW is a parameter matrix ( the transforms do not share the parameter matrices).

Static Mask: Consider a matrix MM of entries 0 or 1 and replace W by M \odot W:

y=ϕ((MW)x)y = \phi( (M \odot W) x )

MLP-Mixer

NeurIPS2021, Tolstikhin, et.al

The structure is less structured than Convolutional Neural Networks or Vision Transformers.

Blocks of MLP-Mixer:

Token-MLP(X)=W2ϕ(W1X),Channel-MLP(X)=ϕ(XW3)W4\text{Token-MLP}(X) = W_2 \phi(W_1 X), \quad \text{Channel-MLP}(X)= \phi( X W_3) W_4

where W1RγS×SW_1 \in \mathbb{R}^{\gamma S\times S}, W2RS×γSW_2 \in \mathbb{R}^{S \times \gamma S}, W3RC×γCW_3 \in \mathbb{R}^{C\times \gamma C} , W4RγC×CW_4 \in \mathbb{R}^{\gamma C\times C}.

Symmirarity between MLP-Mixer and MLP via vectorization

Vectorization and effective width

We represent the vectorization operation of the matrix S×CS \times C matrix XX by vec(X)\text{vec}(X); more precisely,

(vec(X))(j1)d+i=Xij,(i=1,,S,j=1,,C).(\text{vec}(X))_{{ (j-1)d + i}} = X_{ij} , (i = 1, \dots, S, j= 1, \dots, C).

In other words, the map is the representation

vec:MS,C(R)L2(MS,C(R)),\mathrm{vec}: M_{S,C}(\R) \to L^2(M_{S,C}(\R)),

We also define an inverse operation mat()\text{mat}(\cdot) to recover the matrix representation.
There exists a well-known equation for the vectorization operation and the tensor ( or Kronecker) product denoted by otimes\\otimes;

vec(WXV)=(VW)vec(X),\text{vec}(W X V) = (V^\top \otimes W) \text{vec}(X),

for WRS×SW \in \mathbb{R}^{S \times S} and VRC×CV \in \mathbb{R}^{C \times C}.

As discussed later, the aforementioned equation corresponds to the vectorization of an MLP-Mixer block with a linear activation function.
The vectorization of the feature matrix WXVW X V is equivalent to a fully connected layer of width

m=SCm=SC

with a weight matrix VWV^\top \otimes W. We refer to this mm as the *effective width *of mixing layers.

Under vectorization of feature matrices

Channel-Mixing layer is converted into :

vec(X)(ICW)vec(X),\mathrm{vec}(X) \mapsto (I_C \otimes W) \mathrm{vec}(X),

Token-Mixing layer is converted into:

vec(X)(VIS)vec(X),\mathrm{vec}(X) \mapsto (V^\top \otimes I_S) \mathrm{vec}(X),

In MLP-Mixer, when we treat each S×CS \times C feature matrix XX as an SCSC-dimensional vector vec(X)\mathrm{vec}(X), the right multiplication by an C×CC \times C weight VV and the left weight multiplication by a S×SS \times S weight WW are represented as \begin{align}

vec(X)(ICW)vec(X),  vec(X)(VIS)vec(X).\mathrm{vec}(X) \mapsto (I_C \otimes W)\mathrm{vec}(X), \ \ \mathrm{vec}(X) \mapsto (V^\top \otimes I_S ) \mathrm{vec}(X).

This expression clarifies that the mixing layers work as an MLP with special weight matrices with the tensor product. As usual,

S,C102 to  103S, C \sim 10^2 \text{\ to \ }10^3

Mixer is equivalent to an extremely wide MLP

m=SC=104 to  106m= SC=10^4 \text{\ to \ } 10^6

Moreover, the ratio of non-zero entries in the weight matrix ICWI_C \otimes W is 1/C1/C and that of VISV^\top \otimes I_S is 1/S1/S.

#non-zero entries in ICW=1/C\# \text{non-zero entries in } I_C \otimes W = 1/C #non-zero entries in VIS=1/S\# \text{non-zero entries in } V^\top \otimes I_S = 1/S

e.g. Block-matrix rep:

IcW=(W000W000W)I_c \otimes W = \begin{pmatrix} W & 0 & \cdots & 0 \\ 0 & W & \cdots & 0 \\ \vdots & & \ddots & \vdots\\ 0 & 0 & \cdots & W \end{pmatrix}


Therefore, the weight of the effective MLP is highly sparse.

Commutation Matrix

Furthermore, to consider only the left multiplication of weights, we introduce commutation matrices:

A commutation matrix JCJ_C is defined as

Jcvec(X)=vec(X)J_c \mathrm{vec}(X) = \mathrm{vec}(X^\top)

where XX is an S×CS \times C matrix. Note that for nay entry-wise function ϕ\phi,

Jcϕ(x)=ϕ(Jcx),xRmJ_c \phi (x) = \phi (J_c x), x \in \R^m

Note that

VIS=Jc(ISV)Jc.V^\top \otimes I_S = J_c^\top (I_S \otimes V)J_c.

Effective Expression of MLP-Mixer: Channel-MLP Block:

u=ϕ(Jc(ICW2)ϕ((ICW1)x)),u= \phi (J_c (I_C \otimes W_2) \phi((I_C \otimes W_1) x)),\\

Token-MLP Block:

y=ϕ(Jc(ISW4)ϕ((ISW3)u))y= \phi(J_c^\top ( I_S \otimes W^\top_4 ) \phi(( I_S \otimes W^\top_3 ) u) )

MLP with static-mask

Static Mask: Consider a matrix MM of entries 0 or 1 distributed and replace W in each layer of MLP by MWM \odot W:

y=ϕ((MW)x)y = \phi( (M \odot W) x ) MBernoulli(p)M\sim \mathrm{Bernoulli}(p)
  • The mask matrix MM is fixed durring the trainining.

Hidden features and test accuracy

To validate the similarity of networks in a robust and scalable way, we look at the similarity of hidden features of MLPs with sparse weights and MLP-Mixers based on the centered kernel alignment (CKA)  Nguyen T., Raghu M, Kornblith S.

CKAminibatch(X,Y)=k1iHSIC1(XiXi,YiYi)k1iHSIC1(XiXi,XiXi)k1iHSIC1(YiYi,YiYi)\text{CKA}_\text{minibatch}(X,Y) = { k^{-1}\sum_i \text{HSIC}_1(X_i X_i^\top, Y_iY_i^\top)\over \sqrt{k^{-1}\sum_i \text{HSIC}_1(X_iX_i^\top,X_iX_i^\top)}\sqrt{k^{-1}\sum_i \text{HSIC}_1(Y_iY_i^\top,Y_iY_i^\top)}}

In practice, we computed the mini-batch CKA [Section~3.1(2)](Ngueyen 2021) among features of trained networks.