Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions scripts/heat_map_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def get_ao_from_mission_file(file_path):
# Extract Mission Waypoint coordinates
if "mission" in plan and "items" in plan["mission"]:
for item in plan["mission"]["items"]:
# Only select waypoints within a mission (command 16)
if item["command"] != 16:
continue
# Waypoint params are usually [p1, p2, p3, p4, lat, lon, alt]
if "params" in item and len(item["params"]) >= 6:
lat, lon = item["params"][4], item["params"][5]
Expand All @@ -186,7 +189,7 @@ def get_ao_from_mission_file(file_path):
# Calculate Bounding Box
lats = [p[0] for p in points]
lons = [p[1] for p in points]

min_lat, max_lat = min(lats), max(lats)
min_lon, max_lon = min(lons), max(lons)

Expand All @@ -197,15 +200,15 @@ def get_ao_from_mission_file(file_path):
lat_dist_m = (max_lat - min_lat) * 111111
# Longitude distance depends on the latitude
lon_dist_m = (max_lon - min_lon) * 111111 * math.cos(math.radians(center_lat))

# Ensure a minimum size of 50 m if the mission is a single point
ao = {
"lat": min_lat,
"lon": min_lon,
"length_m": max(50, lat_dist_m),
"width_m": max(50, lon_dist_m)
}

print(f"[INFO] Mission AO Calculated: Center({ao['lat']}, {ao['lon']}), Size({ao['width_m']:.1f}m x {ao['length_m']:.1f}m)")
return ao

Expand All @@ -224,7 +227,7 @@ def get_ao_from_mission_file(file_path):
sys.path.insert(0, sam_path)

import mask_generator_SAM3

# Start up SAM 3
mask_gen = mask_generator_SAM3.SAM3_Model()
prompt_template = ""
Expand Down Expand Up @@ -266,9 +269,9 @@ def get_ao_from_mission_file(file_path):
# Run waypoint gen script
print("Generating waypoints...")
new_mission = waypoint_gen.generate_mission_waypoints(
input_file=input_path,
output_file=defines.GENERATED_WAYPOINT,
strategy="kmeans",
input_file=input_path,
output_file=defines.GENERATED_WAYPOINT,
strategy="kmeans",
cell_area=defines.WP_CALL_AREA,
altitude=defines.UAV_ALT
)
Expand Down Expand Up @@ -315,9 +318,9 @@ def get_ao_from_mission_file(file_path):
# Run waypoint gen script
print("Generating waypoints...")
new_mission = waypoint_gen.generate_mission_waypoints(
input_file=input_path,
output_file=defines.GENERATED_WAYPOINT,
strategy="kmeans",
input_file=input_path,
output_file=defines.GENERATED_WAYPOINT,
strategy="kmeans",
cell_area=defines.WP_CALL_AREA,
altitude=defines.UAV_ALT
)
Expand Down Expand Up @@ -394,10 +397,10 @@ def get_ao_from_mission_file(file_path):
window_coords.append((y, y + tile_h, x, x + tile_w))
x += stride_w
# Handle right edge if it doesn't fit perfectly
if x < width and (width - tile_w) > (x - stride_w):
if x < width and (width - tile_w) > (x - stride_w):
window_coords.append((y, y + tile_h, width - tile_w, width))
y += stride_h

# Handle bottom edge if it doesn't fit perfectly
if y < height and (height - tile_h) > (y - stride_h):
# Add a row for the bottom edge
Expand All @@ -423,7 +426,7 @@ def get_ao_from_mission_file(file_path):

for i, (y1, y2, x1, x2) in enumerate(window_coords):
print(f"\n--- Processing Window {i+1}/{total_windows} [y:{y1}-{y2}, x:{x1}-{x2}] ---")

chunk_img = original_img[y1:y2, x1:x2]
chunk_path = f"temp_chunk_{i}.jpg"
cv2.imwrite(chunk_path, chunk_img)
Expand All @@ -432,27 +435,27 @@ def get_ao_from_mission_file(file_path):
for label in positive_labels:
mode = "add" if label in positive_labels else "subtract"
full_prompt = prompt_template + label

result = mask_gen.generate_mask(full_prompt, chunk_path)

seg_success, mask_path, _ = result

if seg_success:
current_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

if current_mask is not None:
# Resize mask to fit the specific chunk dimensions
current_mask = cv2.resize(current_mask, (x2 - x1, y2 - y1), interpolation=cv2.INTER_NEAREST)
binary_mask = (current_mask > 0).astype(np.float32)

# --- ACCUMULATION & AVERAGING LOGIC ---
if mode == "add":
full_heatmap_accumulator[y1:y2, x1:x2] += binary_mask
else:
full_heatmap_accumulator[y1:y2, x1:x2] -= binary_mask

successful_generations += 1

# --- MEMORY CLEANUP ---
del result, current_mask, binary_mask
else:
Expand All @@ -461,16 +464,16 @@ def get_ao_from_mission_file(file_path):
else:
# --- MEMORY CLEANUP ---
del result

# Increment the count for every pixel in this window
count_accumulator[y1:y2, x1:x2] += 1.0

# Force Python to collect garbage, then empty the CUDA cache
gc.collect()
torch.cuda.empty_cache()
# New line to keep things running nicely
print("")

# Clean up tile chunk
if os.path.exists(chunk_path):
os.remove(chunk_path)
Expand Down Expand Up @@ -539,4 +542,4 @@ def get_ao_from_mission_file(file_path):
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=4)

print(f"[SUCCESS] Metadata exported to: {metadata_path}")
print(f"[SUCCESS] Metadata exported to: {metadata_path}")