diff --git a/src/Components/Engine/Engine.Dockerfile b/src/Components/Engine/Engine.Dockerfile index eb56c62ff..5f276e668 100644 --- a/src/Components/Engine/Engine.Dockerfile +++ b/src/Components/Engine/Engine.Dockerfile @@ -71,7 +71,7 @@ ENV PATH="/opt/venv/bin:$PATH" COPY --from=echo_engine_builder /build/echo_engine.sh ./ COPY yamnet_dir/ ./yamnet_dir/ -COPY ./echo_engine.py ./ +COPY ./echo_engine_iot.py ./echo_engine.py COPY ./echo_engine.json ./ COPY ./echo_credentials.json ./ COPY ./helpers ./helpers diff --git a/src/Components/Engine/echo_engine.json b/src/Components/Engine/echo_engine.json index c835bf3db..a5bfe3ef3 100644 --- a/src/Components/Engine/echo_engine.json +++ b/src/Components/Engine/echo_engine.json @@ -20,6 +20,7 @@ "MQTT_PUBLISH_URL": "projectecho/engine/2", "GCLOUD_PROJECT": "sit-23t1-project-echo-25288b9", "BUCKET_NAME": "project_echo_bucket_1", + "API_URL": "http://ts-api-cont:9000/engine/event", "MODEL_SERVER": "http://ts-echo-model-cont:8501/v1/models/echo_model/versions/1:predict", "WEATHER_SERVER": "http://ts-echo-model-cont:8501/v1/models/weather_model/versions/1:predict", "DB_HOSTNAME": "ts-mongodb-cont", diff --git a/src/Components/Engine/echo_engine_iot.py b/src/Components/Engine/echo_engine_iot.py index 48d610d47..83166d529 100644 --- a/src/Components/Engine/echo_engine_iot.py +++ b/src/Components/Engine/echo_engine_iot.py @@ -32,7 +32,6 @@ import base64 import io import json -import base64 import tempfile import pickle import math @@ -514,7 +513,7 @@ def echo_api_send_detection_event(self, audio_event, sample_rate, predicted_clas "sampleRate": sample_rate } - url = 'http://ts-api-cont:9000/engine/event' + url = self.config['API_URL'] x = requests.post(url, json = detection_event) print(x.text) @@ -727,12 +726,41 @@ def on_iot_disconnect(self, client, userdata, rc): def on_iot_subscribe(self, client, userdata, mid, granted_qos): print(f"IoT MQTT Subscribed with mid {mid}, qos {granted_qos}") + def _handle_edge_prediction(self, payload: dict): + """Forward a pre-computed prediction from an edge-inference IoT device.""" + gps_data = payload.get("gps_data", {}) + lat = gps_data.get("lat") + lon = gps_data.get("lon") + if lat is None or lon is None: + print("Edge prediction missing GPS coordinates, skipping.") + return + + lla = [lat, lon, 0.0] + audio_event = { + "timestamp": payload.get("timestamp", str(int(time.time()))), + "sensorId": payload.get("sensor_id", "unknown_edge_node"), + "microphoneLLA": lla, + "animalEstLLA": lla, + "animalTrueLLA": lla, + "animalLLAUncertainty": payload.get("gps_uncertainty", 10.0), + "audioClip": "", + } + species = payload.get("species", "Unknown") + confidence = payload.get("confidence", 0.0) + print(f"Edge prediction received: {species} ({confidence}%)", flush=True) + self.echo_api_send_detection_event(audio_event, 0, species, confidence) + def on_iot_message(self, client, userdata, msg): - print("Received IoT audio message...") + print("Received IoT message...") try: payload = json.loads(msg.payload) - # Validate required field + # Edge-inference path: device already ran the model, just forward to API + if payload.get("type") == "prediction": + self._handle_edge_prediction(payload) + return + + # Server-inference path: raw audio, run ML pipeline here if "audio_file" not in payload or not payload["audio_file"]: print("Invalid IoT payload: missing or empty audio_file.") return diff --git a/src/Components/Engine/requirements.txt b/src/Components/Engine/requirements.txt index e2c9be733..9bb1a6901 100644 --- a/src/Components/Engine/requirements.txt +++ b/src/Components/Engine/requirements.txt @@ -8,6 +8,7 @@ librosa matplotlib numpy paho-mqtt==1.6.1 +scikit-learn pandas opencv-python pymongo[srv] diff --git a/src/Components/Engine/test_iot_integration.py b/src/Components/Engine/test_iot_integration.py index d9e641126..d13a09240 100644 --- a/src/Components/Engine/test_iot_integration.py +++ b/src/Components/Engine/test_iot_integration.py @@ -78,7 +78,7 @@ def _patched_open(path, *args, **kwargs): builtins.open = _patched_open -from echo_engine import EchoEngine # noqa: E402 +from echo_engine_iot import EchoEngine # noqa: E402 builtins.open = _real_open diff --git a/src/Components/IoT/edge_inference/README.md b/src/Components/IoT/edge_inference/README.md new file mode 100644 index 000000000..785944429 --- /dev/null +++ b/src/Components/IoT/edge_inference/README.md @@ -0,0 +1,216 @@ +# IoT Edge Inference — EfficientNetV2 TFLite on Raspberry Pi + +## Overview + +This document covers all changes made to support **on-device ML inference** for Project Echo IoT nodes, and explains how to deploy and run the edge client. + +Previously, IoT devices sent raw audio over MQTT to the engine server, which ran the model. Now a Raspberry Pi can run the EfficientNetV2 TFLite model locally and publish only the prediction result — no audio is transmitted over the network. + +--- + +## Architecture + +### Before (server-side inference) +``` +RPi → [raw audio bytes, base64, ~500KB/clip] → MQTT → Engine → TF Serving → API → MongoDB +``` + +### After (edge inference — this PR) +``` +RPi → [species + confidence + GPS, ~300 bytes] → MQTT → Engine → API → MongoDB +``` + +Both modes are supported simultaneously on the same MQTT topic. The engine detects which path to use based on the payload. + +--- + +## Files Changed + +### New files + +| File | Purpose | +|---|---| +| `src/Components/IoT/edge_inference/iot_edge_client.py` | RPi script — records audio, runs TFLite, publishes prediction | +| `src/Components/IoT/edge_inference/requirements.txt` | RPi Python dependencies | +| `src/Components/IoT/edge_inference/README.md` | This file | + +### Modified files + +| File | Change | +|---|---| +| `src/Components/Engine/echo_engine_iot.py` | Added `_handle_edge_prediction()` method; `on_iot_message()` now routes on `payload["type"]` | +| `src/Components/Engine/echo_engine_iot.py` | Removed duplicate `import base64` | +| `src/Components/Engine/echo_engine_iot.py` | Hardcoded API URL replaced with `self.config['API_URL']` | +| `src/Components/Engine/echo_engine.json` | Added `"API_URL"` field | +| `src/Components/Engine/Engine.Dockerfile` | Now copies `echo_engine_iot.py` as `echo_engine.py` so IoT engine runs in Docker | +| `src/Components/Engine/requirements.txt` | Added `scikit-learn` (was imported but missing) | +| `src/Components/Engine/test_iot_integration.py` | Updated import to `from echo_engine_iot import EchoEngine` so tests run locally | +| `src/Prototypes/engine/torch_impl/light_echo_engine.json` | Fixed `MQTT_CLIENT_URL` from `"mqtt-broker"` → `"ts-mqtt-server-cont"` (was causing DNS failure) | +| `src/Prototypes/engine/torch_impl/requirements.txt` | Added `paho-mqtt==1.6.1`, `pymongo`, `geopy`, `google-cloud-storage` | + +--- + +## Engine: How the Two Modes Work + +The engine's `on_iot_message()` handler now routes based on the payload `type` field: + +``` +MQTT message received + │ + ▼ +payload["type"] == "prediction"? + Yes ──► _handle_edge_prediction() ──► POST /engine/event (no ML on server) + No ──► audio_file present? + Yes ──► combined_pipeline() → TF Serving → POST /engine/event + No ──► drop message +``` + +**Edge prediction payload** (sent by `iot_edge_client.py`): +```json +{ + "type": "prediction", + "timestamp": "1746700800", + "sensor_id": "rpi_edge_node_1", + "species": "Acanthiza chrysorrhoa", + "confidence": 93.37, + "top5": [ + {"label": "Acanthiza chrysorrhoa", "confidence": 93.37}, + {"label": "Malurus cyaneus", "confidence": 3.12} + ], + "gps_data": {"lat": -37.8136, "lon": 144.9631}, + "gps_uncertainty": 10.0 +} +``` + +**Raw audio payload** (sent by the original `client.py` — still supported): +```json +{ + "audio_file": "", + "gps_data": {"lat": -37.8136, "lon": 144.9631}, + "sensor_id": "rpi_node_1" +} +``` + +--- + +## Deploying the Edge Client on a Raspberry Pi + +### 1. Copy model files to the RPi + +You need three files from `src/Prototypes/engine/torch_impl/Integrate_EfficientNetV2_Engine/_trained_models/`: + +``` +efficientnetv2_project_echo.tflite +class_mapping.json +preprocess_config.json +``` + +Create a `models/` directory next to `iot_edge_client.py` and place them there: + +``` +edge_inference/ +├── iot_edge_client.py +├── requirements.txt +├── models/ +│ ├── efficientnetv2_project_echo.tflite +│ ├── class_mapping.json +│ └── preprocess_config.json +``` + +### 2. Install dependencies + +On the Raspberry Pi (Python 3.9+): + +```bash +pip install -r requirements.txt +``` + +> `tflite-runtime` is the lightweight TFLite package (~1 MB) designed for embedded devices. +> If you have full TensorFlow installed, the script falls back to `tensorflow.lite` automatically. + +### 3. Run with fixed GPS coordinates + +Use this during development or when no GPS module is attached: + +```bash +python iot_edge_client.py \ + --broker broker.hivemq.com \ + --port 1883 \ + --topic iot/data/test \ + --sensor-id rpi_edge_node_1 \ + --lat -37.8136 \ + --lon 144.9631 \ + --interval 10 +``` + +### 4. Run with live GPS (gpsd) + +If a GPS module is connected and `gpsd` is running: + +```bash +python iot_edge_client.py \ + --sensor-id rpi_edge_node_1 \ + --use-gps \ + --interval 10 +``` + +### CLI arguments + +| Argument | Default | Description | +|---|---|---| +| `--broker` | `broker.hivemq.com` | MQTT broker hostname | +| `--port` | `1883` | MQTT broker port | +| `--topic` | `iot/data/test` | MQTT topic to publish on | +| `--sensor-id` | `rpi_edge_node_1` | Unique identifier for this RPi node | +| `--lat` | `-37.8136` | Fixed latitude (used when `--use-gps` not set) | +| `--lon` | `144.9631` | Fixed longitude (used when `--use-gps` not set) | +| `--gps-uncertainty` | `10.0` | GPS accuracy estimate in metres | +| `--interval` | `10.0` | Seconds between recordings | +| `--use-gps` | off | Read GPS from `gpsd` instead of fixed coordinates | + +--- + +## Running Unit Tests + +Unit tests cover payload validation, happy-path handler, config keys, and startup order. They mock all heavy dependencies so no Docker stack or GPU is needed. + +```bash +cd src/Components/Engine +python -m pytest test_iot_integration.py -v +``` + +Expected output: +``` +test_iot_integration.py::TestIoTPayloadValidation::test_empty_audio_file_rejected PASSED +test_iot_integration.py::TestIoTPayloadValidation::test_missing_audio_file_rejected PASSED +test_iot_integration.py::TestIoTPayloadValidation::test_missing_gps_data_rejected PASSED +test_iot_integration.py::TestIoTPayloadValidation::test_missing_lat_rejected PASSED +test_iot_integration.py::TestIoTPayloadValidation::test_missing_lon_rejected PASSED +test_iot_integration.py::TestIoTMessageHandler::... PASSED +... +12 passed +``` + +## End-to-End Test (no RPi needed) + +Publish a synthetic test event from any machine to verify the engine picks it up: + +```bash +cd src/Components/Engine +python test_iot_publisher.py +``` + +This sends a 440 Hz WAV clip to `broker.hivemq.com` on `iot/data/test`. If the engine container is running, it will log: +``` +Received IoT message... +IoT Predicted class : +IoT Predicted probability : +``` + +--- + +## Notes + +- The MQTT broker (`broker.hivemq.com`) is a public test broker. For production, configure a private broker and set `IOT_MQTT_BROKER` in `echo_engine.json`. +- The model expects audio at **32 kHz** — the recording sample rate in `iot_edge_client.py` is set automatically from `preprocess_config.json`. +- On RPi 4 / RPi 5, TFLite inference on a 5-second clip takes approximately 200–800 ms. diff --git a/src/Components/IoT/edge_inference/iot_edge_client.py b/src/Components/IoT/edge_inference/iot_edge_client.py new file mode 100644 index 000000000..672fdade7 --- /dev/null +++ b/src/Components/IoT/edge_inference/iot_edge_client.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python3 +""" +IoT Edge Inference Client + +Records audio from the microphone, classifies it on-device using the +EfficientNetV2 TFLite model, and publishes only the prediction result +(species + confidence + GPS) over MQTT — no audio is transmitted. + +Model files expected in ./models/: + efficientnetv2_project_echo.tflite + class_mapping.json + preprocess_config.json + +Usage: + python iot_edge_client.py [--broker BROKER] [--port PORT] [--topic TOPIC] + [--sensor-id ID] [--lat LAT] [--lon LON] + [--interval SECS] [--use-gps] +""" + +import argparse +import json +import time +import threading +from pathlib import Path + +import numpy as np +import sounddevice as sd +import librosa +import paho.mqtt.client as mqtt + +try: + import tflite_runtime.interpreter as tflite +except ImportError: + import tensorflow.lite as tflite # fallback if full TF installed + +# --------------------------------------------------------------------------- +# Paths (relative to this script) +# --------------------------------------------------------------------------- +_DIR = Path(__file__).parent +MODEL_PATH = _DIR / "models" / "efficientnetv2_project_echo.tflite" +CLASS_MAPPING_PATH = _DIR / "models" / "class_mapping.json" +PREPROCESS_CFG_PATH = _DIR / "models" / "preprocess_config.json" + + +# --------------------------------------------------------------------------- +# Model loading +# --------------------------------------------------------------------------- +def load_model(): + with open(PREPROCESS_CFG_PATH) as f: + cfg = json.load(f) + with open(CLASS_MAPPING_PATH) as f: + index_to_label = json.load(f)["index_to_label"] + + interpreter = tflite.Interpreter(model_path=str(MODEL_PATH)) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + print(f"TFLite model loaded. Input shape: {input_details[0]['shape']}", flush=True) + return interpreter, input_details, output_details, cfg, index_to_label + + +# --------------------------------------------------------------------------- +# Audio recording +# --------------------------------------------------------------------------- +def record_audio(duration_s: float, sample_rate: int) -> np.ndarray: + print(f"Recording {duration_s}s at {sample_rate} Hz...", flush=True) + audio = sd.rec( + int(duration_s * sample_rate), + samplerate=sample_rate, + channels=1, + dtype="float32", + ) + sd.wait() + return audio.flatten() + + +# --------------------------------------------------------------------------- +# Preprocessing (matches efficientnetv2_preprocess_audio_bytes in the engine) +# --------------------------------------------------------------------------- +def preprocess(audio: np.ndarray, cfg: dict) -> np.ndarray: + target_sr = cfg["target_sr"] + duration_s = cfg["duration_s"] + n_mels = cfg["n_mels"] + hop_length = cfg["hop_length"] + fmin = cfg["fmin"] + fmax = cfg["fmax"] + + target_len = int(target_sr * duration_s) + if len(audio) < target_len: + audio = np.pad(audio, (0, target_len - len(audio)), mode="constant") + else: + audio = audio[:target_len] + + mel = librosa.feature.melspectrogram( + y=audio, sr=target_sr, n_mels=n_mels, + hop_length=hop_length, fmin=fmin, fmax=fmax, + ) + mel_db = librosa.power_to_db(mel, ref=np.max).astype(np.float32) + mel_db = (mel_db - np.mean(mel_db)) / (np.std(mel_db) + 1e-6) + + # NCHW: [1, 1, n_mels, time_frames] + return mel_db[np.newaxis, np.newaxis, :, :].astype(np.float32) + + +# --------------------------------------------------------------------------- +# Inference +# --------------------------------------------------------------------------- +def run_inference( + interpreter, input_details, output_details, + x: np.ndarray, index_to_label: dict, +) -> tuple[str, float, list]: + expected_shape = tuple(input_details[0]["shape"]) + + if tuple(x.shape) != expected_shape: + # Try NHWC layout if model was exported that way + x_nhwc = np.transpose(x, (0, 2, 3, 1)) + if tuple(x_nhwc.shape) == expected_shape: + x = x_nhwc + else: + raise ValueError( + f"Shape mismatch: model expects {expected_shape}, " + f"got NCHW {x.shape} or NHWC {x_nhwc.shape}" + ) + + interpreter.set_tensor(input_details[0]["index"], x) + interpreter.invoke() + logits = interpreter.get_tensor(output_details[0]["index"])[0].astype(np.float32) + + exp_vals = np.exp(logits - np.max(logits)) + probs = exp_vals / np.sum(exp_vals) + + top_indices = np.argsort(probs)[::-1][:5] + predicted_index = int(top_indices[0]) + confidence = round(float(probs[predicted_index]) * 100.0, 2) + species = index_to_label[str(predicted_index)] + + top5 = [ + { + "label": index_to_label[str(int(i))], + "confidence": round(float(probs[i]) * 100.0, 2), + } + for i in top_indices + ] + return species, confidence, top5 + + +# --------------------------------------------------------------------------- +# GPS (optional — uses gpsd if available, otherwise falls back to fixed coords) +# --------------------------------------------------------------------------- +def get_gps(fallback_lat: float, fallback_lon: float) -> dict: + try: + from gps3 import gps3 + gps_socket = gps3.GPSDSocket() + data_stream = gps3.DataStream() + gps_socket.connect() + gps_socket.watch() + for new_data in gps_socket: + if new_data: + data_stream.unpack(new_data) + lat = data_stream.TPV.get("lat") + lon = data_stream.TPV.get("lon") + mode = data_stream.TPV.get("mode") + if mode == 3 and lat != "n/a" and lon != "n/a": + return {"lat": float(lat), "lon": float(lon)} + except Exception: + pass + return {"lat": fallback_lat, "lon": fallback_lon} + + +# --------------------------------------------------------------------------- +# MQTT +# --------------------------------------------------------------------------- +def connect_mqtt(broker: str, port: int) -> mqtt.Client: + client = mqtt.Client() + connected = threading.Event() + + def on_connect(c, userdata, flags, rc): + if rc == 0: + print(f"MQTT connected to {broker}:{port}", flush=True) + connected.set() + else: + print(f"MQTT connection failed (rc={rc})", flush=True) + + client.on_connect = on_connect + client.connect(broker, port, keepalive=60) + client.loop_start() + + if not connected.wait(timeout=10): + raise RuntimeError(f"Could not connect to MQTT broker {broker}:{port}") + return client + + +# --------------------------------------------------------------------------- +# Payload +# --------------------------------------------------------------------------- +def build_payload( + species: str, confidence: float, top5: list, + sensor_id: str, gps: dict, gps_uncertainty: float, +) -> dict: + return { + "type": "prediction", + "timestamp": str(int(time.time())), + "sensor_id": sensor_id, + "species": species, + "confidence": confidence, + "top5": top5, + "gps_data": gps, + "gps_uncertainty": gps_uncertainty, + } + + +# --------------------------------------------------------------------------- +# Main loop +# --------------------------------------------------------------------------- +def main(): + parser = argparse.ArgumentParser(description="IoT edge inference client") + parser.add_argument("--broker", default="broker.hivemq.com") + parser.add_argument("--port", type=int, default=1883) + parser.add_argument("--topic", default="iot/data/test") + parser.add_argument("--sensor-id", default="rpi_edge_node_1") + parser.add_argument("--lat", type=float, default=-37.8136) + parser.add_argument("--lon", type=float, default=144.9631) + parser.add_argument("--gps-uncertainty", type=float, default=10.0) + parser.add_argument("--interval", type=float, default=10.0, + help="Seconds between recordings (default: 10)") + parser.add_argument("--use-gps", action="store_true", + help="Read GPS from gpsd instead of using fixed --lat/--lon") + args = parser.parse_args() + + interpreter, input_details, output_details, cfg, index_to_label = load_model() + client = connect_mqtt(args.broker, args.port) + + print("Edge inference client running. Ctrl+C to stop.\n", flush=True) + try: + while True: + gps = get_gps(args.lat, args.lon) if args.use_gps else {"lat": args.lat, "lon": args.lon} + + audio = record_audio(cfg["duration_s"], cfg["target_sr"]) + x = preprocess(audio, cfg) + species, confidence, top5 = run_inference( + interpreter, input_details, output_details, x, index_to_label + ) + + print(f" → {species} {confidence:.1f}%", flush=True) + for entry in top5[1:]: + print(f" {entry['label']}: {entry['confidence']:.1f}%", flush=True) + + payload = build_payload( + species, confidence, top5, + args.sensor_id, gps, args.gps_uncertainty, + ) + client.publish(args.topic, json.dumps(payload), qos=1) + print(f"Published to {args.topic}\n", flush=True) + + time.sleep(args.interval) + + except KeyboardInterrupt: + print("Stopping.", flush=True) + finally: + client.loop_stop() + client.disconnect() + + +if __name__ == "__main__": + main() diff --git a/src/Components/IoT/edge_inference/requirements.txt b/src/Components/IoT/edge_inference/requirements.txt new file mode 100644 index 000000000..b88b66f04 --- /dev/null +++ b/src/Components/IoT/edge_inference/requirements.txt @@ -0,0 +1,13 @@ +# Raspberry Pi / edge device dependencies +# Install with: pip install -r requirements.txt +# +# tflite-runtime is the lightweight TFLite package for RPi (no full TF needed). +# If running on a dev machine with TensorFlow installed, tflite-runtime can be +# omitted and the fallback import in iot_edge_client.py will use tensorflow.lite. + +tflite-runtime +numpy +librosa +sounddevice +paho-mqtt==1.6.1 +gps3 # optional — only needed if --use-gps flag is used diff --git a/src/Prototypes/engine/torch_impl/requirements.txt b/src/Prototypes/engine/torch_impl/requirements.txt index 28ddd43cb..3228878bc 100644 --- a/src/Prototypes/engine/torch_impl/requirements.txt +++ b/src/Prototypes/engine/torch_impl/requirements.txt @@ -1,13 +1,24 @@ tensorflow==2.10.0 librosa==0.9.2 soundfile==0.13.1 -requests==2.32.5 -pymongo==4.17.0 -google-cloud-storage==3.9.0 -geopy==2.4.1 -pandas==2.3.3 -scikit-learn==1.6.1 -diskcache==5.6.3 +soxr==0.5.0.post1 +sympy==1.14.0 +tensorboard==2.20.0 +tensorboard-data-server==0.7.2 +tensorboardx==2.6.4 +threadpoolctl==3.6.0 +torch==2.8.0 +torchaudio==2.8.0 +torchvision==0.23.0 +tqdm==4.67.1 +triton==3.4.0 +typing-extensions==4.14.1 +tzdata==2025.2 +umap-learn==0.5.9.post2 +urllib3==2.5.0 +werkzeug==3.1.3 +wheel==0.45.1 +gradio paho-mqtt==1.6.1 pymongo geopy