Skip to content

Commit 9de912e

Browse files
Attempt at numpy 2.x string formatting
1 parent 8993958 commit 9de912e

1 file changed

Lines changed: 13 additions & 4 deletions

File tree

src/image_utils/custom_library_ops.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import random
44
from . import is_tensor, is_ndarray, is_arr
55

6-
76
def generic_print(self, arr_values):
87
assert is_arr(self)
98

@@ -48,15 +47,25 @@ def get_first_and_last_lines(text):
4847
torch.Tensor.__repr__ = lambda self: generic_print(self, normal_repr(self))
4948

5049
np.set_printoptions(suppress=True, precision=3, threshold=10, edgeitems=2, linewidth=120)
51-
import copy
5250

5351
normal_repr_ = np.ndarray.__str__
54-
np.set_string_function(lambda self: generic_print(self, normal_repr_(self)), repr=True)
52+
if int(np.__version__.split('.')[0]) >= 2:
53+
np.set_printoptions(override_repr=lambda self: generic_print(np.array(self), normal_repr_(np.array(self))))
54+
else:
55+
np.set_string_function(lambda self: generic_print(self, normal_repr_(self)), repr=True)
5556

5657
def disable():
5758
torch.set_printoptions(profile="default")
5859
torch.Tensor.__repr__ = normal_repr
59-
np.set_string_function(normal_repr_, repr=True)
60+
if int(np.__version__.split('.')[0]) >= 2:
61+
pass
62+
# TODO: Currently broken for numpy 2.x
63+
# np.set_printoptions(formatter={
64+
# 'int_kind': normal_repr_,
65+
# 'float_kind': normal_repr_,
66+
# })
67+
else:
68+
np.set_string_function(normal_repr_, repr=True)
6069

6170
def set_random_seeds():
6271
torch.manual_seed(0)

0 commit comments

Comments
 (0)