-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathegg_simulator.py
More file actions
executable file
·410 lines (358 loc) · 20.1 KB
/
egg_simulator.py
File metadata and controls
executable file
·410 lines (358 loc) · 20.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
import numpy as np
import paho.mqtt.client as mqtt
import time
import os
import json
import glob
import sys
import pandas as pd
# Configuration
MQTT_HOST = "localhost"
MQTT_PORT = 1883
MQTT_TOPIC = "eeg/data"
# Find the Npz directory and CSV file relative to the script
NPZ_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "Npz")
CSV_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ESR.csv")
WINDOW_SIZE = 256
SEND_INTERVAL = 2.0 # Default to 2 seconds instead of 1
class EEGSimulator:
def __init__(self, use_csv=True, interval=None):
self.client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="eeg_simulator")
self.client.on_connect = self.on_connect
self.segment_counter = 0
self.use_csv = use_csv
self.send_interval = interval if interval is not None else SEND_INTERVAL
print(f"Using send interval of {self.send_interval} seconds")
def on_connect(self, client, userdata, flags, rc, properties=None):
if rc == 0:
print("Connected to MQTT broker")
else:
print(f"Connection error: Code {rc}")
sys.exit(1)
def load_and_stream(self):
if self.use_csv and os.path.exists(CSV_FILE):
self._stream_from_csv()
else:
self._stream_from_npz()
def _stream_from_csv(self):
try:
print(f"Loading data from CSV file: {CSV_FILE}")
# Load the CSV file, explicitly convert all values to numeric
df = pd.read_csv(CSV_FILE)
print(f"CSV loaded successfully. Shape: {df.shape}")
# Print column names for debugging
print(f"Columns in CSV: {df.columns.tolist()}")
# Remove any unnamed or index columns (including the specific 'Unnamed' column)
feature_columns = [col for col in df.columns if not col.startswith('Unnamed:') and col != 'Unnamed']
# Check if 'y' is the last column (common for label)
target_column = 'y' if 'y' in feature_columns else None
if target_column:
feature_columns.remove(target_column)
print(f"Using '{target_column}' as target column")
print(f"Using {len(feature_columns)} feature columns")
print(f"Using send interval of {self.send_interval} seconds between segments")
# Iterate through each row in the CSV
total_rows = len(df)
for idx, row in df.iterrows():
try:
# Efficiently convert all features to float values
features = []
for col in feature_columns:
try:
# Ensure we're getting a simple float value
val = float(row[col])
features.append(val)
except (ValueError, TypeError):
# If conversion fails, use 0.0
features.append(0.0)
# Get label for this row (if available)
try:
label = int(row[target_column]) if target_column and target_column in row else 0
except (ValueError, TypeError):
label = 0
# Create payload for MQTT with guaranteed numeric data
payload = {
"timestamp": time.time(),
"segment_id": self.segment_counter,
"data": features, # This is now guaranteed to be a list of float values
"label": label,
"window_size": len(features)
}
# Publish to MQTT
try:
json_payload = json.dumps(payload)
self.client.publish(MQTT_TOPIC, json_payload)
print(f"Published segment {self.segment_counter} ({idx+1}/{total_rows}) with {len(features)} data points")
self.segment_counter += 1
except TypeError as e:
print(f"Error serializing payload: {e}")
# Apply additional sanitization if needed
sanitized_features = [float(x) for x in features]
payload["data"] = sanitized_features
json_payload = json.dumps(payload)
self.client.publish(MQTT_TOPIC, json_payload)
print(f"Published segment {self.segment_counter} with sanitized data")
self.segment_counter += 1
# Wait before sending next row
time.sleep(self.send_interval)
except Exception as e:
print(f"Error processing row {idx}: {str(e)}")
import traceback
traceback.print_exc()
continue
print(f"Finished streaming {total_rows} rows from CSV file")
except Exception as e:
print(f"Error streaming from CSV: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)
def _stream_from_npz(self):
try:
# Check if directory exists
if not os.path.exists(NPZ_DIR):
os.makedirs(NPZ_DIR)
print(f"Created directory {NPZ_DIR}")
print("Please add NPZ files to this directory and restart the simulator")
# Create a sample file if none exists
self._create_sample_data()
# Loading files in numerical order
npz_files = sorted(
glob.glob(os.path.join(NPZ_DIR, 'segment_*.npz')),
key=lambda x: int(x.split('_')[-1].split('.')[0])
)
if not npz_files:
print(f"No NPZ files found in {NPZ_DIR}")
print("Creating sample data...")
self._create_sample_data()
# Try again after creating sample
npz_files = sorted(glob.glob(os.path.join(NPZ_DIR, 'segment_*.npz')))
if not npz_files:
print("Failed to create sample data. Exiting.")
sys.exit(1)
print(f"Loading {len(npz_files)} NPZ files...")
# Inspect the first file to determine available keys and structure
with np.load(npz_files[0]) as first_data:
print(f"Available keys in NPZ file: {list(first_data.keys())}")
# Determine the correct keys for data and labels
data_key = 'data' if 'data' in first_data else None
label_key = 'label' if 'label' in first_data else None
# Try common alternative keys if the defaults aren't found
if data_key is None:
for possible_key in ['segments', 'eeg_data', 'segment', 'X']:
if possible_key in first_data:
data_key = possible_key
break
# Check for label key alternatives
if label_key is None:
for possible_key in ['labels', 'y', 'targets', 'classes']:
if possible_key in first_data:
label_key = possible_key
break
if data_key is None:
print("Error: Could not find EEG data in NPZ files.")
print(f"Available keys are: {list(first_data.keys())}")
sys.exit(1)
# Check the structure of the data
data_array = first_data[data_key]
data_ndim = data_array.ndim
data_shape = data_array.shape
print(f"Data structure: dimension {data_ndim}, shape {data_shape}")
if label_key:
label_ndim = first_data[label_key].ndim
label_shape = first_data[label_key].shape
print(f"Label structure: dimension {label_ndim}, shape {label_shape}")
# Process and stream each file
for file_path in npz_files:
try:
with np.load(file_path) as data:
eeg_data = data[data_key]
# If we have labels, use them, otherwise set to 0
if label_key and label_key in data:
label = data[label_key]
# Convert 0-d array to scalar if necessary
if label.ndim == 0:
label = label.item()
else:
label = 0
print(f"Processing file {os.path.basename(file_path)}")
# Handle different dimensions of eeg_data
if eeg_data.ndim == 0: # Single value
payload = {
"timestamp": time.time(),
"segment_id": self.segment_counter,
"data": [eeg_data.item()],
"label": int(label) if not isinstance(label, float) or not np.isnan(label) else 0,
"window_size": 1
}
# Validate payload before publishing
try:
json_payload = json.dumps(payload)
self.client.publish(MQTT_TOPIC, json_payload)
print(f"Published segment {self.segment_counter} with {len(payload['data'])} data points")
self.segment_counter += 1
except TypeError as e:
print(f"Error serializing payload: {e}")
fallback_payload = {
"timestamp": time.time(),
"segment_id": self.segment_counter,
"data": [],
"label": 0,
"error": str(e)
}
self.client.publish(MQTT_TOPIC, json.dumps(fallback_payload))
self.segment_counter += 1
time.sleep(self.send_interval)
elif eeg_data.ndim == 1: # 1D array (time series)
payload = {
"timestamp": time.time(),
"segment_id": self.segment_counter,
"data": eeg_data.tolist(),
"label": int(label) if not isinstance(label, float) or not np.isnan(label) else 0,
"window_size": len(eeg_data)
}
# Validate payload before publishing
try:
json_payload = json.dumps(payload)
self.client.publish(MQTT_TOPIC, json_payload)
print(f"Published segment {self.segment_counter} with {len(payload['data'])} data points")
self.segment_counter += 1
except TypeError as e:
print(f"Error serializing payload: {e}")
fallback_payload = {
"timestamp": time.time(),
"segment_id": self.segment_counter,
"data": [],
"label": 0,
"error": str(e)
}
self.client.publish(MQTT_TOPIC, json.dumps(fallback_payload))
self.segment_counter += 1
time.sleep(self.send_interval)
elif eeg_data.ndim >= 2: # 2D+ array (channels x time, or trials x channels x time)
# For 2D, if data is wide (more columns than rows), transpose it
if eeg_data.ndim == 2:
# If it's likely time x channels (wide matrix), we want to work with channels x time
if eeg_data.shape[1] > 100 and eeg_data.shape[0] < eeg_data.shape[1]:
print(f"Transposing matrix from shape {eeg_data.shape} to {eeg_data.T.shape}")
eeg_data = eeg_data.T
num_segments = 1
# Take a subset of channels if there are too many
if eeg_data.shape[0] > 20: # If more than 20 channels
print(f"Using only first 10 channels from {eeg_data.shape[0]} channels")
segments = [eeg_data[:10]]
else:
segments = [eeg_data]
else: # For 3D+, assume first dimension is segments
num_segments = eeg_data.shape[0]
segments = [eeg_data[i] for i in range(num_segments)]
for i, segment in enumerate(segments):
# Ensure we only send a reasonable amount of data in the payload
if isinstance(segment, np.ndarray):
if segment.ndim == 2:
# For 2D data, take only first channel/row
if segment.shape[0] > segment.shape[1]: # More channels than time points
data_to_send = segment[0].tolist()
else: # More time points than channels
data_to_send = segment[:,0].tolist()
elif segment.ndim == 1:
data_to_send = segment.tolist()
else:
data_to_send = segment.flatten()[:256].tolist() # Flatten and limit
else:
data_to_send = segment # If it's already a list
payload = {
"timestamp": time.time(),
"segment_id": self.segment_counter,
"data": data_to_send,
"label": int(label) if not isinstance(label, float) or not np.isnan(label) else 0,
"window_size": segment.shape[-1] if hasattr(segment, 'shape') and segment.ndim > 1 else len(segment)
}
# Validate payload before publishing
try:
json_payload = json.dumps(payload)
self.client.publish(MQTT_TOPIC, json_payload)
print(f"Published segment {self.segment_counter} with {len(payload['data'])} data points")
self.segment_counter += 1
except TypeError as e:
print(f"Error serializing payload: {e}")
fallback_payload = {
"timestamp": time.time(),
"segment_id": self.segment_counter,
"data": [],
"label": 0,
"error": str(e)
}
self.client.publish(MQTT_TOPIC, json.dumps(fallback_payload))
self.segment_counter += 1
time.sleep(self.send_interval)
except Exception as e:
print(f"Error processing file {file_path}: {str(e)}")
continue
except Exception as e:
print(f"Error: {str(e)}")
sys.exit(1)
def _create_sample_data(self):
"""Create sample NPZ files for testing if none exist"""
try:
if not os.path.exists(NPZ_DIR):
os.makedirs(NPZ_DIR)
# Create 5 sample files with random data
for i in range(5):
# Create random sine wave data
t = np.linspace(0, 10, 256) # 10 seconds, 256 samples
frequency = 0.5 + np.random.random() * 2 # Random frequency between 0.5-2.5 Hz
amplitude = 2.0 + np.random.random() * 3.0 # Random amplitude
noise = np.random.normal(0, 0.5, 256) # Add some noise
# Generate sine wave data with noise
data = amplitude * np.sin(2 * np.pi * frequency * t) + noise
# Add occasional spike for seizure simulation
if i % 2 == 0: # Every other file has a "seizure" spike
spike_idx = np.random.randint(50, 200)
data[spike_idx:spike_idx+25] = data[spike_idx:spike_idx+25] * 3 + 5
label = 1 # Seizure
else:
label = 0 # Normal
# Save data
filename = os.path.join(NPZ_DIR, f'segment_{i+1:05d}.npz')
np.savez(filename, data=data, label=label)
print(f"Created sample file {filename}")
print(f"Created {5} sample NPZ files in {NPZ_DIR}")
return True
except Exception as e:
print(f"Error creating sample data: {str(e)}")
return False
def run(self):
try:
print(f"Connecting to MQTT broker at {MQTT_HOST}:{MQTT_PORT}...")
try:
self.client.connect(MQTT_HOST, MQTT_PORT)
except ConnectionRefusedError:
print("\nERROR: Could not connect to MQTT broker.")
print("Make sure an MQTT broker is running on localhost:1883.")
print("You can install and start Mosquitto with these commands:")
print(" sudo apt-get install mosquitto mosquitto-clients")
print(" sudo systemctl start mosquitto")
sys.exit(1)
self.client.loop_start()
time.sleep(2)
self.load_and_stream()
finally:
self.client.loop_stop()
self.client.disconnect()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='EEG Simulator')
parser.add_argument('--npz', action='store_true', help='Use NPZ files instead of CSV')
parser.add_argument('--interval', type=float, default=SEND_INTERVAL,
help=f'Time interval between segments in seconds (default: {SEND_INTERVAL})')
args = parser.parse_args()
# Default to CSV unless --npz flag is used
use_csv = not args.npz
if use_csv:
print("Using CSV file for simulation")
else:
print("Using NPZ files for simulation")
interval = args.interval
print(f"Using interval of {interval} seconds between segments")
simulator = EEGSimulator(use_csv=use_csv, interval=interval)
simulator.run()