-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDataset_Support.py
More file actions
151 lines (116 loc) · 5.22 KB
/
Dataset_Support.py
File metadata and controls
151 lines (116 loc) · 5.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# File: Dataset_Support.py
# Owner: Jeff Brown
# Dependencies: Standard libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
# Machine Learning - Keras (Tensorflow) - Dataset Generation
from tensorflow.keras.datasets import mnist # Images: Handwritten digits 0-9 (28x28 grayscale, 60K train, 10K test)
# Function to import the Keras MNIST handwritten digits sample dataset
def mnist_load_ds():
(X_train, y_train), (X_test, y_test) = mnist.load_data()
return (X_train, y_train), (X_test, y_test)
# Function to plot a list of up to 10 digits on a single subplot
def mnist_plot_digit_list( a_X_list = None, a_y_list = None, a_find_all_digits = False):
# The first 10 digits from the specified list
# If no list is specified then return None
if (a_X_list is None):
return None
else:
X_list = list(a_X_list)
if (a_y_list is None):
return None
else:
y_list = list(a_y_list)
# Find All Digits flag
# If True => Find and plot all digits 0-9 within the list, starting at index 0
# If False => Plot up to the first 10 digits in the list
if a_find_all_digits:
# Flag is True: Get indices of samples for each of the digits 0-9 within the 1000 sample subset
# If the digit is not present in the input list then move on to the next digit
d_i_list = []
for d in range(10):
try:
# Add the index at which this digit can be found to the list
d_i_list.append( y_list.index(d) )
except ValueError:
# Digit is not present in the input list -- move on to the next digit
pass
else:
# Flag is False: Get the indices for up to the first 10 values in the list
d_i_list = range( min(10, len(y_list) ))
# The iterpolation method to use for ploting the digit images
i_type_selected = 'lanczos'
print("Indices:", d_i_list)
# Plot Classification Performance results: Best Score vs. Mean Fit Time (ms)
fig = plt.figure(figsize=(20,9))
# Create subplots for each of the sampled digits
for i in range(len(d_i_list)):
# Create a subplot for this iteration
ax = fig.add_subplot( math.ceil(len(d_i_list)/min(5, len(d_i_list))), min(5, len(d_i_list)), i+1 )
# Display a note for each subplot
point_text = f"Label: {y_list[d_i_list[i]]}"
point_text += f"\nSample Index: {d_i_list[i]}"
# ax.text(1, 2+1.4*point_text.count("\n"), point_text )
ax.set_title(point_text)
# Display the image
ax.imshow(X_list[d_i_list[i]], cmap=plt.cm.Greys, interpolation=i_type_selected)
# Return the number of digits plotted
return i+1
# Function to plot example images
def plot_examples( a_X_list = None, a_y_list = None, a_find_labels=False, a_label = None):
"""
Plot examples from the specified list
a_X_list: List of example inputs
a_y_list: List of example labels
a_find_labels:
False: Plot all examples
True:
If a_label is None, plot one example of each unique label
If a_label not None, plot all examples with the single specified label
"""
# Parse args
X_list = list(a_X_list.squeeze() )
y_list = list(a_y_list.squeeze() )
# Parse flags to determine which examples to plot
if a_find_labels == False:
# Plot all examples
d_i_list = range( len(y_list) )
if a_find_labels == True:
# Plot examples based upon the label
if a_label is None:
# Plot examples for each unique label in the input
label_list = sorted(set(y_list))
d_i_list = []
for label in label_list:
try:
# Add the index at which this label can be found to the list
d_i_list.append( y_list.index(label) )
except ValueError:
# Digit is not present in the input list
# => Shouldn't happen since the list of labels is drawn from the input list
pass
else:
# Plot all examples with the single specified label
d_i_list = []
for i in range(len(y_list)):
if y_list[i] == a_label:
d_i_list.append(i)
# print("Image Indices to display:", d_i_list)
# Plot images
fig = plt.figure(figsize=(20,10))
# Create subplots for each of the sampled digits
for i in range(len(d_i_list)):
# Create a subplot for this iteration
ax = fig.add_subplot( math.ceil(len(d_i_list)/min(5, len(d_i_list))), min(5, len(d_i_list)), i+1 )
# Display a note for each subplot
point_text = f"Label: {y_list[d_i_list[i]]}"
point_text += f"\nSample Index: {d_i_list[i]}"
# ax.text(1, 2+1.4*point_text.count("\n"), point_text )
ax.set_title(point_text)
# Display the image
ax.imshow(X_list[d_i_list[i]], cmap=plt.cm.Greys, interpolation='lanczos')
# Return the number of digits plotted
return i+1