-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
50 lines (39 loc) · 2 KB
/
utils.py
File metadata and controls
50 lines (39 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
from sklearn.datasets import make_classification, make_regression
from sklearn.preprocessing import StandardScaler
def generate_and_preprocess_data(n_workers, config):
problem_type = config['problem_type']
n_samples = config['n_samples']
n_features = config['n_features']
n_informative = config['n_informative_features']
class_sep = config.get('classification_sep', 0.8)
print(f"Generating Non-IID data")
if problem_type == 'logistic':
X, y = make_classification(n_samples=n_samples, n_features=n_features,
n_informative=n_informative, n_redundant=n_features - n_informative,
n_clusters_per_class=1, flip_y=0.05, class_sep=class_sep,
random_state=203)
y = 2 * y - 1 # Normalize to -1, 1
elif problem_type == 'quadratic':
X, y, coef = make_regression(n_samples=n_samples, n_features=n_features,
n_informative=n_informative, noise=10.0, coef=True, random_state=203)
else:
raise NotImplementedError(f"Wrong {problem_type}")
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
X_scaled_bias = np.hstack([X_scaled, np.ones((X_scaled.shape[0], 1))])
n_features_bias = X_scaled_bias.shape[1]
worker_data = []
# Force non-IID by sorting
sorted_indices = np.argsort(y)
indices = sorted_indices
# Distribute data to workers
worker_indices = np.array_split(indices, n_workers)
for i in range(n_workers):
idx = worker_indices[i]
X_local_data = X_scaled_bias[idx, :]
y_local_data = y[idx]
worker_data.append({'X': X_local_data, 'y': y_local_data})
print(f"Worker {i}: {len(idx)} samples, Target y range: [{np.min(y_local_data):.2f}, {np.max(y_local_data):.2f}], Mean y: {np.mean(y_local_data):.2f}")
print(f"Generated {n_samples} samples, {n_features_bias} features")
return worker_data, n_features_bias, X_scaled_bias, y