diff --git a/.gitignore b/.gitignore index 86a9bee..62acad8 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ node_modules/ backend/.sqlx/*.json backend/.sqlx/query-* target/ +backend/shared-logic/src/signal_processing/moss/checkpoints/ +backend/shared-logic/src/signal_processing/moss/moss_models/*.npz diff --git a/backend/shared-logic/src/signal_processing/moss/README.txt b/backend/shared-logic/src/signal_processing/moss/README.txt new file mode 100644 index 0000000..7615421 --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/README.txt @@ -0,0 +1,157 @@ +================================================================ +MOSS BCI Platform — Predict Package +Mental State Classifier using NeuroLM + Muse 2 +================================================================ +Version: 1.0 | March 2026 | UBC MINT Team + +---------------------------------------------------------------- +WHAT THIS DOES +---------------------------------------------------------------- +Given a Muse 2 EEG recording (CSV from Mind Monitor app), +this tool predicts your mental state using a frozen NeuroLM +foundation model + trained classifiers. + +Available tasks: + activity — what you were doing: eat / game / read / rest / toy / tv + focus — attention level: relaxed / neutral / concentrating + emotion — emotional state: neutral / anger / fear / happiness / sadness + stress — stress level: Low / Moderate / High (experimental, not reliable) + +---------------------------------------------------------------- +REQUIREMENTS +---------------------------------------------------------------- +- Windows 10/11 (Mac/Linux also works with minor path changes) +- Miniconda or Anaconda: https://www.anaconda.com/download +- ~4GB free disk space (for NeuroLM weights + environment) +- Muse 2 headband + Mind Monitor app (iOS/Android, ~$15) + +---------------------------------------------------------------- +ONE-TIME SETUP (do this once) +---------------------------------------------------------------- +1. Install Miniconda if you don't have it + https://docs.anaconda.com/miniconda/ + +2. Double-click setup.bat (or run it from Anaconda Prompt) + This will: + - Create a Python environment called "MOSS" + - Install all required packages + - Takes about 5-10 minutes + +3. Download NeuroLM model weights (ONE required file, ~500MB): + https://huggingface.co/username/neurolm (ask Natalia for link) + + Place the file here: + MOSS\checkpoints\checkpoints\NeuroLM-B.pt + + Your folder structure should look like: + MOSS\ + checkpoints\ + checkpoints\ + NeuroLM-B.pt <-- put it here + moss_models\ + muse2_classifier.pkl + focus_classifier.pkl + emotion_classifier.pkl + stress_classifier.pkl + muse2_predict.py + setup.bat + predict.bat + README.txt + +---------------------------------------------------------------- +RECORDING YOUR EEG +---------------------------------------------------------------- +1. Open Mind Monitor app on your phone +2. Connect your Muse 2 headband +3. Press record — sit still and do your task for at least 2 minutes + (longer = more reliable prediction) +4. Export the CSV: + Mind Monitor → Menu → Export CSV → save to your computer + +The CSV will have columns like: + TimeStamp, RAW_TP9, RAW_AF7, RAW_AF8, RAW_TP10, ... + +---------------------------------------------------------------- +RUNNING A PREDICTION +---------------------------------------------------------------- +Option A — Double-click predict.bat + It will ask you to: + 1. Paste the path to your CSV file + 2. Choose a task (activity / focus / emotion / stress) + +Option B — Run from Anaconda Prompt manually: + conda activate MOSS + cd path\to\MOSS + python muse2_predict.py --input "path\to\your_recording.csv" --task activity + + Change --task to: activity, focus, emotion, or stress + +---------------------------------------------------------------- +EXAMPLE OUTPUT +---------------------------------------------------------------- + MOSS Prediction + =============== + Input: my_recording.csv + Task: focus + Model: trained on 4 subjects, 633 segments + + Segment-by-segment predictions: + [ 0s-4s] relaxed 94.2% ██████████████████ + [ 2s-6s] relaxed 87.1% █████████████████ + [ 4s-8s] concentrating 78.3% ███████████████ + [ 6s-10s] neutral 65.4% ████████████ + ... + + Overall prediction: RELAXED (67% of segments) + + Class probabilities (mean across all segments): + relaxed 58.1% ███████████████████████ + neutral 24.3% █████████ + concentrating 17.6% ███████ + +---------------------------------------------------------------- +CLASSIFIER PERFORMANCE (what to expect) +---------------------------------------------------------------- + Task Classes Accuracy Chance Notes + -------- ------- -------- ------ ----- + Activity 6 91.7% 16.7% Very reliable + Focus 3 71.9% 33.3% Reliable + Emotion 5 45.5% 20.0% Use with caution + Stress 3 28.0% 33.3% Not reliable yet + +Accuracy is Leave-One-Subject-Out cross-validation — +meaning the model was tested on people it had never seen before. + +---------------------------------------------------------------- +TIPS FOR BEST RESULTS +---------------------------------------------------------------- +- Record at least 2 minutes (ideally 5+) for stable predictions +- Sit still — jaw clenching and movement create artifacts +- Make sure headband fits snugly (check Mind Monitor signal quality) +- Do one clearly defined task per recording +- Green signal quality bars in Mind Monitor = good contact + +---------------------------------------------------------------- +TROUBLESHOOTING +---------------------------------------------------------------- +"No module named X" + → Re-run setup.bat or run: conda activate MOSS + +"File not found: NeuroLM-B.pt" + → Make sure checkpoint is at MOSS\checkpoints\checkpoints\NeuroLM-B.pt + +"Recording too short" + → Record at least 4 seconds; 2+ minutes recommended + +"ERROR loading CSV" + → Check that your CSV has RAW_TP9/AF7/AF8/TP10 columns + → Export directly from Mind Monitor (not Muse Direct) + +---------------------------------------------------------------- +CONTACT +---------------------------------------------------------------- +Questions? Contact Natalia (UBC MINT Team) +Project: MOSS — Modular Open-Source Signal System +GitHub: [link TBD] + +================================================================ diff --git a/backend/shared-logic/src/signal_processing/moss/classifier.py b/backend/shared-logic/src/signal_processing/moss/classifier.py new file mode 100644 index 0000000..8d73e0e --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/classifier.py @@ -0,0 +1,232 @@ +""" +MOSS - classifier.py +==================== +Handles MLP classifier training, saving, loading, and prediction. + +Input: (N, 768) numpy embeddings from encoder.py +Output: predicted labels + confidence scores + +Each task (activity, focus, emotion, stress) has its own saved .pkl file. +New tasks can be added by training a new classifier on embeddings for that task. + +Used by: coordinator.py +""" + +import os +import pickle +import numpy as np +from typing import Optional +from sklearn.preprocessing import StandardScaler +from sklearn.neural_network import MLPClassifier +from sklearn.model_selection import StratifiedKFold +from sklearn.metrics import accuracy_score, balanced_accuracy_score +from sklearn.utils.class_weight import compute_sample_weight +from collections import Counter + +# ── Default paths ────────────────────────────────────────────────────────────── +DEFAULT_MODELS_DIR = os.path.join(os.path.dirname(__file__), 'moss_models') + +# ── Task → classifier file mapping ──────────────────────────────────────────── +TASK_CLASSIFIER_MAP = { + 'activity': 'muse2_classifier.pkl', + 'focus': 'focus_classifier.pkl', + 'emotion': 'emotion_classifier.pkl', + 'stress': 'stress_classifier.pkl', +} + + +class MossClassifier: + """ + Thin wrapper around sklearn MLP for MOSS mental state classification. + + Handles: + - Training with optional class balancing + - Saving/loading to .pkl + - Predicting labels + confidence scores from embeddings + """ + + def __init__(self, + task: str, + label_names: list[str], + models_dir: str = DEFAULT_MODELS_DIR): + """ + Args: + task: task name (e.g. 'activity', 'focus', 'emotion') + label_names: ordered list of class names (index = class id) + models_dir: directory where .pkl files are saved/loaded + """ + self.task = task + self.label_names = label_names + self.models_dir = models_dir + self.clf = None + self.scaler = None + os.makedirs(models_dir, exist_ok=True) + + @property + def pkl_path(self) -> str: + filename = TASK_CLASSIFIER_MAP.get(self.task, f'{self.task}_classifier.pkl') + return os.path.join(self.models_dir, filename) + + def train(self, + embeddings: np.ndarray, + labels: np.ndarray, + balance_classes: bool = True, + n_splits: int = 5) -> dict: + """ + Train MLP classifier on embeddings with optional k-fold CV evaluation. + + Args: + embeddings: (N, 768) array + labels: (N,) integer class labels + balance_classes: use sample weights to handle class imbalance + n_splits: number of CV folds (set to 0 to skip CV) + + Returns: + results: dict with accuracy, balanced_accuracy, per-fold scores + """ + results = {} + + # Optional cross-validation evaluation + if n_splits > 1: + skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) + fold_accs, fold_bals = [], [] + + for tr, te in skf.split(embeddings, labels): + scaler = StandardScaler() + X_tr = scaler.fit_transform(embeddings[tr]) + X_te = scaler.transform(embeddings[te]) + + clf = self._make_mlp() + sw = compute_sample_weight('balanced', labels[tr]) if balance_classes else None + clf.fit(X_tr, labels[tr], sw) + + preds = clf.predict(X_te) + fold_accs.append(accuracy_score(labels[te], preds)) + fold_bals.append(balanced_accuracy_score(labels[te], preds)) + + results['cv_accuracy'] = float(np.mean(fold_accs)) + results['cv_balanced_accuracy'] = float(np.mean(fold_bals)) + results['cv_fold_accuracies'] = [float(x) for x in fold_accs] + + # Train final classifier on all data + self.scaler = StandardScaler() + X_all = self.scaler.fit_transform(embeddings) + self.clf = self._make_mlp() + sw = compute_sample_weight('balanced', labels) if balance_classes else None + self.clf.fit(X_all, labels, sw) + + results['n_samples'] = len(labels) + results['n_classes'] = len(np.unique(labels)) + results['label_names'] = self.label_names + results['class_distribution'] = { + self.label_names[k]: int(v) + for k, v in sorted(Counter(labels).items()) + } + + return results + + def _make_mlp(self) -> MLPClassifier: + return MLPClassifier( + hidden_layer_sizes=(256, 128), + max_iter=500, + random_state=42, + early_stopping=True, + n_iter_no_change=20 + ) + + def save(self) -> str: + """Save trained classifier + scaler to .pkl. Returns path.""" + if self.clf is None or self.scaler is None: + raise RuntimeError("Classifier not trained yet. Call train() first.") + + bundle = { + 'classifier': self.clf, + 'scaler': self.scaler, + 'label_names': self.label_names, + 'activities': self.label_names, # kept for predict.py compatibility + 'task': self.task, + } + with open(self.pkl_path, 'wb') as f: + pickle.dump(bundle, f) + + return self.pkl_path + + def load(self) -> None: + """Load classifier + scaler from .pkl.""" + if not os.path.exists(self.pkl_path): + raise FileNotFoundError( + f"No classifier found for task '{self.task}' at {self.pkl_path}\n" + f"Train the classifier first using the appropriate train script." + ) + with open(self.pkl_path, 'rb') as f: + bundle = pickle.load(f) + + self.clf = bundle['classifier'] + self.scaler = bundle['scaler'] + self.label_names = bundle.get('label_names', bundle.get('activities', [])) + + def predict(self, embeddings: np.ndarray) -> tuple[list[str], np.ndarray]: + """ + Predict mental state labels for a batch of embeddings. + + Args: + embeddings: (N, 768) numpy array + + Returns: + labels: list of N predicted label strings + confidences: (N, n_classes) probability array + """ + if self.clf is None: + self.load() + + X = self.scaler.transform(embeddings) + pred_indices = self.clf.predict(X) + probabilities = self.clf.predict_proba(X) + pred_labels = [self.label_names[i] for i in pred_indices] + + return pred_labels, probabilities + + def predict_majority(self, embeddings: np.ndarray) -> tuple[str, float, np.ndarray]: + """ + Predict a single label for a recording via majority vote across segments. + + Args: + embeddings: (N, 768) array for all segments in a recording + + Returns: + label: overall predicted label string + confidence: fraction of segments that voted for this label + mean_proba: (n_classes,) mean probability across all segments + """ + labels, probas = self.predict(embeddings) + counts = Counter(labels) + top_label = counts.most_common(1)[0][0] + confidence = counts.most_common(1)[0][1] / len(labels) + mean_proba = probas.mean(axis=0) + + return top_label, confidence, mean_proba + + +def load_classifier(task: str, + models_dir: str = DEFAULT_MODELS_DIR) -> 'MossClassifier': + """ + Convenience function to load a saved classifier by task name. + + Args: + task: 'activity', 'focus', 'emotion', or 'stress' + models_dir: directory containing .pkl files + + Returns: + loaded MossClassifier ready for prediction + """ + clf = MossClassifier(task=task, label_names=[], models_dir=models_dir) + clf.load() + return clf + + +if __name__ == '__main__': + # Quick test — load activity classifier and print info + clf = load_classifier('activity') + print(f"Task: {clf.task}") + print(f"Labels: {clf.label_names}") + print(f"Classifier: {clf.clf}") diff --git a/backend/shared-logic/src/signal_processing/moss/coordinator.py b/backend/shared-logic/src/signal_processing/moss/coordinator.py new file mode 100644 index 0000000..68c2aed --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/coordinator.py @@ -0,0 +1,188 @@ +""" +MOSS - coordinator.py +===================== +Python coordinator layer — orchestrates preprocessing, encoding, and classification. + +This is the main entry point for the Rust signal_processor.rs to call into. +It exposes a clean JSON interface so Rust can invoke it via subprocess or PyO3. + +Usage (from command line / Rust subprocess): + python coordinator.py predict --input recording.csv --task activity + python coordinator.py predict --input recording.csv --task focus + python coordinator.py predict --input recording.csv --task emotion + +Output (JSON to stdout): + { + "task": "activity", + "overall_label": "rest", + "confidence": 0.83, + "segments": [ + {"start_s": 0, "end_s": 4, "label": "rest", "confidence": 0.91}, + ... + ], + "class_probabilities": {"eat": 0.02, "game": 0.01, "read": 0.03, ...}, + "duration_s": 45.2, + "n_segments": 21, + "status": "ok" + } + +Architecture: + Rust (signal_processor.rs) + ↓ subprocess / PyO3 + coordinator.py ← you are here + ↓ ↓ ↓ + preprocessing.py → encoder.py → classifier.py +""" + +import os +import sys +import json +import argparse +import traceback +import numpy as np + +# Add parent dir to path so model/ is importable +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) +sys.path.insert(0, os.path.dirname(__file__)) + +from preprocessing import preprocess +from encoder import NeuroLMEncoder +from classifier import load_classifier + +# ── Paths ────────────────────────────────────────────────────────────────────── +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +CHECKPOINT = os.path.join(SCRIPT_DIR, 'checkpoints', 'checkpoints', 'NeuroLM-B.pt') +MODELS_DIR = os.path.join(SCRIPT_DIR, 'moss_models') + +# Singleton encoder — loaded once and reused across calls +_encoder: NeuroLMEncoder = None + +def get_encoder() -> NeuroLMEncoder: + """Lazy-load the NeuroLM encoder (expensive, so only once).""" + global _encoder + if _encoder is None: + _encoder = NeuroLMEncoder( + checkpoint_path=CHECKPOINT, + neurolm_dir=os.path.dirname(SCRIPT_DIR) + ) + return _encoder + + +def predict(input_path: str, task: str) -> dict: + """ + Full prediction pipeline for one recording and one task. + + Args: + input_path: path to Muse 2 CSV file + task: 'activity', 'focus', 'emotion', or 'stress' + + Returns: + result dict (see module docstring for schema) + """ + # 1. Preprocess + segments, src_fs, duration = preprocess(input_path) + + +def predict_from_array(raw: np.ndarray, src_fs: int, task: str) -> dict: + """ + Full prediction pipeline from a raw numpy array. + Use this for real-time streaming from the Muse 2 headset. + + Args: + raw: (n_samples, 4) float32 array [TP9, AF7, AF8, TP10] + src_fs: sample rate of the incoming data (e.g. 256 for Muse 2) + task: 'activity', 'focus', 'emotion', or 'stress' + + Returns: + result dict (same schema as predict()) + + Example: + raw = np.array(lsl_samples, dtype=np.float32) # from LSL stream + result = predict_from_array(raw, src_fs=256, task='focus') + print(result['overall_label']) # e.g. 'concentrating' + """ + from preprocessing import from_array + + # 1. Preprocess from array + segments, duration = from_array(raw, src_fs) + src_fs_out = src_fs + + # 2. Encode + encoder = get_encoder() + embeddings = encoder.encode(segments) # (N, 768) + + # 3. Classify + clf = load_classifier(task, models_dir=MODELS_DIR) + seg_labels, seg_probas = clf.predict(embeddings) + overall_label, confidence, mean_proba = clf.predict_majority(embeddings) + + # 4. Build output + step_s = 2.0 # 50% overlap → 2s step + win_s = 4.0 + + segment_results = [] + for i, (label, proba) in enumerate(zip(seg_labels, seg_probas)): + segment_results.append({ + 'start_s': round(i * step_s, 1), + 'end_s': round(i * step_s + win_s, 1), + 'label': label, + 'confidence': round(float(proba.max()), 4), + }) + + class_probabilities = { + name: round(float(p), 4) + for name, p in sorted( + zip(clf.label_names, mean_proba), + key=lambda x: -x[1] + ) + } + + return { + 'status': 'ok', + 'task': task, + 'overall_label': overall_label, + 'confidence': round(confidence, 4), + 'segments': segment_results, + 'class_probabilities': class_probabilities, + 'duration_s': round(duration, 2), + 'n_segments': len(segments), + 'src_sample_rate_hz': src_fs, + } + + +def main(): + parser = argparse.ArgumentParser( + description='MOSS coordinator — EEG mental state prediction' + ) + subparsers = parser.add_subparsers(dest='command') + + # predict command + pred_parser = subparsers.add_parser('predict', help='Run prediction on a recording') + pred_parser.add_argument('--input', required=True, help='Path to Muse 2 CSV file') + pred_parser.add_argument('--task', required=True, + choices=['activity', 'focus', 'emotion', 'stress'], + help='Mental state task to predict') + pred_parser.add_argument('--pretty', action='store_true', + help='Pretty-print JSON output') + + args = parser.parse_args() + + if args.command == 'predict': + try: + result = predict(args.input, args.task) + except Exception as e: + result = { + 'status': 'error', + 'error': str(e), + 'trace': traceback.format_exc() + } + + indent = 2 if args.pretty else None + print(json.dumps(result, indent=indent)) + + else: + parser.print_help() + + +if __name__ == '__main__': + main() diff --git a/backend/shared-logic/src/signal_processing/moss/encoder.py b/backend/shared-logic/src/signal_processing/moss/encoder.py new file mode 100644 index 0000000..01d84ae --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/encoder.py @@ -0,0 +1,221 @@ +""" +MOSS - encoder.py +================= +Handles NeuroLM model loading and EEG embedding extraction. + +Input: list of (4, 800) numpy arrays from preprocessing.py +Output: (N, 768) numpy array of embeddings — one per segment + +The NeuroLM encoder is frozen (never trained/fine-tuned). +It acts as a universal EEG feature extractor across tasks and devices. + +Used by: coordinator.py +""" + +import os +import sys +import numpy as np +import torch +from einops import rearrange +from typing import Optional + +# ── Constants ────────────────────────────────────────────────────────────────── +DEFAULT_CHECKPOINT = os.path.join( + os.path.dirname(__file__), 'checkpoints', 'checkpoints', 'NeuroLM-B.pt' +) +EEG_MAX_LEN = 276 # max token sequence length (matches training) +PATCH_SIZE = 200 # samples per token patch +EMB_DIM = 768 # NeuroLM-B output embedding dimension + +# Muse 2 channel names in standard 10-20 order +MUSE_CHANS = ['TP9', 'AF7', 'AF8', 'TP10'] + +# Full standard 10-20 vocabulary used by NeuroLM +STANDARD_1020 = [ + 'FP1','FPZ','FP2','AF9','AF7','AF5','AF3','AF1','AFZ','AF2','AF4','AF6','AF8','AF10', + 'F9','F7','F5','F3','F1','FZ','F2','F4','F6','F8','F10', + 'FT9','FT7','FC5','FC3','FC1','FCZ','FC2','FC4','FC6','FT8','FT10', + 'T9','T7','C5','C3','C1','CZ','C2','C4','C6','T8','T10', + 'TP9','TP7','CP5','CP3','CP1','CPZ','CP2','CP4','CP6','TP8','TP10', + 'P9','P7','P5','P3','P1','PZ','P2','P4','P6','P8','P10', + 'PO9','PO7','PO5','PO3','PO1','POZ','PO2','PO4','PO6','PO8','PO10', + 'O1','OZ','O2','O9','CB1','CB2','IZ','O10', + 'T3','T5','T4','T6','M1','M2','A1','A2', + 'CFC1','CFC2','CFC3','CFC4','CFC5','CFC6','CFC7','CFC8', + 'CCP1','CCP2','CCP3','CCP4','CCP5','CCP6','CCP7','CCP8', + 'T1','T2','FTT9h','TTP7h','TPP9h','FTT10h','TPP8h','TPP10h', + 'FP1-F7','F7-T7','T7-P7','P7-O1','FP2-F8','F8-T8','T8-P8','P8-O2', + 'pad','I1','I2' +] + + +class NeuroLMEncoder: + """ + Wrapper around the frozen NeuroLM-B model for EEG embedding extraction. + + Usage: + encoder = NeuroLMEncoder() + embeddings = encoder.encode(segments) # (N, 768) + """ + + def __init__(self, + checkpoint_path: str = DEFAULT_CHECKPOINT, + device: str = 'cpu', + neurolm_dir: Optional[str] = None): + """ + Load NeuroLM-B from checkpoint. + + Args: + checkpoint_path: path to NeuroLM-B.pt + device: 'cpu' or 'cuda' + neurolm_dir: path to NeuroLM source (adds to sys.path) + """ + self.device = torch.device(device) + + # Add NeuroLM source to path + if neurolm_dir is None: + neurolm_dir = os.path.dirname(os.path.dirname(__file__)) + if neurolm_dir not in sys.path: + sys.path.insert(0, neurolm_dir) + + self.model = self._load_model(checkpoint_path) + print(f"NeuroLM-B loaded from {checkpoint_path}") + + def _load_model(self, checkpoint_path: str): + """Load and return the frozen NeuroLM model.""" + from model.model_neurolm import NeuroLM + from model.model import GPTConfig + + if not os.path.exists(checkpoint_path): + raise FileNotFoundError( + f"NeuroLM checkpoint not found at: {checkpoint_path}\n" + f"Please download NeuroLM-B.pt and place it there." + ) + + ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=False) + cfg = GPTConfig(**ckpt['model_args']) + model = NeuroLM(cfg, init_from='scratch') + + # Strip torch.compile prefix if present + sd = ckpt['model'] + for k in list(sd.keys()): + if k.startswith('_orig_mod.'): + sd[k[10:]] = sd.pop(k) + + model.load_state_dict(sd) + model.to(self.device) + model.eval() + return model + + def _segment_to_tensors(self, seg: np.ndarray): + """ + Convert a (4, 800) EEG segment to NeuroLM input tensors. + + Replicates the exact tokenization used during NeuroLM pretraining: + (4 chans, 800 samples) → (16 tokens, 200 samples) → padded to 276 + """ + n_chans, n_total = seg.shape # 4, 800 + T = PATCH_SIZE # 200 + n_time = n_total // T # 4 time windows + + # Normalize and reshape into tokens + data = torch.FloatTensor(seg / 100.0) + std = data.std() + if std > 0: + data = (data - data.mean()) / std + + # (4, 800) → (16, 200): time-major interleaved with channels + data = rearrange(data, 'N (A T) -> (A N) T', T=T) + valid_len = data.shape[0] # 16 + + # Pad to EEG_MAX_LEN + X_eeg = torch.zeros((EEG_MAX_LEN, T)) + X_eeg[:valid_len] = data + + eeg_mask = torch.ones(EEG_MAX_LEN) + eeg_mask[valid_len:] = 0 + + # Channel index tokens: channel names repeated per time window, then 'pad' + chans = MUSE_CHANS * n_time + ['pad'] * (EEG_MAX_LEN - valid_len) + input_chans = torch.IntTensor([STANDARD_1020.index(c) for c in chans]) + + # Time index tokens + input_time = ( + [i for i in range(n_time) for _ in range(n_chans)] + + [0] * (EEG_MAX_LEN - valid_len) + ) + input_time = torch.IntTensor(input_time) + + return ( + X_eeg.unsqueeze(0), # (1, 276, 200) + input_chans.unsqueeze(0), # (1, 276) + input_time.unsqueeze(0), # (1, 276) + eeg_mask.bool().unsqueeze(0), # (1, 276) + ) + + @torch.no_grad() + def encode_segment(self, seg: np.ndarray) -> np.ndarray: + """ + Encode a single (4, 800) EEG segment into a 768-dim embedding. + + Args: + seg: numpy array of shape (4, 800) + + Returns: + embedding: numpy array of shape (768,) + """ + X_eeg, input_chans, input_time, eeg_mask = self._segment_to_tensors(seg) + + # 4D attention mask + mask = eeg_mask.unsqueeze(1).repeat(1, X_eeg.size(1), 1).unsqueeze(1) + + tokens = self.model.tokenizer( + X_eeg, input_chans, input_time, mask, + return_all_tokens=True + ) # (1, 276, 400) + + # Mean-pool over valid tokens only → project to 768-dim + valid_len = int(eeg_mask.sum().item()) + emb = tokens[0, :valid_len, :].mean(dim=0) + emb = self.model.encode_transform_layer(emb) + + return emb.cpu().numpy() + + def encode(self, segments: list[np.ndarray], + verbose: bool = False) -> np.ndarray: + """ + Encode a list of EEG segments into embeddings. + + Args: + segments: list of (4, 800) numpy arrays from preprocessing.py + verbose: print progress every 10 segments + + Returns: + embeddings: numpy array of shape (N, 768) + """ + embeddings = [] + for i, seg in enumerate(segments): + emb = self.encode_segment(seg) + embeddings.append(emb) + if verbose and (i + 1) % 10 == 0: + print(f" Encoded [{i+1}/{len(segments)}]") + + return np.array(embeddings) # (N, 768) + + +if __name__ == '__main__': + # Quick test + import sys + sys.path.insert(0, os.path.dirname(__file__)) + from preprocessing import preprocess + + if len(sys.argv) < 2: + print("Usage: python encoder.py path/to/recording.csv") + sys.exit(1) + + segs, fs, dur = preprocess(sys.argv[1]) + print(f"Loaded {len(segs)} segments") + + encoder = NeuroLMEncoder() + embeddings = encoder.encode(segs, verbose=True) + print(f"Embeddings shape: {embeddings.shape}") diff --git a/backend/shared-logic/src/signal_processing/moss/model/__init__.py b/backend/shared-logic/src/signal_processing/moss/model/__init__.py new file mode 100644 index 0000000..6e9fa1a --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/model/__init__.py @@ -0,0 +1,4 @@ +from .model_vq import VQ_Align +from .model_neural_transformer import NTConfig +from .model_neural_transformer import NeuralTransformer +from .model_neurolm import NeuroLM diff --git a/backend/shared-logic/src/signal_processing/moss/model/__pycache__/__init__.cpython-312.pyc b/backend/shared-logic/src/signal_processing/moss/model/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..e7b4c71 Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/model/__pycache__/__init__.cpython-312.pyc differ diff --git a/backend/shared-logic/src/signal_processing/moss/model/__pycache__/model.cpython-312.pyc b/backend/shared-logic/src/signal_processing/moss/model/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000..8086b77 Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/model/__pycache__/model.cpython-312.pyc differ diff --git a/backend/shared-logic/src/signal_processing/moss/model/__pycache__/model_neural_transformer.cpython-312.pyc b/backend/shared-logic/src/signal_processing/moss/model/__pycache__/model_neural_transformer.cpython-312.pyc new file mode 100644 index 0000000..987c2d8 Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/model/__pycache__/model_neural_transformer.cpython-312.pyc differ diff --git a/backend/shared-logic/src/signal_processing/moss/model/__pycache__/model_neurolm.cpython-312.pyc b/backend/shared-logic/src/signal_processing/moss/model/__pycache__/model_neurolm.cpython-312.pyc new file mode 100644 index 0000000..2658b60 Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/model/__pycache__/model_neurolm.cpython-312.pyc differ diff --git a/backend/shared-logic/src/signal_processing/moss/model/__pycache__/model_vq.cpython-312.pyc b/backend/shared-logic/src/signal_processing/moss/model/__pycache__/model_vq.cpython-312.pyc new file mode 100644 index 0000000..706df70 Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/model/__pycache__/model_vq.cpython-312.pyc differ diff --git a/backend/shared-logic/src/signal_processing/moss/model/__pycache__/norm_ema_quantizer.cpython-312.pyc b/backend/shared-logic/src/signal_processing/moss/model/__pycache__/norm_ema_quantizer.cpython-312.pyc new file mode 100644 index 0000000..ae551cf Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/model/__pycache__/norm_ema_quantizer.cpython-312.pyc differ diff --git a/backend/shared-logic/src/signal_processing/moss/model/model.py b/backend/shared-logic/src/signal_processing/moss/model/model.py new file mode 100644 index 0000000..55a49fa --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/model/model.py @@ -0,0 +1,407 @@ +""" +by Wei-Bang Jiang +https://github.com/935963004/NeuroLM +""" + +import math +import inspect +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +class LayerNorm(nn.Module): + + """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ + + def __init__(self, ndim, bias): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + + def forward(self, input): + return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) + + +class CausalSelfAttention(nn.Module): + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 + self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + if not self.flash: + print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) + + def forward(self, x, mask=None): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + if mask is None: + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) + elif mask == 'unmask': + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False) + else: + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0, is_causal=False) + + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + #att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + if mask is not None: + att = att.masked_fill(mask == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class MLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.gelu = nn.GELU() + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + +class Block(nn.Module): + + def __init__(self, config): + super().__init__() + self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(config) + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.mlp = MLP(config) + + def forward(self, x, mask=None): + x = x + self.attn(self.ln_1(x), mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +@dataclass +class GPTConfig: + + block_size: int = 1024 + vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.0 + bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + + +class GPT(nn.Module): + + def __init__(self, config): + super().__init__() + assert config.vocab_size is not None + assert config.block_size is not None + self.config = config + + self.transformer = nn.ModuleDict(dict( + wte = nn.Embedding(config.vocab_size, config.n_embd), + wpe = nn.Embedding(config.block_size, config.n_embd), + drop = nn.Dropout(config.dropout), + h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + ln_f = LayerNorm(config.n_embd, bias=config.bias), + )) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + # with weight tying when using torch.compile() some warnings get generated: + # "UserWarning: functional_call was passed multiple values for tied weights. + # This behavior is deprecated and will be an error in future versions" + # not 100% sure what this is, so far seems to be harmless. TODO investigate + self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying + + # init all weights + self.apply(self._init_weights) + # apply special scaled init to the residual projections, per GPT-2 paper + for pn, p in self.named_parameters(): + if pn.endswith('c_proj.weight'): + torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) + + # report number of parameters + print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.transformer.wpe.weight.numel() + return n_params + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward(self, x_eeg=None, y_eeg=None, x_text=None, y_text=None, eeg_time_idx=None, eeg_mask=None, eeg_text_mask=None, lm_head=True): + if x_eeg is not None and x_text is None: + device = x_eeg.device + pos_emb = self.transformer.wpe(eeg_time_idx) + x = x_eeg + pos_emb + mask = eeg_mask + if y_eeg is not None: + targets = y_eeg + self.config.vocab_size + else: + targets = None + elif x_text is not None and x_eeg is None: + device = x_text.device + b, t = x_text.size() + assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) + + # forward the GPT model itself + tok_emb = self.transformer.wte(x_text) # token embeddings of shape (b, t, n_embd) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) + x = tok_emb + pos_emb + mask = None + targets = y_text + elif x_text is not None and x_eeg is not None: + device = x_eeg.device + pos_emb = self.transformer.wpe(eeg_time_idx) + x_eeg = x_eeg + pos_emb + + b, t = x_text.size() + assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + eeg_max_time = torch.max(eeg_time_idx) + pos = torch.arange(eeg_max_time + 1, eeg_max_time + 1 + t, dtype=torch.long, device=device) # shape (t) + + # forward the GPT model itself + tok_emb = self.transformer.wte(x_text) # token embeddings of shape (b, t, n_embd) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) + x_text = tok_emb + pos_emb + x = torch.cat((x_eeg, x_text), dim=1) + + mask = eeg_text_mask + if y_eeg is None or y_text is None: + targets = None + elif y_eeg == 'nan' or y_text == 'nan': + targets = 'nan' + else: + y_eeg = y_eeg + self.config.vocab_size + targets = torch.cat((y_eeg, y_text), dim=-1) + + x = self.transformer.drop(x) + + for block in self.transformer.h: + x = block(x, mask) + x = self.transformer.ln_f(x) + + if not lm_head: + return x + + if targets is not None: + # if we are given some desired targets also calculate the loss + logits = self.lm_head(x) + if targets == 'nan': + return logits, None, None + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + + _, preds = logits.max(dim=-1) + not_ignore = targets.ne(-1) + num_targets = not_ignore.long().sum().item() + correct = (targets == preds) & not_ignore + correct = correct.float().sum() + accuracy = correct / num_targets + return logits, loss, accuracy + else: + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim + loss = None + accuracy = None + + return logits, loss, accuracy + + def enlarge_lm_head(self, vocab_size): + assert vocab_size >= self.config.vocab_size + new_lm_head = nn.Linear(in_features=self.config.n_embd, out_features=vocab_size, bias=False) + new_lm_head.weight.data[:self.lm_head.weight.shape[0]] = self.lm_head.weight.data + self.lm_head = new_lm_head + + def enlarge_wte(self, vocab_size): + assert vocab_size >= self.config.vocab_size + new_wte = nn.Embedding(vocab_size, self.config.n_embd) + new_wte.weight.data[:self.config.vocab_size] = self.transformer.wte.weight.data + self.transformer.wte = new_wte + self.config.vocab_size = vocab_size + + def crop_block_size(self, block_size): + # model surgery to decrease the block size if necessary + # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) + # but want to use a smaller block size for some smaller, simpler model + assert block_size <= self.config.block_size + self.config.block_size = block_size + self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) + for block in self.transformer.h: + if hasattr(block.attn, 'bias'): + block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] + + @classmethod + def from_pretrained(cls, model_type, override_args=None): + assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} + override_args = override_args or {} # default to empty dict + # only dropout can be overridden see more notes below + assert all(k == 'dropout' for k in override_args) + from transformers import GPT2LMHeadModel + print("loading weights from pretrained gpt: %s" % model_type) + + # n_layer, n_head and n_embd are determined from model_type + config_args = { + 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params + 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params + 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params + 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params + }[model_type] + print("forcing vocab_size=50257, block_size=1024, bias=True") + config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints + config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints + config_args['bias'] = True # always True for GPT model checkpoints + # we can override the dropout rate, if desired + if 'dropout' in override_args: + print(f"overriding dropout rate to {override_args['dropout']}") + config_args['dropout'] = override_args['dropout'] + # create a from-scratch initialized minGPT model + config = GPTConfig(**config_args) + model = GPT(config) + sd = model.state_dict() + sd_keys = sd.keys() + sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param + + # init a huggingface/transformers model + model_hf = GPT2LMHeadModel.from_pretrained(model_type) + sd_hf = model_hf.state_dict() + + # copy while ensuring all of the parameters are aligned and match in names and shapes + sd_keys_hf = sd_hf.keys() + sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer + sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) + transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] + # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear + # this means that we have to transpose these weights when we import them + assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" + for k in sd_keys_hf: + if any(k.endswith(w) for w in transposed): + # special treatment for the Conv1D weights we need to transpose + assert sd_hf[k].shape[::-1] == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k].t()) + else: + # vanilla copy over the other parameters + assert sd_hf[k].shape == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k]) + + return model + + def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): + # start with all of the candidate parameters + param_dict = {pn: p for pn, p in self.named_parameters()} + # filter out those that do not require grad + param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + optim_groups = [ + {'params': decay_params, 'weight_decay': weight_decay}, + {'params': nodecay_params, 'weight_decay': 0.0} + ] + num_decay_params = sum(p.numel() for p in decay_params) + num_nodecay_params = sum(p.numel() for p in nodecay_params) + print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") + print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + # Create AdamW optimizer and use the fused version if it is available + fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and device_type == 'cuda' + extra_args = dict(fused=True) if use_fused else dict() + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + print(f"using fused AdamW: {use_fused}") + + return optimizer + + def estimate_mfu(self, fwdbwd_per_iter, dt): + """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ + # first estimate the number of flops we do per iteration. + # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 + N = self.get_num_params() + cfg = self.config + L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size + flops_per_token = 6*N + 12*L*H*Q*T + flops_per_fwdbwd = flops_per_token * T + flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter + # express our flops throughput as ratio of A100 bfloat16 peak flops + flops_achieved = flops_per_iter * (1.0/dt) # per second + flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS + mfu = flops_achieved / flops_promised + return mfu + + @torch.no_grad() + def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + for _ in range(max_new_tokens): + # if the sequence context is growing too long we must crop it at block_size + idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] + # forward the model to get the logits for the index in the sequence + logits, _ = self(x_text=idx_cond) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float('Inf') + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + idx = torch.cat((idx, idx_next), dim=1) + + return idx diff --git a/backend/shared-logic/src/signal_processing/moss/model/model_neural_transformer.py b/backend/shared-logic/src/signal_processing/moss/model/model_neural_transformer.py new file mode 100644 index 0000000..3877058 --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/model/model_neural_transformer.py @@ -0,0 +1,142 @@ +""" +by Wei-Bang Jiang +https://github.com/935963004/NeuroLM +""" + +import math +import torch.nn as nn +from model.model import Block +from einops import rearrange +from dataclasses import dataclass + + +class TemporalConv(nn.Module): + """ EEG to Patch Embedding + """ + def __init__(self, in_chans=1, out_chans=8): + ''' + in_chans: in_chans of nn.Conv2d() + out_chans: out_chans of nn.Conv2d(), determing the output dimension + ''' + super().__init__() + self.conv1 = nn.Conv2d(in_chans, out_chans, kernel_size=(1, 15), stride=(1, 8), padding=(0, 7)) + self.gelu1 = nn.GELU() + self.norm1 = nn.GroupNorm(4, out_chans) + self.conv2 = nn.Conv2d(out_chans, out_chans, kernel_size=(1, 3), padding=(0, 1)) + self.gelu2 = nn.GELU() + self.norm2 = nn.GroupNorm(4, out_chans) + self.conv3 = nn.Conv2d(out_chans, out_chans, kernel_size=(1, 3), padding=(0, 1)) + self.norm3 = nn.GroupNorm(4, out_chans) + self.gelu3 = nn.GELU() + self.l = nn.Sequential( + nn.Linear(400, 768), + nn.GELU() + ) + + def forward(self, x, **kwargs): + B, NA, T = x.shape + x = x.unsqueeze(1) + x = self.gelu1(self.norm1(self.conv1(x))) + x = self.gelu2(self.norm2(self.conv2(x))) + x = self.gelu3(self.norm3(self.conv3(x))) + x = rearrange(x, 'B C NA T -> B NA (T C)') + x = self.l(x) + return x + + +@dataclass +class NTConfig: + block_size: int = 1024 + patch_size: int = 200 + num_classes: int = 0 + in_chans: int = 1 + out_chans: int = 16 + use_mean_pooling: bool = True + init_scale: float = 0.001 + n_layer: int = 12 + n_head: int = 10 + n_embd: int = 400 + dropout: float = 0.0 + bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + + +class NeuralTransformer(nn.Module): + def __init__(self, config, **kwargs): + super().__init__() + self.num_classes = config.num_classes + + # To identify whether it is neural tokenizer or neural decoder. + # For the neural decoder, use linear projection (PatchEmbed) to project codebook dimension to hidden dimension. + # Otherwise, use TemporalConv to extract temporal features from EEG signals. + self.patch_embed = TemporalConv(out_chans=config.out_chans) if config.in_chans == 1 else nn.Linear(config.in_chans, config.n_embd) + self.patch_size = config.patch_size + + self.pos_embed = nn.Embedding(256, config.n_embd) + self.time_embed = nn.Embedding(64, config.n_embd) + + self.rel_pos_bias = None + + self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) + self.norm = nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.n_embd, eps=1e-6) + self.fc_norm = nn.LayerNorm(config.n_embd, eps=1e-6) if config.use_mean_pooling else None + self.head = nn.Linear(config.n_embd, self.num_classes) if self.num_classes > 0 else nn.Identity() + + self.pos_drop = nn.Dropout(p=config.dropout) + + if isinstance(self.head, nn.Linear): + nn.init.trunc_normal_(self.head.weight, std=.02) + self.apply(self._init_weights) + self.fix_init_weight() + + if isinstance(self.head, nn.Linear): + self.head.weight.data.mul_(config.init_scale) + self.head.bias.data.mul_(config.init_scale) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.c_proj.weight.data, layer_id + 1) + rescale(layer.mlp.c_proj.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x, input_chans=None, input_times=None, mask=None, return_all_tokens=False, **kwargs): + batch_size, n, t = x.shape + x = self.patch_embed(x) + + # add position and temporal embeddings + pos_embed_used = self.pos_embed(input_chans) + x = x + pos_embed_used + time_embed = self.time_embed(input_times) + x = x + time_embed + + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x, mask) + + x = self.norm(x) + if self.fc_norm is not None: + if return_all_tokens: + return self.fc_norm(x) + else: + return self.fc_norm(x.mean(1)) + else: + return x + + def forward(self, x, input_chans=None, input_times=None, mask=None, return_all_tokens=False, **kwargs): + ''' + x: [batch size, sequence length, patch size] + ''' + x = self.forward_features(x, input_chans, input_times, mask, return_all_tokens=return_all_tokens, **kwargs) + x = self.head(x) + return x diff --git a/backend/shared-logic/src/signal_processing/moss/model/model_neurolm.py b/backend/shared-logic/src/signal_processing/moss/model/model_neurolm.py new file mode 100644 index 0000000..03a86d6 --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/model/model_neurolm.py @@ -0,0 +1,191 @@ +""" +by Wei-Bang Jiang +https://github.com/935963004/NeuroLM +""" + +import inspect +import torch +import torch.nn as nn +from torch.nn import functional as F +from model.model import * +from torch.autograd import Function +from model.model_neural_transformer import NTConfig +from model.model_neural_transformer import NeuralTransformer +from collections import OrderedDict + + +class ReverseLayerF(Function): + @staticmethod + def forward(ctx, x, alpha): + ctx.alpha = alpha + return x.view_as(x) + + @staticmethod + def backward(ctx, grad_output): + output = grad_output.neg() * ctx.alpha + return output, None + + +class NeuroLM(nn.Module): + + def __init__(self, + GPT_config, + tokenizer_ckpt_path=None, + init_from='gpt2', + n_embd=768, + eeg_vocab_size=8192, + ): + super().__init__() + + if init_from == 'scratch': + self.GPT2 = GPT(GPT_config) + elif init_from.startswith('gpt2'): + override_args = dict(dropout=0.0) + self.GPT2 = GPT.from_pretrained(init_from, override_args) + self.GPT2.enlarge_wte(50304) + self.GPT2.enlarge_lm_head(self.GPT2.config.vocab_size + eeg_vocab_size) + + if tokenizer_ckpt_path is not None: + print('loading weight from VQ_align') + encoder_args = dict(n_layer=12, n_head=10, n_embd=400, block_size=1024, + bias=False, dropout=0., num_classes=0, in_chans=1, out_chans=16) + tokenizer_checkpoint = torch.load(tokenizer_ckpt_path) + tokenizer_checkpoint_model_args = tokenizer_checkpoint['encoder_args'] + # force these config attributes to be equal otherwise we can't even resume training + # the rest of the attributes (e.g. dropout) can stay as desired from command line + for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias']: + encoder_args[k] = tokenizer_checkpoint_model_args[k] + tokenizer_checkpoint_model_args = tokenizer_checkpoint['decoder_args'] + # create the model + encoder_conf = NTConfig(**encoder_args) + self.tokenizer = NeuralTransformer(encoder_conf) + tokenizer_state_dict = tokenizer_checkpoint['model'] + # fix the keys of the state dictionary :( + # honestly no idea how checkpoints sometimes get this prefix, have to debug more + unwanted_prefix = '_orig_mod.' + for k,v in list(tokenizer_state_dict.items()): + if k.startswith(unwanted_prefix): + tokenizer_state_dict[k[len(unwanted_prefix):]] = tokenizer_state_dict.pop(k) + + all_keys = list(tokenizer_state_dict.keys()) + new_dict = OrderedDict() + for key in all_keys: + if key.startswith('VQ.encoder.'): + new_dict[key[11:]] = tokenizer_state_dict[key] + self.tokenizer.load_state_dict(new_dict) + else: + encoder_args = dict(n_layer=12, n_head=12, n_embd=768, block_size=1024, + bias=False, dropout=0., num_classes=0, in_chans=1, out_chans=16) + encoder_conf = NTConfig(**encoder_args) + self.tokenizer = NeuralTransformer(encoder_conf) + + for p in self.tokenizer.parameters(): + p.requires_grad = False + + self.pos_embed = nn.Embedding(256, self.GPT2.config.n_embd) + + # task layer + self.encode_transform_layer = nn.Sequential( + nn.Linear(n_embd, self.GPT2.config.n_embd), + nn.GELU(), + ) if n_embd != self.GPT2.config.n_embd else nn.Identity() + + self.encode_transform_layer.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x_eeg=None, y_eeg=None, x_text=None, y_text=None, input_chans=None, input_time=None, input_mask=None, eeg_mask=None, eeg_text_mask=None): + """ + x_eeg: shape [B, N1, T] + x_text: shape [B, N2] + """ + if x_eeg is not None: + input_mask = input_mask.unsqueeze(1).repeat(1, x_eeg.size(1), 1).unsqueeze(1) + x_eeg = self.tokenizer(x_eeg, input_chans, input_time, input_mask, return_all_tokens=True) + x_eeg = self.encode_transform_layer(x_eeg) + x_eeg += self.pos_embed(input_chans) + + logits, loss, accuracy = self.GPT2(x_eeg, y_eeg, x_text, y_text, input_time, eeg_mask, eeg_text_mask) + + log = {} + split="train" if self.training else "val" + if loss is not None: + log[f'{split}/loss'] = loss.item() + if accuracy is not None: + log[f'{split}/accuracy'] = accuracy.item() + + return loss, log, logits + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.GPT2.transformer.wpe.weight.numel() + return n_params + + @torch.no_grad() + def generate(self, x_eeg, x_text, input_chans, input_time, input_mask, eeg_mask=None, eeg_text_mask=None, max_new_tokens=10, temperature=1.0, top_k=1): + if x_eeg is not None: + input_mask = input_mask.unsqueeze(1).repeat(1, x_eeg.size(1), 1).unsqueeze(1) + x_eeg = self.tokenizer(x_eeg, input_chans, input_time, input_mask, return_all_tokens=True) + x_eeg = self.encode_transform_layer(x_eeg) + x_eeg += self.pos_embed(input_chans) + #input_time = torch.zeros((x_eeg.size(0), x_eeg.size(1)), device=x_eeg.device).int() + + for _ in range(max_new_tokens): + logits, _, _ = self.GPT2(x_eeg=x_eeg, x_text=x_text, eeg_time_idx=input_time, eeg_mask=eeg_mask, eeg_text_mask=eeg_text_mask) + logits = logits[:, -1, :50257] / temperature + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float('Inf') + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + #_, idx_next = logits.max(-1) + + x_text = torch.cat((x_text, idx_next), dim=1) + if eeg_text_mask is not None: + eeg_text_mask = torch.cat((eeg_text_mask, torch.zeros((eeg_text_mask.size(0), eeg_text_mask.size(1), eeg_text_mask.size(2), 1), device=eeg_text_mask.device)), dim=-1) + eeg_text_mask = torch.cat((eeg_text_mask, torch.ones((eeg_text_mask.size(0), eeg_text_mask.size(1), 1, eeg_text_mask.size(3)), device=eeg_text_mask.device)), dim=-2) + + return x_text + + def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): + # start with all of the candidate parameters + param_dict = {pn: p for pn, p in self.named_parameters()} + # filter out those that do not require grad + param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + optim_groups = [ + {'params': decay_params, 'weight_decay': weight_decay}, + {'params': nodecay_params, 'weight_decay': 0.0} + ] + num_decay_params = sum(p.numel() for p in decay_params) + num_nodecay_params = sum(p.numel() for p in nodecay_params) + print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") + print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + # Create AdamW optimizer and use the fused version if it is available + fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and device_type == 'cuda' + extra_args = dict(fused=True) if use_fused else dict() + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + print(f"using fused AdamW: {use_fused}") + + return optimizer diff --git a/backend/shared-logic/src/signal_processing/moss/model/model_vq.py b/backend/shared-logic/src/signal_processing/moss/model/model_vq.py new file mode 100644 index 0000000..aa89785 --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/model/model_vq.py @@ -0,0 +1,264 @@ +""" +by Wei-Bang Jiang +https://github.com/935963004/NeuroLM +""" + +import torch +from torch import nn +import torch.nn.functional as F +import inspect + +from model.model_neural_transformer import NeuralTransformer +from model.norm_ema_quantizer import NormEMAVectorQuantizer + +from torch.autograd import Function +from transformers import GPT2LMHeadModel + + +class ReverseLayerF(Function): + @staticmethod + def forward(ctx, x, alpha): + ctx.alpha = alpha + return x.view_as(x) + + @staticmethod + def backward(ctx, grad_output): + output = grad_output.neg() * ctx.alpha + return output, None + + +class VQ(nn.Module): + def __init__(self, + encoder_config, + decoder_config, + n_embed=8192, + embed_dim=128, + decay=0.99, + quantize_kmeans_init=True, + decoder_out_dim=200, + smooth_l1_loss = False, + **kwargs + ): + super().__init__() + print(kwargs) + if decoder_config.in_chans != embed_dim: + print(f"Rewrite the in_chans in decoder from {decoder_config.in_chans} to {embed_dim}") + decoder_config.in_chans = embed_dim + + # encoder & decode params + print('Final encoder config', encoder_config) + self.encoder = NeuralTransformer(encoder_config) + + print('Final decoder config', decoder_config) + self.decoder_freq = NeuralTransformer(decoder_config) + self.decoder_raw = NeuralTransformer(decoder_config) + + self.quantize = NormEMAVectorQuantizer( + n_embed=n_embed, embedding_dim=embed_dim, beta=1.0, kmeans_init=quantize_kmeans_init, decay=decay, + ) + + self.decoder_out_dim = decoder_out_dim + + # task layer + self.encode_task_layer = nn.Sequential( + nn.Linear(encoder_config.n_embd, encoder_config.n_embd), + nn.Tanh(), + nn.Linear(encoder_config.n_embd, embed_dim) # for quantize + ) + self.decode_task_layer_freq = nn.Sequential( + nn.Linear(decoder_config.n_embd, decoder_config.n_embd), + nn.Tanh(), + nn.Linear(decoder_config.n_embd, self.decoder_out_dim // 2), + ) + self.decode_task_layer_raw = nn.Sequential( + nn.Linear(decoder_config.n_embd, decoder_config.n_embd), + nn.Tanh(), + nn.Linear(decoder_config.n_embd, self.decoder_out_dim), + ) + + self.kwargs = kwargs + + self.encode_task_layer.apply(self._init_weights) + self.decode_task_layer_freq.apply(self._init_weights) + self.decode_task_layer_raw.apply(self._init_weights) + + self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + # @torch.jit.ignore + # def no_weight_decay(self): + # return {'quantize.embedding.weight', 'decoder.pos_embed', 'decoder.time_embed', + # 'encoder.pos_embed', 'encoder.time_embed'} + + @property + def device(self): + return self.decoder.cls_token.device + + def get_number_of_tokens(self): + return self.quantize.n_e + + def get_tokens(self, data, input_chans=None, input_times=None, mask=None, **kwargs): + quantize, embed_ind, loss, _ = self.encode(data, input_chans, input_times, mask) + return embed_ind.view(data.size(0), data.size(1)) + + def encode(self, x, input_chans=None, input_time=None, mask=None): + batch_size, n, t = x.shape + encoder_features = self.encoder(x, input_chans, input_time, mask, return_all_tokens=True) + + with torch.cuda.amp.autocast(enabled=False): + to_quantizer_features = self.encode_task_layer(encoder_features.type_as(self.encode_task_layer[-1].weight)) + + quantize, loss, embed_ind = self.quantize(to_quantizer_features) + + return quantize, embed_ind, loss, encoder_features + + def decode(self, quantize, input_chans=None, input_time=None, mask=None, **kwargs): + # reshape tokens to feature maps for patch embed in decoder + decoder_features_freq = self.decoder_freq(quantize, input_chans, input_time, mask, return_all_tokens=True) + decoder_features_raw = self.decoder_raw(quantize, input_chans, input_time, mask, return_all_tokens=True) + rec_freq = self.decode_task_layer_freq(decoder_features_freq) + rec_raw = self.decode_task_layer_raw(decoder_features_raw) + return rec_freq, rec_raw + + def get_codebook_indices(self, x, input_chans=None, input_time=None, input_mask=None, **kwargs): + if input_mask is None: + mask = None + else: + mask = input_mask.unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1) + return self.get_tokens(x, input_chans, input_time, mask, **kwargs) + + def calculate_rec_loss(self, rec, target): + rec_loss = self.loss_fn(rec, target) + return rec_loss + + def forward(self, x, y_freq, y_raw, input_chans=None, input_time=None, input_mask=None, **kwargs): + """ + x: shape [B, N, T] + """ + mask = input_mask.unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1) + quantize, embed_ind, emb_loss, encoder_features = self.encode(x, input_chans, input_time, mask) + + xrec_freq, xrec_raw = self.decode(quantize, input_chans, input_time, mask) + + loss_freq_mask = input_mask.unsqueeze(-1).repeat(1, 1, xrec_freq.size(-1)) + loss_raw_mask = input_mask.unsqueeze(-1).repeat(1, 1, xrec_raw.size(-1)) + rec_freq_loss = self.calculate_rec_loss(xrec_freq * loss_freq_mask, y_freq) + rec_raw_loss = self.calculate_rec_loss(xrec_raw * loss_raw_mask, y_raw) + loss = emb_loss + rec_freq_loss + rec_raw_loss + + log = {} + split="train" if self.training else "val" + log[f'{split}/quant_loss'] = emb_loss.detach().mean() + log[f'{split}/rec_freq_loss'] = rec_freq_loss.detach().mean() + log[f'{split}/rec_raw_loss'] = rec_raw_loss.detach().mean() + log[f'{split}/total_loss'] = loss.detach().mean() + + return loss, encoder_features, log + + def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): + # start with all of the candidate parameters + param_dict = {pn: p for pn, p in self.named_parameters()} + # filter out those that do not require grad + param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + optim_groups = [ + {'params': decay_params, 'weight_decay': weight_decay}, + {'params': nodecay_params, 'weight_decay': 0.0} + ] + num_decay_params = sum(p.numel() for p in decay_params) + num_nodecay_params = sum(p.numel() for p in nodecay_params) + print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") + print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + # Create AdamW optimizer and use the fused version if it is available + fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and device_type == 'cuda' + extra_args = dict(fused=True) if use_fused else dict() + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + print(f"using fused AdamW: {use_fused}") + + return optimizer + + +class VQ_Align(nn.Module): + def __init__(self, + encoder_config, + decoder_config, + ): + super(VQ_Align, self).__init__() + self.VQ = VQ(encoder_config, decoder_config) + self.domain_classifier = nn.Sequential( + nn.Linear(decoder_config.n_embd, 256), + nn.GELU(), + nn.Linear(256, 2) + ) + + model_hf = GPT2LMHeadModel.from_pretrained('gpt2') + sd_hf = model_hf.state_dict() + self.wte = nn.Embedding(50257, 768, _freeze=True) + self.wte.weight.data = sd_hf['transformer.wte.weight'] + + self.domain_classifier.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x, y_freq=None, y_raw=None, input_chans=None, input_time=None, input_mask=None, alpha=0): + if y_freq is not None: + loss, encoder_features, log = self.VQ(x, y_freq, y_raw, input_chans, input_time, input_mask) + reverse_x = ReverseLayerF.apply(encoder_features, alpha) + domain_out = self.domain_classifier(reverse_x) + target = torch.full((domain_out.size(0), domain_out.size(1)), fill_value=-1, device=x.device) + target[input_mask == True] = 0 + domain_loss = F.cross_entropy(domain_out.view(-1, domain_out.size(-1)), target.view(-1), ignore_index=-1) + split="train" if self.training else "val" + log[f'{split}/domain_loss'] = domain_loss.detach().item() + return loss, domain_loss, log + else: + x = self.wte(x).detach() + domain_out = self.domain_classifier(x) + domain_loss = F.cross_entropy(domain_out.view(-1, domain_out.size(-1)), torch.ones((x.size(0) * x.size(1),), device=x.device).long(), ignore_index=-1) + return domain_loss + + def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): + # start with all of the candidate parameters + param_dict = {pn: p for pn, p in self.named_parameters()} + # filter out those that do not require grad + param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + optim_groups = [ + {'params': decay_params, 'weight_decay': weight_decay}, + {'params': nodecay_params, 'weight_decay': 0.0} + ] + num_decay_params = sum(p.numel() for p in decay_params) + num_nodecay_params = sum(p.numel() for p in nodecay_params) + print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") + print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + # Create AdamW optimizer and use the fused version if it is available + fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and device_type == 'cuda' + extra_args = dict(fused=True) if use_fused else dict() + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + print(f"using fused AdamW: {use_fused}") + + return optimizer + \ No newline at end of file diff --git a/backend/shared-logic/src/signal_processing/moss/model/norm_ema_quantizer.py b/backend/shared-logic/src/signal_processing/moss/model/norm_ema_quantizer.py new file mode 100644 index 0000000..40239b6 --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/model/norm_ema_quantizer.py @@ -0,0 +1,201 @@ +""" +by Wei-Bang Jiang +https://github.com/935963004/NeuroLM +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as distributed +from einops import rearrange, repeat + + +def l2norm(t): + return F.normalize(t, p = 2, dim = -1) + +def ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) + +def sample_vectors(samples, num): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device = device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device = device) + + return samples[indices] + +def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): + dim, dtype, device = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, 'n d -> n () d') \ + - rearrange(means, 'c d -> () c d') + dists = -(diffs ** 2).sum(dim = -1) + + buckets = dists.max(dim = -1).indices + bins = torch.bincount(buckets, minlength = num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype = dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d = dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = l2norm(new_means) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''): + super().__init__() + self.num_tokens = num_tokens + self.codebook_dim = codebook_dim + self.decay = decay + self.eps = eps + if codebook_init_path == '': + if not kmeans_init: + weight = torch.randn(num_tokens, codebook_dim) + weight = l2norm(weight) + else: + weight = torch.zeros(num_tokens, codebook_dim) + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + else: + print(f"load init codebook weight from {codebook_init_path}") + codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu') + weight = codebook_ckpt_weight.clone() + self.register_buffer('initted', torch.Tensor([True])) + + self.weight = nn.Parameter(weight, requires_grad = False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) + # self.register_buffer('initted', torch.Tensor([not kmeans_init])) + self.update = True + + @torch.jit.ignore + def init_embed_(self, data): + if self.initted: + return + print("Performing Kemans init for codebook") + embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim = True) + self.weight.data.copy_(embed) + self.cluster_size.data.copy_(cluster_size) + self.initted.data.copy_(torch.Tensor([True])) + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + #normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1)) + self.weight.data.copy_(embed_normalized) + +def norm_ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) + moving_avg.data.copy_(l2norm(moving_avg.data)) + +class NormEMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + statistic_code_usage=True, kmeans_init=False, codebook_init_path=''): + super().__init__() + self.codebook_dim = embedding_dim + self.num_tokens = n_embed + self.beta = beta + self.decay = decay + + # learnable = True if orthogonal_reg_weight > 0 else False + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path) + + self.statistic_code_usage = statistic_code_usage + if statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(n_embed)) + if distributed.is_available() and distributed.is_initialized(): + print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!") + self.all_reduce_fn = distributed.all_reduce + else: + self.all_reduce_fn = nn.Identity() + + def reset_cluster_size(self, device): + if self.statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) + self.cluster_size = self.cluster_size.to(device) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + #z, 'b c h w -> b h w c' + #z = rearrange(z, 'b c h w -> b h w c') + z = l2norm(z) # by JWB, z: (b, n, c) + z_flattened = z.reshape(-1, self.codebook_dim) + self.embedding.init_embed_(z_flattened) + + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + + if not self.training: + with torch.no_grad(): + cluster_size = encodings.sum(0) + self.all_reduce_fn(cluster_size) + ema_inplace(self.cluster_size, cluster_size, self.decay) + + if self.training and self.embedding.update: + #EMA cluster size + + bins = encodings.sum(0) + self.all_reduce_fn(bins) + + # self.embedding.cluster_size_ema_update(bins) + ema_inplace(self.cluster_size, bins, self.decay) + + zero_mask = (bins == 0) + bins = bins.masked_fill(zero_mask, 1.) + + embed_sum = z_flattened.t() @ encodings + self.all_reduce_fn(embed_sum) + + embed_normalized = (embed_sum / bins.unsqueeze(0)).t() + embed_normalized = l2norm(embed_normalized) + + embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight, + embed_normalized) + + norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + #z_q, 'b h w c -> b c h w' + #z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, encoding_indices + diff --git a/backend/shared-logic/src/signal_processing/moss/moss_models/emotion_classifier.pkl b/backend/shared-logic/src/signal_processing/moss/moss_models/emotion_classifier.pkl new file mode 100644 index 0000000..842ff53 Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/moss_models/emotion_classifier.pkl differ diff --git a/backend/shared-logic/src/signal_processing/moss/moss_models/emotion_embeddings.npz b/backend/shared-logic/src/signal_processing/moss/moss_models/emotion_embeddings.npz new file mode 100644 index 0000000..915a18f Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/moss_models/emotion_embeddings.npz differ diff --git a/backend/shared-logic/src/signal_processing/moss/moss_models/emotion_results.csv b/backend/shared-logic/src/signal_processing/moss/moss_models/emotion_results.csv new file mode 100644 index 0000000..7f2596f --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/moss_models/emotion_results.csv @@ -0,0 +1,46 @@ +subject,accuracy,balanced_acc,n_segments +1,37.4,12.7,107 +103,51.4,24.0,107 +106,29.5,25.0,105 +107,64.8,25.0,108 +117,19.0,29.0,105 +12,56.5,30.2,115 +121,43.1,31.9,109 +122,37.2,34.2,113 +124,40.0,12.0,110 +126,56.8,21.5,111 +128,63.3,23.4,109 +129,58.9,20.9,112 +131,49.5,33.5,111 +137,51.4,32.7,109 +138,50.0,22.4,108 +14,31.2,29.7,109 +141,30.9,44.8,110 +15,32.4,21.4,111 +159,43.6,24.3,110 +16,42.2,26.5,109 +160,46.3,31.5,108 +2,20.8,20.7,101 +24,64.5,28.7,110 +25,62.9,42.1,105 +34,20.0,25.4,110 +35,72.0,31.2,82 +36,60.4,30.0,111 +4,59.3,19.9,108 +40,35.4,18.7,96 +45,27.3,24.9,110 +47,48.6,35.8,107 +49,56.2,20.1,112 +55,25.7,34.6,113 +56,45.0,23.4,109 +61,45.9,27.6,109 +66,47.7,24.1,109 +68,55.4,30.1,101 +72,52.3,46.9,44 +76,36.7,27.6,109 +77,40.0,19.8,100 +79,63.3,31.0,109 +81,53.2,21.7,111 +86,56.0,35.4,109 +91,4.6,8.8,108 +92,60.0,33.7,110 diff --git a/backend/shared-logic/src/signal_processing/moss/moss_models/focus_classifier.pkl b/backend/shared-logic/src/signal_processing/moss/moss_models/focus_classifier.pkl new file mode 100644 index 0000000..412820e Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/moss_models/focus_classifier.pkl differ diff --git a/backend/shared-logic/src/signal_processing/moss/moss_models/focus_embeddings.npz b/backend/shared-logic/src/signal_processing/moss/moss_models/focus_embeddings.npz new file mode 100644 index 0000000..1e58b51 Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/moss_models/focus_embeddings.npz differ diff --git a/backend/shared-logic/src/signal_processing/moss/moss_models/focus_results.csv b/backend/shared-logic/src/signal_processing/moss/moss_models/focus_results.csv new file mode 100644 index 0000000..942fe13 --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/moss_models/focus_results.csv @@ -0,0 +1,6 @@ +subject,accuracy,balanced_acc,n_segments +name,100.0,100.0,29 +subjecta,80.6,80.9,170 +subjectb,84.6,85.7,149 +subjectc,35.8,39.3,148 +subjectd,58.4,66.2,137 diff --git a/backend/shared-logic/src/signal_processing/moss/moss_models/muse2_classifier.pkl b/backend/shared-logic/src/signal_processing/moss/moss_models/muse2_classifier.pkl new file mode 100644 index 0000000..d802abc Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/moss_models/muse2_classifier.pkl differ diff --git a/backend/shared-logic/src/signal_processing/moss/moss_models/muse2_embeddings.npz b/backend/shared-logic/src/signal_processing/moss/moss_models/muse2_embeddings.npz new file mode 100644 index 0000000..1ef7ba9 Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/moss_models/muse2_embeddings.npz differ diff --git a/backend/shared-logic/src/signal_processing/moss/moss_models/muse2_results.csv b/backend/shared-logic/src/signal_processing/moss/moss_models/muse2_results.csv new file mode 100644 index 0000000..22e32cf --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/moss_models/muse2_results.csv @@ -0,0 +1,9 @@ +volunteer,accuracy,balanced_acc,chance +1,93.6,93.7,16.7 +2,86.7,86.8,16.7 +3,88.3,88.2,25.0 +4,90.1,90.3,16.7 +5,92.1,92.1,16.7 +6,96.2,96.2,16.7 +7,93.1,93.0,16.7 +8,93.7,93.7,16.7 diff --git a/backend/shared-logic/src/signal_processing/moss/moss_models/stress_classifier.pkl b/backend/shared-logic/src/signal_processing/moss/moss_models/stress_classifier.pkl new file mode 100644 index 0000000..958a968 Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/moss_models/stress_classifier.pkl differ diff --git a/backend/shared-logic/src/signal_processing/moss/moss_models/stress_embeddings.npz b/backend/shared-logic/src/signal_processing/moss/moss_models/stress_embeddings.npz new file mode 100644 index 0000000..fe65245 Binary files /dev/null and b/backend/shared-logic/src/signal_processing/moss/moss_models/stress_embeddings.npz differ diff --git a/backend/shared-logic/src/signal_processing/moss/moss_models/stress_results.csv b/backend/shared-logic/src/signal_processing/moss/moss_models/stress_results.csv new file mode 100644 index 0000000..d912efb --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/moss_models/stress_results.csv @@ -0,0 +1,26 @@ +subject,true_label,pred_label,correct,n_segments +S01,Moderate,High,False,53 +S02,Moderate,High,False,46 +S03,High,High,True,39 +S04,Low,Moderate,False,47 +S05,Moderate,High,False,86 +S06,Low,High,False,61 +S07,High,Moderate,False,80 +S08,Moderate,High,False,72 +S09,Moderate,High,False,130 +S10,High,Moderate,False,93 +S11,Low,High,False,68 +S12,High,High,True,85 +S13,High,Low,False,64 +S14,Low,Low,True,93 +S15,High,Moderate,False,59 +S16,High,Low,False,60 +S17,Low,High,False,57 +S18,Moderate,High,False,42 +S19,Moderate,Low,False,22 +S21,High,High,True,24 +S22,Low,Low,True,23 +S23,High,High,True,23 +S24,Low,Moderate,False,17 +S25,High,Moderate,False,22 +S26,Moderate,Moderate,True,21 diff --git a/backend/shared-logic/src/signal_processing/moss/muse2_predict.py b/backend/shared-logic/src/signal_processing/moss/muse2_predict.py new file mode 100644 index 0000000..8712601 --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/muse2_predict.py @@ -0,0 +1,240 @@ +""" +MOSS - Phase 2: Predict +======================= +Run activity prediction on any new Muse 2 recording. + +Usage: + python muse2_predict.py --input "path/to/your_recording.csv" + python muse2_predict.py --input "path/to/your_recording.csv" --task activity + +Supported tasks (as more classifiers are trained): + activity - eat / game / read / rest / toy / tv + stress - Low / Moderate / High + focus - relaxed / neutral / concentrating + emotion - neutral / anger / fear / happiness / sadness + +Your CSV must have columns: RAW_TP9, RAW_AF7, RAW_AF8, RAW_TP10 +(standard Mind Monitor export format) +""" + +import os +import sys +import argparse +import pickle +import numpy as np +import pandas as pd +import torch +from scipy import signal as scipy_signal +from einops import rearrange + +# ── Paths ────────────────────────────────────────────────────────────────────── +NEUROLM_DIR = r"C:\Users\kiara\NeuroLM" +SAVE_DIR = r"C:\Users\kiara\NeuroLM\moss_models" + +SRC_FS = 256 +TGT_FS = 200 +WIN_SEC = 4 +STEP_SEC = 2 +WIN_SAMPLES = TGT_FS * WIN_SEC +STEP_SAMPLES = TGT_FS * STEP_SEC +EEG_MAX_LEN = 276 + +MUSE_CHANS = ['TP9', 'AF7', 'AF8', 'TP10'] +RAW_COLS = ['RAW_TP9', 'RAW_AF7', 'RAW_AF8', 'RAW_TP10'] + +# Mind Monitor also sometimes uses these column names +ALT_COLS = ['RAW_TP9', 'RAW_AF7', 'RAW_AF8', 'RAW_TP10', + 'channel1', 'channel2', 'channel3', 'channel4'] + +DEVICE = torch.device('cpu') + +sys.path.insert(0, NEUROLM_DIR) +from model.model_neurolm import NeuroLM +from model.model import GPTConfig + +standard_1020 = [ + 'FP1','FPZ','FP2','AF9','AF7','AF5','AF3','AF1','AFZ','AF2','AF4','AF6','AF8','AF10', + 'F9','F7','F5','F3','F1','FZ','F2','F4','F6','F8','F10', + 'FT9','FT7','FC5','FC3','FC1','FCZ','FC2','FC4','FC6','FT8','FT10', + 'T9','T7','C5','C3','C1','CZ','C2','C4','C6','T8','T10', + 'TP9','TP7','CP5','CP3','CP1','CPZ','CP2','CP4','CP6','TP8','TP10', + 'P9','P7','P5','P3','P1','PZ','P2','P4','P6','P8','P10', + 'PO9','PO7','PO5','PO3','PO1','POZ','PO2','PO4','PO6','PO8','PO10', + 'O1','OZ','O2','O9','CB1','CB2','IZ','O10', + 'T3','T5','T4','T6','M1','M2','A1','A2', + 'CFC1','CFC2','CFC3','CFC4','CFC5','CFC6','CFC7','CFC8', + 'CCP1','CCP2','CCP3','CCP4','CCP5','CCP6','CCP7','CCP8', + 'T1','T2','FTT9h','TTP7h','TPP9h','FTT10h','TPP8h','TPP10h', + 'FP1-F7','F7-T7','T7-P7','P7-O1','FP2-F8','F8-T8','T8-P8','P8-O2', + 'pad','I1','I2' +] + +# ── Model loading ────────────────────────────────────────────────────────────── +def load_neurolm(): + checkpoint_path = os.path.join(NEUROLM_DIR, 'checkpoints', 'checkpoints', 'NeuroLM-B.pt') + ckpt = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False) + cfg = GPTConfig(**ckpt['model_args']) + model = NeuroLM(cfg, init_from='scratch') + sd = ckpt['model'] + for k in list(sd.keys()): + if k.startswith('_orig_mod.'): + sd[k[10:]] = sd.pop(k) + model.load_state_dict(sd) + model.eval() + return model + +# ── EEG processing ──────────────────────────────────────────────────────────── +def load_recording(filepath): + """Load a Muse 2 CSV - handles both RAW_TP9 and channel1/2/3/4 formats.""" + df = pd.read_csv(filepath) + cols = df.columns.tolist() + + if 'RAW_TP9' in cols: + eeg = df[RAW_COLS].dropna().values.astype(np.float32) + elif 'channel1' in cols: + # export.csv style - map channel1->TP9, channel2->AF7, etc. + eeg = df[['channel1','channel2','channel3','channel4']].dropna().values.astype(np.float32) + else: + raise ValueError(f"Unrecognized columns. Expected RAW_TP9/AF7/AF8/TP10 or channel1-4. Got: {cols[:10]}") + + # Detect sample rate from timestamps if available + src_fs = SRC_FS + if 'TimeStamp' in cols or 'time' in cols: + ts_col = 'TimeStamp' if 'TimeStamp' in cols else 'time' + try: + ts = pd.to_datetime(df[ts_col]) + dt = (ts.iloc[-1] - ts.iloc[0]).total_seconds() + src_fs = int(round(len(df) / dt)) + print(f" Detected sample rate: {src_fs} Hz") + except: + pass + + n_out = int(eeg.shape[0] * TGT_FS / src_fs) + resampled = np.stack([scipy_signal.resample(eeg[:, c], n_out) for c in range(4)], axis=1) + return resampled, src_fs + +def segment_eeg(eeg_np): + segs, start = [], 0 + while start + WIN_SAMPLES <= eeg_np.shape[0]: + segs.append(eeg_np[start:start + WIN_SAMPLES, :].T) + start += STEP_SAMPLES + return segs + +def segment_to_tensors(seg): + n_chans, n_total = seg.shape + T, n_time = 200, n_total // 200 + data = torch.FloatTensor(seg / 100.0) + std = data.std() + if std > 0: + data = (data - data.mean()) / std + data = rearrange(data, 'N (A T) -> (A N) T', T=T) + valid_len = data.shape[0] + X_eeg = torch.zeros((EEG_MAX_LEN, T)) + X_eeg[:valid_len] = data + eeg_mask = torch.ones(EEG_MAX_LEN) + eeg_mask[valid_len:] = 0 + chans = MUSE_CHANS * n_time + ['pad'] * (EEG_MAX_LEN - valid_len) + input_chans = torch.IntTensor([standard_1020.index(c) for c in chans]) + input_time = [i for i in range(n_time) for _ in range(n_chans)] + [0] * (EEG_MAX_LEN - valid_len) + input_time = torch.IntTensor(input_time) + return X_eeg.unsqueeze(0), input_chans.unsqueeze(0), input_time.unsqueeze(0), eeg_mask.bool().unsqueeze(0) + +@torch.no_grad() +def embed_segment(model, seg): + X_eeg, input_chans, input_time, eeg_mask = segment_to_tensors(seg) + mask = eeg_mask.unsqueeze(1).repeat(1, X_eeg.size(1), 1).unsqueeze(1) + tokens = model.tokenizer(X_eeg, input_chans, input_time, mask, return_all_tokens=True) + valid_len = int(eeg_mask.sum().item()) + emb = tokens[0, :valid_len, :].mean(dim=0) + return model.encode_transform_layer(emb).numpy() + +# ── Main ─────────────────────────────────────────────────────────────────────── +def main(): + parser = argparse.ArgumentParser(description='MOSS - Predict mental state from Muse 2 EEG') + parser.add_argument('--input', required=True, help='Path to your Muse 2 CSV recording') + parser.add_argument('--task', default='activity', help='Task to run (default: activity)') + args = parser.parse_args() + + clf_map = {'activity': 'muse2_classifier.pkl', 'stress': 'stress_classifier.pkl', 'focus': 'focus_classifier.pkl', 'emotion': 'emotion_classifier.pkl'} + clf_file = clf_map.get(args.task, f'{args.task}_classifier.pkl') + clf_path = os.path.join(SAVE_DIR, clf_file) + if not os.path.exists(clf_path): + print(f"ERROR: No classifier found at {clf_path}") + print("Run muse2_train.py first to train the model.") + sys.exit(1) + + print(f"\n{'='*55}") + print(f" MOSS Prediction") + print(f"{'='*55}") + print(f" Input: {args.input}") + print(f" Task: {args.task}") + + # Load classifier + with open(clf_path, 'rb') as f: + bundle = pickle.load(f) + clf = bundle['classifier'] + scaler = bundle['scaler'] + activities = bundle['activities'] + print(f" Model: trained on {bundle['trained_on']}") + + # Load recording + print(f"\nLoading recording...") + eeg, src_fs = load_recording(args.input) + duration = eeg.shape[0] / TGT_FS + print(f" Duration: {duration:.1f} seconds ({eeg.shape[0]} samples at {TGT_FS}Hz)") + + segs = segment_eeg(eeg) + print(f" Segments: {len(segs)} x {WIN_SEC}s windows") + + if len(segs) == 0: + print(f"\nERROR: Recording too short. Need at least {WIN_SEC}s, got {duration:.1f}s") + sys.exit(1) + + # Encode with NeuroLM + print(f"\nRunning NeuroLM encoder...") + model = load_neurolm() + embeddings = [] + for i, seg in enumerate(segs): + emb = embed_segment(model, seg) + embeddings.append(emb) + if (i+1) % 10 == 0: + print(f" [{i+1}/{len(segs)}]") + + X = scaler.transform(np.array(embeddings)) + + # Predict + preds = clf.predict(X) + proba = clf.predict_proba(X) + pred_names = [activities[p] for p in preds] + + # ── Results ─────────────────────────────────────────────────────────────── + print(f"\n{'='*55}") + print(f" Results") + print(f"{'='*55}") + + # Per-segment predictions + print(f"\n Segment-by-segment predictions:") + for i, (pred, prob) in enumerate(zip(pred_names, proba)): + t_start = i * STEP_SEC + t_end = t_start + WIN_SEC + conf = prob.max() * 100 + bar = '█' * int(conf / 5) + print(f" [{t_start:4.0f}s-{t_end:.0f}s] {pred:<8} {conf:5.1f}% {bar}") + + # Overall prediction (majority vote) + from collections import Counter + counts = Counter(pred_names) + top_pred = counts.most_common(1)[0][0] + top_pct = counts.most_common(1)[0][1] / len(pred_names) * 100 + + print(f"\n Overall prediction: {top_pred.upper()} ({top_pct:.0f}% of segments)") + print(f"\n Class probabilities (mean across all segments):") + mean_proba = proba.mean(axis=0) + for act, p in sorted(zip(activities, mean_proba), key=lambda x: -x[1]): + bar = '█' * int(p * 40) + print(f" {act:<8} {p*100:5.1f}% {bar}") + + print(f"\n{'='*55}\n") + +if __name__ == '__main__': + main() diff --git a/backend/shared-logic/src/signal_processing/moss/predict.bat b/backend/shared-logic/src/signal_processing/moss/predict.bat new file mode 100644 index 0000000..9d611f8 --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/predict.bat @@ -0,0 +1,80 @@ +@echo off +echo ================================================================ +echo MOSS BCI Platform - Mental State Prediction +echo ================================================================ +echo. + +:: Check environment exists +call conda activate MOSS 2>nul +if %errorlevel% neq 0 ( + echo ERROR: MOSS environment not found. + echo Please run setup.bat first. + pause + exit /b 1 +) + +:: Check checkpoint exists +if not exist "%~dp0checkpoints\checkpoints\NeuroLM-B.pt" ( + echo ERROR: NeuroLM-B.pt not found. + echo. + echo Please download the model weights and place them at: + echo %~dp0checkpoints\checkpoints\NeuroLM-B.pt + echo. + echo Contact Natalia ^(UBC MINT Team^) for the download link. + pause + exit /b 1 +) + +:: Get input file +echo Drag and drop your Muse 2 CSV file here, then press Enter: +echo ^(or type the full path manually^) +echo. +set /p INPUT_FILE="CSV path: " + +:: Strip surrounding quotes if user dragged file +set INPUT_FILE=%INPUT_FILE:"=% + +:: Check file exists +if not exist "%INPUT_FILE%" ( + echo. + echo ERROR: File not found: %INPUT_FILE% + pause + exit /b 1 +) + +:: Choose task +echo. +echo Choose a task: +echo 1. activity - what you were doing ^(eat/game/read/rest/toy/tv^) +echo 2. focus - attention level ^(relaxed/neutral/concentrating^) +echo 3. emotion - emotional state ^(neutral/anger/fear/happiness/sadness^) +echo 4. stress - stress level ^(Low/Moderate/High^) ^[experimental^] +echo. +set /p TASK_NUM="Enter number (1-4): " + +if "%TASK_NUM%"=="1" set TASK=activity +if "%TASK_NUM%"=="2" set TASK=focus +if "%TASK_NUM%"=="3" set TASK=emotion +if "%TASK_NUM%"=="4" set TASK=stress + +if not defined TASK ( + echo Invalid choice. Please enter 1, 2, 3, or 4. + pause + exit /b 1 +) + +:: Run prediction +echo. +echo ================================================================ +echo Running %TASK% prediction... +echo ================================================================ +echo. + +cd /d "%~dp0" +python muse2_predict.py --input "%INPUT_FILE%" --task %TASK% + +echo. +echo ================================================================ +echo Done! Press any key to close. +echo ================================================================ +pause diff --git a/backend/shared-logic/src/signal_processing/moss/preprocessing.py b/backend/shared-logic/src/signal_processing/moss/preprocessing.py new file mode 100644 index 0000000..b4e4836 --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/preprocessing.py @@ -0,0 +1,231 @@ +""" +MOSS - preprocessing.py +======================= +Handles all EEG signal preprocessing: + - CSV loading (Mind Monitor + MuseLSL formats) + - Sample rate detection + - Resampling to 200Hz + - Segmentation into 4-second windows + +Input: path to a Muse 2 CSV file +Output: list of (4, 800) numpy arrays — one per segment + +Used by: coordinator.py +""" + +import numpy as np +import pandas as pd +from scipy import signal as scipy_signal +from typing import Optional + +# ── Constants ────────────────────────────────────────────────────────────────── +TGT_FS = 200 # NeuroLM expected sample rate (Hz) +WIN_SEC = 4 # window length in seconds +STEP_SEC = 2 # step size (50% overlap) +WIN_SAMPLES = TGT_FS * WIN_SEC # 800 samples per window +STEP_SAMPLES = TGT_FS * STEP_SEC # 400 samples per step +DEFAULT_FS = 256 # fallback sample rate if detection fails + +# Column names for different Muse export formats +MIND_MONITOR_COLS = ['RAW_TP9', 'RAW_AF7', 'RAW_AF8', 'RAW_TP10'] +MUSELSL_COLS = ['TP9', 'AF7', 'AF8', 'TP10'] +EXPORT_COLS = ['channel1', 'channel2', 'channel3', 'channel4'] + +# Canonical channel order used throughout MOSS +CHANNEL_NAMES = ['TP9', 'AF7', 'AF8', 'TP10'] + + +def load_csv(filepath: str) -> tuple[np.ndarray, int]: + """ + Load a Muse 2 CSV file and return raw EEG as a numpy array. + + Supports three export formats: + - Mind Monitor: RAW_TP9, RAW_AF7, RAW_AF8, RAW_TP10 + - MuseLSL: TP9, AF7, AF8, TP10 + - Direct export: channel1, channel2, channel3, channel4 + + Returns: + eeg: np.ndarray of shape (n_samples, 4), dtype float32 + src_fs: detected sample rate in Hz + """ + df = pd.read_csv(filepath) + cols = df.columns.tolist() + + # Detect column format + if 'RAW_TP9' in cols: + eeg_cols = MIND_MONITOR_COLS + ts_col = 'TimeStamp' if 'TimeStamp' in cols else None + elif 'TP9' in cols: + eeg_cols = MUSELSL_COLS + ts_col = 'timestamps' if 'timestamps' in cols else None + elif 'channel1' in cols: + eeg_cols = EXPORT_COLS + ts_col = 'time' if 'time' in cols else None + else: + raise ValueError( + f"Unrecognized CSV format. Expected RAW_TP9/AF7/AF8/TP10, " + f"TP9/AF7/AF8/TP10, or channel1-4. Got columns: {cols[:10]}" + ) + + df = df[([ts_col] if ts_col else []) + eeg_cols].dropna(subset=eeg_cols) + eeg = df[eeg_cols].values.astype(np.float32) + + # Detect sample rate from timestamps + src_fs = _detect_sample_rate(df, ts_col) + + return eeg, src_fs + + +def _detect_sample_rate(df: pd.DataFrame, ts_col: Optional[str]) -> int: + """Estimate sample rate from timestamp column, fallback to DEFAULT_FS.""" + if ts_col is None or ts_col not in df.columns: + return DEFAULT_FS + try: + ts = pd.to_datetime(df[ts_col]) + dt = (ts.iloc[-1] - ts.iloc[0]).total_seconds() + if dt > 0: + fs = int(round(len(df) / dt)) + return max(100, min(512, fs)) # clamp to sane range + except Exception: + pass + + try: + # MuseLSL format: unix epoch floats + ts = df[ts_col].values.astype(float) + diffs = np.diff(ts) + diffs = diffs[diffs > 0] + if len(diffs) > 10: + fs = int(round(1.0 / np.median(diffs))) + return max(100, min(512, fs)) + except Exception: + pass + + return DEFAULT_FS + + +def resample(eeg: np.ndarray, src_fs: int, tgt_fs: int = TGT_FS) -> np.ndarray: + """ + Resample EEG from src_fs to tgt_fs using Fourier method. + + Args: + eeg: (n_samples, 4) array + src_fs: source sample rate + tgt_fs: target sample rate (default 200Hz) + + Returns: + resampled: (n_out, 4) array at tgt_fs + """ + if src_fs == tgt_fs: + return eeg + n_out = int(eeg.shape[0] * tgt_fs / src_fs) + return np.stack( + [scipy_signal.resample(eeg[:, c], n_out) for c in range(eeg.shape[1])], + axis=1 + ).astype(np.float32) + + +def from_array(raw: np.ndarray, + src_fs: int, + channel_order: Optional[list[str]] = None) -> tuple[list[np.ndarray], float]: + """ + Preprocess a raw EEG numpy array directly — no CSV needed. + Use this for real-time streaming from the Muse 2 headset via LSL. + + Args: + raw: (n_samples, 4) float32 array, channels in order + [TP9, AF7, AF8, TP10] by default + src_fs: sample rate of the incoming data (e.g. 256 for Muse 2) + channel_order: list of 4 channel names if different from default + default: ['TP9', 'AF7', 'AF8', 'TP10'] + + Returns: + segments: list of (4, 800) numpy arrays + duration: recording duration in seconds + + Example (LSL stream): + from pylsl import StreamInlet, resolve_stream + streams = resolve_stream('type', 'EEG') + inlet = StreamInlet(streams[0]) + samples, _ = inlet.pull_chunk(max_samples=1024) + raw = np.array(samples, dtype=np.float32) + segments, duration = from_array(raw, src_fs=256) + """ + if raw.ndim != 2 or raw.shape[1] != 4: + raise ValueError( + f"Expected (n_samples, 4) array, got shape {raw.shape}" + ) + + eeg = raw.astype(np.float32) + eeg = resample(eeg, src_fs) + + duration = eeg.shape[0] / TGT_FS + if duration < WIN_SEC: + raise ValueError( + f"Array too short: {duration:.1f}s. Need at least {WIN_SEC}s." + ) + + segments = segment(eeg) + return segments, duration + + +def segment(eeg: np.ndarray, + win_samples: int = WIN_SAMPLES, + step_samples: int = STEP_SAMPLES) -> list[np.ndarray]: + """ + Slice EEG into overlapping windows. + + Args: + eeg: (n_samples, 4) array at TGT_FS + win_samples: samples per window (default 800 = 4s @ 200Hz) + step_samples: step size (default 400 = 2s, 50% overlap) + + Returns: + segments: list of (4, win_samples) arrays — channels first + """ + segs, start = [], 0 + while start + win_samples <= eeg.shape[0]: + seg = eeg[start:start + win_samples, :].T # (4, 800) + segs.append(seg) + start += step_samples + return segs + + +def preprocess(filepath: str) -> tuple[list[np.ndarray], int, float]: + """ + Full preprocessing pipeline: load → resample → segment. + + Args: + filepath: path to Muse 2 CSV file + + Returns: + segments: list of (4, 800) numpy arrays + src_fs: detected source sample rate + duration: recording duration in seconds + + Raises: + ValueError: if CSV format unrecognized or recording too short + """ + eeg, src_fs = load_csv(filepath) + eeg = resample(eeg, src_fs) + + duration = eeg.shape[0] / TGT_FS + if duration < WIN_SEC: + raise ValueError( + f"Recording too short: {duration:.1f}s. Need at least {WIN_SEC}s." + ) + + segments = segment(eeg) + return segments, src_fs, duration + + +if __name__ == '__main__': + # Quick test + import sys + if len(sys.argv) < 2: + print("Usage: python preprocessing.py path/to/recording.csv") + sys.exit(1) + + segs, fs, dur = preprocess(sys.argv[1]) + print(f"Sample rate detected: {fs} Hz") + print(f"Duration: {dur:.1f}s") + print(f"Segments: {len(segs)} x {segs[0].shape}") diff --git a/backend/shared-logic/src/signal_processing/moss/setup.bat b/backend/shared-logic/src/signal_processing/moss/setup.bat new file mode 100644 index 0000000..e517bd6 --- /dev/null +++ b/backend/shared-logic/src/signal_processing/moss/setup.bat @@ -0,0 +1,40 @@ +@echo off +echo ================================================================ +echo MOSS BCI Platform - Environment Setup +echo ================================================================ +echo. +echo This will create a Python environment and install all packages. +echo This takes about 5-10 minutes. Please wait... +echo. + +:: Create conda environment +call conda create -n MOSS python=3.11 -y +if %errorlevel% neq 0 ( + echo ERROR: Failed to create conda environment. + echo Make sure Miniconda/Anaconda is installed and try again. + pause + exit /b 1 +) + +:: Activate and install packages +call conda activate MOSS + +echo. +echo Installing PyTorch (CPU)... +call pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + +echo. +echo Installing other packages... +call pip install numpy pandas scipy scikit-learn einops transformers + +echo. +echo ================================================================ +echo Setup complete! +echo ================================================================ +echo. +echo Next step: place NeuroLM-B.pt in: +echo %~dp0checkpoints\checkpoints\NeuroLM-B.pt +echo. +echo Then double-click predict.bat to run predictions. +echo. +pause