Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Tensorized Fourier Neural Operator for 3D Battery Heat Analysis

This example builds off of the [Fourier Neural Operator for 3D Battery Heat Analysis](https://github.com/matlab-deep-learning/SciML-and-Physics-Informed-Machine-Learning-Examples/tree/main/battery-module-cooling-analysis-with-fourier-neural-operator) example to apply a Tensorized Fourier Neural Operator (TFNO) [1, 2] to heat analysis of a 3D battery module. The TFNO compresses the standard Fourier Neural Operator using tensorization, achieving 14.3x parameter reduction while maintaining accuracy.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be able to link the example with a relative path in the repo.


![](./images/prediction_vs_gt.png)
![](./images/absolute_error.png)
Comment on lines +5 to +6
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth adding alt-text descriptions for these.


## Setup

Run the example by running [`tensorizedFourierNeuralOperatorForBatteryCoolingAnalysis.m`](./tensorizedFourierNeuralOperatorForBatteryCoolingAnalysis.m).

## Requirements

Requires:
- [MATLAB](https://www.mathworks.com/products/matlab.html) (R2025a or newer)
- [Deep Learning Toolbox™](https://www.mathworks.com/products/deep-learning.html)
- [Partial Differential Equation Toolbox™](https://mathworks.com/products/pde.html)
- [Parallel Computing Toolbox™](https://mathworks.com/products/parallel-computing.html) (for training on a GPU)

## References
[1] Li, Zongyi, et al. "Fourier Neural Operator for Parametric Partial Differential Equations."
In International Conference on Learning Representations (2021). https://arxiv.org/pdf/2010.08895

[2] Kossaifi, Jean, et al. Kossaifi, Jean, et al. "Multi-Grid Tensorized Fourier Neural Operator for High-Resolution PDEs."
Transactions on Machine Learning Research (2024). https://arxiv.org/pdf/2310.00120

## Example Overview

This example applies a 3D Tensorized Fourier Neural Operator (TFNO) to thermal analysis of a battery module composed of 20 cells. Given initial conditions (ambient temperature, convection, heat generation) at T=0, the TFNO predicts temperature distribution at T=10 minutes.

### Architecture Modifications

The TFNO includes two key modifications from the standard FNO:
1. **Transformer-like architecture**: Adds layer normalization, MLPs, and linear skip connections
2. **Tensorized spectral convolution**: Low-rank approximation of weight tensors

### Key Hyperparameters

- **Input channels**: 3 (ambient temperature, convection, heat generation)
- **Output channels**: 1 (temperature)
- **Number of modes**: 4 (retained Fourier modes per dimension)
- **Hidden channels**: 64
- **FNO blocks**: 4
- **Compression rank**: 0.05 (5% of original parameters in spectral layers)
- **Grid resolution**: 32×32×32

### Performance

- **Inference speed**: 88ms per sample (batch size 1) on NVIDIA RTX 2080 Ti GPU and 230ms on Intel Xeon CPU (136x faster than FEM solver, 1.15x faster than the architecture from the prior [FNO example](https://github.com/matlab-deep-learning/SciML-and-Physics-Informed-Machine-Learning-Examples/tree/main/battery-module-cooling-analysis-with-fourier-neural-operator))
- The speedup may be more pronounced on larger problem domains, higher dimensional problems, and/or when running inference on memory -constrained devices
- **Relative L2 error**: 0.009% error on test set
- **Training time**: 5.75 hours for 1000 epochs
- **Parameter reduction**: From 3,263,809 to 227,521 parameters for a 14.35x reduction
- **Memory savings**: 2.74MB compressed model vs 23.01MB dense model

### Considerations
The example here is one instance of a TFNO applied to battery thermal analysis. It is likely that the TFNO may be further optimized with negligible accuracy loss by:
- Experimenting with higher compression ratios (e.g., 0.01-0.03) to achieve even greater parameter reduction
- Reducing the number of hidden channel dimensions
- Reducing the number of FNO blocks

---
Copyright 2026 The MathWorks, Inc.
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
function [geomModule, domainIDs, boundaryIDs, volume, boundaryArea, ReferencePoint] = createBatteryModuleGeometry(numCellsInModule, cellWidth,cellThickness,tabThickness,tabWidth,cellHeight,tabHeight, connectorHeight )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might already be an open discussion or issue to handle shared helpers on this repo, so we don't have duplicate implementations of things like this. It's probably best to keep the example self contained as it is currently for now, since we haven't fixed on one solution.

%% Uses Boolean geometry functionality in PDE Toolbox, which requires release R2025a or later.
% If you have an older version, use the helper function in this example:
% https://www.mathworks.com/help/pde/ug/battery-module-cooling-analysis-and-reduced-order-thermal-model.html

% Copyright 2025 The MathWorks, Inc.

% First, create a single pouch cell by unioning the cell, tab and connector
% Cell creation
cell1 = fegeometry(multicuboid(cellThickness,cellWidth,cellHeight));
cell1 = translate(cell1,[cellThickness/2,cellWidth/2,0]);
% Tab creation
tab = fegeometry(multicuboid(tabThickness,tabWidth,tabHeight));
tabLeft = translate(tab,[cellThickness/2,tabWidth,cellHeight]);
tabRight = translate(tab,[cellThickness/2,cellWidth-tabWidth,cellHeight]);
% Union tabs to cells
geomPouch = union(cell1, tabLeft, KeepBoundaries=true);
geomPouch = union(geomPouch, tabRight, KeepBoundaries=true);
% Connector creation
overhang = (cellThickness-tabThickness)/2;
connector = fegeometry(multicuboid(tabThickness+overhang,tabWidth,connectorHeight));
connectorRight = translate(connector,[cellThickness/2+overhang/2,tabWidth,cellHeight+tabHeight]);
connectorLeft = translate(connector,[(cellThickness/2-overhang/2),cellWidth-tabWidth,cellHeight+tabHeight]);
% Union connectors to tabs
geomPouch = union(geomPouch,connectorLeft,KeepBoundaries=true);
geomPouch = union(geomPouch,connectorRight,KeepBoundaries=true);
% Scale and translate completed pouch cell to create mirrored cell
geomPouchMirrored = translate(scale(geomPouch,[-1 1 1]),[cellThickness,0,0]);
% Union individual pouches to create full module
% Union even-numbered pouch cells together (original cells)
geomForward = fegeometry;
for i = 0:2:numCellsInModule-1
offset = cellThickness*i;
geom_to_append = translate(geomPouch,[offset,0,0]);
geomForward = union(geomForward,geom_to_append);
end
% Union odd-numbered pouch cells together (mirrored cells)
geomBackward = fegeometry;
for i = 1:2:numCellsInModule-1
offset = cellThickness*i;
geom_to_append = translate(geomPouchMirrored,[offset,0,0]);
geomBackward = union(geomBackward,geom_to_append);
end
% Union to create completed geometry module
geomModule = union(geomForward,geomBackward,KeepBoundaries=true);
% Rotate and translate the geometry
geomModule = translate(scale(geomModule,[1 -1 1]),[0 cellWidth 0]);
% Mesh the geometry to use query functions for identifying cells and faces
geomModule = generateMesh(geomModule,GeometricOrder="linear");
% Create Reference Points for each geometry future
ReferencePoint.Cell = [cellThickness/2,cellWidth/2,cellHeight/2];
ReferencePoint.TabLeft = [cellThickness/2,tabWidth,cellHeight+tabHeight/2];
ReferencePoint.TabRight = [cellThickness/2,cellWidth-tabWidth,cellHeight+tabHeight/2];
ReferencePoint.ConnectorLeft = [cellThickness/2,tabWidth,cellHeight+tabHeight+connectorHeight/2];
ReferencePoint.ConnectorRight = [cellThickness/2,cellWidth-tabWidth,cellHeight+tabHeight+connectorHeight/2];
% Helper function to get the cell IDs belonging to cell, tab and connector
[~,~,t] = meshToPet(geomModule.Mesh);
elementDomain = t(end,:);
tr = triangulation(geomModule.Mesh.Elements',geomModule.Mesh.Nodes');
getCellID = @(point,cellNumber) elementDomain(pointLocation(tr,point+(cellNumber(:)-1)*[cellThickness,0,0]));
% Helper function to get the volume of the cells, tabs, and connectors
getVolumeOneCell = @(geomCellID) geomModule.Mesh.volume(findElements(geomModule.Mesh,"region",Cell=geomCellID));
getVolume = @(geomCellIDs) arrayfun(@(n) getVolumeOneCell(n),geomCellIDs);
% Initialize cell ID and volume structs
domainIDs(1:numCellsInModule) = struct(Cell=[], ...
TabLeft=[],TabRight=[], ...
ConnectorLeft=[],ConnectorRight=[]);
volume(1:numCellsInModule) = struct(Cell=[], ...
TabLeft=[],TabRight=[], ...
ConnectorLeft=[],ConnectorRight=[]);
% Helper function to get the IDs belonging to the left, right, front, back, top and bottom faces
getFaceID = @(offsetVal,offsetDirection,cellNumber) nearestFace(geomModule,...
ReferencePoint.Cell + offsetVal/2 .*offsetDirection ... % offset ref. point to face
+ cellThickness*(cellNumber(:)-1)*[1,0,0]); % offset to cell
% Initialize face ID and area structs
boundaryIDs(1:numCellsInModule) = struct(FrontFace=[],BackFace=[], ...
RightFace=[],LeftFace=[], ...
TopFace=[],BottomFace=[]);
boundaryArea(1:numCellsInModule) = struct(FrontFace=[],BackFace=[], ...
RightFace=[],LeftFace=[], ...
TopFace=[],BottomFace=[]);
% Loop over cell, left tab, right tab, left connector, and right connector to get cell IDs and volumes
for part = string(fieldnames(domainIDs))'
partid = num2cell(getCellID(ReferencePoint.(part),1:numCellsInModule));
[domainIDs.(part)] = partid{:};
volumesPart = num2cell(getVolume([partid{:}]));
[volume.(part)] = volumesPart{:};
end
% Loop over front, back, right, left, top, and bottom faces IDs and areas
dimensions = [cellThickness;cellThickness;cellWidth;cellWidth;cellHeight;cellHeight];
vectors = [-1,0,0;1,0,0;0,1,0;0,-1,0;0,0,1;0,0,-1];
areaFormula = [cellHeight*cellWidth;cellHeight*cellWidth;cellThickness*cellHeight;cellThickness*cellHeight;cellThickness*cellWidth - tabThickness*tabWidth;cellThickness*cellWidth - tabThickness*tabWidth];
i = 1;
for face = string(fieldnames(boundaryIDs))'
faceid = num2cell(getFaceID(dimensions(i),vectors(i,:),1:numCellsInModule));
[boundaryIDs.(face)] = faceid{:};
areasFace = num2cell(areaFormula(i)*ones(1,numCellsInModule));
[boundaryArea.(face)] = areasFace{:};
i = i+1;
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
function downloadSimulationData(url,destination)
% The downloadSimulationData function downloads pregenerated simulation
% data for the 3D battery heat analysis problem.

% Copyright 2026 The MathWorks, Inc.

if ~exist(destination,"dir")
mkdir(destination);
end

[~,name,filetype] = fileparts(url);
netFileFullPath = fullfile(destination,name+filetype);

% Check for the existence of the file and download the file if it does not
% exist
if ~exist(netFileFullPath,"file")
disp("Downloading simulation data.");
disp("This can take several minutes to download...");
websave(netFileFullPath,url);

% If the file is a ZIP file, extract it
if filetype == ".zip"
unzip(netFileFullPath,destination)
end
disp("Done.");

end
end
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
function H1 = h1Norm(X, params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these functions called, given they're in a sub-directory? You'd either have to change directory or addpath right?

Personally I prefer using a namespace +lossFunctions so you can call everything like lossFunctions.h1Norm from the base directory of this example. Maybe something more standard is to just put everything in the base directory of the example - that might be more or less what doc examples do when you use the openExample command.

%H1NORM - Compute H1 norm on a grid.
% H1 = H1NORM(X) computes the H1 norm of the input array X
% with default parameters.
%
% H1 = H1NORM(X, Name=Value) specifies additional options using
% one or more name-value arguments:
%
% Spacings - 1xD vector of grid spacings [Δ1, Δ2, ..., ΔD].
% The default value is ones(1,D).
%
% IncludeL2 - If true, computes full H1 norm (L2 + gradient).
% If false, computes seminorm only (gradient).
% The default value is true.
%
% Reduction - Method for reducing the norm across batch.
% Options are 'mean', 'sum', or 'none'.
% The default value is 'mean'.
%
% Periodic - 1xD logical array indicating which spatial
% dimensions are periodic. The default value
% is true for all dimensions.
%
% SquareRoot - If false, returns the squared H1 norm.
% If true, returns the H1 norm. The default
% value is false.
%
% Normalize - If true, divides output by C*prod(S1, S2, ...).
% The default value is false.
%
% The H1 norm is defined as:
% ||u||_{H^1} = (||u||_{L^2}^2 + ||∇u||_{L^2}^2)^{1/2}
% where ||∇u||_{L^2}^2 = Σ_i ||∂u/∂x_i||_{L^2}^2.
%
% Input X must be a numeric array of size [B, C, S1, S2, ..., SD]
% where B is batch size, C is number of channels, and S1...SD are
% spatial dimensions.
Comment on lines +35 to +37
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why BC(S..S)? That seems more like PyTorch's layout, whereas dlarray default orders to "SSCB" when using labels.

%
% Gradients are estimated using central differences and one-sided
% differences at boundaries (unless periodic boundary conditions).
%
% Example:
% B=2; C=1; S1=64; S2=64;
% X = randn(B,C,S1,S2);
% H1 = h1Norm(X);
%
% Copyright 2026 The MathWorks, Inc.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We tend to separate the copyright from the m-help so it isn't displayed in help(h1Norm).


arguments
X dlarray {mustBeNumeric}
params.Spacings (1,:) double = []
params.IncludeL2 (1,1) logical = true
params.Reduction (1,1) string {mustBeMember(params.Reduction, {'mean', 'sum', 'none'})} = "mean"
params.Periodic (1,:) logical = true
params.SquareRoot (1,1) logical = false
params.Normalize (1,1) logical = false
end

sz = size(X);
nd = ndims(X);
if nd < 3
error('Input must be at least [B, C, S1].');
end
B = sz(1);
C = sz(2);
spatialSizes = sz(3:end);
D = numel(spatialSizes);

if isempty(params.Spacings)
params.Spacings = ones(1, D);
else
if numel(params.Spacings) ~= D
error('params.Spacings must have length equal to the number of spatial dimensions (D).');
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd probably wouldn't include params here - it's a variable name in the implementation, not something the user is aware of without looking into the implementation.

end
end

if isscalar(params.Periodic)
params.Periodic = repmat(params.Periodic, 1, D);
elseif numel(params.Periodic) ~= D
error('params.Periodic must be scalar or 1xD logical.');
end

% Initialize H1 as the L2 error,
if params.IncludeL2
H1 = l2Norm(X, Reduction="none", SquareRoot=false, Normalize=false);
else
H1 = zeros(B, 1, 'like', X);
end

% Reshape to [B*C, S1, S2, ... Sn] so that all batch, channel
% combinations are handled independently.
X = reshape(X, [B*C spatialSizes]);

% Add the H1 seminorm using forward differences.
for d = 1:D
delta = params.Spacings(d);

dm = 1 + d; % Dimension index of this spatial axis in reshaped X.

% Central difference with wrap.
fd = (circshift(X, -1, dm) - circshift(X, 1, dm)) / (2 * delta);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a warning that circshift isn't a dlarray method, so the way it supports dlarray functionality like dlgradient and dlaccelerate is that we trace the dlarray-s through the circshift implementation - if that implementation happens to use only dlarray compatible methods and patterns, things should work out.

I expect you need dlgradient and dlaccelerate would be beneficial for a loss function. A couple reasons to be cautious with stuff that's not explicitly a dlarray method, but work through this "tracing" approach:

  1. There are many codepaths underlying circshift and other functions - you'd need to verify that all of those are dlarray compatible code, or ensure that you only ever go down codepaths that are.

  2. Since circshift isn't a dlarray method, there's no reason it couldn't be replaced in a future release by a C/C++ built-in in future which would not support dlgradient or dlaccelerate - I wouldn't expect us to have internal tests that would catch this because circshift isn't a dlarray method and we can't reasonably say that every function in MATLAB that supports dlarray through tracing should always support it in future.


if ~params.Periodic(d)
% Replace first/last elements with forward/reverse differences.

if min(spatialSizes) < 4
error("Non-periodic dimensions require at least 4 grid points for 3rd-order differences.");
end

fd = applyThirdOrderDifferenceAtBoundary(fd, X, dm, delta);
end

fd = fd.^2;

% Reshape back to original size.
fd = reshape(fd, sz);

% Sum over channels and spatial dimensions, giving size of [B, 1].
fd = sum(fd, 2:nd);

% Accumulate per-batch sum.
H1 = H1 + fd;
end

if params.SquareRoot
H1 = sqrt(H1);
end

if params.Normalize
% Normalize by channels and number of spatial points
H1 = H1 / (C * prod(spatialSizes));
end

if strcmp(params.Reduction, "mean")
H1 = mean(H1, 1);
elseif strcmp(params.Reduction, "sum")
H1 = sum(H1, 1);
end
end

function fd = applyThirdOrderDifferenceAtBoundary(fd, X, d, delta)

% Get the indices of components for 3rd-order forward differences.
idx1 = makeIndex(ndims(fd), d, 1);
idx2 = makeIndex(ndims(fd), d, 2);
idx3 = makeIndex(ndims(fd), d, 3);
idx4 = makeIndex(ndims(fd), d, 4);

% Apply 3rd-order forward differences at left boundary.
fd(idx1{:})= (-11*X(idx1{:}) + 18*X(idx2{:}) - 9*X(idx3{:}) + 2*X(idx4{:})) / (6 * delta);

% Get the indices of components for 3rd-order backward differences.
sz = size(fd, d);
idx1 = makeIndex(ndims(fd), d, sz);
idx2 = makeIndex(ndims(fd), d, sz-1);
idx3 = makeIndex(ndims(fd), d, sz-2);
idx4 = makeIndex(ndims(fd), d, sz-3);

% Apply 3rd-order backward differences at right boundary
fd(idx1{:}) = (-2*X(idx4{:}) + 9*X(idx3{:}) - 18*X(idx2{:}) + 11*X(idx1{:})) / (6 * delta);
end

function idx = makeIndex(ndims, toChange, val)
idx = repmat({':'}, 1, ndims);
idx{toChange} = val;
end
Loading