diff --git a/visualizer.py b/visualizer.py index c826e42..ad975c4 100644 --- a/visualizer.py +++ b/visualizer.py @@ -23,11 +23,19 @@ def showImage(image, prediction=None, dimensions=None, axis=None, cmap='gray'): [cmap] - color map. Default is gray. """ - if not dimensions: - dim = int(np.sqrt(len(image))) - dimensions = (dim, dim) + # our images can have the color layer as first dimension (1 = greyscale) + if image.shape[0] == 1: + image = image[0] - img = image.reshape(dimensions[0], dimensions[1]) + # reshape if the image isn't already in right format + if len(image.shape) < 2: + if not dimensions: + dim = int(np.sqrt(len(image))) + dimensions = (dim, dim) + + img = image.reshape(dimensions[0], dimensions[1]) + else: + img = image if axis: axis.imshow(img, cmap=cmap)