An experimental package for accelerating DASCore with JAX.
python -m pip install -e ".[dev]"dasjax's main feature is the ability to create compiled DAS pipelines that can run on CPU, GPU, or TPU. These also perform kernel fusions for increased efficiency.
Use JaxPatchPipeline when you want to compile a reusable sequence once and run it across many compatible patches.
import dascore as dc
from dasjax import JaxPatchPipeline
patch = dc.get_example_patch("example_event_1")
pipeline = (
JaxPatchPipeline()
.scale(2.0)
.add(1.0)
.detrend(dim="time", type="constant")
.normalize(dim="time")
)
compiled = pipeline.compile()
out = patch.pipe(compiled)
print(out.shape)dasjax is organized as a small three-tier stack:
- Pipeline layer:
src/dasjax/pipeline.pyrecords operation chains and compiles reusable patch transforms. This is the main user-facing API. - Operation layer:
src/dasjax/operations/defines the operation registry, execution policies, validation rules, eager patch implementations, and compiled leaf transforms. - Kernel layer:
src/dasjax/kernels/contains the array-level JAX and callback-backed kernels that actually do the numerical work, grouped by domain (basic,signal,filters,spectral).
This split keeps the package easier to extend: add or update numerical behavior in the kernel layer, describe how it plugs into compiled execution in the operation layer, and expose it through the pipeline layer.
The table below tracks what is missing and roughly how much effort each addition requires.
Implemented in the current package:
real,imag,angle,conjflip,roll,padstandardize,differentiate,integratedft,idfthilbert,envelopetaper,taper_rangewhiten
These need either more work in the kernel layer or are shape-changing (segmented pipeline execution, same mechanism as fbe).
| Method | Implementation notes |
|---|---|
notch_filter |
SOS filter; same pattern as pass_filter |
savgol_filter |
polynomial fitting per frame; JAX-doable |
rolling |
rolling-window reductions (mean, std, …); needs strided views |
correlate |
cross-correlation via jnp.fft |
stft / istft |
expose the STFT kernel already used by fbe |
decimate |
anti-aliased downsampling; shape-changing |
aggregate / mean / std / sum |
axis reductions; shape-changing |
- The intended fast path is to build a
JaxPatchPipeline, call.compile()once, and reuse the returned callable across many patches of compatible shape and dtype. - Equivalent pipeline definitions reuse cached compiled callables automatically.
- Benchmarks live under
benchmarks/and compare compileddasjaxpipelines against equivalent DASCore operation chains.
- Add new JAX patch methods by defining an array kernel in
src/dasjax/kernels/and one operation spec in the relevantsrc/dasjax/operations/family module. - The operation spec is the single source of truth for pipeline support, validation, and shared parity test cases.
- Every new patch method must be tested against a DASCore baseline across the shared mixed-patch fixture in
tests/conftest.py. - Prefer comparing internal operation behavior and compiled pipeline outputs against the closest native DASCore method or operator. If DASCore has no direct method, compare against an equivalent
Patch.update(...)baseline. - Method-equivalence assertions should check data closeness with
equal_nan=Truewhen needed and should also verify coordinate preservation. - Compiled pipeline parity should come from the same declared operation cases rather than a separate hand-maintained test matrix.
- Install Git hooks locally with
prek install.