diff --git a/swarm_gpt/api/server.py b/swarm_gpt/api/server.py index 568cc9a..bd6987b 100644 --- a/swarm_gpt/api/server.py +++ b/swarm_gpt/api/server.py @@ -472,6 +472,19 @@ def deploy(job_id: str) -> dict[str, Any]: _start_thread(job, lambda: _run_deploy_job(store, job)) return {"jobId": job.id} + @app.post("/api/jobs/{job_id}/emergency-stop") + def emergency_stop(job_id: str) -> dict[str, Any]: + try: + job = store.get(job_id) + except KeyError: + raise HTTPException(status_code=404, detail="Job not found") from None + try: + job.backend.emergency_stop_active_swarm() + except RuntimeError as exc: + raise HTTPException(status_code=409, detail=str(exc)) from exc + store.emit(job, "emergency_stop_sent", {}) + return {"jobId": job.id, "emergencyStopped": True} + @app.post("/api/jobs/{job_id}/preset") def save_preset(job_id: str) -> dict[str, Any]: try: diff --git a/swarm_gpt/core/backend.py b/swarm_gpt/core/backend.py index e34c9e2..2e79e61 100644 --- a/swarm_gpt/core/backend.py +++ b/swarm_gpt/core/backend.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from numpy.typing import NDArray as Array + from swarm_gpt.core.drone_swarm import DroneSwarm from swarm_gpt.utils.llm_providers import LLMProvider logging.basicConfig(level=logging.WARNING) @@ -119,6 +120,7 @@ def __init__( self._preset: None | str = None self._strict_processing = strict_processing self._strict_drone_match = strict_drone_match + self._active_swarm: DroneSwarm | None = None if set(self.songs) & set(self.presets): raise ValueError("Songs and presets must have unique names") @@ -299,6 +301,7 @@ def deploy(self, drone_ids: list[int] | None = None) -> bool: return False swarm = DroneSwarm(self.choreographer.drones, lighthouse=self.settings["lighthouse"]) + self._active_swarm = swarm logger.info("Swarm connected...") # generate references @@ -378,14 +381,23 @@ def deploy(self, drone_ids: list[int] | None = None) -> bool: ) self.music_manager.stop() swarm.goto(final_pos_dict, duration=2.0) # Transition from ideal point to hover pos - if self.settings["land_on_docks"]: # Commented out for demo + if self.settings["land_on_docks"]: swarm.goto(final_pos_dict, duration=3.0) # Hovering swarm.land(duration=1.5) # Landing finally: + self._active_swarm = None swarm.close() logger.info("Deployment successful") return True + def emergency_stop_active_swarm(self) -> None: + """Emergency-stop the currently active deployment swarm, if one exists.""" + swarm = self._active_swarm + if swarm is None: + raise RuntimeError("No active deployment swarm to emergency stop.") + swarm.emergency_stop() + self.music_manager.stop() + def load_preset(self, preset_id: str) -> str: """Load a preset response. diff --git a/swarm_gpt/core/drone_swarm.py b/swarm_gpt/core/drone_swarm.py index eebfbf0..b2fa6b2 100644 --- a/swarm_gpt/core/drone_swarm.py +++ b/swarm_gpt/core/drone_swarm.py @@ -346,8 +346,13 @@ async def _close() -> None: self.ros_connector.close() def _run(self, coroutine: Awaitable[Any]) -> Any: - """Run a cflib2 coroutine on the swarm event loop.""" - if self._loop_thread is not None: + """Run a cflib2 coroutine on the swarm event loop. + + Dispatches cross-thread when the loop runs in another thread or is already running + (e.g. a deployment driving it), so an emergency stop from the request thread that + handles the frontend button still reaches the swarm mid-performance. + """ + if self._loop_thread is not None or self._loop.is_running(): return asyncio.run_coroutine_threadsafe(coroutine, self._loop).result() return self._loop.run_until_complete(coroutine) diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index efeb997..c7b629e 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -1,3 +1,6 @@ +import threading +import time +from collections.abc import Generator from pathlib import Path from types import SimpleNamespace from urllib.parse import quote @@ -6,6 +9,7 @@ import pytest from fastapi.testclient import TestClient +import swarm_gpt.api.server as server from swarm_gpt.api.server import ApiConfig, _backend_from_config, create_app, normalize_playback from swarm_gpt.utils.llm_providers import DEFAULT_OPENAI_MODEL_CHOICES @@ -79,3 +83,73 @@ def test_library_returns_preset_display_metadata_and_delete(tmp_path: Path): delete_response.raise_for_status() assert delete_response.json() == {"deleted": preset_id} assert not (preset_dir / preset_id).exists() + + +def test_emergency_stop_endpoint_runs_while_deploy_is_active( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + class DeployingBackend: + def __init__(self) -> None: + self.songs = ["Test Song"] + self.presets: list[str] = [] + self.settings = {"axswarm": {"pos_min": [-1, -1, 0], "pos_max": [1, 1, 2]}} + self.music_manager = SimpleNamespace(song="Test Song") + self.splines: dict[int, object] = {} + self.deploy_entered = threading.Event() + self.stop_requested = threading.Event() + self.emergency_stop_calls = 0 + + def initial_prompt(self, selection: str) -> list[dict[str, str]]: + return [] + + def simulate(self) -> Generator[None, None, dict[str, object]]: + self.splines[0] = object() + states = np.zeros((1, 1, 13)) + states[:, :, 3:7] = [0, 0, 0, 1] + if False: + yield None + return {"timestamps": np.array([0.0]), "states": states, "num_drones": 1} + + def crop_window(self, song: str) -> tuple[float, float]: + return (0.0, 60.0) + + def deploy(self) -> bool: + self.deploy_entered.set() + self.stop_requested.wait(timeout=2.0) + return True + + def emergency_stop_active_swarm(self) -> None: + self.emergency_stop_calls += 1 + self.stop_requested.set() + + backends: list[DeployingBackend] = [] + + def backend_from_config(config: ApiConfig, provider: str, model_id: str) -> DeployingBackend: + backend = DeployingBackend() + backends.append(backend) + return backend + + (tmp_path / "Test Song.mp3").write_bytes(b"") + monkeypatch.setattr(server, "_backend_from_config", backend_from_config) + client = TestClient(create_app(ApiConfig(music_dir=tmp_path))) + + create_response = client.post( + "/api/jobs", json={"selection": "Test Song", "provider": "openai", "modelId": "gpt"} + ) + create_response.raise_for_status() + job_id = create_response.json()["jobId"] + backend = backends[0] + for _ in range(50): + if client.get(f"/api/jobs/{job_id}").json()["status"] == "ready": + break + time.sleep(0.01) + + deploy_response = client.post(f"/api/jobs/{job_id}/deploy") + deploy_response.raise_for_status() + assert backend.deploy_entered.wait(timeout=1.0) + + stop_response = client.post(f"/api/jobs/{job_id}/emergency-stop") + stop_response.raise_for_status() + + assert stop_response.json() == {"jobId": job_id, "emergencyStopped": True} + assert backend.emergency_stop_calls == 1 diff --git a/tests/unit/test_backend.py b/tests/unit/test_backend.py index 8a33cb0..a4e9458 100644 --- a/tests/unit/test_backend.py +++ b/tests/unit/test_backend.py @@ -46,3 +46,24 @@ def test_preset_metadata_and_delete(tmp_path: Path): app.delete_preset(preset_id) assert not (preset_dir / preset_id).exists() + + +def test_emergency_stop_active_swarm_stops_live_swarm_and_music() -> None: + class ActiveSwarm: + def __init__(self) -> None: + self.calls: list[str] = [] + + def emergency_stop(self) -> None: + self.calls.append("emergency_stop") + + config_path = virtual_crazyswarm_config(n_drones=1) + app = AppBackend(config_file=config_path) + swarm = ActiveSwarm() + music_calls: list[str] = [] + app._active_swarm = swarm + app.music_manager.stop = lambda: music_calls.append("stop") + + app.emergency_stop_active_swarm() + + assert swarm.calls == ["emergency_stop"] + assert music_calls == ["stop"] diff --git a/tests/unit/test_drone_swarm.py b/tests/unit/test_drone_swarm.py index d96394a..d7ef99e 100644 --- a/tests/unit/test_drone_swarm.py +++ b/tests/unit/test_drone_swarm.py @@ -93,6 +93,24 @@ def make_swarm(uris: list[str]) -> DroneSwarm: return swarm +def test_run_schedules_threadsafe_when_loop_is_already_running(): + swarm = make_swarm([]) + swarm._loop_thread = None + + async def command() -> str: + return "stopped" + + thread = threading.Thread(target=swarm._loop.run_forever) + thread.start() + + try: + assert swarm._run(command()) == "stopped" + finally: + swarm._loop.call_soon_threadsafe(swarm._loop.stop) + thread.join() + swarm._loop.close() + + def test_setpoint_sends_once_and_returns(): uris = ["radio://0/80/2M/E7E7E7E701", "radio://0/80/2M/E7E7E7E702"] swarm = make_swarm(uris) diff --git a/web/src/App.tsx b/web/src/App.tsx index dc48d4c..1f5b72a 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -16,6 +16,7 @@ import { createJob, deletePreset, deployJob, + emergencyStopJob, getLibrary, getLlm, getPlayback, @@ -67,6 +68,7 @@ export function App() { const [presetNotice, setPresetNotice] = useState(null); const [savingPreset, setSavingPreset] = useState(false); const [deletingPreset, setDeletingPreset] = useState(null); + const [emergencyStopping, setEmergencyStopping] = useState(false); const socketRef = useRef(null); const providerInfo = llm?.providers.find((entry) => entry.id === provider); @@ -141,8 +143,13 @@ export function App() { if (event.type === "deploy_complete") { setStage("ready"); } + if (event.type === "emergency_stop_sent") { + setPresetNotice("Emergency stop sent."); + setEmergencyStopping(false); + } if (event.type === "failed") { setStage((current) => (current === "deploying" ? "ready" : "failed")); + setEmergencyStopping(false); setError(String(event.payload.message ?? "Job failed")); } }; @@ -158,6 +165,7 @@ export function App() { setPlayback(null); setError(null); setPresetNotice(null); + setEmergencyStopping(false); setProgress(0); setStage("thinking"); setDetailsOpen(false); @@ -195,6 +203,8 @@ export function App() { return; } setError(null); + setPresetNotice(null); + setEmergencyStopping(false); setStage("deploying"); try { await deployJob(jobId); @@ -204,6 +214,21 @@ export function App() { } }; + const emergencyStop = async () => { + if (!jobId) { + return; + } + setError(null); + setPresetNotice(null); + setEmergencyStopping(true); + try { + await emergencyStopJob(jobId); + setPresetNotice("Emergency stop sent."); + } finally { + setEmergencyStopping(false); + } + }; + const saveSafePreset = async () => { if (!jobId) { return; @@ -255,6 +280,7 @@ export function App() { setRefineText(""); setError(null); setPresetNotice(null); + setEmergencyStopping(false); }; if (stage === "playing" && playback) { @@ -463,6 +489,19 @@ export function App() { )} + {stage === "deploying" && ( +
+ +
+ )} + {stage === "failed" && (