diff --git a/pyproject.toml b/pyproject.toml index a10488c7a9..3560ec3d80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -476,6 +476,10 @@ url = 'https://gridtools.github.io/pypi/' # dace = {index = "gridtools"} [tool.uv.sources] atlas4py = {index = "test.pypi"} +dace = [ + {git = "https://github.com/GridTools/dace", branch = "romanc/math-functions", group = "dace-cartesian"}, + {index = "gridtools", group = "dace-next"} +] # -- versioningit -- [tool.versioningit] diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index b7a9a0cece..5bcb8e9e08 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -920,7 +920,7 @@ def generate_extension(self) -> None: @register class DaceGPUBackend(BaseDaceBackend): - """DaCe python backend using gt4py.cartesian.gtc.""" + """GPU DaCe python with an optimal KJI loop layout""" name = "dace:gpu" languages: ClassVar[dict] = {"computation": "cuda", "bindings": ["python"]} @@ -933,3 +933,20 @@ class DaceGPUBackend(BaseDaceBackend): def generate_extension(self) -> None: return self.make_extension(uses_cuda=True) + + +@register +class DaceGPUBackendIJK(BaseDaceBackend): + """GPU DaCe python with an optimal IJK loop layout""" + + name = "dace:gpu_IJK" + languages: ClassVar[dict] = {"computation": "cuda", "bindings": ["python"]} + storage_info: ClassVar[layout.LayoutInfo] = layout_registry.from_name(name) + MODULE_GENERATOR_CLASS = DaCeCUDAPyExtModuleGenerator + options: ClassVar[GTBackendOptions] = { + **BaseGTBackend.GT_BACKEND_OPTS, + "device_sync": {"versioning": True, "type": bool}, + } + + def generate_extension(self) -> None: + return self.make_extension(uses_cuda=True) diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py b/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py index 2329128d70..6a7c71047e 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py @@ -226,38 +226,38 @@ def visit_NativeFunction(self, node: common.NativeFunction, **_kwargs: Any) -> s common.NativeFunction.ABS: "abs", common.NativeFunction.MIN: "min", common.NativeFunction.MAX: "max", - common.NativeFunction.MOD: "fmod", + common.NativeFunction.MOD: "dace.math.fmod", common.NativeFunction.SIN: "dace.math.sin", common.NativeFunction.COS: "dace.math.cos", common.NativeFunction.TAN: "dace.math.tan", - common.NativeFunction.ARCSIN: "asin", - common.NativeFunction.ARCCOS: "acos", - common.NativeFunction.ARCTAN: "atan", + common.NativeFunction.ARCSIN: "dace.math.asin", + common.NativeFunction.ARCCOS: "dace.math.acos", + common.NativeFunction.ARCTAN: "dace.math.atan", common.NativeFunction.SINH: "dace.math.sinh", common.NativeFunction.COSH: "dace.math.cosh", common.NativeFunction.TANH: "dace.math.tanh", - common.NativeFunction.ARCSINH: "asinh", - common.NativeFunction.ARCCOSH: "acosh", - common.NativeFunction.ARCTANH: "atanh", + common.NativeFunction.ARCSINH: "dace.math.asinh", + common.NativeFunction.ARCCOSH: "dace.math.acosh", + common.NativeFunction.ARCTANH: "dace.math.atanh", common.NativeFunction.SQRT: "dace.math.sqrt", common.NativeFunction.POW: "dace.math.pow", common.NativeFunction.EXP: "dace.math.exp", common.NativeFunction.LOG: "dace.math.log", - common.NativeFunction.LOG10: "log10", - common.NativeFunction.GAMMA: "tgamma", - common.NativeFunction.CBRT: "cbrt", + common.NativeFunction.LOG10: "dace.math.log10", + common.NativeFunction.GAMMA: "dace.math.tgamma", + common.NativeFunction.CBRT: "dace.math.cbrt", common.NativeFunction.ISFINITE: "isfinite", common.NativeFunction.ISINF: "isinf", common.NativeFunction.ISNAN: "isnan", common.NativeFunction.FLOOR: "dace.math.ifloor", - common.NativeFunction.CEIL: "ceil", - common.NativeFunction.TRUNC: "trunc", + common.NativeFunction.CEIL: "dace.math.ceil", + common.NativeFunction.TRUNC: "dace.math.trunc", common.NativeFunction.INT32: "dace.int32", common.NativeFunction.INT64: "dace.int64", common.NativeFunction.FLOAT32: "dace.float32", common.NativeFunction.FLOAT64: "dace.float64", - common.NativeFunction.ERF: "erf", - common.NativeFunction.ERFC: "erfc", + common.NativeFunction.ERF: "dace.math.erf", + common.NativeFunction.ERFC: "dace.math.erfc", common.NativeFunction.ROUND: "nearbyint", common.NativeFunction.ROUND_AWAY_FROM_ZERO: "round", } diff --git a/src/gt4py/cartesian/utils/compiler.py b/src/gt4py/cartesian/utils/compiler.py index 409a4c7e09..ec8dc20de0 100644 --- a/src/gt4py/cartesian/utils/compiler.py +++ b/src/gt4py/cartesian/utils/compiler.py @@ -137,7 +137,7 @@ class GPUCompilerName(enum.Enum): class GPUConfiguration: name: GPUCompilerName """Name identifier of the compiler""" - gpu_compile_flags: list[str] + gpu_compile_flags: str """Compile flags for device code""" binary_path: str """Path to binaries for GPU compiler & tools""" @@ -181,7 +181,7 @@ def gpu_configuration(optimization_level: str) -> GPUConfiguration: return GPUConfiguration( name=name, - gpu_compile_flags=gpu_compile_flags, + gpu_compile_flags=" ".join(gpu_compile_flags).strip(), binary_path=os.path.join(cuda_root, "bin"), include_path=os.path.join(cuda_root, "include"), library_path=library_path, diff --git a/src/gt4py/storage/cartesian/layout_registry.py b/src/gt4py/storage/cartesian/layout_registry.py index 4fadc0f7d0..f80c7e3f8d 100644 --- a/src/gt4py/storage/cartesian/layout_registry.py +++ b/src/gt4py/storage/cartesian/layout_registry.py @@ -66,6 +66,15 @@ def register(name: str, info: LayoutInfo) -> None: is_optimal_layout=layout_checker_factory(layout_maker_factory((2, 1, 0))), ), ) +register( + "dace:gpu_IJK", + LayoutInfo( + alignment=32, + device="gpu", + layout_map=layout_maker_factory((0, 1, 2)), + is_optimal_layout=layout_checker_factory(layout_maker_factory((0, 1, 2))), + ), +) register( "debug", LayoutInfo( diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_oir_to_tasklet.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_oir_to_tasklet.py index 1ecb7659ee..980727fa3c 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_oir_to_tasklet.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_oir_to_tasklet.py @@ -139,3 +139,22 @@ def test_integer_power_of_integer() -> None: tasklet_code = visitor.visit_NativeFuncCall(pow_call, ctx=fake_context, is_target=False) assert "ipow" not in tasklet_code + + +@pytest.mark.parametrize( + "arg", + [ + oir.Literal(value="2", dtype=common.DataType.FLOAT32), + oir.Literal(value="2", dtype=common.DataType.FLOAT64), + ], +) +def test_log10_respects_floating_point_precision(arg: oir.Literal) -> None: + log10_call = oir.NativeFuncCall(func=common.NativeFunction.LOG10, args=[arg]) + + visitor = oir_to_tasklet.OIRToTasklet() + fake_context = oir_to_tasklet.Context( + code="asdf", targets=set(), inputs={}, outputs={}, tree=None, scope=None + ) + tasklet_code = visitor.visit_NativeFuncCall(log10_call, ctx=fake_context, is_target=False) + + assert "dace.math.log10" in tasklet_code diff --git a/uv.lock b/uv.lock index 7027f7d4f5..7e06c6bb11 100644 --- a/uv.lock +++ b/uv.lock @@ -1206,8 +1206,23 @@ wheels = [ [[package]] name = "dace" -version = "2.0.0a3" -source = { registry = "https://pypi.org/simple" } +version = "1.0.0" +source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fmath-functions#3df061c8aeabcaeea966f79e39a4dbded2628df9" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.14' and sys_platform == 'emscripten'", + "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'emscripten'", + "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.12.*' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.11'", +] dependencies = [ { name = "astunparse" }, { name = "dill" }, @@ -1734,6 +1749,12 @@ build = [ { name = "setuptools" }, { name = "wheel" }, ] +dace-cartesian = [ + { name = "dace", version = "1.0.0", source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fmath-functions#3df061c8aeabcaeea966f79e39a4dbded2628df9" } }, +] +dace-next = [ + { name = "dace", version = "43!2026.4.27", source = { registry = "https://gridtools.github.io/pypi/" } }, +] dev = [ { name = "atlas4py" }, { name = "coverage", extra = ["toml"] }, @@ -1902,9 +1923,9 @@ build = [ { name = "setuptools", specifier = ">=77.0.3" }, { name = "wheel", specifier = ">=0.33.6" }, ] +dace-cartesian = [{ name = "dace", git = "https://github.com/GridTools/dace?branch=romanc%2Fmath-functions" }] +dace-next = [{ name = "dace", specifier = "==43!2026.4.27", index = "https://gridtools.github.io/pypi/", conflict = { package = "gt4py", group = "dace-next" } }] dev = [ - { name = "atlas4py", specifier = ">=0.41", index = "https://test.pypi.org/simple" }, - { name = "coverage", extras = ["toml"], specifier = ">=7.6.1" }, { name = "cython", specifier = ">=3.0.0" }, { name = "esbonio", specifier = ">=0.16.0" }, { name = "hypothesis", specifier = ">=6.0.0" },