-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexplore.py
More file actions
37 lines (27 loc) · 1.09 KB
/
explore.py
File metadata and controls
37 lines (27 loc) · 1.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import tensorflow as tf
import tensorflow_federated as tff
import tensorflow_datasets as tfds
import collections
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
only_digits=True, cache_dir=None)
example_dataset = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[0])
print(len(example_dataset))
NUM_EPOCHS = 5
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10
def preprocess(dataset, batch_size):
def batch_format_fn(element):
"""Flatten a batch `pixels` and return the features as an `OrderedDict`."""
return collections.OrderedDict(
x=tf.reshape(element['pixels'], [-1, 28,28]),
y=tf.reshape(element['label'], [-1, 1]))
return dataset.shuffle(SHUFFLE_BUFFER, seed=1).batch(
batch_size).map(batch_format_fn).prefetch(PREFETCH_BUFFER)
processed = preprocess(example_dataset, 93)
sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
next(iter(processed)))
print(sample_batch['x'][0])