Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions swarm_gpt/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,19 @@ def deploy(job_id: str) -> dict[str, Any]:
_start_thread(job, lambda: _run_deploy_job(store, job))
return {"jobId": job.id}

@app.post("/api/jobs/{job_id}/emergency-stop")
def emergency_stop(job_id: str) -> dict[str, Any]:
try:
job = store.get(job_id)
except KeyError:
raise HTTPException(status_code=404, detail="Job not found") from None
try:
job.backend.emergency_stop_active_swarm()
except RuntimeError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
store.emit(job, "emergency_stop_sent", {})
return {"jobId": job.id, "emergencyStopped": True}

@app.post("/api/jobs/{job_id}/preset")
def save_preset(job_id: str) -> dict[str, Any]:
try:
Expand Down
14 changes: 13 additions & 1 deletion swarm_gpt/core/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
if TYPE_CHECKING:
from numpy.typing import NDArray as Array

from swarm_gpt.core.drone_swarm import DroneSwarm
from swarm_gpt.utils.llm_providers import LLMProvider

logging.basicConfig(level=logging.WARNING)
Expand Down Expand Up @@ -119,6 +120,7 @@ def __init__(
self._preset: None | str = None
self._strict_processing = strict_processing
self._strict_drone_match = strict_drone_match
self._active_swarm: DroneSwarm | None = None
if set(self.songs) & set(self.presets):
raise ValueError("Songs and presets must have unique names")

Expand Down Expand Up @@ -299,6 +301,7 @@ def deploy(self, drone_ids: list[int] | None = None) -> bool:
return False

swarm = DroneSwarm(self.choreographer.drones, lighthouse=self.settings["lighthouse"])
self._active_swarm = swarm
logger.info("Swarm connected...")

# generate references
Expand Down Expand Up @@ -378,14 +381,23 @@ def deploy(self, drone_ids: list[int] | None = None) -> bool:
)
self.music_manager.stop()
swarm.goto(final_pos_dict, duration=2.0) # Transition from ideal point to hover pos
if self.settings["land_on_docks"]: # Commented out for demo
if self.settings["land_on_docks"]:
swarm.goto(final_pos_dict, duration=3.0) # Hovering
swarm.land(duration=1.5) # Landing
finally:
self._active_swarm = None
swarm.close()
logger.info("Deployment successful")
return True

def emergency_stop_active_swarm(self) -> None:
"""Emergency-stop the currently active deployment swarm, if one exists."""
swarm = self._active_swarm
if swarm is None:
raise RuntimeError("No active deployment swarm to emergency stop.")
swarm.emergency_stop()
self.music_manager.stop()

Comment on lines +393 to +400

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we not simply closing the swarm? This will also trigger an estop and reduces the number of changes significantly.

def load_preset(self, preset_id: str) -> str:
"""Load a preset response.

Expand Down
9 changes: 7 additions & 2 deletions swarm_gpt/core/drone_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,13 @@ async def _close() -> None:
self.ros_connector.close()

def _run(self, coroutine: Awaitable[Any]) -> Any:
"""Run a cflib2 coroutine on the swarm event loop."""
if self._loop_thread is not None:
"""Run a cflib2 coroutine on the swarm event loop.

Dispatches cross-thread when the loop runs in another thread or is already running
(e.g. a deployment driving it), so an emergency stop from the request thread that
handles the frontend button still reaches the swarm mid-performance.
"""
if self._loop_thread is not None or self._loop.is_running():
return asyncio.run_coroutine_threadsafe(coroutine, self._loop).result()
return self._loop.run_until_complete(coroutine)

Expand Down
74 changes: 74 additions & 0 deletions tests/unit/test_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import threading
import time
from collections.abc import Generator
from pathlib import Path
from types import SimpleNamespace
from urllib.parse import quote
Expand All @@ -6,6 +9,7 @@
import pytest
from fastapi.testclient import TestClient

import swarm_gpt.api.server as server
from swarm_gpt.api.server import ApiConfig, _backend_from_config, create_app, normalize_playback
from swarm_gpt.utils.llm_providers import DEFAULT_OPENAI_MODEL_CHOICES

Expand Down Expand Up @@ -79,3 +83,73 @@ def test_library_returns_preset_display_metadata_and_delete(tmp_path: Path):
delete_response.raise_for_status()
assert delete_response.json() == {"deleted": preset_id}
assert not (preset_dir / preset_id).exists()


def test_emergency_stop_endpoint_runs_while_deploy_is_active(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
class DeployingBackend:
def __init__(self) -> None:
self.songs = ["Test Song"]
self.presets: list[str] = []
self.settings = {"axswarm": {"pos_min": [-1, -1, 0], "pos_max": [1, 1, 2]}}
self.music_manager = SimpleNamespace(song="Test Song")
self.splines: dict[int, object] = {}
self.deploy_entered = threading.Event()
self.stop_requested = threading.Event()
self.emergency_stop_calls = 0

def initial_prompt(self, selection: str) -> list[dict[str, str]]:
return []

def simulate(self) -> Generator[None, None, dict[str, object]]:
self.splines[0] = object()
states = np.zeros((1, 1, 13))
states[:, :, 3:7] = [0, 0, 0, 1]
if False:
yield None
return {"timestamps": np.array([0.0]), "states": states, "num_drones": 1}

def crop_window(self, song: str) -> tuple[float, float]:
return (0.0, 60.0)

def deploy(self) -> bool:
self.deploy_entered.set()
self.stop_requested.wait(timeout=2.0)
return True

def emergency_stop_active_swarm(self) -> None:
self.emergency_stop_calls += 1
self.stop_requested.set()

backends: list[DeployingBackend] = []

def backend_from_config(config: ApiConfig, provider: str, model_id: str) -> DeployingBackend:
backend = DeployingBackend()
backends.append(backend)
return backend

(tmp_path / "Test Song.mp3").write_bytes(b"")
monkeypatch.setattr(server, "_backend_from_config", backend_from_config)
client = TestClient(create_app(ApiConfig(music_dir=tmp_path)))

create_response = client.post(
"/api/jobs", json={"selection": "Test Song", "provider": "openai", "modelId": "gpt"}
)
create_response.raise_for_status()
job_id = create_response.json()["jobId"]
backend = backends[0]
for _ in range(50):
if client.get(f"/api/jobs/{job_id}").json()["status"] == "ready":
break
time.sleep(0.01)

deploy_response = client.post(f"/api/jobs/{job_id}/deploy")
deploy_response.raise_for_status()
assert backend.deploy_entered.wait(timeout=1.0)

stop_response = client.post(f"/api/jobs/{job_id}/emergency-stop")
stop_response.raise_for_status()

assert stop_response.json() == {"jobId": job_id, "emergencyStopped": True}
assert backend.emergency_stop_calls == 1
21 changes: 21 additions & 0 deletions tests/unit/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,24 @@ def test_preset_metadata_and_delete(tmp_path: Path):

app.delete_preset(preset_id)
assert not (preset_dir / preset_id).exists()


def test_emergency_stop_active_swarm_stops_live_swarm_and_music() -> None:
class ActiveSwarm:
def __init__(self) -> None:
self.calls: list[str] = []

def emergency_stop(self) -> None:
self.calls.append("emergency_stop")

config_path = virtual_crazyswarm_config(n_drones=1)
app = AppBackend(config_file=config_path)
swarm = ActiveSwarm()
music_calls: list[str] = []
app._active_swarm = swarm
app.music_manager.stop = lambda: music_calls.append("stop")

app.emergency_stop_active_swarm()

assert swarm.calls == ["emergency_stop"]
assert music_calls == ["stop"]
18 changes: 18 additions & 0 deletions tests/unit/test_drone_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,24 @@ def make_swarm(uris: list[str]) -> DroneSwarm:
return swarm


def test_run_schedules_threadsafe_when_loop_is_already_running():
swarm = make_swarm([])
swarm._loop_thread = None

async def command() -> str:
return "stopped"

thread = threading.Thread(target=swarm._loop.run_forever)
thread.start()

try:
assert swarm._run(command()) == "stopped"
finally:
swarm._loop.call_soon_threadsafe(swarm._loop.stop)
thread.join()
swarm._loop.close()


def test_setpoint_sends_once_and_returns():
uris = ["radio://0/80/2M/E7E7E7E701", "radio://0/80/2M/E7E7E7E702"]
swarm = make_swarm(uris)
Expand Down
39 changes: 39 additions & 0 deletions web/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
createJob,
deletePreset,
deployJob,
emergencyStopJob,
getLibrary,
getLlm,
getPlayback,
Expand Down Expand Up @@ -67,6 +68,7 @@ export function App() {
const [presetNotice, setPresetNotice] = useState<string | null>(null);
const [savingPreset, setSavingPreset] = useState(false);
const [deletingPreset, setDeletingPreset] = useState<string | null>(null);
const [emergencyStopping, setEmergencyStopping] = useState(false);
const socketRef = useRef<WebSocket | null>(null);

const providerInfo = llm?.providers.find((entry) => entry.id === provider);
Expand Down Expand Up @@ -141,8 +143,13 @@ export function App() {
if (event.type === "deploy_complete") {
setStage("ready");
}
if (event.type === "emergency_stop_sent") {
setPresetNotice("Emergency stop sent.");
setEmergencyStopping(false);
}
if (event.type === "failed") {
setStage((current) => (current === "deploying" ? "ready" : "failed"));
setEmergencyStopping(false);
setError(String(event.payload.message ?? "Job failed"));
}
};
Expand All @@ -158,6 +165,7 @@ export function App() {
setPlayback(null);
setError(null);
setPresetNotice(null);
setEmergencyStopping(false);
setProgress(0);
setStage("thinking");
setDetailsOpen(false);
Expand Down Expand Up @@ -195,6 +203,8 @@ export function App() {
return;
}
setError(null);
setPresetNotice(null);
setEmergencyStopping(false);
setStage("deploying");
try {
await deployJob(jobId);
Expand All @@ -204,6 +214,21 @@ export function App() {
}
};

const emergencyStop = async () => {
if (!jobId) {
return;
}
setError(null);
setPresetNotice(null);
setEmergencyStopping(true);
try {
await emergencyStopJob(jobId);
setPresetNotice("Emergency stop sent.");
} finally {
setEmergencyStopping(false);
}
};

const saveSafePreset = async () => {
if (!jobId) {
return;
Expand Down Expand Up @@ -255,6 +280,7 @@ export function App() {
setRefineText("");
setError(null);
setPresetNotice(null);
setEmergencyStopping(false);
};

if (stage === "playing" && playback) {
Expand Down Expand Up @@ -463,6 +489,19 @@ export function App() {
</div>
)}

{stage === "deploying" && (
<div className="ready-actions">
<button
className="danger-action"
disabled={emergencyStopping}
onClick={() => emergencyStop().catch((err: Error) => setError(err.message))}
>
<X size={18} />
{emergencyStopping ? "Stopping" : "E-stop"}
</button>
</div>
)}

{stage === "failed" && (
<div className="ready-actions">
<button className="secondary-action" onClick={reset}>
Expand Down
7 changes: 7 additions & 0 deletions web/src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ export function deployJob(jobId: string) {
return request<{ jobId: string }>(`/api/jobs/${jobId}/deploy`, { method: "POST" });
}

export function emergencyStopJob(jobId: string) {
return request<{ jobId: string; emergencyStopped: boolean }>(
`/api/jobs/${jobId}/emergency-stop`,
{ method: "POST" }
);
}

export function savePreset(jobId: string) {
return request<{ preset: LibraryItem }>(`/api/jobs/${jobId}/preset`, { method: "POST" });
}
Expand Down
12 changes: 11 additions & 1 deletion web/src/styles.css
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ textarea {

.primary-action,
.secondary-action,
.danger-action,
.icon-button {
display: inline-flex;
align-items: center;
Expand Down Expand Up @@ -257,6 +258,14 @@ textarea {
padding: 0 14px;
}

.danger-action {
padding: 0 15px;
border: 1px solid rgba(255, 180, 168, 0.64);
color: #2b0806;
background: #ff8f7f;
font-weight: 900;
}

.compact {
min-height: 36px;
padding-inline: 11px;
Expand Down Expand Up @@ -490,7 +499,8 @@ button:disabled {
}

.primary-action,
.secondary-action {
.secondary-action,
.danger-action {
width: 100%;
}

Expand Down
Loading
Loading