From 9e8088994e0b28c78a223f47c0f0ce30a70304ff Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 19 Feb 2025 12:26:33 -0600 Subject: [PATCH] Add (non-)reproducer for einsum tag cycle --- test/test_pytato.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index da176f124..f22b4e2f5 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1844,6 +1844,34 @@ def test_pickling_hash(): # }}} +def test_einsum_tagging(): + from pytools.tag import UniqueTag + + class FTag(UniqueTag): + pass + + class ATag(UniqueTag): + pass + + class CTag(UniqueTag): + pass + + class DTag(UniqueTag): + pass + + p = (pt.zeros((2, 3, 4, 5)) + .with_tagged_axis(0, FTag()) + .with_tagged_axis(1, ATag()) + .with_tagged_axis(2, CTag()) + .with_tagged_axis(3, DTag())) + + a = pt.zeros((2, 3, 3)) + + result = pt.einsum("facb, fdce, fad -> cbe", p, p, a) + + pt.unify_axes_tags(result) + + if __name__ == "__main__": import os if "INVOCATION_INFO" in os.environ: