Skip to content

how to solve the huge precision drift after the torch compile #44

@hantao-zhou

Description

@hantao-zhou
izero
Using cuda and torch.bfloat16...


=========================
Final action dimensions: torch.Size([2, 4, 7])
Final action values: tensor([[[-1.0000,  1.0000,  0.7852, -0.2695, -1.0000,  0.9688, -0.7930],
         [-1.0000,  1.0000, -0.8281, -1.0000,  0.3750,  0.3145, -0.8047],
         [-1.0000,  1.0000, -0.3516, -0.0557,  1.0000, -0.2285,  0.9336],
         [-1.0000, -0.0381,  1.0000, -0.7969, -0.7773,  1.0000, -1.0000]],

        [[-0.2256, -0.4766, -0.8828, -1.0000,  0.8203,  1.0000, -0.3555],
         [-0.3887, -1.0000,  1.0000,  1.0000,  0.1641,  1.0000,  0.8906],
         [-0.5781,  0.3047,  1.0000, -1.0000, -0.6719,  0.5391, -0.5000],
         [-1.0000, -0.1162,  1.0000, -0.3691,  0.6406,  1.0000, -1.0000]]],
       device='cuda:0', dtype=torch.bfloat16)
Time taken: 0.3848745822906494
============================


Max absolute difference between uncompiled and compiled: 2.0
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/hansgutmann/workspace/accelerate_pizero/open-pi-zero/src/model/vla/pizero.py", line 931, in <module>
    assert torch.allclose(uncompiled_results, compiled_results, atol=1e-3, rtol=1e-3)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

Here are the error, so after the torch compile ,the model's output has a huge drift.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions