Skip to content
Open
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,10 @@ url = 'https://gridtools.github.io/pypi/'
# dace = {index = "gridtools"}
[tool.uv.sources]
atlas4py = {index = "test.pypi"}
dace = [
{git = "https://github.com/GridTools/dace", branch = "romanc/math-functions", group = "dace-cartesian"},
{index = "gridtools", group = "dace-next"}
]

# -- versioningit --
[tool.versioningit]
Expand Down
19 changes: 18 additions & 1 deletion src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def generate_extension(self) -> None:

@register
class DaceGPUBackend(BaseDaceBackend):
"""DaCe python backend using gt4py.cartesian.gtc."""
"""GPU DaCe python with an optimal KJI loop layout"""

name = "dace:gpu"
languages: ClassVar[dict] = {"computation": "cuda", "bindings": ["python"]}
Expand All @@ -933,3 +933,20 @@ class DaceGPUBackend(BaseDaceBackend):

def generate_extension(self) -> None:
return self.make_extension(uses_cuda=True)


@register
class DaceGPUBackendIJK(BaseDaceBackend):
"""GPU DaCe python with an optimal IJK loop layout"""

name = "dace:gpu_IJK"
languages: ClassVar[dict] = {"computation": "cuda", "bindings": ["python"]}
storage_info: ClassVar[layout.LayoutInfo] = layout_registry.from_name(name)
MODULE_GENERATOR_CLASS = DaCeCUDAPyExtModuleGenerator
options: ClassVar[GTBackendOptions] = {
**BaseGTBackend.GT_BACKEND_OPTS,
"device_sync": {"versioning": True, "type": bool},
}

def generate_extension(self) -> None:
return self.make_extension(uses_cuda=True)
28 changes: 14 additions & 14 deletions src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,38 +226,38 @@ def visit_NativeFunction(self, node: common.NativeFunction, **_kwargs: Any) -> s
common.NativeFunction.ABS: "abs",
common.NativeFunction.MIN: "min",
common.NativeFunction.MAX: "max",
common.NativeFunction.MOD: "fmod",
common.NativeFunction.MOD: "dace.math.fmod",
common.NativeFunction.SIN: "dace.math.sin",
common.NativeFunction.COS: "dace.math.cos",
common.NativeFunction.TAN: "dace.math.tan",
common.NativeFunction.ARCSIN: "asin",
common.NativeFunction.ARCCOS: "acos",
common.NativeFunction.ARCTAN: "atan",
common.NativeFunction.ARCSIN: "dace.math.asin",
common.NativeFunction.ARCCOS: "dace.math.acos",
common.NativeFunction.ARCTAN: "dace.math.atan",
common.NativeFunction.SINH: "dace.math.sinh",
common.NativeFunction.COSH: "dace.math.cosh",
common.NativeFunction.TANH: "dace.math.tanh",
common.NativeFunction.ARCSINH: "asinh",
common.NativeFunction.ARCCOSH: "acosh",
common.NativeFunction.ARCTANH: "atanh",
common.NativeFunction.ARCSINH: "dace.math.asinh",
common.NativeFunction.ARCCOSH: "dace.math.acosh",
common.NativeFunction.ARCTANH: "dace.math.atanh",
common.NativeFunction.SQRT: "dace.math.sqrt",
common.NativeFunction.POW: "dace.math.pow",
common.NativeFunction.EXP: "dace.math.exp",
common.NativeFunction.LOG: "dace.math.log",
common.NativeFunction.LOG10: "log10",
common.NativeFunction.GAMMA: "tgamma",
common.NativeFunction.CBRT: "cbrt",
common.NativeFunction.LOG10: "dace.math.log10",
common.NativeFunction.GAMMA: "dace.math.tgamma",
common.NativeFunction.CBRT: "dace.math.cbrt",
common.NativeFunction.ISFINITE: "isfinite",
common.NativeFunction.ISINF: "isinf",
common.NativeFunction.ISNAN: "isnan",
common.NativeFunction.FLOOR: "dace.math.ifloor",
common.NativeFunction.CEIL: "ceil",
common.NativeFunction.TRUNC: "trunc",
common.NativeFunction.CEIL: "dace.math.ceil",
common.NativeFunction.TRUNC: "dace.math.trunc",
common.NativeFunction.INT32: "dace.int32",
common.NativeFunction.INT64: "dace.int64",
common.NativeFunction.FLOAT32: "dace.float32",
common.NativeFunction.FLOAT64: "dace.float64",
common.NativeFunction.ERF: "erf",
common.NativeFunction.ERFC: "erfc",
common.NativeFunction.ERF: "dace.math.erf",
common.NativeFunction.ERFC: "dace.math.erfc",
common.NativeFunction.ROUND: "nearbyint",
common.NativeFunction.ROUND_AWAY_FROM_ZERO: "round",
}
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/cartesian/utils/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class GPUCompilerName(enum.Enum):
class GPUConfiguration:
name: GPUCompilerName
"""Name identifier of the compiler"""
gpu_compile_flags: list[str]
gpu_compile_flags: str
"""Compile flags for device code"""
binary_path: str
"""Path to binaries for GPU compiler & tools"""
Expand Down Expand Up @@ -181,7 +181,7 @@ def gpu_configuration(optimization_level: str) -> GPUConfiguration:

return GPUConfiguration(
name=name,
gpu_compile_flags=gpu_compile_flags,
gpu_compile_flags=" ".join(gpu_compile_flags).strip(),
binary_path=os.path.join(cuda_root, "bin"),
include_path=os.path.join(cuda_root, "include"),
library_path=library_path,
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/storage/cartesian/layout_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def register(name: str, info: LayoutInfo) -> None:
is_optimal_layout=layout_checker_factory(layout_maker_factory((2, 1, 0))),
),
)
register(
"dace:gpu_IJK",
LayoutInfo(
alignment=32,
device="gpu",
layout_map=layout_maker_factory((0, 1, 2)),
is_optimal_layout=layout_checker_factory(layout_maker_factory((0, 1, 2))),
),
)
register(
"debug",
LayoutInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,22 @@ def test_integer_power_of_integer() -> None:
tasklet_code = visitor.visit_NativeFuncCall(pow_call, ctx=fake_context, is_target=False)

assert "ipow" not in tasklet_code


@pytest.mark.parametrize(
"arg",
[
oir.Literal(value="2", dtype=common.DataType.FLOAT32),
oir.Literal(value="2", dtype=common.DataType.FLOAT64),
],
)
def test_log10_respects_floating_point_precision(arg: oir.Literal) -> None:
log10_call = oir.NativeFuncCall(func=common.NativeFunction.LOG10, args=[arg])

visitor = oir_to_tasklet.OIRToTasklet()
fake_context = oir_to_tasklet.Context(
code="asdf", targets=set(), inputs={}, outputs={}, tree=None, scope=None
)
tasklet_code = visitor.visit_NativeFuncCall(log10_call, ctx=fake_context, is_target=False)

assert "dace.math.log10" in tasklet_code
29 changes: 25 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading