Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
//!
//! The trained model is serialised as JSON and hot-loaded at runtime so that
//! the classification thresholds adapt to the specific room and ESP32 placement.
//!
//! Classes are discovered dynamically from training data filenames instead of
//! being hardcoded, so new activity classes can be added just by recording data
//! with the appropriate filename convention.

use serde::{Deserialize, Serialize};
use std::collections::HashMap;
Expand All @@ -20,9 +24,8 @@ use std::path::{Path, PathBuf};
/// Extended feature vector: 7 server features + 8 subcarrier-derived features = 15.
const N_FEATURES: usize = 15;

/// Activity classes we recognise.
pub const CLASSES: &[&str] = &["absent", "present_still", "present_moving", "active"];
const N_CLASSES: usize = 4;
/// Default class names for backward compatibility with old saved models.
const DEFAULT_CLASSES: &[&str] = &["absent", "present_still", "present_moving", "active"];

/// Extract extended feature vector from a JSONL frame (features + raw amplitudes).
pub fn features_from_frame(frame: &serde_json::Value) -> [f64; N_FEATURES] {
Expand Down Expand Up @@ -124,36 +127,48 @@ pub struct ClassStats {
pub struct AdaptiveModel {
/// Per-class feature statistics (centroid + spread).
pub class_stats: Vec<ClassStats>,
/// Logistic regression weights: [N_CLASSES x (N_FEATURES + 1)] (last = bias).
pub weights: Vec<[f64; N_FEATURES + 1]>,
/// Logistic regression weights: [n_classes x (N_FEATURES + 1)] (last = bias).
/// Dynamic: the outer Vec length equals the number of discovered classes.
pub weights: Vec<Vec<f64>>,
/// Global feature normalisation: mean and stddev across all training data.
pub global_mean: [f64; N_FEATURES],
pub global_std: [f64; N_FEATURES],
/// Training metadata.
pub trained_frames: usize,
pub training_accuracy: f64,
pub version: u32,
/// Dynamically discovered class names (in index order).
#[serde(default = "default_class_names")]
pub class_names: Vec<String>,
}

/// Backward-compatible fallback for models saved without class_names.
fn default_class_names() -> Vec<String> {
DEFAULT_CLASSES.iter().map(|s| s.to_string()).collect()
}

impl Default for AdaptiveModel {
fn default() -> Self {
let n_classes = DEFAULT_CLASSES.len();
Self {
class_stats: Vec::new(),
weights: vec![[0.0; N_FEATURES + 1]; N_CLASSES],
weights: vec![vec![0.0; N_FEATURES + 1]; n_classes],
global_mean: [0.0; N_FEATURES],
global_std: [1.0; N_FEATURES],
trained_frames: 0,
training_accuracy: 0.0,
version: 1,
class_names: default_class_names(),
}
}
}

impl AdaptiveModel {
/// Classify a raw feature vector. Returns (class_label, confidence).
pub fn classify(&self, raw_features: &[f64; N_FEATURES]) -> (&'static str, f64) {
if self.weights.is_empty() || self.class_stats.is_empty() {
return ("present_still", 0.5);
pub fn classify(&self, raw_features: &[f64; N_FEATURES]) -> (String, f64) {
let n_classes = self.weights.len();
if n_classes == 0 || self.class_stats.is_empty() {
return ("present_still".to_string(), 0.5);
}

// Normalise features.
Expand All @@ -163,8 +178,8 @@ impl AdaptiveModel {
}

// Compute logits: w·x + b for each class.
let mut logits = [0.0f64; N_CLASSES];
for c in 0..N_CLASSES.min(self.weights.len()) {
let mut logits: Vec<f64> = vec![0.0; n_classes];
for c in 0..n_classes {
let w = &self.weights[c];
let mut z = w[N_FEATURES]; // bias
for i in 0..N_FEATURES {
Expand All @@ -176,16 +191,20 @@ impl AdaptiveModel {
// Softmax.
let max_logit = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exp_sum: f64 = logits.iter().map(|z| (z - max_logit).exp()).sum();
let mut probs = [0.0f64; N_CLASSES];
for c in 0..N_CLASSES {
let mut probs: Vec<f64> = vec![0.0; n_classes];
for c in 0..n_classes {
probs[c] = ((logits[c] - max_logit).exp()) / exp_sum;
}

// Pick argmax.
let (best_c, best_p) = probs.iter().enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap();
let label = if best_c < CLASSES.len() { CLASSES[best_c] } else { "present_still" };
let label = if best_c < self.class_names.len() {
self.class_names[best_c].clone()
} else {
"present_still".to_string()
};
(label, *best_p)
}

Expand Down Expand Up @@ -228,48 +247,88 @@ fn load_recording(path: &Path, class_idx: usize) -> Vec<Sample> {
}).collect()
}

/// Map a recording filename to a class index.
fn classify_recording_name(name: &str) -> Option<usize> {
/// Map a recording filename to a class name (String).
/// Returns the discovered class name for the file, or None if it cannot be determined.
fn classify_recording_name(name: &str) -> Option<String> {
let lower = name.to_lowercase();
if lower.contains("empty") || lower.contains("absent") { Some(0) }
else if lower.contains("still") || lower.contains("sitting") || lower.contains("standing") { Some(1) }
else if lower.contains("walking") || lower.contains("moving") { Some(2) }
else if lower.contains("active") || lower.contains("exercise") || lower.contains("running") { Some(3) }
else { None }
// Strip "train_" prefix and ".jsonl" suffix, then extract the class label.
// Convention: train_<class>_<description>.jsonl
// The class is the first segment after "train_" that matches a known pattern,
// or the entire middle portion if no pattern matches.

// Check common patterns first for backward compat
if lower.contains("empty") || lower.contains("absent") { return Some("absent".into()); }
if lower.contains("still") || lower.contains("sitting") || lower.contains("standing") { return Some("present_still".into()); }
if lower.contains("walking") || lower.contains("moving") { return Some("present_moving".into()); }
if lower.contains("active") || lower.contains("exercise") || lower.contains("running") { return Some("active".into()); }

// Fallback: extract class from filename structure train_<class>_*.jsonl
let stem = lower.trim_start_matches("train_").trim_end_matches(".jsonl");
let class_name = stem.split('_').next().unwrap_or(stem);
if !class_name.is_empty() {
Some(class_name.to_string())
} else {
None
}
}

/// Train a model from labeled JSONL recordings in a directory.
///
/// Recordings are matched to classes by filename pattern:
/// - `*empty*` / `*absent*` → absent (0)
/// - `*still*` / `*sitting*` → present_still (1)
/// - `*walking*` / `*moving*` → present_moving (2)
/// - `*active*` / `*exercise*`→ active (3)
/// Recordings are matched to classes by filename pattern. Classes are discovered
/// dynamically from the training data filenames:
/// - `*empty*` / `*absent*` → absent
/// - `*still*` / `*sitting*` → present_still
/// - `*walking*` / `*moving*` → present_moving
/// - `*active*` / `*exercise*`→ active
/// - Any other `train_<class>_*.jsonl` → <class>
pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, String> {
// Scan for train_* files.
let mut samples: Vec<Sample> = Vec::new();
let entries = std::fs::read_dir(recordings_dir)
.map_err(|e| format!("Cannot read {}: {}", recordings_dir.display(), e))?;

for entry in entries.flatten() {
// First pass: scan filenames to discover all unique class names.
let entries: Vec<_> = std::fs::read_dir(recordings_dir)
.map_err(|e| format!("Cannot read {}: {}", recordings_dir.display(), e))?
.flatten()
.collect();

let mut class_map: HashMap<String, usize> = HashMap::new();
let mut class_names: Vec<String> = Vec::new();

// Collect (entry, class_name) pairs for files that match.
let mut file_classes: Vec<(PathBuf, String, String)> = Vec::new(); // (path, fname, class_name)
for entry in &entries {
let fname = entry.file_name().to_string_lossy().to_string();
if !fname.starts_with("train_") || !fname.ends_with(".jsonl") {
continue;
}
if let Some(class_idx) = classify_recording_name(&fname) {
let loaded = load_recording(&entry.path(), class_idx);
eprintln!(" Loaded {}: {} frames → class '{}'",
fname, loaded.len(), CLASSES[class_idx]);
samples.extend(loaded);
if let Some(class_name) = classify_recording_name(&fname) {
if !class_map.contains_key(&class_name) {
let idx = class_names.len();
class_map.insert(class_name.clone(), idx);
class_names.push(class_name.clone());
}
file_classes.push((entry.path(), fname, class_name));
}
}

let n_classes = class_names.len();
if n_classes == 0 {
return Err("No training samples found. Record data with train_* prefix.".into());
}

// Second pass: load recordings with the discovered class indices.
let mut samples: Vec<Sample> = Vec::new();
for (path, fname, class_name) in &file_classes {
let class_idx = class_map[class_name];
let loaded = load_recording(path, class_idx);
eprintln!(" Loaded {}: {} frames → class '{}'",
fname, loaded.len(), class_name);
samples.extend(loaded);
}

if samples.is_empty() {
return Err("No training samples found. Record data with train_* prefix.".into());
}

let n = samples.len();
eprintln!("Total training samples: {n}");
eprintln!("Total training samples: {n} across {n_classes} classes: {:?}", class_names);

// ── Compute global normalisation stats ──
let mut global_mean = [0.0f64; N_FEATURES];
Expand All @@ -289,9 +348,9 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
}

// ── Compute per-class statistics ──
let mut class_sums = vec![[0.0f64; N_FEATURES]; N_CLASSES];
let mut class_sq = vec![[0.0f64; N_FEATURES]; N_CLASSES];
let mut class_counts = vec![0usize; N_CLASSES];
let mut class_sums = vec![[0.0f64; N_FEATURES]; n_classes];
let mut class_sq = vec![[0.0f64; N_FEATURES]; n_classes];
let mut class_counts = vec![0usize; n_classes];
for s in &samples {
let c = s.class_idx;
class_counts[c] += 1;
Expand All @@ -302,7 +361,7 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
}

let mut class_stats = Vec::new();
for c in 0..N_CLASSES {
for c in 0..n_classes {
let cnt = class_counts[c].max(1) as f64;
let mut mean = [0.0; N_FEATURES];
let mut stddev = [0.0; N_FEATURES];
Expand All @@ -311,7 +370,7 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
stddev[i] = ((class_sq[c][i] / cnt) - mean[i] * mean[i]).max(0.0).sqrt();
}
class_stats.push(ClassStats {
label: CLASSES[c].to_string(),
label: class_names[c].clone(),
count: class_counts[c],
mean,
stddev,
Expand All @@ -328,7 +387,7 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
}).collect();

// ── Train logistic regression via mini-batch SGD ──
let mut weights = vec![[0.0f64; N_FEATURES + 1]; N_CLASSES];
let mut weights: Vec<Vec<f64>> = vec![vec![0.0f64; N_FEATURES + 1]; n_classes];
let lr = 0.1;
let epochs = 200;
let batch_size = 32;
Expand All @@ -348,36 +407,36 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
}

let mut epoch_loss = 0.0f64;
let mut batch_count = 0;
let mut _batch_count = 0;

for batch_start in (0..norm_samples.len()).step_by(batch_size) {
let batch_end = (batch_start + batch_size).min(norm_samples.len());
let batch = &norm_samples[batch_start..batch_end];

// Accumulate gradients.
let mut grad = vec![[0.0f64; N_FEATURES + 1]; N_CLASSES];
let mut grad: Vec<Vec<f64>> = vec![vec![0.0f64; N_FEATURES + 1]; n_classes];

for (x, target) in batch {
// Forward: softmax.
let mut logits = [0.0f64; N_CLASSES];
for c in 0..N_CLASSES {
let mut logits: Vec<f64> = vec![0.0; n_classes];
for c in 0..n_classes {
logits[c] = weights[c][N_FEATURES]; // bias
for i in 0..N_FEATURES {
logits[c] += weights[c][i] * x[i];
}
}
let max_l = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exp_sum: f64 = logits.iter().map(|z| (z - max_l).exp()).sum();
let mut probs = [0.0f64; N_CLASSES];
for c in 0..N_CLASSES {
let mut probs: Vec<f64> = vec![0.0; n_classes];
for c in 0..n_classes {
probs[c] = ((logits[c] - max_l).exp()) / exp_sum;
}

// Cross-entropy loss.
epoch_loss += -(probs[*target].max(1e-15)).ln();

// Gradient: prob - one_hot(target).
for c in 0..N_CLASSES {
for c in 0..n_classes {
let delta = probs[c] - if c == *target { 1.0 } else { 0.0 };
for i in 0..N_FEATURES {
grad[c][i] += delta * x[i];
Expand All @@ -389,12 +448,12 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
// Update weights.
let bs = batch.len() as f64;
let current_lr = lr * (1.0 - epoch as f64 / epochs as f64); // linear decay
for c in 0..N_CLASSES {
for c in 0..n_classes {
for i in 0..=N_FEATURES {
weights[c][i] -= current_lr * grad[c][i] / bs;
}
}
batch_count += 1;
_batch_count += 1;
}

if epoch % 50 == 0 || epoch == epochs - 1 {
Expand All @@ -406,8 +465,8 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
// ── Evaluate accuracy ──
let mut correct = 0;
for (x, target) in &norm_samples {
let mut logits = [0.0f64; N_CLASSES];
for c in 0..N_CLASSES {
let mut logits: Vec<f64> = vec![0.0; n_classes];
for c in 0..n_classes {
logits[c] = weights[c][N_FEATURES];
for i in 0..N_FEATURES {
logits[c] += weights[c][i] * x[i];
Expand All @@ -422,12 +481,12 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
eprintln!("Training accuracy: {correct}/{n} = {accuracy:.1}%");

// ── Per-class accuracy ──
let mut class_correct = vec![0usize; N_CLASSES];
let mut class_total = vec![0usize; N_CLASSES];
let mut class_correct = vec![0usize; n_classes];
let mut class_total = vec![0usize; n_classes];
for (x, target) in &norm_samples {
class_total[*target] += 1;
let mut logits = [0.0f64; N_CLASSES];
for c in 0..N_CLASSES {
let mut logits: Vec<f64> = vec![0.0; n_classes];
for c in 0..n_classes {
logits[c] = weights[c][N_FEATURES];
for i in 0..N_FEATURES {
logits[c] += weights[c][i] * x[i];
Expand All @@ -438,9 +497,9 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
.unwrap().0;
if pred == *target { class_correct[*target] += 1; }
}
for c in 0..N_CLASSES {
for c in 0..n_classes {
let tot = class_total[c].max(1);
eprintln!(" {}: {}/{} ({:.0}%)", CLASSES[c], class_correct[c], tot,
eprintln!(" {}: {}/{} ({:.0}%)", class_names[c], class_correct[c], tot,
class_correct[c] as f64 / tot as f64 * 100.0);
}

Expand All @@ -452,6 +511,7 @@ pub fn train_from_recordings(recordings_dir: &Path) -> Result<AdaptiveModel, Str
trained_frames: n,
training_accuracy: accuracy,
version: 1,
class_names,
})
}

Expand Down
Loading