-
Notifications
You must be signed in to change notification settings - Fork 47
Add Tensorization Example Applied to Battery Thermal Analysis #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Tensorized Fourier Neural Operator for 3D Battery Heat Analysis | ||
|
|
||
| This example builds off of the [Fourier Neural Operator for 3D Battery Heat Analysis](https://github.com/matlab-deep-learning/SciML-and-Physics-Informed-Machine-Learning-Examples/tree/main/battery-module-cooling-analysis-with-fourier-neural-operator) example to apply a Tensorized Fourier Neural Operator (TFNO) [1, 2] to heat analysis of a 3D battery module. The TFNO compresses the standard Fourier Neural Operator using tensorization, achieving 14.3x parameter reduction while maintaining accuracy. | ||
|
|
||
|  | ||
|  | ||
|
Comment on lines
+5
to
+6
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's worth adding alt-text descriptions for these. |
||
|
|
||
| ## Setup | ||
|
|
||
| Run the example by running [`tensorizedFourierNeuralOperatorForBatteryCoolingAnalysis.m`](./tensorizedFourierNeuralOperatorForBatteryCoolingAnalysis.m). | ||
|
|
||
| ## Requirements | ||
|
|
||
| Requires: | ||
| - [MATLAB](https://www.mathworks.com/products/matlab.html) (R2025a or newer) | ||
| - [Deep Learning Toolbox™](https://www.mathworks.com/products/deep-learning.html) | ||
| - [Partial Differential Equation Toolbox™](https://mathworks.com/products/pde.html) | ||
| - [Parallel Computing Toolbox™](https://mathworks.com/products/parallel-computing.html) (for training on a GPU) | ||
|
|
||
| ## References | ||
| [1] Li, Zongyi, et al. "Fourier Neural Operator for Parametric Partial Differential Equations." | ||
| In International Conference on Learning Representations (2021). https://arxiv.org/pdf/2010.08895 | ||
|
|
||
| [2] Kossaifi, Jean, et al. Kossaifi, Jean, et al. "Multi-Grid Tensorized Fourier Neural Operator for High-Resolution PDEs." | ||
| Transactions on Machine Learning Research (2024). https://arxiv.org/pdf/2310.00120 | ||
|
|
||
| ## Example Overview | ||
|
|
||
| This example applies a 3D Tensorized Fourier Neural Operator (TFNO) to thermal analysis of a battery module composed of 20 cells. Given initial conditions (ambient temperature, convection, heat generation) at T=0, the TFNO predicts temperature distribution at T=10 minutes. | ||
|
|
||
| ### Architecture Modifications | ||
|
|
||
| The TFNO includes two key modifications from the standard FNO: | ||
| 1. **Transformer-like architecture**: Adds layer normalization, MLPs, and linear skip connections | ||
| 2. **Tensorized spectral convolution**: Low-rank approximation of weight tensors | ||
|
|
||
| ### Key Hyperparameters | ||
|
|
||
| - **Input channels**: 3 (ambient temperature, convection, heat generation) | ||
| - **Output channels**: 1 (temperature) | ||
| - **Number of modes**: 4 (retained Fourier modes per dimension) | ||
| - **Hidden channels**: 64 | ||
| - **FNO blocks**: 4 | ||
| - **Compression rank**: 0.05 (5% of original parameters in spectral layers) | ||
| - **Grid resolution**: 32×32×32 | ||
|
|
||
| ### Performance | ||
|
|
||
| - **Inference speed**: 88ms per sample (batch size 1) on NVIDIA RTX 2080 Ti GPU and 230ms on Intel Xeon CPU (136x faster than FEM solver, 1.15x faster than the architecture from the prior [FNO example](https://github.com/matlab-deep-learning/SciML-and-Physics-Informed-Machine-Learning-Examples/tree/main/battery-module-cooling-analysis-with-fourier-neural-operator)) | ||
| - The speedup may be more pronounced on larger problem domains, higher dimensional problems, and/or when running inference on memory -constrained devices | ||
| - **Relative L2 error**: 0.009% error on test set | ||
| - **Training time**: 5.75 hours for 1000 epochs | ||
| - **Parameter reduction**: From 3,263,809 to 227,521 parameters for a 14.35x reduction | ||
| - **Memory savings**: 2.74MB compressed model vs 23.01MB dense model | ||
|
|
||
| ### Considerations | ||
| The example here is one instance of a TFNO applied to battery thermal analysis. It is likely that the TFNO may be further optimized with negligible accuracy loss by: | ||
| - Experimenting with higher compression ratios (e.g., 0.01-0.03) to achieve even greater parameter reduction | ||
| - Reducing the number of hidden channel dimensions | ||
| - Reducing the number of FNO blocks | ||
|
|
||
| --- | ||
| Copyright 2026 The MathWorks, Inc. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| function [geomModule, domainIDs, boundaryIDs, volume, boundaryArea, ReferencePoint] = createBatteryModuleGeometry(numCellsInModule, cellWidth,cellThickness,tabThickness,tabWidth,cellHeight,tabHeight, connectorHeight ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There might already be an open discussion or issue to handle shared helpers on this repo, so we don't have duplicate implementations of things like this. It's probably best to keep the example self contained as it is currently for now, since we haven't fixed on one solution. |
||
| %% Uses Boolean geometry functionality in PDE Toolbox, which requires release R2025a or later. | ||
| % If you have an older version, use the helper function in this example: | ||
| % https://www.mathworks.com/help/pde/ug/battery-module-cooling-analysis-and-reduced-order-thermal-model.html | ||
|
|
||
| % Copyright 2025 The MathWorks, Inc. | ||
|
|
||
| % First, create a single pouch cell by unioning the cell, tab and connector | ||
| % Cell creation | ||
| cell1 = fegeometry(multicuboid(cellThickness,cellWidth,cellHeight)); | ||
| cell1 = translate(cell1,[cellThickness/2,cellWidth/2,0]); | ||
| % Tab creation | ||
| tab = fegeometry(multicuboid(tabThickness,tabWidth,tabHeight)); | ||
| tabLeft = translate(tab,[cellThickness/2,tabWidth,cellHeight]); | ||
| tabRight = translate(tab,[cellThickness/2,cellWidth-tabWidth,cellHeight]); | ||
| % Union tabs to cells | ||
| geomPouch = union(cell1, tabLeft, KeepBoundaries=true); | ||
| geomPouch = union(geomPouch, tabRight, KeepBoundaries=true); | ||
| % Connector creation | ||
| overhang = (cellThickness-tabThickness)/2; | ||
| connector = fegeometry(multicuboid(tabThickness+overhang,tabWidth,connectorHeight)); | ||
| connectorRight = translate(connector,[cellThickness/2+overhang/2,tabWidth,cellHeight+tabHeight]); | ||
| connectorLeft = translate(connector,[(cellThickness/2-overhang/2),cellWidth-tabWidth,cellHeight+tabHeight]); | ||
| % Union connectors to tabs | ||
| geomPouch = union(geomPouch,connectorLeft,KeepBoundaries=true); | ||
| geomPouch = union(geomPouch,connectorRight,KeepBoundaries=true); | ||
| % Scale and translate completed pouch cell to create mirrored cell | ||
| geomPouchMirrored = translate(scale(geomPouch,[-1 1 1]),[cellThickness,0,0]); | ||
| % Union individual pouches to create full module | ||
| % Union even-numbered pouch cells together (original cells) | ||
| geomForward = fegeometry; | ||
| for i = 0:2:numCellsInModule-1 | ||
| offset = cellThickness*i; | ||
| geom_to_append = translate(geomPouch,[offset,0,0]); | ||
| geomForward = union(geomForward,geom_to_append); | ||
| end | ||
| % Union odd-numbered pouch cells together (mirrored cells) | ||
| geomBackward = fegeometry; | ||
| for i = 1:2:numCellsInModule-1 | ||
| offset = cellThickness*i; | ||
| geom_to_append = translate(geomPouchMirrored,[offset,0,0]); | ||
| geomBackward = union(geomBackward,geom_to_append); | ||
| end | ||
| % Union to create completed geometry module | ||
| geomModule = union(geomForward,geomBackward,KeepBoundaries=true); | ||
| % Rotate and translate the geometry | ||
| geomModule = translate(scale(geomModule,[1 -1 1]),[0 cellWidth 0]); | ||
| % Mesh the geometry to use query functions for identifying cells and faces | ||
| geomModule = generateMesh(geomModule,GeometricOrder="linear"); | ||
| % Create Reference Points for each geometry future | ||
| ReferencePoint.Cell = [cellThickness/2,cellWidth/2,cellHeight/2]; | ||
| ReferencePoint.TabLeft = [cellThickness/2,tabWidth,cellHeight+tabHeight/2]; | ||
| ReferencePoint.TabRight = [cellThickness/2,cellWidth-tabWidth,cellHeight+tabHeight/2]; | ||
| ReferencePoint.ConnectorLeft = [cellThickness/2,tabWidth,cellHeight+tabHeight+connectorHeight/2]; | ||
| ReferencePoint.ConnectorRight = [cellThickness/2,cellWidth-tabWidth,cellHeight+tabHeight+connectorHeight/2]; | ||
| % Helper function to get the cell IDs belonging to cell, tab and connector | ||
| [~,~,t] = meshToPet(geomModule.Mesh); | ||
| elementDomain = t(end,:); | ||
| tr = triangulation(geomModule.Mesh.Elements',geomModule.Mesh.Nodes'); | ||
| getCellID = @(point,cellNumber) elementDomain(pointLocation(tr,point+(cellNumber(:)-1)*[cellThickness,0,0])); | ||
| % Helper function to get the volume of the cells, tabs, and connectors | ||
| getVolumeOneCell = @(geomCellID) geomModule.Mesh.volume(findElements(geomModule.Mesh,"region",Cell=geomCellID)); | ||
| getVolume = @(geomCellIDs) arrayfun(@(n) getVolumeOneCell(n),geomCellIDs); | ||
| % Initialize cell ID and volume structs | ||
| domainIDs(1:numCellsInModule) = struct(Cell=[], ... | ||
| TabLeft=[],TabRight=[], ... | ||
| ConnectorLeft=[],ConnectorRight=[]); | ||
| volume(1:numCellsInModule) = struct(Cell=[], ... | ||
| TabLeft=[],TabRight=[], ... | ||
| ConnectorLeft=[],ConnectorRight=[]); | ||
| % Helper function to get the IDs belonging to the left, right, front, back, top and bottom faces | ||
| getFaceID = @(offsetVal,offsetDirection,cellNumber) nearestFace(geomModule,... | ||
| ReferencePoint.Cell + offsetVal/2 .*offsetDirection ... % offset ref. point to face | ||
| + cellThickness*(cellNumber(:)-1)*[1,0,0]); % offset to cell | ||
| % Initialize face ID and area structs | ||
| boundaryIDs(1:numCellsInModule) = struct(FrontFace=[],BackFace=[], ... | ||
| RightFace=[],LeftFace=[], ... | ||
| TopFace=[],BottomFace=[]); | ||
| boundaryArea(1:numCellsInModule) = struct(FrontFace=[],BackFace=[], ... | ||
| RightFace=[],LeftFace=[], ... | ||
| TopFace=[],BottomFace=[]); | ||
| % Loop over cell, left tab, right tab, left connector, and right connector to get cell IDs and volumes | ||
| for part = string(fieldnames(domainIDs))' | ||
| partid = num2cell(getCellID(ReferencePoint.(part),1:numCellsInModule)); | ||
| [domainIDs.(part)] = partid{:}; | ||
| volumesPart = num2cell(getVolume([partid{:}])); | ||
| [volume.(part)] = volumesPart{:}; | ||
| end | ||
| % Loop over front, back, right, left, top, and bottom faces IDs and areas | ||
| dimensions = [cellThickness;cellThickness;cellWidth;cellWidth;cellHeight;cellHeight]; | ||
| vectors = [-1,0,0;1,0,0;0,1,0;0,-1,0;0,0,1;0,0,-1]; | ||
| areaFormula = [cellHeight*cellWidth;cellHeight*cellWidth;cellThickness*cellHeight;cellThickness*cellHeight;cellThickness*cellWidth - tabThickness*tabWidth;cellThickness*cellWidth - tabThickness*tabWidth]; | ||
| i = 1; | ||
| for face = string(fieldnames(boundaryIDs))' | ||
| faceid = num2cell(getFaceID(dimensions(i),vectors(i,:),1:numCellsInModule)); | ||
| [boundaryIDs.(face)] = faceid{:}; | ||
| areasFace = num2cell(areaFormula(i)*ones(1,numCellsInModule)); | ||
| [boundaryArea.(face)] = areasFace{:}; | ||
| i = i+1; | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| function downloadSimulationData(url,destination) | ||
| % The downloadSimulationData function downloads pregenerated simulation | ||
| % data for the 3D battery heat analysis problem. | ||
|
|
||
| % Copyright 2026 The MathWorks, Inc. | ||
|
|
||
| if ~exist(destination,"dir") | ||
| mkdir(destination); | ||
| end | ||
|
|
||
| [~,name,filetype] = fileparts(url); | ||
| netFileFullPath = fullfile(destination,name+filetype); | ||
|
|
||
| % Check for the existence of the file and download the file if it does not | ||
| % exist | ||
| if ~exist(netFileFullPath,"file") | ||
| disp("Downloading simulation data."); | ||
| disp("This can take several minutes to download..."); | ||
| websave(netFileFullPath,url); | ||
|
|
||
| % If the file is a ZIP file, extract it | ||
| if filetype == ".zip" | ||
| unzip(netFileFullPath,destination) | ||
| end | ||
| disp("Done."); | ||
|
|
||
| end | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| function H1 = h1Norm(X, params) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How are these functions called, given they're in a sub-directory? You'd either have to change directory or Personally I prefer using a namespace |
||
| %H1NORM - Compute H1 norm on a grid. | ||
| % H1 = H1NORM(X) computes the H1 norm of the input array X | ||
| % with default parameters. | ||
| % | ||
| % H1 = H1NORM(X, Name=Value) specifies additional options using | ||
| % one or more name-value arguments: | ||
| % | ||
| % Spacings - 1xD vector of grid spacings [Δ1, Δ2, ..., ΔD]. | ||
| % The default value is ones(1,D). | ||
| % | ||
| % IncludeL2 - If true, computes full H1 norm (L2 + gradient). | ||
| % If false, computes seminorm only (gradient). | ||
| % The default value is true. | ||
| % | ||
| % Reduction - Method for reducing the norm across batch. | ||
| % Options are 'mean', 'sum', or 'none'. | ||
| % The default value is 'mean'. | ||
| % | ||
| % Periodic - 1xD logical array indicating which spatial | ||
| % dimensions are periodic. The default value | ||
| % is true for all dimensions. | ||
| % | ||
| % SquareRoot - If false, returns the squared H1 norm. | ||
| % If true, returns the H1 norm. The default | ||
| % value is false. | ||
| % | ||
| % Normalize - If true, divides output by C*prod(S1, S2, ...). | ||
| % The default value is false. | ||
| % | ||
| % The H1 norm is defined as: | ||
| % ||u||_{H^1} = (||u||_{L^2}^2 + ||∇u||_{L^2}^2)^{1/2} | ||
| % where ||∇u||_{L^2}^2 = Σ_i ||∂u/∂x_i||_{L^2}^2. | ||
| % | ||
| % Input X must be a numeric array of size [B, C, S1, S2, ..., SD] | ||
| % where B is batch size, C is number of channels, and S1...SD are | ||
| % spatial dimensions. | ||
|
Comment on lines
+35
to
+37
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why BC(S..S)? That seems more like PyTorch's layout, whereas |
||
| % | ||
| % Gradients are estimated using central differences and one-sided | ||
| % differences at boundaries (unless periodic boundary conditions). | ||
| % | ||
| % Example: | ||
| % B=2; C=1; S1=64; S2=64; | ||
| % X = randn(B,C,S1,S2); | ||
| % H1 = h1Norm(X); | ||
| % | ||
| % Copyright 2026 The MathWorks, Inc. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We tend to separate the copyright from the m-help so it isn't displayed in |
||
|
|
||
| arguments | ||
| X dlarray {mustBeNumeric} | ||
| params.Spacings (1,:) double = [] | ||
| params.IncludeL2 (1,1) logical = true | ||
| params.Reduction (1,1) string {mustBeMember(params.Reduction, {'mean', 'sum', 'none'})} = "mean" | ||
| params.Periodic (1,:) logical = true | ||
| params.SquareRoot (1,1) logical = false | ||
| params.Normalize (1,1) logical = false | ||
| end | ||
|
|
||
| sz = size(X); | ||
| nd = ndims(X); | ||
| if nd < 3 | ||
| error('Input must be at least [B, C, S1].'); | ||
| end | ||
| B = sz(1); | ||
| C = sz(2); | ||
| spatialSizes = sz(3:end); | ||
| D = numel(spatialSizes); | ||
|
|
||
| if isempty(params.Spacings) | ||
| params.Spacings = ones(1, D); | ||
| else | ||
| if numel(params.Spacings) ~= D | ||
| error('params.Spacings must have length equal to the number of spatial dimensions (D).'); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'd probably wouldn't include |
||
| end | ||
| end | ||
|
|
||
| if isscalar(params.Periodic) | ||
| params.Periodic = repmat(params.Periodic, 1, D); | ||
| elseif numel(params.Periodic) ~= D | ||
| error('params.Periodic must be scalar or 1xD logical.'); | ||
| end | ||
|
|
||
| % Initialize H1 as the L2 error, | ||
| if params.IncludeL2 | ||
| H1 = l2Norm(X, Reduction="none", SquareRoot=false, Normalize=false); | ||
| else | ||
| H1 = zeros(B, 1, 'like', X); | ||
| end | ||
|
|
||
| % Reshape to [B*C, S1, S2, ... Sn] so that all batch, channel | ||
| % combinations are handled independently. | ||
| X = reshape(X, [B*C spatialSizes]); | ||
|
|
||
| % Add the H1 seminorm using forward differences. | ||
| for d = 1:D | ||
| delta = params.Spacings(d); | ||
|
|
||
| dm = 1 + d; % Dimension index of this spatial axis in reshaped X. | ||
|
|
||
| % Central difference with wrap. | ||
| fd = (circshift(X, -1, dm) - circshift(X, 1, dm)) / (2 * delta); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a warning that I expect you need
|
||
|
|
||
| if ~params.Periodic(d) | ||
| % Replace first/last elements with forward/reverse differences. | ||
|
|
||
| if min(spatialSizes) < 4 | ||
| error("Non-periodic dimensions require at least 4 grid points for 3rd-order differences."); | ||
| end | ||
|
|
||
| fd = applyThirdOrderDifferenceAtBoundary(fd, X, dm, delta); | ||
| end | ||
|
|
||
| fd = fd.^2; | ||
|
|
||
| % Reshape back to original size. | ||
| fd = reshape(fd, sz); | ||
|
|
||
| % Sum over channels and spatial dimensions, giving size of [B, 1]. | ||
| fd = sum(fd, 2:nd); | ||
|
|
||
| % Accumulate per-batch sum. | ||
| H1 = H1 + fd; | ||
| end | ||
|
|
||
| if params.SquareRoot | ||
| H1 = sqrt(H1); | ||
| end | ||
|
|
||
| if params.Normalize | ||
| % Normalize by channels and number of spatial points | ||
| H1 = H1 / (C * prod(spatialSizes)); | ||
| end | ||
|
|
||
| if strcmp(params.Reduction, "mean") | ||
| H1 = mean(H1, 1); | ||
| elseif strcmp(params.Reduction, "sum") | ||
| H1 = sum(H1, 1); | ||
| end | ||
| end | ||
|
|
||
| function fd = applyThirdOrderDifferenceAtBoundary(fd, X, d, delta) | ||
|
|
||
| % Get the indices of components for 3rd-order forward differences. | ||
| idx1 = makeIndex(ndims(fd), d, 1); | ||
| idx2 = makeIndex(ndims(fd), d, 2); | ||
| idx3 = makeIndex(ndims(fd), d, 3); | ||
| idx4 = makeIndex(ndims(fd), d, 4); | ||
|
|
||
| % Apply 3rd-order forward differences at left boundary. | ||
| fd(idx1{:})= (-11*X(idx1{:}) + 18*X(idx2{:}) - 9*X(idx3{:}) + 2*X(idx4{:})) / (6 * delta); | ||
|
|
||
| % Get the indices of components for 3rd-order backward differences. | ||
| sz = size(fd, d); | ||
| idx1 = makeIndex(ndims(fd), d, sz); | ||
| idx2 = makeIndex(ndims(fd), d, sz-1); | ||
| idx3 = makeIndex(ndims(fd), d, sz-2); | ||
| idx4 = makeIndex(ndims(fd), d, sz-3); | ||
|
|
||
| % Apply 3rd-order backward differences at right boundary | ||
| fd(idx1{:}) = (-2*X(idx4{:}) + 9*X(idx3{:}) - 18*X(idx2{:}) + 11*X(idx1{:})) / (6 * delta); | ||
| end | ||
|
|
||
| function idx = makeIndex(ndims, toChange, val) | ||
| idx = repmat({':'}, 1, ndims); | ||
| idx{toChange} = val; | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be able to link the example with a relative path in the repo.