Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/fp-arena-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
sudo apt-get update
# gcc/g++ (C++20) + cmake/ninja drive DaCe's CPU codegen build; the
# generated maps use OpenMP (libgomp).
sudo apt-get install -y cmake ninja-build gcc g++ libgomp1
sudo apt-get install -y cmake ninja-build gcc g++ libgomp1 libmpfr-dev

- name: Install DaCe (yakup/dev) and FP-Arena
run: |
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ time — no fork, no DaCe source changes.

## Install

FP-Arena tracks the latest DaCe `yakup/dev`:
FP-Arena tracks the latest DaCe [`extended`](https://github.com/spcl/dace/commits/extended/):

```bash
pip install git+https://github.com/spcl/FP-Arena.git
```

Already have a DaCe checkout (any `yakup/dev`-based branch) you want to use? Install
Already have a DaCe checkout (any `extended`-based branch) you want to use? Install
without pulling DaCe:

```bash
Expand Down
8 changes: 8 additions & 0 deletions fp_arena/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Float64sr,
float32sr,
float64sr,
mpfr,
register,
FP_ARENA_TYPECLASSES,
)
Expand All @@ -29,6 +30,10 @@
INCLUDE_DIR,
)
from fp_arena.transformations.change_fp_types import change_fptype
from fp_arena.transformations.change_and_propagate_fp_types import (
DEFAULT_PROMOTION_RULES,
change_and_propagate_fp_types,
)

# Register the types and the SDFG convenience method on import (idempotent).
register()
Expand All @@ -47,6 +52,7 @@
"Float64sr",
"float32sr",
"float64sr",
"mpfr",
"register",
"FP_ARENA_TYPECLASSES",
"enable_fp_arena_extensions",
Expand All @@ -59,4 +65,6 @@
"fp_arena_global_code",
"INCLUDE_DIR",
"change_fptype",
"change_and_propagate_fp_types",
"DEFAULT_PROMOTION_RULES",
]
61 changes: 60 additions & 1 deletion fp_arena/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
"""

import numpy
import ctypes

import dace
from dace import dtypes as _ddtypes
from dace import dtypes as _ddtypes, typeclass

#: C++ namespace-qualified type names emitted into generated code.
_FLOAT32SR_CTYPE = "fp_arena::float32sr"
Expand Down Expand Up @@ -77,6 +78,60 @@ def __repr__(self) -> str:
}


class _mpfr_t(ctypes.Structure):
_fields_ = [
("_mpfr_prec", ctypes.c_long),
("_mpfr_sign", ctypes.c_int),
("_mpfr_exp", ctypes.c_long),
("_mpfr_d", ctypes.c_void_p),
]


class mpfr(typeclass):
"""
A data type for custom Multiple Precision Floating-Point (MPFR) types.

Example use: `dace.mpfr(128)` for 128-bit precision.
"""

def __init__(self, precision: int):
self.precision = precision
self.type = numpy.object_
self.bytes = ctypes.sizeof(_mpfr_t)
self.dtype = self
self.typename = f"mpfr{precision}"

def to_string(self):
return self.typename

def to_json(self):
return {"type": "mpfr", "precision": self.precision}

@staticmethod
def from_json(json_obj, context=None):
if json_obj["type"] != "mpfr":
raise TypeError("Invalid type for mpfr")
return mpfr(json_obj["precision"])

@property
def ctype(self):
return f"dace::mpfr<{self.precision}>"

@property
def ctype_unaligned(self):
return self.ctype

def as_ctypes(self):
return ctypes.c_void_p

def as_numpy_dtype(self):
return numpy.dtype(numpy.object_)

@property
def base_type(self):
return self


def register():
"""
Register the FP-Arena types into DaCe's global registries.
Expand All @@ -100,4 +155,8 @@ def register():
_ddtypes.TYPECLASS_STRINGS.append(name)
_ddtypes.TYPECLASS_TO_STRING.setdefault(tc, tc.ctype)

# Also expose the parametric mpfr class so `dace.mpfr(128)` works.
setattr(_ddtypes, "mpfr", mpfr)
setattr(dace, "mpfr", mpfr)

return FP_ARENA_TYPECLASSES
21 changes: 16 additions & 5 deletions fp_arena/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_HEADERS = (
os.path.join(INCLUDE_DIR, "fp_arena", "float32sr.h"),
os.path.join(INCLUDE_DIR, "fp_arena", "float64sr.h"),
os.path.join(INCLUDE_DIR, "fp_arena", "mpfr.h"),
)

#: Backends whose global-code section receives the includes (CPU frame + CUDA).
Expand All @@ -51,8 +52,8 @@
#: Marker so the includes are only injected once per SDFG.
_GUARD = "// fp_arena extensions enabled"

#: Substring identifying an FP-Arena C type (used to detect SR usage in an SDFG).
_CTYPE_MARKER = "fp_arena::"
#: Substrings identifying an FP-Arena C type (fp_arena:: for SR types, dace::mpfr for mpfr).
_CTYPE_MARKERS = ("fp_arena::", "dace::mpfr")


def fp_arena_global_code() -> str:
Expand Down Expand Up @@ -145,13 +146,23 @@ def enable_fp_arena_extensions(sdfg: dace.SDFG) -> dace.SDFG:
def uses_fp_arena_types(sdfg: dace.SDFG) -> bool:
"""
:param sdfg: the SDFG to inspect.
:returns: ``True`` if any data descriptor in ``sdfg`` or its nested SDFGs has
an FP-Arena C type (e.g. ``float32sr``), ``False`` otherwise.
:returns: ``True`` if any data descriptor or tasklet body in ``sdfg`` or its
nested SDFGs references an FP-Arena C type, ``False`` otherwise.
"""
from dace.sdfg import nodes as _dnodes
for nested in sdfg.all_sdfgs_recursive():
for desc in nested.arrays.values():
if _CTYPE_MARKER in (getattr(desc.dtype, "ctype", "") or ""):
if any(m in (getattr(desc.dtype, "ctype", "") or "") for m in _CTYPE_MARKERS):
return True
for state in nested.states():
for node in state.nodes():
if isinstance(node, _dnodes.Tasklet):
try:
code_str = node.code.as_string
except AttributeError:
code_str = str(node.code)
if any(m in code_str for m in _CTYPE_MARKERS):
return True
return False


Expand Down
Loading
Loading