From 1b1e92f927d1a4a39ef5daae2dde9a73de067029 Mon Sep 17 00:00:00 2001 From: Yiyi Xu Date: Tue, 23 Jun 2026 17:03:52 +0200 Subject: [PATCH 1/7] Fix drone emergency stop on deployment interrupt --- swarm_gpt/core/backend.py | 12 ++++++- swarm_gpt/core/drone_swarm.py | 16 ++++++++-- tests/unit/test_backend.py | 58 ++++++++++++++++++++++++++++++++++ tests/unit/test_drone_swarm.py | 41 ++++++++++++++++++++++++ 4 files changed, 124 insertions(+), 3 deletions(-) diff --git a/swarm_gpt/core/backend.py b/swarm_gpt/core/backend.py index e34c9e2..6dc6f27 100644 --- a/swarm_gpt/core/backend.py +++ b/swarm_gpt/core/backend.py @@ -378,9 +378,19 @@ def deploy(self, drone_ids: list[int] | None = None) -> bool: ) self.music_manager.stop() swarm.goto(final_pos_dict, duration=2.0) # Transition from ideal point to hover pos - if self.settings["land_on_docks"]: # Commented out for demo + if self.settings["land_on_docks"]: # Commented out for demo swarm.goto(final_pos_dict, duration=3.0) # Hovering swarm.land(duration=1.5) # Landing + except BaseException: + try: + swarm.emergency_stop() + except Exception as e: + logger.error(f"Emergency stop after deployment interrupt failed: {e}") + try: + self.music_manager.stop() + except Exception as e: + logger.error(f"Stopping music after deployment interrupt failed: {e}") + raise finally: swarm.close() logger.info("Deployment successful") diff --git a/swarm_gpt/core/drone_swarm.py b/swarm_gpt/core/drone_swarm.py index eebfbf0..e732d8e 100644 --- a/swarm_gpt/core/drone_swarm.py +++ b/swarm_gpt/core/drone_swarm.py @@ -348,8 +348,20 @@ async def _close() -> None: def _run(self, coroutine: Awaitable[Any]) -> Any: """Run a cflib2 coroutine on the swarm event loop.""" if self._loop_thread is not None: - return asyncio.run_coroutine_threadsafe(coroutine, self._loop).result() - return self._loop.run_until_complete(coroutine) + future = asyncio.run_coroutine_threadsafe(coroutine, self._loop) + try: + return future.result() + except BaseException: + future.cancel() + raise + + task = self._loop.create_task(coroutine) + try: + return self._loop.run_until_complete(task) + except BaseException: + task.cancel() + self._loop.run_until_complete(asyncio.gather(task, return_exceptions=True)) + raise def _start_estimator_updater(self) -> None: """Start continuous mocap updates after the radio links are connected.""" diff --git a/tests/unit/test_backend.py b/tests/unit/test_backend.py index 8a33cb0..749c335 100644 --- a/tests/unit/test_backend.py +++ b/tests/unit/test_backend.py @@ -1,7 +1,10 @@ from pathlib import Path +import numpy as np +import pytest from conftest import virtual_crazyswarm_config +import swarm_gpt.core.drone_swarm as drone_swarm from swarm_gpt.core.backend import AppBackend @@ -46,3 +49,58 @@ def test_preset_metadata_and_delete(tmp_path: Path): app.delete_preset(preset_id) assert not (preset_dir / preset_id).exists() + + +def test_deploy_emergency_stops_before_close_on_keyboard_interrupt( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class InterruptingSwarm: + instances = [] + + def __init__(self, drones: dict, *, lighthouse: bool) -> None: + self.drones_by_uri = {drone["uri"]: drone for drone in drones.values()} + self.calls: list[str] = [] + self.instances.append(self) + + def get_obs(self, uri: str) -> dict[str, np.ndarray]: + pos = np.asarray(self.drones_by_uri[uri]["pos"], dtype=float) + if "goto" in self.calls: + pos = pos + np.array([0.0, 0.0, 0.5]) + return {"pos": pos, "quat": np.array([0.0, 0.0, 0.0, 1.0])} + + def goto(self, target: dict[str, np.ndarray], duration: float = 3.0) -> None: + self.calls.append("goto") + + def is_active(self, uri: str) -> bool: + return uri in self.drones_by_uri + + def execute_choreography(self, *args: object, **kwargs: object) -> None: + self.calls.append("execute_choreography") + raise KeyboardInterrupt + + def emergency_stop(self) -> None: + self.calls.append("emergency_stop") + + def land(self, height: float = 0.0, duration: float = 3.0) -> None: + self.calls.append("land") + + def close(self) -> None: + self.calls.append("close") + + config_path = virtual_crazyswarm_config(n_drones=1) + app = AppBackend(config_file=config_path) + app.settings["lighthouse"] = True + app.settings["land_on_docks"] = False + app.music_manager.song = "Crazyflie Drones Theme" + app.music_manager.verify_libvlc = lambda: True + app.music_manager.play = lambda *, wait, start_s, end_s: True + app.music_manager.stop = lambda: None + app.waypoints = {"time": np.array([[0.0, 1.0]])} + app.splines[0] = lambda t: np.array([0.0, 0.0, 1.0]) + monkeypatch.setattr(drone_swarm, "DroneSwarm", InterruptingSwarm) + + with pytest.raises(KeyboardInterrupt): + app.deploy() + + swarm = InterruptingSwarm.instances[0] + assert swarm.calls == ["goto", "execute_choreography", "emergency_stop", "close"] diff --git a/tests/unit/test_drone_swarm.py b/tests/unit/test_drone_swarm.py index d96394a..dc46132 100644 --- a/tests/unit/test_drone_swarm.py +++ b/tests/unit/test_drone_swarm.py @@ -93,6 +93,47 @@ def make_swarm(uris: list[str]) -> DroneSwarm: return swarm +def test_run_cancels_threadsafe_command_when_keyboard_interrupt_escapes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class InterruptingFuture: + def __init__(self) -> None: + self.cancelled = False + + def result(self) -> None: + raise KeyboardInterrupt + + def cancel(self) -> None: + self.cancelled = True + + future = InterruptingFuture() + captured: dict[str, Any] = {} + swarm = make_swarm([]) + swarm._loop_thread = threading.Thread() + + async def command() -> None: + return None + + def run_coroutine_threadsafe( + coroutine: Any, loop: asyncio.AbstractEventLoop + ) -> InterruptingFuture: + captured["coroutine"] = coroutine + captured["loop"] = loop + return future + + monkeypatch.setattr(asyncio, "run_coroutine_threadsafe", run_coroutine_threadsafe) + + try: + with pytest.raises(KeyboardInterrupt): + swarm._run(command()) + finally: + captured["coroutine"].close() + swarm._loop.close() + + assert captured["loop"] is swarm._loop + assert future.cancelled + + def test_setpoint_sends_once_and_returns(): uris = ["radio://0/80/2M/E7E7E7E701", "radio://0/80/2M/E7E7E7E702"] swarm = make_swarm(uris) From 98cee043542d3e7000a42ffc3dbf710b9387f679 Mon Sep 17 00:00:00 2001 From: Yiyi Xu Date: Tue, 23 Jun 2026 17:40:58 +0200 Subject: [PATCH 2/7] attempt to estop again and adding it to frontend --- swarm_gpt/api/server.py | 13 ++++++ swarm_gpt/core/backend.py | 15 ++++++- swarm_gpt/core/drone_swarm.py | 31 ++++++++++---- tests/unit/test_api.py | 74 ++++++++++++++++++++++++++++++++++ tests/unit/test_backend.py | 21 ++++++++++ tests/unit/test_drone_swarm.py | 38 +++++++++++++++++ web/src/App.tsx | 39 ++++++++++++++++++ web/src/api.ts | 7 ++++ web/src/styles.css | 12 +++++- web/tests/player.spec.js | 48 ++++++++++++++++++++++ 10 files changed, 289 insertions(+), 9 deletions(-) diff --git a/swarm_gpt/api/server.py b/swarm_gpt/api/server.py index 568cc9a..bd6987b 100644 --- a/swarm_gpt/api/server.py +++ b/swarm_gpt/api/server.py @@ -472,6 +472,19 @@ def deploy(job_id: str) -> dict[str, Any]: _start_thread(job, lambda: _run_deploy_job(store, job)) return {"jobId": job.id} + @app.post("/api/jobs/{job_id}/emergency-stop") + def emergency_stop(job_id: str) -> dict[str, Any]: + try: + job = store.get(job_id) + except KeyError: + raise HTTPException(status_code=404, detail="Job not found") from None + try: + job.backend.emergency_stop_active_swarm() + except RuntimeError as exc: + raise HTTPException(status_code=409, detail=str(exc)) from exc + store.emit(job, "emergency_stop_sent", {}) + return {"jobId": job.id, "emergencyStopped": True} + @app.post("/api/jobs/{job_id}/preset") def save_preset(job_id: str) -> dict[str, Any]: try: diff --git a/swarm_gpt/core/backend.py b/swarm_gpt/core/backend.py index 6dc6f27..8a8e012 100644 --- a/swarm_gpt/core/backend.py +++ b/swarm_gpt/core/backend.py @@ -119,6 +119,7 @@ def __init__( self._preset: None | str = None self._strict_processing = strict_processing self._strict_drone_match = strict_drone_match + self._active_swarm: Any | None = None if set(self.songs) & set(self.presets): raise ValueError("Songs and presets must have unique names") @@ -299,6 +300,7 @@ def deploy(self, drone_ids: list[int] | None = None) -> bool: return False swarm = DroneSwarm(self.choreographer.drones, lighthouse=self.settings["lighthouse"]) + self._active_swarm = swarm logger.info("Swarm connected...") # generate references @@ -392,10 +394,21 @@ def deploy(self, drone_ids: list[int] | None = None) -> bool: logger.error(f"Stopping music after deployment interrupt failed: {e}") raise finally: - swarm.close() + try: + swarm.close() + finally: + self._active_swarm = None logger.info("Deployment successful") return True + def emergency_stop_active_swarm(self) -> None: + """Emergency-stop the currently active deployment swarm, if one exists.""" + swarm = self._active_swarm + if swarm is None: + raise RuntimeError("No active deployment swarm to emergency stop.") + swarm.emergency_stop() + self.music_manager.stop() + def load_preset(self, preset_id: str) -> str: """Load a preset response. diff --git a/swarm_gpt/core/drone_swarm.py b/swarm_gpt/core/drone_swarm.py index e732d8e..da3c62b 100644 --- a/swarm_gpt/core/drone_swarm.py +++ b/swarm_gpt/core/drone_swarm.py @@ -34,6 +34,7 @@ _DISCONNECT_ERRORS = (DisconnectedError, LinkError, TimeoutError) _LIGHTHOUSE_DECK_PARAM = "deck.bcLighthouse4" _POWER_CYCLE_BOOT_WAIT = 3.0 +_EMERGENCY_STOP_TIMEOUT = 0.5 _CommanderLevel = Literal["low", "high"] @@ -268,7 +269,7 @@ def emergency_stop(self, uri: str | None = None): else: self._validate_known_uris("uri", {uri: None}) uris = [uri] - self._run(self._parallel_by_uri("Emergency stop", uris, self._emergency_stop, timeout=0.5)) + self._run(self._emergency_stop_many(uris)) def reset(self): """Reset all active drones.""" @@ -312,12 +313,10 @@ async def _shutdown_leds(uri: str) -> None: await self._apply_drone_color(uri, np.zeros(4), "both") async def _close() -> None: - active_uris = [uri for uri in self.uris if uri in self.active_uris] + active_uris = [uri for uri in self.uris if uri in self.cfs] if active_uris: try: - await self._parallel_by_uri( - "Emergency stop", active_uris, self._emergency_stop, timeout=0.5 - ) + await self._emergency_stop_many(active_uris) await asyncio.sleep(0.1) await self._parallel_by_uri( "Shutdown LEDs", active_uris, _shutdown_leds, timeout=0.5 @@ -347,7 +346,7 @@ async def _close() -> None: def _run(self, coroutine: Awaitable[Any]) -> Any: """Run a cflib2 coroutine on the swarm event loop.""" - if self._loop_thread is not None: + if self._loop_thread is not None or self._loop.is_running(): future = asyncio.run_coroutine_threadsafe(coroutine, self._loop) try: return future.result() @@ -489,8 +488,26 @@ async def _parallel_by_uri( logger.error(f"{action_name} failed for {uri}: {result}") return results + async def _emergency_stop_many(self, uris: Iterable[str]) -> None: + """Best-effort emergency stop for every connected URI without cross-drone blocking.""" + target_uris = [uri for uri in uris if uri in self.cfs] + + async def _stop_one(uri: str) -> None: + try: + await asyncio.wait_for(self._emergency_stop(uri), timeout=_EMERGENCY_STOP_TIMEOUT) + self._commander_levels[uri] = None + except Exception as exc: + if isinstance(exc, _DISCONNECT_ERRORS): + self.active_uris.discard(uri) + self._commander_levels[uri] = None + logger.error(f"{uri} disconnected or unreachable. Emergency stop failed: {exc}") + else: + logger.error(f"Emergency stop failed for {uri}: {exc}") + + await asyncio.gather(*[_stop_one(uri) for uri in target_uris]) + async def _emergency_stop(self, uri: str) -> None: - await self._cf(uri).localization().emergency().send_emergency_stop() + await self.cfs[uri].localization().emergency().send_emergency_stop() async def _change_commander_level(self, uri: str, level: _CommanderLevel) -> None: """Switch commander level only when local state expects a different mode.""" diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index efeb997..c7b629e 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -1,3 +1,6 @@ +import threading +import time +from collections.abc import Generator from pathlib import Path from types import SimpleNamespace from urllib.parse import quote @@ -6,6 +9,7 @@ import pytest from fastapi.testclient import TestClient +import swarm_gpt.api.server as server from swarm_gpt.api.server import ApiConfig, _backend_from_config, create_app, normalize_playback from swarm_gpt.utils.llm_providers import DEFAULT_OPENAI_MODEL_CHOICES @@ -79,3 +83,73 @@ def test_library_returns_preset_display_metadata_and_delete(tmp_path: Path): delete_response.raise_for_status() assert delete_response.json() == {"deleted": preset_id} assert not (preset_dir / preset_id).exists() + + +def test_emergency_stop_endpoint_runs_while_deploy_is_active( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + class DeployingBackend: + def __init__(self) -> None: + self.songs = ["Test Song"] + self.presets: list[str] = [] + self.settings = {"axswarm": {"pos_min": [-1, -1, 0], "pos_max": [1, 1, 2]}} + self.music_manager = SimpleNamespace(song="Test Song") + self.splines: dict[int, object] = {} + self.deploy_entered = threading.Event() + self.stop_requested = threading.Event() + self.emergency_stop_calls = 0 + + def initial_prompt(self, selection: str) -> list[dict[str, str]]: + return [] + + def simulate(self) -> Generator[None, None, dict[str, object]]: + self.splines[0] = object() + states = np.zeros((1, 1, 13)) + states[:, :, 3:7] = [0, 0, 0, 1] + if False: + yield None + return {"timestamps": np.array([0.0]), "states": states, "num_drones": 1} + + def crop_window(self, song: str) -> tuple[float, float]: + return (0.0, 60.0) + + def deploy(self) -> bool: + self.deploy_entered.set() + self.stop_requested.wait(timeout=2.0) + return True + + def emergency_stop_active_swarm(self) -> None: + self.emergency_stop_calls += 1 + self.stop_requested.set() + + backends: list[DeployingBackend] = [] + + def backend_from_config(config: ApiConfig, provider: str, model_id: str) -> DeployingBackend: + backend = DeployingBackend() + backends.append(backend) + return backend + + (tmp_path / "Test Song.mp3").write_bytes(b"") + monkeypatch.setattr(server, "_backend_from_config", backend_from_config) + client = TestClient(create_app(ApiConfig(music_dir=tmp_path))) + + create_response = client.post( + "/api/jobs", json={"selection": "Test Song", "provider": "openai", "modelId": "gpt"} + ) + create_response.raise_for_status() + job_id = create_response.json()["jobId"] + backend = backends[0] + for _ in range(50): + if client.get(f"/api/jobs/{job_id}").json()["status"] == "ready": + break + time.sleep(0.01) + + deploy_response = client.post(f"/api/jobs/{job_id}/deploy") + deploy_response.raise_for_status() + assert backend.deploy_entered.wait(timeout=1.0) + + stop_response = client.post(f"/api/jobs/{job_id}/emergency-stop") + stop_response.raise_for_status() + + assert stop_response.json() == {"jobId": job_id, "emergencyStopped": True} + assert backend.emergency_stop_calls == 1 diff --git a/tests/unit/test_backend.py b/tests/unit/test_backend.py index 749c335..111cffb 100644 --- a/tests/unit/test_backend.py +++ b/tests/unit/test_backend.py @@ -104,3 +104,24 @@ def close(self) -> None: swarm = InterruptingSwarm.instances[0] assert swarm.calls == ["goto", "execute_choreography", "emergency_stop", "close"] + + +def test_emergency_stop_active_swarm_stops_live_swarm_and_music() -> None: + class ActiveSwarm: + def __init__(self) -> None: + self.calls: list[str] = [] + + def emergency_stop(self) -> None: + self.calls.append("emergency_stop") + + config_path = virtual_crazyswarm_config(n_drones=1) + app = AppBackend(config_file=config_path) + swarm = ActiveSwarm() + music_calls: list[str] = [] + app._active_swarm = swarm + app.music_manager.stop = lambda: music_calls.append("stop") + + app.emergency_stop_active_swarm() + + assert swarm.calls == ["emergency_stop"] + assert music_calls == ["stop"] diff --git a/tests/unit/test_drone_swarm.py b/tests/unit/test_drone_swarm.py index dc46132..4a79966 100644 --- a/tests/unit/test_drone_swarm.py +++ b/tests/unit/test_drone_swarm.py @@ -134,6 +134,24 @@ def run_coroutine_threadsafe( assert future.cancelled +def test_run_schedules_threadsafe_when_loop_is_already_running(): + swarm = make_swarm([]) + swarm._loop_thread = None + + async def command() -> str: + return "stopped" + + thread = threading.Thread(target=swarm._loop.run_forever) + thread.start() + + try: + assert swarm._run(command()) == "stopped" + finally: + swarm._loop.call_soon_threadsafe(swarm._loop.stop) + thread.join() + swarm._loop.close() + + def test_setpoint_sends_once_and_returns(): uris = ["radio://0/80/2M/E7E7E7E701", "radio://0/80/2M/E7E7E7E702"] swarm = make_swarm(uris) @@ -150,6 +168,26 @@ def test_setpoint_sends_once_and_returns(): assert cf.fake_param.values == [("commander.enHighLevel", 0)] +def test_emergency_stop_one_hung_drone_does_not_block_others(): + uris = ["radio://0/80/2M/E7E7E7E701", "radio://0/80/2M/E7E7E7E702"] + swarm = make_swarm(uris) + stopped: list[str] = [] + + async def emergency_stop(uri: str) -> None: + if uri == uris[0]: + await asyncio.sleep(1) + stopped.append(uri) + + swarm._emergency_stop = emergency_stop + + try: + swarm.emergency_stop() + finally: + swarm._loop.close() + + assert stopped == [uris[1]] + + def test_estimator_updater_copies_batch_and_skips_inactive_drones(): uris = [ "radio://0/80/2M/E7E7E7E701", diff --git a/web/src/App.tsx b/web/src/App.tsx index dc48d4c..1f5b72a 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -16,6 +16,7 @@ import { createJob, deletePreset, deployJob, + emergencyStopJob, getLibrary, getLlm, getPlayback, @@ -67,6 +68,7 @@ export function App() { const [presetNotice, setPresetNotice] = useState(null); const [savingPreset, setSavingPreset] = useState(false); const [deletingPreset, setDeletingPreset] = useState(null); + const [emergencyStopping, setEmergencyStopping] = useState(false); const socketRef = useRef(null); const providerInfo = llm?.providers.find((entry) => entry.id === provider); @@ -141,8 +143,13 @@ export function App() { if (event.type === "deploy_complete") { setStage("ready"); } + if (event.type === "emergency_stop_sent") { + setPresetNotice("Emergency stop sent."); + setEmergencyStopping(false); + } if (event.type === "failed") { setStage((current) => (current === "deploying" ? "ready" : "failed")); + setEmergencyStopping(false); setError(String(event.payload.message ?? "Job failed")); } }; @@ -158,6 +165,7 @@ export function App() { setPlayback(null); setError(null); setPresetNotice(null); + setEmergencyStopping(false); setProgress(0); setStage("thinking"); setDetailsOpen(false); @@ -195,6 +203,8 @@ export function App() { return; } setError(null); + setPresetNotice(null); + setEmergencyStopping(false); setStage("deploying"); try { await deployJob(jobId); @@ -204,6 +214,21 @@ export function App() { } }; + const emergencyStop = async () => { + if (!jobId) { + return; + } + setError(null); + setPresetNotice(null); + setEmergencyStopping(true); + try { + await emergencyStopJob(jobId); + setPresetNotice("Emergency stop sent."); + } finally { + setEmergencyStopping(false); + } + }; + const saveSafePreset = async () => { if (!jobId) { return; @@ -255,6 +280,7 @@ export function App() { setRefineText(""); setError(null); setPresetNotice(null); + setEmergencyStopping(false); }; if (stage === "playing" && playback) { @@ -463,6 +489,19 @@ export function App() { )} + {stage === "deploying" && ( +
+ +
+ )} + {stage === "failed" && (