-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
133 lines (103 loc) · 3.6 KB
/
main.py
File metadata and controls
133 lines (103 loc) · 3.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from tffdataset.DatasetUtils import DatasetID, getDataset
from tffdataset.FedDataset import FedDataset, PartitioningScheme
from tffmodel.FedCoreModel import FedCoreModel
from tffmodel.FedApiModel import FedApiModel
from tffmodel.KerasModel import KerasModel
from tffmodel.ModelUtils import ModelUtils
import getopt
import logging
import sys
import tensorflow as tf
config = {
"seed": 13,
"dataset_id": DatasetID.Mnist,
"part_scheme": PartitioningScheme.ROUND_ROBIN,
"num_workers": 4,
"num_train_rounds": 10,
"log_dir": "./log/training",
"log_level": logging.DEBUG,
}
def trainLocalKeras(dataset, config):
# ===== Local Training =====
# create and fit the local keras model
keras_model = KerasModel(config)
keras_model.initModel(dataset.train)
keras_model.fit(dataset)
# evaluate the model
evaluation_metrics = keras_model.evaluate(dataset.val)
return evaluation_metrics
def trainFedApi(dataset, fed_dataset, config):
# ===== Federated Training =====
# create and fit the federated model with tff api
fed_api_model = FedApiModel(config)
fed_api_model.fit(fed_dataset)
# evaluate the model
evaluation_metrics = fed_api_model.evaluate(fed_dataset.val)
# evaluation_metrics = fed_api_model.evaluateCentralized(dataset.val)
return evaluation_metrics
def trainFedCore(dataset, fed_dataset, config):
# ===== Federated Training =====
# create and fit the federated model with tff core
fed_core_model = FedCoreModel(config)
fed_core_model.fit(fed_dataset)
# evaluate the model
evaluation_metrics = fed_core_model.evaluate(fed_dataset.val)
# evaluation_metrics = fed_core_model.evaluateCentralized(dataset.val)
return evaluation_metrics
def main(argv):
logger = logging.getLogger("main.py")
logger.setLevel(config["log_level"])
try:
opts, args = getopt.getopt(argv[1:], "hl", ["help"])
except getopt.GetoptError:
print("Wrong usage.")
print("Usage:", argv[0])
sys.exit(2)
for opt, arg in opts:
if opt in ("-h", "--help"):
print("Usage:", argv[0])
sys.exit()
# obtain the dataset (either load or compute the response labels)
dataset = getDataset(config)
dataset.load()
# construct data partitions for federated execution
fed_dataset = FedDataset(config)
fed_dataset.construct(dataset)
fed_dataset.batch()
dataset.batch()
evaluations = dict()
evaluations["keras"] = trainLocalKeras(dataset, config)
evaluations["fedapi"] = trainFedApi(dataset, fed_dataset, config)
evaluations["fedcore"] = trainFedCore(dataset, fed_dataset, config)
logger.info(ModelUtils.printEvaluations(evaluations, config))
# model_abbrvs = [
# "c10_avg_dr25",
# "c20_c10_avg_dr25",
# "c40_c20_c10_avg_dr25",
# "c32_c64_c16_avg_fl_dr50",
# "c64_c32_avg_dr50_dr25",
# "c64_c32_c32_avg_dr50_dr25",
# "c96_c64_c32_avg_dr50_dr25"
# ]
# model_evaluations = dict()
# for ma in model_abbrvs:
# config["model"] = ma
# model_evaluations[ma] = trainLocalKeras(dataset, config)
# print(model_evaluations)
# learning_rates = [
# 0.1,
# 0.01,
# 0.001,
# 0.0005,
# 0.0001,
# 0.00005,
# 0.00001
# ]
# model_evaluations = dict()
# for lr in learning_rates:
# print(f'Training with learning rate {lr}')
# config["learning_rate"] = lr
# model_evaluations[lr] = trainLocalKeras(dataset, config)
# print(model_evaluations)
if __name__ == '__main__':
main(sys.argv)