From 369a5645ae18b1b2ba3148ec5c1e56738b100fa0 Mon Sep 17 00:00:00 2001 From: Manfredss Date: Thu, 5 Feb 2026 12:43:59 +0800 Subject: [PATCH] fix --- tests/test_Tensor_inverse.py | 79 ++++++++++++++++++++++++++++++++++++ tests/test_inverse.py | 65 +++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+) diff --git a/tests/test_Tensor_inverse.py b/tests/test_Tensor_inverse.py index 97c18b89f..f85ed30ae 100644 --- a/tests/test_Tensor_inverse.py +++ b/tests/test_Tensor_inverse.py @@ -64,3 +64,82 @@ def _test_case_3(): """ ) obj.run(pytorch_code, ["result"]) + + +# Test with float32 dtype +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.7308, 1.0060, 0.5270, 1.4516], + [-0.1383, 1.5706, 0.4724, 0.4141], + [ 0.1193, 0.2829, 0.9037, 0.3957], + [-0.8202, -0.6474, -0.1631, -0.6543]], dtype=torch.float32) + result = x.inverse() + """ + ) + obj.run(pytorch_code, ["result"], rtol=1e-5) + + +# Test with float64 dtype +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.7308, 1.0060, 0.5270, 1.4516], + [-0.1383, 1.5706, 0.4724, 0.4141], + [ 0.1193, 0.2829, 0.9037, 0.3957], + [-0.8202, -0.6474, -0.1631, -0.6543]], dtype=torch.float64) + result = x.inverse() + """ + ) + obj.run(pytorch_code, ["result"]) + + +# Test with 3D batched input +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[[ 0.7308, 1.0060, 0.5270, 1.4516], + [-0.1383, 1.5706, 0.4724, 0.4141], + [ 0.1193, 0.2829, 0.9037, 0.3957], + [-0.8202, -0.6474, -0.1631, -0.6543]], + [[ 0.7308, 1.0060, 0.5270, 1.4516], + [-0.1383, 1.5706, 0.4724, 0.4141], + [ 0.1193, 0.2829, 0.9037, 0.3957], + [-0.8202, -0.6474, -0.1631, -0.6543]]]) + result = x.inverse() + """ + ) + obj.run(pytorch_code, ["result"], rtol=1e-5) + + +# Test with method chaining +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.7308, 1.0060, 0.5270, 1.4516], + [-0.1383, 1.5706, 0.4724, 0.4141], + [ 0.1193, 0.2829, 0.9037, 0.3957], + [-0.8202, -0.6474, -0.1631, -0.6543]]) + result = x.clone().inverse() + """ + ) + obj.run(pytorch_code, ["result"]) + + +# Test inverse of inverse (should approximately equal original) +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.7308, 1.0060, 0.5270, 1.4516], + [-0.1383, 1.5706, 0.4724, 0.4141], + [ 0.1193, 0.2829, 0.9037, 0.3957], + [-0.8202, -0.6474, -0.1631, -0.6543]], dtype=torch.float64) + result = x.inverse() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 218a2d7a2..d3a241174 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -111,3 +111,68 @@ def test_case_6(): """ ) obj.run(pytorch_code, ["result", "out"]) + + +# Test with *args unpacking +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.7308, 1.0060, 0.5270, 1.4516], + [-0.1383, 1.5706, 0.4724, 0.4141], + [ 0.1193, 0.2829, 0.9037, 0.3957], + [-0.8202, -0.6474, -0.1631, -0.6543]]) + args = (x,) + result = torch.inverse(*args) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# Test with float32 dtype +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.7308, 1.0060, 0.5270, 1.4516], + [-0.1383, 1.5706, 0.4724, 0.4141], + [ 0.1193, 0.2829, 0.9037, 0.3957], + [-0.8202, -0.6474, -0.1631, -0.6543]], dtype=torch.float32) + result = torch.inverse(x) + """ + ) + obj.run(pytorch_code, ["result"], rtol=1e-5) + + +# Test with float64 dtype +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.7308, 1.0060, 0.5270, 1.4516], + [-0.1383, 1.5706, 0.4724, 0.4141], + [ 0.1193, 0.2829, 0.9037, 0.3957], + [-0.8202, -0.6474, -0.1631, -0.6543]], dtype=torch.float64) + result = torch.inverse(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# Test with 3D batched input and keyword argument +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[[ 0.7308, 1.0060, 0.5270, 1.4516], + [-0.1383, 1.5706, 0.4724, 0.4141], + [ 0.1193, 0.2829, 0.9037, 0.3957], + [-0.8202, -0.6474, -0.1631, -0.6543]], + [[ 0.7308, 1.0060, 0.5270, 1.4516], + [-0.1383, 1.5706, 0.4724, 0.4141], + [ 0.1193, 0.2829, 0.9037, 0.3957], + [-0.8202, -0.6474, -0.1631, -0.6543]]]) + result = torch.inverse(input=x) + """ + ) + obj.run(pytorch_code, ["result"], rtol=1e-5)