diff --git a/dali/python/nvidia/dali/experimental/dynamic/_compile.py b/dali/python/nvidia/dali/experimental/dynamic/_compile.py index 8170799bf53..1577b728b9f 100644 --- a/dali/python/nvidia/dali/experimental/dynamic/_compile.py +++ b/dali/python/nvidia/dali/experimental/dynamic/_compile.py @@ -86,21 +86,19 @@ class CompileNode: class _CallTrie: """Trie keyed by call chain CodeLocs for safe call-site identification.""" - __slots__ = ("children", "node") + __slots__ = ("children", "nodes") def __init__(self) -> None: self.children: dict[CodeLoc, _CallTrie] = {} - self.node: CompileNode | None = None + self.nodes: dict[type["Operator"], CompileNode] = {} - def insert(self, call_chain: CallChain, node: CompileNode) -> None: + def insert(self, call_chain: CallChain, op: type["Operator"], node: CompileNode) -> None: current = self for code_loc in call_chain: - if code_loc not in current.children: - current.children[code_loc] = _CallTrie() - current = current.children[code_loc] - current.node = node + current = current.children.setdefault(code_loc, _CallTrie()) + current.nodes[op] = node - def find(self, call_chain: CallChain) -> CompileNode | None: + def find(self, call_chain: CallChain, op: type["Operator"]) -> CompileNode | None: """Look up a node by call chain tuple (not frame). Returns None if not found.""" current = self for code_loc in call_chain: @@ -108,9 +106,9 @@ def find(self, call_chain: CallChain) -> CompileNode | None: if child is None: return None current = child - return current.node + return current.nodes.get(op) - def lookup(self, start_frame: types.FrameType) -> CompileNode | None: + def lookup(self, start_frame: types.FrameType, op: type["Operator"]) -> CompileNode | None: """Walk frames to stack exhaustion or stop early if a frame differs.""" current = self frame: types.FrameType | None = start_frame @@ -120,7 +118,7 @@ def lookup(self, start_frame: types.FrameType) -> CompileNode | None: return None current = child frame = frame.f_back - return current.node + return current.nodes.get(op) class CompiledBatch(Batch): @@ -236,8 +234,12 @@ def record( num_outputs: int, device: Device | None = None, ) -> CompileNode | None: - if existing := self._call_trie.find(call_chain): - if existing.inputs == inputs and existing.kwargs == kwargs: + if existing := self._call_trie.find(call_chain, op_class): + if ( + existing.inputs == inputs + and existing.kwargs == kwargs + and existing.device == device + ): return existing return None @@ -251,7 +253,7 @@ def record( device=device, ) self.nodes.append(node) - self._call_trie.insert(call_chain, node) + self._call_trie.insert(call_chain, op_class, node) return node @_nvtx_range("Building pipeline") @@ -354,12 +356,13 @@ def _matches(self, actual: Any, expected: Any) -> bool: def get_compiled_result( self, frame: types.FrameType, + op_class: type["Operator"], inputs: Sequence[Any], kwargs: Mapping[str, Any], device: Device | None = None, ) -> Any | None: """Return pre-built result for a known call site, or None.""" - node = self._call_trie.lookup(frame) + node = self._call_trie.lookup(frame, op_class) if node is None: return None if device != node.device: @@ -574,7 +577,10 @@ def wrapper(*inputs, batch_size=None, device=None, **raw_kwargs): f"called with batch_size={batch_size}. Cannot change batch_size in " f"compiled mode." ) - if result := compile_ctx.get_compiled_result(frame, inputs, raw_kwargs, device=device): + result = compile_ctx.get_compiled_result( + frame, op_class, inputs, raw_kwargs, device=device + ) + if result is not None: return result return fn_call( *inputs, batch_size=batch_size, device=device, _backend=backend, **raw_kwargs diff --git a/dali/test/python/experimental_mode/test_compile.py b/dali/test/python/experimental_mode/test_compile.py index c5515a289a4..a96c5125dfe 100644 --- a/dali/test/python/experimental_mode/test_compile.py +++ b/dali/test/python/experimental_mode/test_compile.py @@ -87,8 +87,7 @@ def test_compile_basic_pipeline(): compiled_results.append(ndd.as_tensor(images)) assert _is_compiled(images) - assert len(dynamic_results) == len(compiled_results) - for dyn, comp in zip(dynamic_results, compiled_results): + for dyn, comp in zip(dynamic_results, compiled_results, strict=True): np.testing.assert_array_equal(dyn, comp) @@ -114,6 +113,34 @@ def flip(images): ) +@eval_modes() +def test_compile_different_ops_same_call_site(): + ops = [ndd.flip, ndd.sphere] + + reader_dyn = ndd.readers.File(file_root=images_root, pad_last_batch=True) + reader_comp = ndd.readers.File(file_root=images_root, pad_last_batch=True) + + dynamic_results = [] + for jpegs, _ in reader_dyn.next_epoch(batch_size=4): + images = ndd.decoders.image(jpegs) + for op in ops: + out = op(images) + assert not _is_compiled(out) + dynamic_results.append(ndd.as_tensor(out, pad=True)) + + for _ in range(3): + compiled_results = [] + for jpegs, _ in reader_comp.next_epoch(batch_size=4, compile=True): + images = ndd.decoders.image(jpegs) + for op in ops: + out = op(images) + assert _is_compiled(out) + compiled_results.append(ndd.as_tensor(out, pad=True)) + + for dyn, comp in zip(dynamic_results, compiled_results, strict=True): + np.testing.assert_array_equal(dyn, comp) + + @eval_modes() def test_compile_partial(): reader_dyn = ndd.readers.File(file_root=images_root) @@ -135,8 +162,7 @@ def test_compile_partial(): assert not _is_compiled(resized) compiled_results.append(ndd.as_tensor(resized)) - assert len(dynamic_results) == len(compiled_results) - for dyn, comp in zip(dynamic_results, compiled_results): + for dyn, comp in zip(dynamic_results, compiled_results, strict=True): np.testing.assert_array_equal(dyn, comp) @@ -156,8 +182,7 @@ def test_compile_multi_epoch(): images = ndd.decoders.image(jpegs) assert _is_compiled(images) compiled_results.append(ndd.as_tensor(images, pad=True)) - assert len(compiled_results) == len(dynamic_results) - for dyn, comp in zip(dynamic_results, compiled_results): + for dyn, comp in zip(dynamic_results, compiled_results, strict=True): np.testing.assert_array_equal(dyn, comp) @@ -205,8 +230,7 @@ def test_compile_loop_identical(): assert _is_compiled(resized) compiled_results.append(ndd.as_tensor(resized)) - assert len(dynamic_results) == len(compiled_results) - for dyn, comp in zip(dynamic_results, compiled_results): + for dyn, comp in zip(dynamic_results, compiled_results, strict=True): np.testing.assert_array_equal(dyn, comp) @@ -230,8 +254,7 @@ def test_compile_loop_data_dependent(): assert _is_compiled(images) == (i == 0) compiled_results.append(ndd.as_tensor(images)) - assert len(dynamic_results) == len(compiled_results) - for dyn, comp in zip(dynamic_results, compiled_results): + for dyn, comp in zip(dynamic_results, compiled_results, strict=True): np.testing.assert_array_equal(dyn, comp) @@ -269,8 +292,7 @@ def test_compile_diverging_inputs(): assert _is_compiled(images) == (i % 2 == 0) compiled_results.append(ndd.as_tensor(images)) - assert len(dynamic_results) == len(compiled_results) - for dyn, comp in zip(dynamic_results, compiled_results): + for dyn, comp in zip(dynamic_results, compiled_results, strict=True): np.testing.assert_array_equal(dyn, comp) @@ -356,8 +378,7 @@ def _test_video_resize(**resize_args): assert _is_compiled(rotated) compiled_results.append(ndd.as_tensor(rotated).cpu()) - assert len(dynamic_results) == len(compiled_results) - for dyn, comp in zip(dynamic_results, compiled_results): + for dyn, comp in zip(dynamic_results, compiled_results, strict=True): np.testing.assert_array_equal(dyn, comp) @@ -393,8 +414,7 @@ def test_compile_incompatible_kwarg_dtype(): assert _is_compiled(resized), resized compiled_results.append(ndd.as_tensor(resized, pad=True).cpu()) - assert len(dynamic_results) == len(compiled_results) - for dyn, comp in zip(dynamic_results, compiled_results): + for dyn, comp in zip(dynamic_results, compiled_results, strict=True): np.testing.assert_array_equal(dyn, comp)