|
3 | 3 | import random |
4 | 4 | from . import is_tensor, is_ndarray, is_arr |
5 | 5 |
|
6 | | - |
7 | 6 | def generic_print(self, arr_values): |
8 | 7 | assert is_arr(self) |
9 | 8 |
|
@@ -48,15 +47,25 @@ def get_first_and_last_lines(text): |
48 | 47 | torch.Tensor.__repr__ = lambda self: generic_print(self, normal_repr(self)) |
49 | 48 |
|
50 | 49 | np.set_printoptions(suppress=True, precision=3, threshold=10, edgeitems=2, linewidth=120) |
51 | | -import copy |
52 | 50 |
|
53 | 51 | normal_repr_ = np.ndarray.__str__ |
54 | | -np.set_string_function(lambda self: generic_print(self, normal_repr_(self)), repr=True) |
| 52 | +if int(np.__version__.split('.')[0]) >= 2: |
| 53 | + np.set_printoptions(override_repr=lambda self: generic_print(np.array(self), normal_repr_(np.array(self)))) |
| 54 | +else: |
| 55 | + np.set_string_function(lambda self: generic_print(self, normal_repr_(self)), repr=True) |
55 | 56 |
|
56 | 57 | def disable(): |
57 | 58 | torch.set_printoptions(profile="default") |
58 | 59 | torch.Tensor.__repr__ = normal_repr |
59 | | - np.set_string_function(normal_repr_, repr=True) |
| 60 | + if int(np.__version__.split('.')[0]) >= 2: |
| 61 | + pass |
| 62 | + # TODO: Currently broken for numpy 2.x |
| 63 | + # np.set_printoptions(formatter={ |
| 64 | + # 'int_kind': normal_repr_, |
| 65 | + # 'float_kind': normal_repr_, |
| 66 | + # }) |
| 67 | + else: |
| 68 | + np.set_string_function(normal_repr_, repr=True) |
60 | 69 |
|
61 | 70 | def set_random_seeds(): |
62 | 71 | torch.manual_seed(0) |
|
0 commit comments