14 January 2025

About that Trick in Fourier Neural Operators

Notes on some implementation details of Fourier Neural Operators.

  Fourier Neural Operators (FNOs)[1] have been quite successful to say the least with a variety of immediate practical applications from accurate high-resolution weather forecasting, Carbon Capture and Storage (CCS) simulations, industrial-scale automotive aerodynamics, and material deformation to name a few[2].

Here's a quick overview of FNOs - Recall that the general form for a Neural Operator as a composition of layers vv of the form[3]

vl+1(x):=σ(Wvl(x)+Xk(x,z)vl(z)dz) v^{l+1}(\vx) := \sigma \left (\vW v^l(\vx)+ \int_\fX k(\vx,\vz) v^l (\vz) d \vz \right )

Where σ\sigma is the usual pointwise nonlinearity, and kk is some kernel function. The motivation here is that, to learn an operator as a mapping between functions, we have to go beyond just pointwise transformations like Wv(x)\vW v(\vx). The integral part in the above equation defines a family of functional transforms called Kernel Integral Transformations, thus enabling the network to learn a more general form of the input-output functional mapping. The integral over the input domain XRd\fX \subset \R^d ensures that the network is not restricted to learning the fixed number of input function measurements or their locations.

We can identify that the integral term represents a convolution operation, provided the kernel kk is stationary (translation invariant) i.e k(x,z)=k(xz)k(\vx, \vz) = k(\vx - \vz). This immediately suggests the use of Fourier transforms to convert the convolution to matrix multiplication, which is computationally more efficient than computing the integral directly.

Xk(x,z)v(z)dz=Xk(xz)v(z)dz=F1(F(k)F(v))(x) \int_\fX k(\vx, \vz) v(\vz) dz = \int_\fX k(\vx-\vz) v(\vz) dz = \fF^{-1}(\fF(k) \cdot \fF(v))(\vx)

Therefore, instead of designing a network to learn the kernel k:XRdv×dvk: \fX \to \R^{d_v \x d_v}, we now directly learn its Fourier transform F(k)\fF(k), represented as RϕR_{\phi}.

vl+1(x):=σ(Wvl(x)+F1(RϕF(vl))(x)) v^{l+1} (\vx) := \sigma \left (\vW v^l(\vx)+ \fF^{-1}(R_{\phi} \cdot \fF(v^l))(\vx)\right )

Here, I would like to point out that the above equation has the form of a residual layer (you may have to squint hard). I find it funny that DNN researchers have this obsession about showing that their layer architecture looks like a residual layer. Anyhow.

Here, we note a couple of things - FNOs work on uniformly discretized domains X\fX; and the kernel kk is assumed to be periodic and bandlimited.

These assumptions (a total of 3 including the stationary assumption) are crucial for an efficient practical implementation of FNOs. In the actual implementation of FNOs, there is a neat little trick[5] for computing this Fourier multiplication. Refer to a typical implementation for 2D inputs shown below - RR is no where to be seen! and what are those weights1 and weights2?

# Reference:

import torch
import torch.nn as nn

class FourierLayer2D(nn.Module):
    def __init__(self,
                 d_in: int,
                 d_out: int,
                 modes1: int,
                 modes2: int):

        Fourier Layer for 2D inputs. Computes the following operation:

        v_{l+1} = F^{-1}( F(v_l) * R )

        where F is the Fourier Transform, R is the kernel's Fourier coefficients,

            d_in: int: Dimension of input
            d_out: int: Dimension of output
            modes1: int: Max number of modes in the x-direction
            modes2: int: Max number of modes in the y-direction

        super(FourierLayer2D, self).__init__()

        self.d_in   = d_in
        self.d_out  = d_out
        self.modes1 = modes1
        self.modes2 = modes2

        self.scale = (1 / (d_in * d_out))

        # Set the complex weights to parameterize the
        # Kernel's Fourier coefficients R.
        self.weights1 = nn.Parameter(self.scale *
                                               self.modes2, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale *
                                                self.modes2, dtype=torch.cfloat))

    def forward(self, v: torch.Tensor):
        batchsize, M, N = v.shape[0], v.shape[-2], v.shape[-1]

        # Compute Fourier transform of the
        # input tensor v. Note that this is the
        # output of the previous layer.
        v_ft = torch.fft.rfft2(v)

        # Truncate the Fourier modes and multiply with the
        # weight matrices
        out_ft = torch.zeros(batchsize,
                             N//2 + 1,
                             dtype=torch.cfloat, device=v.device)

        out_ft[:, :, :self.modes1, :self.modes2] = \
            self.compl_mul2d(v_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(v_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

        # Inverse Fourier Transform to get the layer outputs in
        # the input physical space.
        v = torch.fft.irfft2(out_ft, s=(v.size(-2), v.size(-1)))
        return v

    def compl_mul2d(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
        # (batch, d_in, x,y ), (d_in, d_out, x,y) -> (batch, d_out, x,y)
        return torch.einsum("bixy,ioxy->boxy", tensor1, tensor2)

Let's try to understand what's happening here. The 2D output tensor vRdv×M×N\vv \in \R^{d_v \x M \x N} of the previous layer is the input to the current layer and is Fourier transformed using rfft2. Since the input is real-valued, the corresponding Fourier transform is Hermitian and therefore, it is efficient to use rfft2 instead of fft2. Therefore, F(v)Cdv×M×(N/2+1)\fF(\vv) \in \C^{d_v \x M \x (N/2 + 1)} and correspondingly the result of the spectral operation (equation (2)) Cdv×M×(N/2+1)\in \C^{d_v \x M \x (N /2 + 1)} (note fewer parameters).

Truncating Fourier modes

Truncating Fourier modes using rfft2. The white boxes show the boundary for truncation. (right) The top half represents positive modes, while the bottom half represent negative modes.

Now, the trick is in understanding that during truncation, we have to consider both the positive and negative Fourier modes. Our actual range of truncation is [pmax,pmax][-p_{\text{max}}, p_{\text{max}}] . If the input domain is discretized into NN and MM points along respective dimensions, we truncate the Fourier modes as v_ft[:, :, :self.modes1, :self.modes2] (positive frequencies) and v_ft[:, :, -self.modes1:, :self.modes2] (negative modes). This can be verified based on the figure above. Therefore, in practice, instead of one RR matrix, we decompose it into two R1,R2Cpmax×qmax×dv×dvR_1, R_2 \in \C^{p_{\text{max}} \x q_{\text{max}} \x d_v \x d_v} matrices, each corresponding to the positive and negative Fourier modes respectively. This results in a more efficient implementation that is extensible to higher dimensions as well[6].

