Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions dali/python/nvidia/dali/experimental/dynamic/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,31 +86,29 @@ class CompileNode:
class _CallTrie:
"""Trie keyed by call chain CodeLocs for safe call-site identification."""

__slots__ = ("children", "node")
__slots__ = ("children", "nodes")

def __init__(self) -> None:
self.children: dict[CodeLoc, _CallTrie] = {}
self.node: CompileNode | None = None
self.nodes: dict[type["Operator"], CompileNode] = {}

def insert(self, call_chain: CallChain, node: CompileNode) -> None:
def insert(self, call_chain: CallChain, op: type["Operator"], node: CompileNode) -> None:
current = self
for code_loc in call_chain:
if code_loc not in current.children:
current.children[code_loc] = _CallTrie()
current = current.children[code_loc]
current.node = node
current = current.children.setdefault(code_loc, _CallTrie())
current.nodes[op] = node

def find(self, call_chain: CallChain) -> CompileNode | None:
def find(self, call_chain: CallChain, op: type["Operator"]) -> CompileNode | None:
"""Look up a node by call chain tuple (not frame). Returns None if not found."""
current = self
for code_loc in call_chain:
child = current.children.get(code_loc)
if child is None:
return None
current = child
return current.node
return current.nodes.get(op)

def lookup(self, start_frame: types.FrameType) -> CompileNode | None:
def lookup(self, start_frame: types.FrameType, op: type["Operator"]) -> CompileNode | None:
"""Walk frames to stack exhaustion or stop early if a frame differs."""
current = self
frame: types.FrameType | None = start_frame
Expand All @@ -120,7 +118,7 @@ def lookup(self, start_frame: types.FrameType) -> CompileNode | None:
return None
current = child
frame = frame.f_back
return current.node
return current.nodes.get(op)


class CompiledBatch(Batch):
Expand Down Expand Up @@ -236,8 +234,12 @@ def record(
num_outputs: int,
device: Device | None = None,
) -> CompileNode | None:
if existing := self._call_trie.find(call_chain):
if existing.inputs == inputs and existing.kwargs == kwargs:
if existing := self._call_trie.find(call_chain, op_class):
if (
existing.inputs == inputs
and existing.kwargs == kwargs
and existing.device == device
):
return existing
return None

Expand All @@ -251,7 +253,7 @@ def record(
device=device,
)
self.nodes.append(node)
self._call_trie.insert(call_chain, node)
self._call_trie.insert(call_chain, op_class, node)
return node

@_nvtx_range("Building pipeline")
Expand Down Expand Up @@ -354,12 +356,13 @@ def _matches(self, actual: Any, expected: Any) -> bool:
def get_compiled_result(
self,
frame: types.FrameType,
op_class: type["Operator"],
inputs: Sequence[Any],
kwargs: Mapping[str, Any],
device: Device | None = None,
) -> Any | None:
"""Return pre-built result for a known call site, or None."""
node = self._call_trie.lookup(frame)
node = self._call_trie.lookup(frame, op_class)
if node is None:
return None
if device != node.device:
Expand Down Expand Up @@ -574,7 +577,10 @@ def wrapper(*inputs, batch_size=None, device=None, **raw_kwargs):
f"called with batch_size={batch_size}. Cannot change batch_size in "
f"compiled mode."
)
if result := compile_ctx.get_compiled_result(frame, inputs, raw_kwargs, device=device):
result = compile_ctx.get_compiled_result(
frame, op_class, inputs, raw_kwargs, device=device
)
if result is not None:
return result
return fn_call(
*inputs, batch_size=batch_size, device=device, _backend=backend, **raw_kwargs
Expand Down
52 changes: 36 additions & 16 deletions dali/test/python/experimental_mode/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def test_compile_basic_pipeline():
compiled_results.append(ndd.as_tensor(images))
assert _is_compiled(images)

assert len(dynamic_results) == len(compiled_results)
for dyn, comp in zip(dynamic_results, compiled_results):
for dyn, comp in zip(dynamic_results, compiled_results, strict=True):
np.testing.assert_array_equal(dyn, comp)


Expand All @@ -114,6 +113,34 @@ def flip(images):
)


@eval_modes()
def test_compile_different_ops_same_call_site():
ops = [ndd.flip, ndd.sphere]

reader_dyn = ndd.readers.File(file_root=images_root, pad_last_batch=True)
reader_comp = ndd.readers.File(file_root=images_root, pad_last_batch=True)

dynamic_results = []
for jpegs, _ in reader_dyn.next_epoch(batch_size=4):
images = ndd.decoders.image(jpegs)
for op in ops:
out = op(images)
assert not _is_compiled(out)
dynamic_results.append(ndd.as_tensor(out, pad=True))

for _ in range(3):
compiled_results = []
for jpegs, _ in reader_comp.next_epoch(batch_size=4, compile=True):
images = ndd.decoders.image(jpegs)
for op in ops:
out = op(images)
assert _is_compiled(out)
compiled_results.append(ndd.as_tensor(out, pad=True))

for dyn, comp in zip(dynamic_results, compiled_results, strict=True):
np.testing.assert_array_equal(dyn, comp)


@eval_modes()
def test_compile_partial():
reader_dyn = ndd.readers.File(file_root=images_root)
Expand All @@ -135,8 +162,7 @@ def test_compile_partial():
assert not _is_compiled(resized)
compiled_results.append(ndd.as_tensor(resized))

assert len(dynamic_results) == len(compiled_results)
for dyn, comp in zip(dynamic_results, compiled_results):
for dyn, comp in zip(dynamic_results, compiled_results, strict=True):
np.testing.assert_array_equal(dyn, comp)


Expand All @@ -156,8 +182,7 @@ def test_compile_multi_epoch():
images = ndd.decoders.image(jpegs)
assert _is_compiled(images)
compiled_results.append(ndd.as_tensor(images, pad=True))
assert len(compiled_results) == len(dynamic_results)
for dyn, comp in zip(dynamic_results, compiled_results):
for dyn, comp in zip(dynamic_results, compiled_results, strict=True):
np.testing.assert_array_equal(dyn, comp)


Expand Down Expand Up @@ -205,8 +230,7 @@ def test_compile_loop_identical():
assert _is_compiled(resized)
compiled_results.append(ndd.as_tensor(resized))

assert len(dynamic_results) == len(compiled_results)
for dyn, comp in zip(dynamic_results, compiled_results):
for dyn, comp in zip(dynamic_results, compiled_results, strict=True):
np.testing.assert_array_equal(dyn, comp)


Expand All @@ -230,8 +254,7 @@ def test_compile_loop_data_dependent():
assert _is_compiled(images) == (i == 0)
compiled_results.append(ndd.as_tensor(images))

assert len(dynamic_results) == len(compiled_results)
for dyn, comp in zip(dynamic_results, compiled_results):
for dyn, comp in zip(dynamic_results, compiled_results, strict=True):
np.testing.assert_array_equal(dyn, comp)


Expand Down Expand Up @@ -269,8 +292,7 @@ def test_compile_diverging_inputs():
assert _is_compiled(images) == (i % 2 == 0)
compiled_results.append(ndd.as_tensor(images))

assert len(dynamic_results) == len(compiled_results)
for dyn, comp in zip(dynamic_results, compiled_results):
for dyn, comp in zip(dynamic_results, compiled_results, strict=True):
np.testing.assert_array_equal(dyn, comp)


Expand Down Expand Up @@ -356,8 +378,7 @@ def _test_video_resize(**resize_args):
assert _is_compiled(rotated)
compiled_results.append(ndd.as_tensor(rotated).cpu())

assert len(dynamic_results) == len(compiled_results)
for dyn, comp in zip(dynamic_results, compiled_results):
for dyn, comp in zip(dynamic_results, compiled_results, strict=True):
np.testing.assert_array_equal(dyn, comp)


Expand Down Expand Up @@ -393,8 +414,7 @@ def test_compile_incompatible_kwarg_dtype():
assert _is_compiled(resized), resized
compiled_results.append(ndd.as_tensor(resized, pad=True).cpu())

assert len(dynamic_results) == len(compiled_results)
for dyn, comp in zip(dynamic_results, compiled_results):
for dyn, comp in zip(dynamic_results, compiled_results, strict=True):
np.testing.assert_array_equal(dyn, comp)


Expand Down