diff --git a/cvtModel/setup.py b/cvtModel/setup.py index 5d4c8809..73d06f05 100644 --- a/cvtModel/setup.py +++ b/cvtModel/setup.py @@ -10,12 +10,15 @@ "scipy", "matplotlib", "pandas", - "pydantic" + "pydantic", + "numba" ], - extras_require={"dev": [ - "black", - "flake8", - "pytest", - "coverage" - ]}, + extras_require={ + "dev": [ + "black", + "flake8", + "pytest", + "coverage" + ], + }, ) diff --git a/cvtModel/src/cvt_simulator/models/slip_model.py b/cvtModel/src/cvt_simulator/models/slip_model.py index 45062271..b7d5bcfb 100644 --- a/cvtModel/src/cvt_simulator/models/slip_model.py +++ b/cvtModel/src/cvt_simulator/models/slip_model.py @@ -16,6 +16,7 @@ from cvt_simulator.constants.constants import ( RUBBER_ALUMINUM_STATIC_FRICTION, ) +from cvt_simulator.utils.numba_utils import maybe_njit class SlipModel: @@ -152,8 +153,9 @@ def calculate_t_max(self, state: SystemState) -> tuple[float, float]: secondary_t_max = max(0, secondary_t_max) return primary_t_max, secondary_t_max + @staticmethod + @maybe_njit(cache=True, fastmath=True) def _relative_speed( - self, primary_angular_velocity: float, secondary_angular_velocity: float, cvt_ratio: float, diff --git a/cvtModel/src/cvt_simulator/utils/numba_utils.py b/cvtModel/src/cvt_simulator/utils/numba_utils.py new file mode 100644 index 00000000..6da40a38 --- /dev/null +++ b/cvtModel/src/cvt_simulator/utils/numba_utils.py @@ -0,0 +1,16 @@ +try: + from numba import njit + + NUMBA_ENABLED = True + + def maybe_njit(*args, **kwargs): + return njit(*args, **kwargs) + +except ImportError: # pragma: no cover - exercised when numba is installed + NUMBA_ENABLED = False + + def maybe_njit(*args, **kwargs): + def decorator(func): + return func + + return decorator