-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualization.py
More file actions
107 lines (88 loc) · 3.01 KB
/
visualization.py
File metadata and controls
107 lines (88 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Visualization helpers for quadrotor simulations."""
from __future__ import annotations
import numpy as np
try: # pragma: no cover
from .simulator import SimulationLog
except ImportError: # pragma: no cover
from simulator import SimulationLog
def plot_trajectory(log: SimulationLog, axes=None):
"""Plot the xyz position traces."""
import matplotlib.pyplot as plt
if axes is None:
fig, axes = plt.subplots(3, 1, figsize=(6, 6), sharex=True)
fig.suptitle("Quadrotor position")
labels = ["x", "y", "z"]
for idx in range(3):
axes[idx].plot(log.time, log.position[:, idx])
axes[idx].set_ylabel(f"{labels[idx]} [m]")
axes[idx].grid(True, alpha=0.3)
axes[-1].set_xlabel("time [s]")
return axes
def animate_log(
log: SimulationLog,
skip: int = 2,
arm_length: float = 0.3,
interval_ms: int = 5,
repeat: bool = True,
save_path: str | None = None,
):
"""Animate the quadrotor states using matplotlib."""
import matplotlib.pyplot as plt
from matplotlib import animation
if skip < 1:
skip = 1
frames = list(range(0, len(log.time), skip))
if frames[-1] != len(log.time) - 1:
frames.append(len(log.time) - 1)
fig = plt.figure(figsize=(7, 6))
ax = fig.add_subplot(111, projection="3d")
path, = ax.plot([], [], [], color="gray", lw=1.0, alpha=0.6)
colors = ["tab:red", "tab:red", "tab:blue", "tab:blue"]
segments_body = np.array(
[
[[0.0, 0.0, 0.0], [arm_length, 0.0, 0.0]],
[[0.0, 0.0, 0.0], [-arm_length, 0.0, 0.0]],
[[0.0, 0.0, 0.0], [0.0, arm_length, 0.0]],
[[0.0, 0.0, 0.0], [0.0, -arm_length, 0.0]],
]
)
body_lines = [
ax.plot([], [], [], color=c, lw=2.5, solid_capstyle="round")[0]
for c in colors
]
pos = log.position
limits = np.array([pos.min(axis=0), pos.max(axis=0)])
span = limits[1] - limits[0]
margin = 2.0 * np.maximum(0.5, 0.2 * span)
center = np.mean(limits, axis=0)
for idx, axis in enumerate("xyz"):
getattr(ax, f"set_{axis}lim")(center[idx] - margin[idx], center[idx] + margin[idx])
ax.set_xlabel("x [m]")
ax.set_ylabel("y [m]")
ax.set_zlabel("z [m]")
ax.set_title("Quadrotor flight")
def update(frame_idx: int):
idx = frames[frame_idx]
R = log.rotation[idx]
p = log.position[idx]
history = pos[: idx + 1]
path.set_data(history[:, 0], history[:, 1])
path.set_3d_properties(history[:, 2])
for seg_body, line in zip(segments_body, body_lines):
seg_world = (R @ seg_body.T).T + p
line.set_data(seg_world[:, 0], seg_world[:, 1])
line.set_3d_properties(seg_world[:, 2])
return body_lines + [path]
ani = animation.FuncAnimation(
fig,
update,
frames=len(frames),
interval=interval_ms,
blit=False,
repeat=repeat,
)
if save_path:
ani.save(save_path)
else:
plt.show()
return ani