Skip to content

Add Tensorization Example Applied to Battery Thermal Analysis#19

Open
jonahweiss wants to merge 1 commit intomatlab-deep-learning:mainfrom
jonahweiss:feature/tfno_example
Open

Add Tensorization Example Applied to Battery Thermal Analysis#19
jonahweiss wants to merge 1 commit intomatlab-deep-learning:mainfrom
jonahweiss:feature/tfno_example

Conversation

@jonahweiss
Copy link

The example is a live M script: tensorizedFourierNeuralOperatorForBatteryCoolingAnalysis.m demonstrating the application of the paper Multi-Grid Tensorized Fourier Neural Operator for High-Resolution PDEs to the Battery Heat Analysis example.

Once the support files containing pregenerated simulation data are live, the URL variable pregeneratedSimulationDataURL in the example will need to be set, and then the function downloadSimulationData.m may download and unzip the data from the given URL.

The tfno/ folder includes the implementation of the TFNO 3D model.

The lossFunctions/ folder includes the implementation of the relative H1 loss.

The trainingPartitions.m and createBatteryModuleGeometry.m functions are taken from the existing Battery Heat Analysis example.

@@ -0,0 +1,166 @@
function H1 = h1Norm(X, params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these functions called, given they're in a sub-directory? You'd either have to change directory or addpath right?

Personally I prefer using a namespace +lossFunctions so you can call everything like lossFunctions.h1Norm from the base directory of this example. Maybe something more standard is to just put everything in the base directory of the example - that might be more or less what doc examples do when you use the openExample command.

% X = randn(B,C,S1,S2);
% H1 = h1Norm(X);
%
% Copyright 2026 The MathWorks, Inc.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We tend to separate the copyright from the m-help so it isn't displayed in help(h1Norm).

Comment on lines +35 to +37
% Input X must be a numeric array of size [B, C, S1, S2, ..., SD]
% where B is batch size, C is number of channels, and S1...SD are
% spatial dimensions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why BC(S..S)? That seems more like PyTorch's layout, whereas dlarray default orders to "SSCB" when using labels.

params.Spacings = ones(1, D);
else
if numel(params.Spacings) ~= D
error('params.Spacings must have length equal to the number of spatial dimensions (D).');
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd probably wouldn't include params here - it's a variable name in the implementation, not something the user is aware of without looking into the implementation.

dm = 1 + d; % Dimension index of this spatial axis in reshaped X.

% Central difference with wrap.
fd = (circshift(X, -1, dm) - circshift(X, 1, dm)) / (2 * delta);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a warning that circshift isn't a dlarray method, so the way it supports dlarray functionality like dlgradient and dlaccelerate is that we trace the dlarray-s through the circshift implementation - if that implementation happens to use only dlarray compatible methods and patterns, things should work out.

I expect you need dlgradient and dlaccelerate would be beneficial for a loss function. A couple reasons to be cautious with stuff that's not explicitly a dlarray method, but work through this "tracing" approach:

  1. There are many codepaths underlying circshift and other functions - you'd need to verify that all of those are dlarray compatible code, or ensure that you only ever go down codepaths that are.

  2. Since circshift isn't a dlarray method, there's no reason it couldn't be replaced in a future release by a C/C++ built-in in future which would not support dlgradient or dlaccelerate - I wouldn't expect us to have internal tests that would catch this because circshift isn't a dlarray method and we can't reasonably say that every function in MATLAB that supports dlarray through tracing should always support it in future.

Dim = finddim(X, dim);
permuteOrder = [Dim setdiff(1:ndims(X), Dim, 'stable')];
X = permute(stripdims(X), permuteOrder);
X = dlarray(X, fmt);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it matter if the format still makes sense here - e.g. x = dlarray(rand(5,4),"CB"); y = permuteDimFirst(x,"B") will re-label x-s batch dim as y-s channel dim.

I think if you need the dimensions in a particular layout, it's probably best to just work without format labels for as long as that's needed, since the dlarray label auto-permutes are always going to fight back against non-default layouts. If you still need dlarray methods when you don't have format labels, most methods that require labelled data should also have something like a DataFormat name-value pair.

SquareRoot=params.SquareRoot, ...
Periodic=params.Periodic);

loss = num./(den + eps);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We sometimes make this eps settable, e.g. layernorm has an Epsilon name-value pair - I suppose because eps can still be very small and num./(den+eps) is only bounded above by num*2.3e16 or something since eps is about 2.2e16.

@@ -0,0 +1,73 @@
classdef depthwiseConv3dLayer < nnet.layer.Layer & ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be convolution3dLayer(1,numChannels) and convolution3dLayer(1,numChannels,BiasLearnRateFactor=0) when UseBias==false?

assertValidNumConvolutionDimensions(3, hasTimeDimension, numSpatialDimensions);

% Check the input data has a channel dimension
assertInputHasChannelDim(1, cdim);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assertInputHasChannelDim(3,cdim)


% Same initialization as convolution2Dlayer, from
% /matlab/toolbox/nnet/cnn/+nnet/+internal/+cnn/+layer/+learnable/+initializer/Normal.m
layer.Weight = dlarray(randn(weightSize), layout.Format) * 0.01;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the "narrow-normal" weight initializer, but the default for conv layers is Glorot/Xavier initialization which uses uniform random + a scale factor.

% /matlab/toolbox/nnet/cnn/+nnet/+internal/+cnn/+layer/+learnable/+initializer/Normal.m
layer.Weight = dlarray(randn(weightSize), layout.Format) * 0.01;
if layer.UseBias
layer.Bias = dlarray(zeros(weightSize), layout.Format);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most built-in layers initialize weights as single since most dlnetwork stuff is happening in single by default

layerNormalizationLayer(Name="ln1"), ...
additionLayer(2, Name="add1"), ...
geluLayer(Name="gelu1"), ...
convolution3dLayer(1, latentChannelSize * args.MLPExpansion, Name="channelMLP1"), ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use something like ceil(latentChannelSize * args.MLPExpansion) or floor or round here.

net = connectLayers(net, "channelSkip", "add2/in2");
else
net = connectLayers(net, "in", "add2/in2");
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could LinearFNOSkip and ChannelMLPSkip be merged into a SkipConnectionMode = ["identity","linear"]? That would miss the option of using "linear" for just one of the skips, but I don't expect that's common.

@@ -0,0 +1,30 @@
function [pos,neg] = iPositiveAndNegativeFrequencies(N)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The i prefix convention is for internal functions, i.e. internal inside another function/class file.

@@ -0,0 +1,57 @@
classdef spatialEmbeddingLayer3D < nnet.layer.Layer & ...
nnet.layer.Formattable & nnet.layer.Acceleratable %#codegen
%SPATIALEMBEDDINGLAYER3D - 3D spatial embedding layer.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've seen this called grid embedding elsewhere - I'd probably include that phrase somewhere.

function layer = spatialEmbeddingLayer3D(spatialLimits, args)
arguments
spatialLimits (3, 2) double
args.Name (1, 1) string = "depthwiseConv"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default name should be something else.

S3 = linspace(layer.SpatialLimits(3, 1), ...
layer.SpatialLimits(3, 2), sSize(3));

[embedding1, embedding2, embedding3] = meshgrid(S1, S2, S3);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all this should be done once at data preprocessing/feature engineering rather than on every iteration.

You could also compute the embedding once at initialize time and store it as a property to be returned from predict. The repmat over batch dimension would have to happen in predict though.

% creates a spectral convolution 3d layer. outChannels
% specifies the number of channels in the layer output.
% numModes specifies the number of modes which are combined
% in Fourier space for each of the 2 spatial dimensions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 spatial dimensions

@@ -0,0 +1,63 @@
# Tensorized Fourier Neural Operator for 3D Battery Heat Analysis

This example builds off of the [Fourier Neural Operator for 3D Battery Heat Analysis](https://github.com/matlab-deep-learning/SciML-and-Physics-Informed-Machine-Learning-Examples/tree/main/battery-module-cooling-analysis-with-fourier-neural-operator) example to apply a Tensorized Fourier Neural Operator (TFNO) [1, 2] to heat analysis of a 3D battery module. The TFNO compresses the standard Fourier Neural Operator using tensorization, achieving 14.3x parameter reduction while maintaining accuracy.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be able to link the example with a relative path in the repo.

Comment on lines +5 to +6
![](./images/prediction_vs_gt.png)
![](./images/absolute_error.png)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth adding alt-text descriptions for these.

@@ -0,0 +1,100 @@
function [geomModule, domainIDs, boundaryIDs, volume, boundaryArea, ReferencePoint] = createBatteryModuleGeometry(numCellsInModule, cellWidth,cellThickness,tabThickness,tabWidth,cellHeight,tabHeight, connectorHeight )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might already be an open discussion or issue to handle shared helpers on this repo, so we don't have duplicate implementations of things like this. It's probably best to keep the example self contained as it is currently for now, since we haven't fixed on one solution.

Comment on lines +53 to +54
convolution3dLayer(1, liftingChannels, Name="lifting1"), ...
convolution3dLayer(1, latentChannelSize, Name="lifting2")];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a nonlinearity between these 2 - it's usually a bit odd to have 2 consecutive linear layers (since you could just product the two weight matrices together and make it 1 linear layer) and I don't see "lifting1" connected to anything else that would require it to be split up like this.

Xout = zeros([N1,N2,N3,this.OutputSize,size(X,5)],like=X);
Xout(xFreq,yFreq,zFreq,:,:) = X;

% Make Xout conjugate symmetric.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add a bit more detail to this comment and say it's so the ifft output is real valued, and reference the Algorithms section of the ifftn doc page.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants