Skip to content

eikehmueller/JaxMaterials

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

204 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Automated testing

JaxMaterials

High-performance JAX/CUDA-based library for differentiable materials modelling, designed to integrate physics-based models into modern machine learning workflows.

Enables gradient-based sensitivity studies and optimisation of ML surrogate models for material simulations, by combining efficient computation with scalable ML infrastructure which can run on CPUs and GPUs.

Goals

Materials with fine microstructure, such as carbon fibre composites, are expensive to simulate with classical PDE methods. Upscaling methods require a large number of simulations to infer distributions of material parameters. In addition, it is often desirable to provide

  • Sensitivity of output to input parameters
  • Support for running on CPU and GPU hardware

Machine learning surrogate models can reduce runtime but require:

While JAX provides automatic forward- and reverse-mode differentiation capabilities, iterative solvers with a dynamic stopping criterion require:

  • Custom backward derivative implementations

Here, this is realised with the adjoint state method.

Features

  • GPU accelerated differentiable material models (isotropic & anisotropic) implemented in JAX
  • Reverse- mode differentiation with bespoke adjoint implementation handles dynamic loop bounds
  • Automatic differentiation enables sensitivity studies, optimisation and ML training
  • Bespoke CUDA solvers for fast data generation and inference
  • Modular design for extending models and components
  • Compatible with JAX ML pipelines and optimisation frameworks

Achievements

The following figure compares the performance of the (an-) isotropic JAX and CUDA solvers when applied to an isotropic material. Results for reverse mode gradient computation with the adjoint state method are also shown.

Performance of JAX and CUDA solvers

All results are for a $64\times 64\times 32$ grid. The code was run on a NVIDIA GeForce GTX 1660 Super GPU.

Quick installation

Clone and run

pip install jaxmaterials

for the JAX-Python library. See detailed instructions below for CUDA support.

Sample usage

The following code snippets demonstrate the forward-solve and reverse mode differentiation capabilities.

First, import the necessary libraries and set up the specification of the computational grid

import numpy as np
import jax

from jax import numpy as jnp

jax.config.update("jax_enable_x64", True)

from jaxmaterials.common import GridSpec
from jaxmaterials.solver.lippmann_schwinger import lippmann_schwinger_isotropic

nx = 32
ny = 32
nz = 16

grid_spec = GridSpec(nx, ny, nz, Lx=1.0, Ly=1.0, Lz=0.5)

Forward solve

The interface to the differential solvers for isotropic and anisotropic materials can be found in lippmann_schwinger.py

The forward solve for given random Lame parameters $\mu$, $\lambda$ and mean strain $\overline{\varepsilon}$ requires a call to lippmann_schwinger_isotropic() which can optionally use the CUDA backend. It returns the strain $\varepsilon$ and stress $\sigma$:

rng = np.random.default_rng(seed=47273)

mu = rng.uniform(low=0.8, high=1.1, size=(nx, ny, nz))
lmbda = rng.uniform(low=0.6, high=0.7, size=(nx, ny, nz))
epsilon_bar = rng.normal(size=6)

epsilon, sigma = lippmann_schwinger_isotropic(
    mu, lmbda, epsilon_bar, grid_spec=grid_spec, use_cuda=True
)

Gradient computation

Since the JAX implementation of lippmann_schwinger_isotropic() is fully reverse-mode differentiable, jax.grad() can be used to compute the gradient of a loss function $L=L(\varepsilon,\sigma)$ with respect to the inputs $\mu$, $\lambda$ and $\overline{\varepsilon}$. This is demonstrated in the following code snippet:

def loss_fn(mu, lmbda, epsilon_bar):
    epsilon, sigma = lippmann_schwinger_isotropic(
        mu, lmbda, epsilon_bar, grid_spec=grid_spec
    )
    return jnp.sum(epsilon**2 + sigma**2)

grad_fn = jax.grad(loss_fn, argnums=(0, 1, 2))

g_mu, g_lmbda, g_epsilon_bar = grad_fn(mu, lmbda, epsilon_bar)

Contents

This repository contains the following code:

CUDA linear elasticity solver

A highly efficient CUDA accelerated solver of the linear elasticity equation in isotropic materials based on the Lippmann Schwinger method by [Moulinec and Suquet, 1998. Computer Methods in Applied Mechanics and Engineering, 157(1-2), pp.69-94].

Jax linear elasticity solver

A Jax implementation of the same method, which allows back-propagation through the solver for later use in a ML setting. Solvers for both isotropic and anisotropic materials have been implemented.

In addition to the plain Lippmann Schwinger solver, the code also supports Anderson acceleration as described in [Wicht, Schneider and Boehlke, T., 2021. International Journal for Numerical Methods in Engineering, 122(9), pp.2287-2311]. Since any Jax code is inherently differentiable, the solver can be used as a building block in a machine learning framework (see below).

Both solvers use the same discretisation as the AMITEX solver, which is described in [Gelebart 2020. Comptes Rendus. Mecanique, 348(8-9), pp.693-704]. For mathematical details see the ./doc subdirectory.

Fibre distribution generator

Code for sampling the Lame parameters which can be used as an input to the solvers. The generated samples contain cross-ply layers of fibres as shown in the following figure:

Generated fibres

ML toolchain (under construction)

A ML toolchain which uses the above code to train machine learning models that predict the efficient elasticity tensor for a given pair of Lame-parameter fields will be added later.

Installation

Prerequisites

The CUDA solver requires a working cuda installation, including the NVidia CUDA Toolkit which contains the NVidia CuFFT library. A working C++ compiler and CMake is required to compile and install the solver. To run the automated tests, the GoogleTest is required.

See pyproject.toml for a list of required Python packages.

Instructions

The following instructions should work on Linux machines, but will need to be adapted on Windows and Mac.

CUDA solver

  1. Clone this repository
  2. Change to the cuda subdirectory
  3. Configure in the build directory with
cmake -B build -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>

where <INSTALL_DIR> is the directory where the solver library should be installed. If the -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR> is omitted, the default (usually /usr/lib/) is used, and you will likely need root access to install the library in this location.

  1. Build the solver with
cmake --build build
  1. (Optionally), if the google test framework is installed, test the library by running
./build/bin/test
  1. Install the library by running
cmake --install build
  1. Add the install directory to LD_LIBRARY_PATH to ensure that it can be loaded from Python
export LD_LIBRARY_PATH=<INSTALL_DIR>:${LD_LIBRARY_PATH}

Python library

In the main directory of the repository run

python -m pip install .

Optionally, add --editable flag for an editable install.

Run the tests suite with

python -m pytest

About

Jax / CUDA implementation of highly efficient solvers for the equations of linear elasticity

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors