diff --git a/orbit/export/georef_export.py b/orbit/export/georef_export.py index edd237e..df38508 100644 --- a/orbit/export/georef_export.py +++ b/orbit/export/georef_export.py @@ -11,10 +11,13 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple +import numpy as np + from orbit.models.project import Project from orbit.utils.coordinate_transform import ( AffineTransformer, CoordinateTransformer, + DroneAssistedTransformer, HomographyTransformer, HybridTransformer, ) @@ -78,9 +81,7 @@ def build_georef_data( Dictionary with georeferencing data """ # Determine transform method string - if isinstance(transformer, HybridTransformer): - method = "homography" - elif isinstance(transformer, HomographyTransformer): + if isinstance(transformer, (HybridTransformer, HomographyTransformer, DroneAssistedTransformer)): method = "homography" elif isinstance(transformer, AffineTransformer): method = "affine" @@ -112,6 +113,20 @@ def build_georef_data( except Exception: orbit_version = "unknown" + # If the transformer has an active adjustment, bake it into the exported + # matrices so downstream consumers get the fully-adjusted transform. + # For pixel→ENU (transform_matrix): effective = transform_matrix @ A_inv + # For ENU→pixel (inverse_matrix): effective = A @ inverse_matrix + transform_matrix = transformer.transform_matrix + inverse_matrix = transformer.inverse_matrix + if transformer.adjustment is not None and not transformer.adjustment.is_identity(): + A = transformer.adjustment.get_adjustment_matrix() + A_inv = np.linalg.inv(A) + if transform_matrix is not None: + transform_matrix = transform_matrix @ A_inv + if inverse_matrix is not None: + inverse_matrix = A @ inverse_matrix + # Build output structure data = { "format": "ORBIT Georeferencing Data", @@ -131,8 +146,8 @@ def build_georef_data( "longitude": transformer.reference_lon, "latitude": transformer.reference_lat, }, - "transformation_matrix": _matrix_to_list(transformer.transform_matrix), - "inverse_matrix": _matrix_to_list(transformer.inverse_matrix), + "transformation_matrix": _matrix_to_list(transform_matrix), + "inverse_matrix": _matrix_to_list(inverse_matrix), "scale_factors": { "x_meters_per_pixel": scale_x, "y_meters_per_pixel": scale_y, diff --git a/orbit/export/object_builder.py b/orbit/export/object_builder.py index 822a76d..ec806ef 100644 --- a/orbit/export/object_builder.py +++ b/orbit/export/object_builder.py @@ -293,10 +293,10 @@ def _project_onto_meter_centerline( min_dist = dist best_s = cumul_s + t_param * seg_len seg_hdg = np.arctan2(dy, dx) - # t-offset: positive = left of road direction (OpenDRIVE convention) - # In meter space (y-up), cross = direction × offset gives correct sign - cross = dx * (py - proj_y) - dy * (px - proj_x) - best_t = dist if cross >= 0 else -dist + # Signed perpendicular distance from the segment direction (positive = left). + # Using the cross product formula directly (not dist to clamped endpoint) + # ensures correct t even when the anchor is outside the road's s-range. + best_t = (dx * (py - y1) - dy * (px - x1)) / seg_len best_hdg = seg_hdg cumul_s += seg_len @@ -345,6 +345,8 @@ def _set_type_attributes(self, object_elem: etree.Element, obj: RoadObject) -> N elif obj.type == ObjectType.GUARDRAIL: object_elem.set('type', 'barrier') object_elem.set('subtype', 'guardrail') + object_elem.set('orientation', obj.odr_orientation if obj.odr_orientation else 'none') + object_elem.set('hdg', f'{np.radians(obj.orientation):.6f}') object_elem.set('height', f"{obj.dimensions.get('height', 0.81):.2f}") object_elem.set('width', f"{obj.dimensions.get('width', 0.3):.2f}") if obj.validity_length: @@ -432,7 +434,7 @@ def _create_object_outline( self._create_circular_outline(outline, obj, 12) elif obj.type == ObjectType.GUARDRAIL: - self._create_polyline_outline(outline, obj) + self._create_polyline_outline(outline, obj, road_hdg) elif obj.type in (ObjectType.TREE_BROADLEAF, ObjectType.BUSH): self._create_circular_outline(outline, obj, 8) @@ -474,19 +476,28 @@ def _create_circular_outline( corner.set('z', '0.0') corner.set('height', f'{height:.2f}') - def _create_polyline_outline(self, outline: etree.Element, obj: RoadObject) -> None: - """Create polyline outline for guardrails.""" + def _create_polyline_outline(self, outline: etree.Element, obj: RoadObject, + road_hdg: float = 0.0) -> None: + """Create polyline outline for guardrails. + + Rotates metric offsets into road-local (u=s-direction, v=t-direction) + coordinates by applying -road_hdg rotation. + """ if not obj.points or len(obj.points) < 2: return height = obj.dimensions.get('height', 0.81) + cos_h = np.cos(-road_hdg) + sin_h = np.sin(-road_hdg) if self.transformer: pts_m = [self.transformer.pixel_to_meters(px, py) for px, py in obj.points] ref_mx, ref_my = pts_m[0] for pt_m in pts_m: - u = pt_m[0] - ref_mx - v = pt_m[1] - ref_my + dx = pt_m[0] - ref_mx + dy = pt_m[1] - ref_my + u = cos_h * dx - sin_h * dy + v = sin_h * dx + cos_h * dy corner = etree.SubElement(outline, 'cornerLocal') corner.set('u', f'{u:.4f}') corner.set('v', f'{v:.4f}') @@ -495,8 +506,10 @@ def _create_polyline_outline(self, outline: etree.Element, obj: RoadObject) -> N else: ref_x, ref_y = obj.points[0] for px, py in obj.points: - u = (px - ref_x) * self.scale_x - v = (py - ref_y) * self.scale_x + dx = (px - ref_x) * self.scale_x + dy = (py - ref_y) * self.scale_x + u = cos_h * dx - sin_h * dy + v = sin_h * dx + cos_h * dy corner = etree.SubElement(outline, 'cornerLocal') corner.set('u', f'{u:.4f}') corner.set('v', f'{v:.4f}') diff --git a/orbit/export/opendrive_writer.py b/orbit/export/opendrive_writer.py index 8f41e58..7bcd54a 100644 --- a/orbit/export/opendrive_writer.py +++ b/orbit/export/opendrive_writer.py @@ -979,14 +979,16 @@ def _compute_parampoly3_geometry(self, connecting_road: Road, connecting_road, path_meters, is_start=False ) - # Override with connected road headings for C1 continuity + # Override with connected road headings for C1 continuity. + # Departure from a road's start (or arrival at a road's end) means + # traveling opposite to the road's forward direction → flip by π. start_heading = self._override_with_road_heading( connecting_road.predecessor_id, connecting_road.predecessor_contact, - start_heading + start_heading, flip_at_start=True ) end_heading = self._override_with_road_heading( connecting_road.successor_id, connecting_road.successor_contact, - end_heading + end_heading, flip_at_start=False ) # Transform to local u/v frame (origin at start, u along heading) @@ -1055,14 +1057,26 @@ def _resolve_cr_heading(self, connecting_road: Road, path_meters: list, return 0.0 def _override_with_road_heading(self, road_id, contact_point, - current_heading: float) -> float: - """Override heading with connected road's actual heading if available.""" + current_heading: float, + flip_at_start: bool = True) -> float: + """Override heading with connected road's actual heading if available. + + flip_at_start controls directional semantics: + - True (predecessor/departure): flip when contact_point=="start" because + a CR departing from a road's start travels opposite the road's +s direction. + - False (successor/arrival): flip when contact_point=="end" because a CR + arriving at a road's end also travels opposite the road's +s direction. + """ road = self.road_map.get(road_id) if road and road.centerline_id: hdg = self._get_road_heading_at_contact_meters( road.centerline_id, contact_point ) if hdg is not None: + needs_flip = (flip_at_start and contact_point == "start") or \ + (not flip_at_start and contact_point == "end") + if needs_flip: + hdg += math.pi return hdg return current_heading @@ -1140,7 +1154,17 @@ def _build_cr_road_xml(self, connecting_road: Road, elif self.carla_compat: road_elem.append(etree.Element('signals')) - if self.carla_compat: + # Objects on connecting road (e.g. lampposts placed on connecting roads) + export_objects = getattr(self, '_export_objects', None) + if export_objects is None: + export_objects = self._get_export_objects() + cr_objects = self.object_builder.create_objects( + connecting_road, export_objects, connecting_road.inline_path, + geometry_elements=geometry_elements, road_length=road_length, + ) + if cr_objects is not None: + road_elem.append(cr_objects) + elif self.carla_compat: road_elem.append(etree.Element('objects')) return road_elem diff --git a/orbit/export/signal_builder.py b/orbit/export/signal_builder.py index 9328dd3..8a93f3b 100644 --- a/orbit/export/signal_builder.py +++ b/orbit/export/signal_builder.py @@ -44,7 +44,7 @@ def _project_point_onto_polyline(px: float, py: float, pts: List[tuple]): if dist < min_dist: min_dist = dist best_s = cumulative_s + t * seg_len - cross = (px - x1) * dy - (py - y1) * dx + cross = dx * (py - y1) - dy * (px - x1) # standard signed cross product best_t = (1.0 if cross >= 0 else -1.0) * dist cumulative_s += seg_len diff --git a/orbit/gui/dialogs/connecting_road_dialog.py b/orbit/gui/dialogs/connecting_road_dialog.py index ddf5567..23dc078 100644 --- a/orbit/gui/dialogs/connecting_road_dialog.py +++ b/orbit/gui/dialogs/connecting_road_dialog.py @@ -203,6 +203,28 @@ def setup_ui(self): convert_layout.addWidget(self.convert_to_polyline_btn) curve_layout.addRow("Convert:", convert_layout) + # Smooth Curve section — available for any geometry type + smooth_layout = self.add_form_group_with_info( + "Smooth Curve", + "Redistribute intermediate points along a smooth Bezier curve while " + "preserving the start/end positions and tangent angles." + ) + self.smooth_curve_btn = QPushButton("Smooth Curve") + self.smooth_curve_btn.setToolTip( + "Redistribute the curve's control points so the path is smooth and " + "G1-continuous with adjacent roads." + ) + self.smooth_curve_btn.clicked.connect(self.on_smooth_curve) + self.smooth_points_spin = QSpinBox() + self.smooth_points_spin.setRange(10, 500) + self.smooth_points_spin.setValue(50) + self.smooth_points_spin.setSuffix(" pts") + self.smooth_points_spin.setToolTip("Number of output points for the smoothed curve") + smooth_row = QHBoxLayout() + smooth_row.addWidget(self.smooth_points_spin) + smooth_row.addWidget(self.smooth_curve_btn) + smooth_layout.addRow("", smooth_row) + # Create standard OK/Cancel buttons self.create_button_box() @@ -459,6 +481,51 @@ def on_regenerate_curve(self): if self.connecting_road.id in image_view.connecting_road_lanes_items: image_view.connecting_road_lanes_items[self.connecting_road.id].update_graphics() + def on_smooth_curve(self): + """Redistribute inline_path points along a smooth Bezier curve.""" + from orbit.gui.undo_commands import SmoothCRCommand + from orbit.utils.geometry import fit_smooth_curve_to_polyline, get_smooth_cr_tangents + + path = self.connecting_road.inline_path + if not path or len(path) < 2: + QMessageBox.warning(self, "Smooth Curve", "No curve points to smooth.") + return + + cr = self.connecting_road + # Prefer stored headings (authoritative, set when CR was generated) + if cr.stored_start_heading is not None and cr.stored_end_heading is not None: + start_hdg, end_hdg = cr.stored_start_heading, cr.stored_end_heading + else: + tangents = get_smooth_cr_tangents(cr, self.project) + if tangents is None: + QMessageBox.warning( + self, "Smooth Curve", + "Could not determine tangent directions from adjacent roads." + ) + return + start_hdg, end_hdg = tangents + + old_path = list(path) + n_out = max(self.smooth_points_spin.value(), 2) + new_path = fit_smooth_curve_to_polyline(path, start_hdg, end_hdg, num_output_points=n_out) + cr.inline_path = new_path + + # Push undo command if main window has a stack + main_win = self.parent() + if hasattr(main_win, 'undo_stack') and hasattr(main_win, 'image_view'): + cmd = SmoothCRCommand( + main_win.image_view, + cr, + old_path, + new_path, + ) + main_win.undo_stack.push(cmd) + + # Update graphics + if hasattr(main_win, 'image_view'): + image_view = main_win.image_view + image_view.update_connecting_road_graphics(cr.id) + def accept(self): """Save changes and accept dialog.""" # Save predecessor/successor roads diff --git a/orbit/gui/dialogs/export_dialog.py b/orbit/gui/dialogs/export_dialog.py index 2aca250..c0d7df1 100644 --- a/orbit/gui/dialogs/export_dialog.py +++ b/orbit/gui/dialogs/export_dialog.py @@ -40,14 +40,16 @@ class ExportDialog(BaseDialog): """Dialog for OpenDrive export with preview.""" def __init__(self, project: Project, parent=None, xodr_schema_path: Optional[str] = None, - adjustment=None): + transformer_factory=None, adjustment=None): super().__init__("Export to OpenDrive", parent, min_width=700, min_height=600) self.project = project self.transformer: Optional[CoordinateTransformer] = None self.output_path: Optional[Path] = None - self.xodr_schema_path = xodr_schema_path # Path to XSD schema for validation (optional) - self._adjustment = adjustment # TransformAdjustment from interactive fine-tuning + self.xodr_schema_path = xodr_schema_path + self._transformer_factory = transformer_factory + # Legacy: adjustment kept for callers that don't provide a factory + self._adjustment = adjustment self.setup_ui() self.load_properties() @@ -252,14 +254,16 @@ def analyze_project(self): # Check georeferencing if self.project.has_georeferencing(): # Create transformer using project's method - self.transformer = create_transformer( - self.project.control_points, - self.project.transform_method, - use_validation=True, - ) - - if self.transformer and self._adjustment and not self._adjustment.is_identity(): - self.transformer.set_adjustment(self._adjustment) + if self._transformer_factory: + self.transformer = self._transformer_factory(use_validation=True) + else: + self.transformer = create_transformer( + self.project.control_points, + self.project.transform_method, + use_validation=True, + ) + if self.transformer and self._adjustment and not self._adjustment.is_identity(): + self.transformer.set_adjustment(self._adjustment) if self.transformer: method = self.project.transform_method.upper() @@ -453,22 +457,27 @@ def do_export(self): # Create a fresh transformer that uses pyproj for the export # projection. This ensures the homography/affine matrix is # computed in the same coordinate system written to the file. - export_transformer = create_transformer( - self.project.control_points, - self.project.transform_method, - use_validation=True, - export_proj_string=proj_string - ) + if self._transformer_factory: + export_transformer = self._transformer_factory( + use_validation=True, + export_proj_string=proj_string, + ) + else: + export_transformer = create_transformer( + self.project.control_points, + self.project.transform_method, + use_validation=True, + export_proj_string=proj_string + ) + # Apply manual alignment adjustment if one is active + if self._adjustment and not self._adjustment.is_identity(): + export_transformer.set_adjustment(self._adjustment) if not export_transformer: show_error(self, "Failed to create export transformer.", "Export Error") self.export_btn.setEnabled(True) self.status_label.setText("Export failed.") return - # Apply manual alignment adjustment if one is active - if self._adjustment and not self._adjustment.is_identity(): - export_transformer.set_adjustment(self._adjustment) - # Compute origin offset: project the selected origin lat/lon # to get the offset that will be subtracted from all coordinates. origin_selection = self.origin_combo.currentData() diff --git a/orbit/gui/dialogs/georeference_dialog.py b/orbit/gui/dialogs/georeference_dialog.py index 69574a0..62cbd7c 100644 --- a/orbit/gui/dialogs/georeference_dialog.py +++ b/orbit/gui/dialogs/georeference_dialog.py @@ -16,6 +16,7 @@ QComboBox, QDialogButtonBox, QDoubleSpinBox, + QFileDialog, QFormLayout, QGroupBox, QHBoxLayout, @@ -48,6 +49,8 @@ class GeoreferenceDialog(BaseDialog): pick_point_requested = pyqtSignal() # Signal emitted when control points are modified (added/removed) control_points_changed = pyqtSignal() + # Signal emitted when drone metadata changes (load/clear) + drone_metadata_changed = pyqtSignal() def __init__(self, project: Project, parent=None, verbose: bool = False): super().__init__("Georeferencing", parent, min_width=900, min_height=700) @@ -90,6 +93,7 @@ def setup_ui(self): self._create_control_points_section() self._create_add_point_section() + self._create_drone_log_section() self._create_status_section() self._create_uncertainty_section() @@ -126,6 +130,9 @@ def _create_control_points_section(self): self.import_csv_btn = QPushButton("Import from CSV...") self.import_csv_btn.clicked.connect(self.import_from_csv) table_button_layout.addWidget(self.import_csv_btn) + self.import_georef_btn = QPushButton("Import from Georef...") + self.import_georef_btn.clicked.connect(self.import_from_georef) + table_button_layout.addWidget(self.import_georef_btn) table_button_layout.addStretch() points_layout.addLayout(table_button_layout) @@ -179,6 +186,123 @@ def _create_add_point_section(self): add_group.setLayout(add_layout) self.get_main_layout().addWidget(add_group) + def _create_drone_log_section(self): + """Create the drone-assisted georeferencing section.""" + drone_group = QGroupBox("Drone-Assisted Georeferencing (optional)") + drone_layout = QVBoxLayout() + + self.drone_status_label = QLabel("Drone log: Not loaded") + drone_layout.addWidget(self.drone_status_label) + + btn_row = QHBoxLayout() + self.load_drone_log_btn = QPushButton("Load Drone Log (video_stats.json)…") + self.load_drone_log_btn.clicked.connect(self._load_drone_log) + btn_row.addWidget(self.load_drone_log_btn) + + self.clear_drone_log_btn = QPushButton("Clear") + self.clear_drone_log_btn.clicked.connect(self._clear_drone_log) + self.clear_drone_log_btn.setEnabled(self.project.drone_metadata is not None) + btn_row.addWidget(self.clear_drone_log_btn) + btn_row.addStretch() + drone_layout.addLayout(btn_row) + + self.heading_validation_label = QLabel() + self.heading_validation_label.setVisible(False) + drone_layout.addWidget(self.heading_validation_label) + + drone_group.setLayout(drone_layout) + self.get_main_layout().addWidget(drone_group) + + self._refresh_drone_status() + + def _refresh_drone_status(self): + """Update the drone log status label and heading validation line.""" + md = self.project.drone_metadata + if md is None: + self.drone_status_label.setText("Drone log: Not loaded") + self.clear_drone_log_btn.setEnabled(False) + self.heading_validation_label.setVisible(False) + return + + drone_str = md.drone_type or "Unknown drone" + hfov_str = f" | HFOV: {md.hfov_deg:.1f}°" if md.hfov_deg is not None else " | HFOV: from GCPs" + self.drone_status_label.setText( + f"Drone log: {drone_str} ({md.lens_type}) | " + f"Alt: {md.alt_agl:.1f} m | " + f"Heading: {md.gimbal_yaw:.1f}° | " + f"Pitch: {md.gimbal_pitch:.1f}°" + f"{hfov_str}" + ) + self.clear_drone_log_btn.setEnabled(True) + + # Always show heading info (declination always computable; GCP refinement shown when available) + self._update_heading_validation() + + def _update_heading_validation(self): + """Run heading estimation from GCPs and display the result.""" + try: + from orbit.utils.camera_model import DroneCameraModel + + md = self.project.drone_metadata + if md is None: + return + + if self.parent() and self.parent().image_view and self.parent().image_view.image_item: + px = self.parent().image_view.image_item.pixmap() + image_width, image_height = px.width(), px.height() + else: + self.heading_validation_label.setVisible(False) + return + + training_points = [cp for cp in self.project.control_points if not cp.is_validation] + model = DroneCameraModel( + metadata=md, + image_width=image_width, + image_height=image_height, + control_points=training_points if len(training_points) >= 2 else None, + ) + decl = model.declination_deg + refine = model.yaw_refinement_deg + effective = model.effective_yaw + ok = "✓" if abs(refine) < 5 else "⚠" + parts = [ + f"Log heading: {md.gimbal_yaw:.1f}°", + f"Declination: {decl:+.1f}°", + ] + if len(training_points) >= 2: + parts.append(f"GCP refinement: {refine:+.1f}°") + parts.append(f"Effective: {effective:.1f}°") + self.heading_validation_label.setText(f"{ok} " + " | ".join(parts)) + self.heading_validation_label.setVisible(True) + except Exception as e: + logger.debug(f"Heading validation failed: {e}") + self.heading_validation_label.setVisible(False) + + def _load_drone_log(self): + """Load a video_stats.json file and store drone metadata in project.""" + import json + + path, _ = QFileDialog.getOpenFileName( + self, "Load Drone Log", "", "Video Stats JSON (*.json);;All Files (*)" + ) + if not path: + return + try: + with open(path, encoding='utf-8') as f: + stats = json.load(f) + from orbit.models.project import DroneMetadata + self.project.drone_metadata = DroneMetadata.from_video_stats(stats) + self._refresh_drone_status() + self.drone_metadata_changed.emit() + except Exception as e: + show_error(self, f"Failed to load drone log:\n{e}", "Load Error") + + def _clear_drone_log(self): + """Remove drone metadata from project.""" + self.project.drone_metadata = None + self._refresh_drone_status() + self.drone_metadata_changed.emit() + def _create_status_section(self): """Create the status and validation section.""" status_group = QGroupBox("Status and Validation") @@ -525,8 +649,17 @@ def update_validation(self): # Determine minimum required min_required = 4 if self.project.transform_method == 'homography' else 3 + if self.project.transform_method == 'drone_assisted': + min_required = 0 # drone-assisted works with 0 GCPs (but GCPs improve focal length) - if len(training_points) < min_required: + if self.project.transform_method == 'drone_assisted' and self.project.drone_metadata is None: + self.validation_text.setText( + "Drone-assisted mode requires a drone log. Load one in the Drone Log section above." + ) + self.project.georef_validation = {} + return + + if self.project.transform_method != 'drone_assisted' and len(training_points) < min_required: self.validation_text.setText("Insufficient training points for validation.") self.project.georef_validation = {} self.analyze_gcp_btn.setEnabled(False) @@ -540,6 +673,7 @@ def update_validation(self): self.project.control_points, self.project.transform_method, use_validation=True, + drone_metadata=self.project.drone_metadata, ) if not transformer: @@ -547,6 +681,10 @@ def update_validation(self): self.project.georef_validation = {} return + # Update heading cross-check if drone mode + if self.project.drone_metadata is not None: + self._update_heading_validation() + # Build validation report report = [] report.append("=" * 60) @@ -671,6 +809,57 @@ def on_csv_import_cancelled(self): # Clear placer dialog reference when cancelled self.csv_placer_dialog = None + def import_from_georef(self): + """Import control points from a .georef file, replacing existing points.""" + path, _ = QFileDialog.getOpenFileName( + self, + "Import Georef File", + "", + "Georef Files (*_georef.json *.georef);;JSON Files (*.json);;All Files (*)", + ) + if not path: + return + + try: + import json + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + raw_points = data.get("control_points", []) + except Exception as e: + show_error(self, f"Failed to load georef file:\n{e}", "Import Error") + return + + if not raw_points: + show_warning(self, "The georef file contains no control points.", "Import Warning") + return + + new_points = [ + ControlPoint( + pixel_x=cp["pixel_x"], + pixel_y=cp["pixel_y"], + longitude=cp["longitude"], + latitude=cp["latitude"], + name=cp.get("name", ""), + is_validation=cp.get("is_validation", False), + ) + for cp in raw_points + ] + + self.project.control_points.clear() + self.project.control_points.extend(new_points) + + self.load_control_points() + self.update_status() + self.update_validation() + self.control_points_changed.emit() + + show_info( + self, + f"Imported {len(new_points)} control point(s) from georef file.\n" + "Drag the markers on the image to adjust their pixel positions.", + "Import Complete", + ) + def update_uncertainty_statistics(self): """Display uncertainty statistics based on current control points.""" training_points = [cp for cp in self.project.control_points if not cp.is_validation] @@ -1082,11 +1271,21 @@ def show_gcp_analysis(self): show_warning(self, "Need at least 4 training points for GCP analysis.", "Insufficient Points") return + # Get image dimensions (needed for drone-assisted mode) + image_width = image_height = 0 + parent_window = self.parent() + if hasattr(parent_window, 'image_view') and parent_window.image_view.image_item: + pixmap = parent_window.image_view.image_item.pixmap() + image_width, image_height = pixmap.width(), pixmap.height() + # Create transformer transformer = create_transformer( self.project.control_points, self.project.transform_method, use_validation=True, + drone_metadata=self.project.drone_metadata, + image_width=image_width, + image_height=image_height, ) if not transformer: diff --git a/orbit/gui/dialogs/object_properties_dialog.py b/orbit/gui/dialogs/object_properties_dialog.py index 207bda6..c4f9226 100644 --- a/orbit/gui/dialogs/object_properties_dialog.py +++ b/orbit/gui/dialogs/object_properties_dialog.py @@ -104,7 +104,8 @@ def setup_ui(self): self.road_combo = QComboBox() self.road_combo.addItem("(None)", None) for road in self.project.roads: - self.road_combo.addItem(road.name or f"Road {road.id[:8]}", road.id) + label = f"Road {road.id}" + (f" – {road.name}" if road.name else "") + self.road_combo.addItem(label, road.id) self.road_combo.currentIndexChanged.connect(self.on_road_changed) road_layout.addRow("Assigned Road:", self.road_combo) diff --git a/orbit/gui/dialogs/opendrive_import_dialog.py b/orbit/gui/dialogs/opendrive_import_dialog.py index 5d87de5..4189117 100644 --- a/orbit/gui/dialogs/opendrive_import_dialog.py +++ b/orbit/gui/dialogs/opendrive_import_dialog.py @@ -6,6 +6,7 @@ QCheckBox, QDoubleSpinBox, QFileDialog, + QGridLayout, QGroupBox, QHBoxLayout, QLabel, @@ -16,6 +17,9 @@ QVBoxLayout, ) +from orbit.models.object import ObjectType +from orbit.utils.enum_formatting import format_enum_name + from ..utils.message_helpers import show_warning from .base_dialog import BaseDialog @@ -37,6 +41,8 @@ def __init__(self, has_georeferencing: bool = False, verbose: bool = False, pare self.has_georeferencing = has_georeferencing self.verbose = verbose + self._feature_checkboxes: dict = {} # key → QCheckBox + self._feature_group: QGroupBox | None = None self.setup_ui() self.load_properties() @@ -143,8 +149,9 @@ def setup_ui(self): self.get_main_layout().addLayout(button_layout) - # Update button state + # Update button state and scan file on path change self.file_path_edit.textChanged.connect(self._update_button_state) + self.file_path_edit.textChanged.connect(self._on_file_path_changed) def load_properties(self): """Load initial property values.""" @@ -186,6 +193,72 @@ def _update_button_state(self): has_file = bool(self.file_path_edit.text().strip()) self.import_btn.setEnabled(has_file) + def _on_file_path_changed(self, path: str): + """Scan file and update feature category checkboxes when file is selected.""" + from pathlib import Path + p = path.strip() + self._rebuild_feature_group(None) # Remove old group + if not p or not Path(p).is_file(): + return + try: + import importlib + mod = importlib.import_module('orbit.import.opendrive_importer') + counts = mod.scan_xodr_feature_categories(p) + except Exception: + return + if counts: + self._rebuild_feature_group(counts) + + def _rebuild_feature_group(self, counts: dict | None): + """Remove existing feature group and rebuild from counts (or hide if None).""" + layout = self.get_main_layout() + + # Remove old group if present + if self._feature_group is not None: + layout.removeWidget(self._feature_group) + self._feature_group.deleteLater() + self._feature_group = None + self._feature_checkboxes.clear() + + if not counts: + return + + feature_group = QGroupBox("Feature Categories") + feature_layout = QGridLayout() + row, col = 0, 0 + + # Signals and parking first + for key, label in [('signals', 'Signals'), ('parking', 'Parking')]: + count = counts.get(key, 0) + if count: + cb = QCheckBox(f"{label} ({count})") + cb.setChecked(True) + self._feature_checkboxes[key] = cb + feature_layout.addWidget(cb, row, col) + col += 1 + if col >= 3: + col = 0 + row += 1 + + # Per-ObjectType + for obj_type in ObjectType: + count = counts.get(obj_type, 0) + if not count: + continue + cb = QCheckBox(f"{format_enum_name(obj_type)} ({count})") + cb.setChecked(True) + self._feature_checkboxes[obj_type] = cb + feature_layout.addWidget(cb, row, col) + col += 1 + if col >= 3: + col = 0 + row += 1 + + feature_group.setLayout(feature_layout) + # Insert before the button row (last item) + layout.insertWidget(layout.count() - 1, feature_group) + self._feature_group = feature_group + def _update_info_text(self): """Update info text.""" html = """ @@ -207,7 +280,7 @@ def _update_info_text(self):
  • Lanes: Width (constant), type, road marks
  • Elevation: Stored and displayed in polyline properties
  • Signals: Speed limits, traffic lights, stop signs, give way, etc.
  • -
  • Objects: Lampposts, guardrails, buildings, trees, bushes
  • +
  • Objects: Lampposts, guardrails, buildings, trees, bushes, land use
  • Note: Lateral profiles (superelevation, crossfall) are not supported (2D only). @@ -240,6 +313,31 @@ def get_verbose(self) -> bool: """Get verbose output setting (from --verbose flag).""" return self.verbose + def get_import_filter(self) -> tuple: + """Return (import_signals, import_parking, import_object_types). + + import_object_types is None if all types are selected (or no filter shown). + """ + if not self._feature_checkboxes: + return True, True, None + + import_signals = self._feature_checkboxes.get('signals', None) + import_signals = import_signals.isChecked() if import_signals else True + + import_parking = self._feature_checkboxes.get('parking', None) + import_parking = import_parking.isChecked() if import_parking else True + + obj_type_boxes = {k: v for k, v in self._feature_checkboxes.items() + if isinstance(k, ObjectType)} + if not obj_type_boxes: + import_object_types = None + elif all(cb.isChecked() for cb in obj_type_boxes.values()): + import_object_types = None # All checked → no filter + else: + import_object_types = {ot for ot, cb in obj_type_boxes.items() if cb.isChecked()} + + return import_signals, import_parking, import_object_types + def accept(self): """Handle accept (validate before closing).""" file_path = self.get_file_path() diff --git a/orbit/gui/dialogs/preferences_dialog.py b/orbit/gui/dialogs/preferences_dialog.py index e27c6f5..7d1625b 100644 --- a/orbit/gui/dialogs/preferences_dialog.py +++ b/orbit/gui/dialogs/preferences_dialog.py @@ -6,9 +6,18 @@ """ from PyQt6.QtCore import Qt -from PyQt6.QtWidgets import QAbstractItemView, QComboBox, QDoubleSpinBox, QLineEdit, QListWidget, QListWidgetItem +from PyQt6.QtWidgets import ( + QAbstractItemView, + QCheckBox, + QComboBox, + QDoubleSpinBox, + QLineEdit, + QListWidget, + QListWidgetItem, +) from orbit.models import Project, SignLibraryManager +from orbit.utils.provenance import DEFAULT_TEMPLATE, is_dataprov_available from .base_dialog import BaseDialog, InfoIconLabel @@ -20,6 +29,8 @@ def __init__(self, project: Project, parent=None): super().__init__("Project Preferences", parent, min_width=500) self.project = project + from PyQt6.QtCore import QSettings + self.app_settings = QSettings() self.setup_ui() self.load_properties() @@ -47,11 +58,15 @@ def setup_ui(self): self.transform_method_combo = QComboBox() self.transform_method_combo.addItem("Affine (for orthophotos, satellite imagery)", "affine") self.transform_method_combo.addItem("Homography (for oblique drone imagery)", "homography") + self.transform_method_combo.addItem("Drone-assisted (requires drone log)", "drone_assisted") transform_label = InfoIconLabel( "Transformation Method:", "Affine: Best for nadir (straight down) aerial/satellite images. Requires 3+ control points.\n" - "Homography: Best for tilted camera drone images with perspective. Requires 4+ control points.", + "Homography: Best for tilted camera drone images with perspective. Requires 4+ control points.\n" + "Drone-assisted: Uses drone flight log (position, altitude, gimbal) for a physically-derived\n" + " homography. Requires drone log loaded in the Georeferencing dialog. Works even when\n" + " GCPs are nearly collinear (e.g., all along one road).", bold=False ) georef_layout.addRow(transform_label, self.transform_method_combo) @@ -157,6 +172,33 @@ def setup_ui(self): ) sign_layout.addRow(library_label, self.library_list) + # Provenance tracking section + prov_layout = self.add_form_group("Data Provenance") + + dataprov_available = is_dataprov_available() + hint = "" if dataprov_available else " (install dataprov to enable)" + + self.provenance_checkbox = QCheckBox(f"Create provenance sidecar files{hint}") + self.provenance_checkbox.setEnabled(dataprov_available) + self.provenance_checkbox.setToolTip( + "When enabled, a .prov.json file is written alongside each saved project " + "and each exported file, recording the tools and inputs used to create it." + ) + prov_layout.addRow("", self.provenance_checkbox) + + self.provenance_template_edit = QLineEdit() + self.provenance_template_edit.setPlaceholderText(DEFAULT_TEMPLATE) + self.provenance_template_edit.setEnabled(dataprov_available) + template_label = InfoIconLabel( + "File name template:", + "Template for provenance sidecar file names.\n" + "Variables: {dir} parent directory, {stem} filename without extension, " + "{ext} extension (with dot), {name} full filename.\n" + f"Default: {DEFAULT_TEMPLATE}", + bold=False, + ) + prov_layout.addRow(template_label, self.provenance_template_edit) + # Create standard OK/Cancel buttons self.create_button_box() @@ -168,9 +210,19 @@ def load_properties(self): # Transformation method if self.project.transform_method == 'homography': self.transform_method_combo.setCurrentIndex(1) + elif self.project.transform_method == 'drone_assisted': + self.transform_method_combo.setCurrentIndex(2) else: self.transform_method_combo.setCurrentIndex(0) + # Disable drone-assisted if no drone log is loaded + model = self.transform_method_combo.model() + drone_item = model.item(2) + if self.project.drone_metadata is None: + from PyQt6.QtGui import QColor + drone_item.setEnabled(False) + drone_item.setForeground(QColor('gray')) + # Traffic side if self.project.right_hand_traffic: self.traffic_combo.setCurrentIndex(0) @@ -195,6 +247,14 @@ def load_properties(self): if lib_id in enabled_libs: item.setSelected(True) + # Provenance settings (app-level, from QSettings) + self.provenance_checkbox.setChecked( + self.app_settings.value("provenance/enabled", False, type=bool) + ) + self.provenance_template_edit.setText( + self.app_settings.value("provenance/name_template", DEFAULT_TEMPLATE, type=str) + ) + def accept(self): """Save preferences and close dialog.""" # Save map name @@ -227,4 +287,9 @@ def accept(self): enabled_libs = ['se'] self.project.enabled_sign_libraries = enabled_libs + # Save provenance settings to app QSettings + self.app_settings.setValue("provenance/enabled", self.provenance_checkbox.isChecked()) + template = self.provenance_template_edit.text().strip() or DEFAULT_TEMPLATE + self.app_settings.setValue("provenance/name_template", template) + super().accept() diff --git a/orbit/gui/graphics/object_graphics.py b/orbit/gui/graphics/object_graphics.py index 0989382..5892bd0 100644 --- a/orbit/gui/graphics/object_graphics.py +++ b/orbit/gui/graphics/object_graphics.py @@ -18,7 +18,7 @@ ObjectType.TREE_CONIFER: QColor(34, 139, 34, 77), # Forest green ObjectType.BUSH: QColor(34, 139, 34, 77), # Forest green ObjectType.GUARDRAIL: QColor(25, 25, 112, 77), # Dark blue - ObjectType.LAMPPOST: QColor(255, 255, 255, 77), # White + ObjectType.LAMPPOST: QColor(255, 200, 0, 220), # Amber/yellow # Land use areas ObjectType.LANDUSE_FOREST: QColor(0, 100, 0, 77), # Dark green ObjectType.LANDUSE_FARMLAND: QColor(210, 180, 100, 77), # Wheat/tan @@ -47,11 +47,11 @@ def create_lamppost_path(scale: float = 1.0) -> QPainterPath: path = QPainterPath() # Small circle for pole base - radius = 3.0 * scale + radius = 5.0 * scale path.addEllipse(-radius, -radius, radius * 2, radius * 2) # Orientation line (pointing direction) - line_length = 10.0 * scale + line_length = 14.0 * scale path.moveTo(0, 0) path.lineTo(line_length, 0) diff --git a/orbit/gui/graphics/object_graphics_item.py b/orbit/gui/graphics/object_graphics_item.py index 4a91e23..73106cb 100644 --- a/orbit/gui/graphics/object_graphics_item.py +++ b/orbit/gui/graphics/object_graphics_item.py @@ -372,13 +372,11 @@ def get_segment_at(self, scene_pos: QPointF, tolerance: float = 8.0) -> int: return -1 def set_selected(self, selected: bool): - """ - Set selection state of the object. - - Args: - selected: True to select, False to deselect - """ + """Set selection state of the object.""" self._is_selected = selected - self.selection_item.setVisible(selected) + try: + self.selection_item.setVisible(selected) + except RuntimeError: + return if self._is_polygon_with_points(): self.update_graphics() # Refresh vertex handles diff --git a/orbit/gui/image_view.py b/orbit/gui/image_view.py index e36920b..4e0fd70 100644 --- a/orbit/gui/image_view.py +++ b/orbit/gui/image_view.py @@ -14,12 +14,15 @@ from PyQt6.QtWidgets import ( QGraphicsEllipseItem, QGraphicsItem, + QGraphicsItemGroup, QGraphicsLineItem, QGraphicsPathItem, QGraphicsPixmapItem, QGraphicsRectItem, QGraphicsScene, + QGraphicsTextItem, QGraphicsView, + QInputDialog, QMenu, QMessageBox, ) @@ -41,6 +44,75 @@ from .utils.message_helpers import ask_yes_no, show_warning +class ControlPointItem(QGraphicsItemGroup): + """Draggable crosshair marker for a georeferencing control point.""" + + ARM_LENGTH = 10 + GAP = 3 + + def __init__(self, control_point, parent=None): + super().__init__(parent) + self.control_point = control_point + self.moved_callback = None # Set to callable(ControlPoint) after creation + + self.setFlag(QGraphicsItemGroup.GraphicsItemFlag.ItemIsMovable, True) + self.setFlag(QGraphicsItemGroup.GraphicsItemFlag.ItemSendsGeometryChanges, True) + self.setFlag(QGraphicsItemGroup.GraphicsItemFlag.ItemIsSelectable, True) + self.setAcceptHoverEvents(True) + + pen = QPen(QColor(0, 100, 255), 2) + dot_pen = QPen(QColor(0, 100, 255), 1) + dot_brush = QBrush(QColor(0, 100, 255)) + a, g = self.ARM_LENGTH, self.GAP + + for item in [ + QGraphicsLineItem(-a, 0, -g, 0), + QGraphicsLineItem(g, 0, a, 0), + QGraphicsLineItem(0, -a, 0, -g), + QGraphicsLineItem(0, g, 0, a), + ]: + item.setPen(pen) + item.setZValue(10) + self.addToGroup(item) + + dot = QGraphicsEllipseItem(-0.5, -0.5, 1, 1) + dot.setPen(dot_pen) + dot.setBrush(dot_brush) + dot.setZValue(10) + self.addToGroup(dot) + + if control_point.name: + font = QFont() + font.setBold(True) + font.setPointSize(10) + label = QGraphicsTextItem(control_point.name) + label.setDefaultTextColor(QColor(0, 100, 255)) + label.setFont(font) + label.setPos(15, -10) + label.setZValue(11) + self.addToGroup(label) + + self.setPos(control_point.pixel_x, control_point.pixel_y) + self.setZValue(10) + + def itemChange(self, change, value): + if change == QGraphicsItemGroup.GraphicsItemChange.ItemPositionHasChanged: + pos = self.pos() + self.control_point.pixel_x = pos.x() + self.control_point.pixel_y = pos.y() + if self.moved_callback: + self.moved_callback(self.control_point) + return super().itemChange(change, value) + + def hoverEnterEvent(self, event): + self.setCursor(Qt.CursorShape.SizeAllCursor) + super().hoverEnterEvent(event) + + def hoverLeaveEvent(self, event): + self.unsetCursor() + super().hoverLeaveEvent(event) + + class ImageView(QGraphicsView): """Interactive image view with polyline drawing and editing.""" @@ -90,6 +162,7 @@ class ImageView(QGraphicsView): # dragged_road_id, target_road_id, dragged_contact, target_contact road_link_requested = pyqtSignal(str, str, str, str) road_unlink_requested = pyqtSignal(str, str) # road_id, linked_road_id (for disconnect) + control_point_moved = pyqtSignal(object) # Emits ControlPoint when dragged def __init__(self, parent=None, verbose: bool = False): super().__init__(parent) @@ -282,9 +355,9 @@ def set_synthetic_canvas(self, width: int, height: int, color=None): pixmap = QPixmap(width, height) pixmap.fill(color) - # Clear scene and add synthetic pixmap + # Clear scene and all item tracking dicts self.scene.clear() - self.polyline_items.clear() + self._clear_item_dicts() self.image_item = self.scene.addPixmap(pixmap) self.image_item.setZValue(0) self.image_np = None @@ -317,9 +390,9 @@ def load_image(self, image_path: Path): ) pixmap = QPixmap.fromImage(q_image) - # Clear scene and add image + # Clear scene and all item tracking dicts self.scene.clear() - self.polyline_items.clear() + self._clear_item_dicts() self.image_item = self.scene.addPixmap(pixmap) self.image_item.setZValue(0) @@ -1031,49 +1104,11 @@ def update_object_scale_factors(self, scale_factor: float): item.update_scale_factor(scale_factor) def add_control_point_graphics(self, control_point): - """Add a control point marker to the graphics scene as a crosshair.""" - - x, y = control_point.pixel_x, control_point.pixel_y - - # Crosshair parameters - arm_length = 10 # Length of each arm from center - gap = 3 # Gap radius at center (so target pixel is visible) - - # Main crosshair pen (bright blue) - pen = QPen(QColor(0, 100, 255), 2) - - # Horizontal arms (left and right of center gap) - left_arm = self.scene.addLine(x - arm_length, y, x - gap, y, pen) - right_arm = self.scene.addLine(x + gap, y, x + arm_length, y, pen) - - # Vertical arms (top and bottom of center gap) - top_arm = self.scene.addLine(x, y - arm_length, x, y - gap, pen) - bottom_arm = self.scene.addLine(x, y + gap, x, y + arm_length, pen) - - # Tiny center dot for exact position reference - dot_pen = QPen(QColor(0, 100, 255), 1) - dot_brush = QBrush(QColor(0, 100, 255)) - center_dot = self.scene.addEllipse(x - 0.5, y - 0.5, 1, 1, dot_pen, dot_brush) - - # Set z-values and add to tracking list - for item in [left_arm, right_arm, top_arm, bottom_arm, center_dot]: - item.setZValue(10) - self.control_point_items.append(item) - - # Add label with CP name - if control_point.name: - from PyQt6.QtGui import QFont - from PyQt6.QtWidgets import QGraphicsTextItem - text_item = QGraphicsTextItem(control_point.name) - text_item.setDefaultTextColor(QColor(0, 100, 255)) # Bright blue - font = QFont() - font.setBold(True) - font.setPointSize(10) - text_item.setFont(font) - text_item.setPos(x + 15, y - 10) - text_item.setZValue(11) - self.scene.addItem(text_item) - self.control_point_items.append(text_item) + """Add a draggable control point marker to the graphics scene.""" + item = ControlPointItem(control_point) + item.moved_callback = lambda cp: self.control_point_moved.emit(cp) + self.scene.addItem(item) + self.control_point_items.append(item) def add_road_lanes_graphics(self, road: Road, scale_factors: tuple = None): """ @@ -1651,6 +1686,20 @@ def update_polyline(self, polyline_id: str): if self.soffsets_visible: self._update_soffset_labels(polyline_id) + def _clear_item_dicts(self): + """Clear all item tracking dicts after scene.clear() has been called.""" + self.polyline_items.clear() + self.junction_items.clear() + self.signal_items.clear() + self.object_items.clear() + self.parking_items.clear() + self.control_point_items.clear() + self.road_lanes_items.clear() + self.connecting_road_centerline_items.clear() + self.connecting_road_lanes_items.clear() + self.soffset_labels.clear() + self.junction_debug_items.clear() + def safe_remove_item(self, item: QGraphicsItem) -> bool: """ Safely remove a graphics item from the scene. @@ -1686,11 +1735,7 @@ def safe_remove_items(self, items: List[QGraphicsItem]): def clear(self): """Clear the view.""" self.scene.clear() - self.polyline_items.clear() - self.junction_items.clear() - self.control_point_items.clear() - self.road_lanes_items.clear() - self.soffset_labels.clear() # Clear s-offset labels + self._clear_item_dicts() self.project = None self.image_item = None self.image_np = None @@ -2445,6 +2490,18 @@ def select_object(self, object_id: str): x, y = obj.position self.centerOn(x, y) + def _adjacent_section_for_cr(self, cr, adjacent) -> int | None: + """Section number of `adjacent` road that touches `cr`, honoring contact point.""" + if not adjacent.lane_sections: + return None + if cr.predecessor_id == adjacent.id: + section_idx = 0 if cr.predecessor_contact == "start" else -1 + elif cr.successor_id == adjacent.id: + section_idx = 0 if cr.successor_contact == "start" else -1 + else: + section_idx = 0 + return adjacent.lane_sections[section_idx].section_number + def _get_connecting_road_lane_id(self, junction, connecting_road_id: str, source_lane_id: int) -> int | None: """ Determine which lane on a connecting road corresponds to a source lane. @@ -2536,50 +2593,67 @@ def find_connected_lanes(self, road_id: str, section_number: int, lane_id: int) if road.successor_id and road.successor_id in connected_ids: skip_successor = True - # 1. Check direct road predecessor/successor links (not through junctions) - # Skip if both roads are in the same junction - junction connections take precedence - if is_first_section and road.predecessor_id and not skip_predecessor: - pred_road = self.project.get_road(road.predecessor_id) - if pred_road and pred_road.lane_sections: - # Predecessor connects at its last section - pred_section = pred_road.lane_sections[-1].section_number - # Assume same lane exists in connected road (common case for continuous roads) - result['road_lanes'].append((road.predecessor_id, pred_section, lane_id)) - - if is_last_section and road.successor_id and not skip_successor: - succ_road = self.project.get_road(road.successor_id) - if succ_road and succ_road.lane_sections: - # Successor connects at its first section - succ_section = succ_road.lane_sections[0].section_number - # Assume same lane exists in connected road (common case for continuous roads) - result['road_lanes'].append((road.successor_id, succ_section, lane_id)) + # 1. Check direct road predecessor/successor links (not through junctions). + # Skip if both roads are in the same junction - junction connections take precedence. + # Also skip for connecting roads: their neighbor lane mapping requires sign flips at + # opposite-direction contacts, which are fully described by junction.lane_connections below. + if not road.is_connecting_road: + if is_first_section and road.predecessor_id and not skip_predecessor: + pred_road = self.project.get_road(road.predecessor_id) + if pred_road and pred_road.lane_sections: + pred_section = pred_road.lane_sections[-1].section_number + result['road_lanes'].append((road.predecessor_id, pred_section, lane_id)) + + if is_last_section and road.successor_id and not skip_successor: + succ_road = self.project.get_road(road.successor_id) + if succ_road and succ_road.lane_sections: + succ_section = succ_road.lane_sections[0].section_number + result['road_lanes'].append((road.successor_id, succ_section, lane_id)) # 2. Search all junctions for lane connections involving this lane for junction in self.project.junctions: for lane_conn in junction.lane_connections: - # Check if this lane is the source (find successor via junction) - if lane_conn.from_road_id == road_id and lane_conn.from_lane_id == lane_id: - # Only consider if this is the last section (connects to junction) - if is_last_section: - # Only add connecting road lane - don't show destination road beyond junction - if lane_conn.connecting_road_id: - conn_lane_id = self._get_connecting_road_lane_id( - junction, lane_conn.connecting_road_id, lane_id + # Case A: this road is the incoming source (find CR lane at junction). + if (lane_conn.from_road_id == road_id and lane_conn.from_lane_id == lane_id + and is_last_section and lane_conn.connecting_road_id): + conn_lane_id = lane_conn.connecting_lane_id + if conn_lane_id is None: + conn_lane_id = self._get_connecting_road_lane_id( + junction, lane_conn.connecting_road_id, lane_id + ) + if conn_lane_id is not None: + result['connecting_road_lanes'].append((lane_conn.connecting_road_id, conn_lane_id)) + + # Case B: this road is the outgoing destination (find CR lane at junction). + if (lane_conn.to_road_id == road_id and lane_conn.to_lane_id == lane_id + and is_first_section and lane_conn.connecting_road_id): + conn_lane_id = lane_conn.connecting_lane_id + if conn_lane_id is None: + conn_lane_id = self._get_connecting_road_lane_id( + junction, lane_conn.connecting_road_id, lane_conn.from_lane_id + ) + if conn_lane_id is not None: + result['connecting_road_lanes'].append((lane_conn.connecting_road_id, conn_lane_id)) + + # Case C: this road IS the connecting road — highlight incoming/outgoing road lanes + # using the stored from/to lane IDs (which already encode direction-flip semantics). + if (lane_conn.connecting_road_id == road_id + and lane_conn.connecting_lane_id is not None + and lane_conn.connecting_lane_id == lane_id): + from_road = self.project.get_road(lane_conn.from_road_id) + if from_road and from_road.lane_sections: + from_section = self._adjacent_section_for_cr(road, from_road) + if from_section is not None: + result['road_lanes'].append( + (lane_conn.from_road_id, from_section, lane_conn.from_lane_id) ) - if conn_lane_id is not None: - result['connecting_road_lanes'].append((lane_conn.connecting_road_id, conn_lane_id)) - - # Check if this lane is the destination (find predecessor via junction) - if lane_conn.to_road_id == road_id and lane_conn.to_lane_id == lane_id: - # Only consider if this is the first section (connects from junction) - if is_first_section: - # Only add connecting road lane - don't show source road beyond junction - if lane_conn.connecting_road_id: - conn_lane_id = self._get_connecting_road_lane_id( - junction, lane_conn.connecting_road_id, lane_conn.from_lane_id + to_road = self.project.get_road(lane_conn.to_road_id) + if to_road and to_road.lane_sections: + to_section = self._adjacent_section_for_cr(road, to_road) + if to_section is not None: + result['road_lanes'].append( + (lane_conn.to_road_id, to_section, lane_conn.to_lane_id) ) - if conn_lane_id is not None: - result['connecting_road_lanes'].append((lane_conn.connecting_road_id, conn_lane_id)) return result @@ -2873,18 +2947,24 @@ def select_connecting_road_lane(self, connecting_road_id: str, lane_id: int): if not conn_road or not conn_road.is_connecting_road: continue - # Check if this lane connection corresponds to the selected lane - expected_lane = self._get_connecting_road_lane_id( - junction, connecting_road_id, lane_conn.from_lane_id - ) - if expected_lane != lane_id: - continue + # Check if this lane connection corresponds to the selected lane. + # Prefer the stored connecting_lane_id (already encodes direction flips); + # fall back to the ordinal helper only for legacy data without it. + if lane_conn.connecting_lane_id is not None: + if lane_conn.connecting_lane_id != lane_id: + continue + else: + expected_lane = self._get_connecting_road_lane_id( + junction, connecting_road_id, lane_conn.from_lane_id + ) + if expected_lane != lane_id: + continue - # Highlight the from_road lane (last section) + # Highlight the from_road lane at the section that touches the CR from_road = self.project.get_road(lane_conn.from_road_id) if from_road and from_road.lane_sections: - from_section = from_road.lane_sections[-1].section_number - if lane_conn.from_road_id in self.road_lanes_items: + from_section = self._adjacent_section_for_cr(conn_road, from_road) + if from_section is not None and lane_conn.from_road_id in self.road_lanes_items: lanes_item = self.road_lanes_items[lane_conn.from_road_id] for lane_polygon in lanes_item.lane_items: if (isinstance(lane_polygon, InteractiveLanePolygon) and @@ -2894,11 +2974,11 @@ def select_connecting_road_lane(self, connecting_road_id: str, lane_id: int): lane_polygon.set_linked(True) self.linked_lane_polygons.append(lane_polygon) - # Highlight the to_road lane (first section) + # Highlight the to_road lane at the section that touches the CR to_road = self.project.get_road(lane_conn.to_road_id) if to_road and to_road.lane_sections: - to_section = to_road.lane_sections[0].section_number - if lane_conn.to_road_id in self.road_lanes_items: + to_section = self._adjacent_section_for_cr(conn_road, to_road) + if to_section is not None and lane_conn.to_road_id in self.road_lanes_items: lanes_item = self.road_lanes_items[lane_conn.to_road_id] for lane_polygon in lanes_item.lane_items: if (isinstance(lane_polygon, InteractiveLanePolygon) and @@ -3304,6 +3384,14 @@ def _show_centerline_point_menu(self, view_pos, polyline_id: str, point_index: i disconnect_action = menu.addAction(f"Disconnect from '{linked_name}'") linked_road_id = road.successor_id + # Smooth Curve option — always shown for road centerlines + menu.addSeparator() + smooth_action = menu.addAction("Smooth Road Curve") + smooth_action.setToolTip( + "Redistribute polyline points along a smooth Bezier curve, " + "keeping start/end positions and tangent directions." + ) + # Show menu and get selected action action = menu.exec(self.mapToGlobal(view_pos)) @@ -3311,6 +3399,13 @@ def _show_centerline_point_menu(self, view_pos, polyline_id: str, point_index: i self._delete_point(polyline_id, point_index) elif disconnect_action and action == disconnect_action and road and linked_road_id: self.road_unlink_requested.emit(road.id, linked_road_id) + elif action == smooth_action: + n_pts, ok = QInputDialog.getInt( + self, "Smooth Road Curve", "Number of output points:", + value=50, min=5, max=500, step=5, + ) + if ok: + self._smooth_road_polyline(polyline_id, n_pts) elif action == split_section_action and road: # Warn if creating a small section s_coords = road.calculate_centerline_s_coordinates(polyline.points) @@ -3340,6 +3435,143 @@ def _show_centerline_point_menu(self, view_pos, polyline_id: str, point_index: i # Emit signal for MainWindow to handle road splitting self.road_split_requested.emit(road.id, polyline_id, point_index) + def _smooth_road_polyline(self, polyline_id: str, num_points: int = 50) -> None: + """Smooth a regular road's centerline using adjacent road tangents.""" + import math as _math + + from orbit.gui.project_controller import get_contact_pos_heading + from orbit.gui.undo_commands import ModifyPolylineCommand + from orbit.utils.geometry import fit_smooth_curve_to_polyline + + item = self.polyline_items.get(polyline_id) + if not item: + return + pts = list(item.polyline.points) + if len(pts) < 3: + return + + # Default tangents from first/last segment + start_hdg = _math.atan2(pts[1][1] - pts[0][1], pts[1][0] - pts[0][0]) + end_hdg = _math.atan2(pts[-1][1] - pts[-2][1], pts[-1][0] - pts[-2][0]) + + # Override with adjacent road tangents when available (more accurate) + road = self._find_road_by_centerline(polyline_id) + if road and self.project: + if road.predecessor_id: + pred_road = self.project.get_road(road.predecessor_id) + if pred_road: + pred_pl = self.project.get_polyline(pred_road.centerline_id) + if pred_pl: + _, h = get_contact_pos_heading(pred_pl, road.predecessor_contact) + if road.predecessor_contact == "start": + h += _math.pi + start_hdg = h + if road.successor_id: + succ_road = self.project.get_road(road.successor_id) + if succ_road: + succ_pl = self.project.get_polyline(succ_road.centerline_id) + if succ_pl: + _, h = get_contact_pos_heading(succ_pl, road.successor_contact) + if road.successor_contact == "end": + h += _math.pi + end_hdg = h + + # Use num_points; ensure at least original count so we don't lose resolution + n_out = max(num_points, 2) + new_pts = fit_smooth_curve_to_polyline(pts, start_hdg, end_hdg, num_output_points=n_out) + + old_geo = list(item.polyline.geo_points) if item.polyline.geo_points else None + + # Apply first (convention: caller applies, then pushes undo command) + item.polyline.points = new_pts + item.polyline.geo_points = None # pixel coords are now primary + item.update_graphics() + + cmd = ModifyPolylineCommand( + self.parent(), + polyline_id, + old_points=pts, + new_points=new_pts, + old_geo_points=old_geo, + new_geo_points=None, + description="Smooth Road Curve", + ) + if self.parent() and hasattr(self.parent(), 'undo_stack'): + self.parent().undo_stack.push(cmd) + + def _show_cr_centerline_menu(self, view_pos, conn_road_id: str, point_index: int) -> None: + """Context menu for a connecting road centerline (right-click on point or segment).""" + from orbit.gui.undo_commands import SmoothCRCommand + from orbit.utils.geometry import fit_smooth_curve_to_polyline, get_smooth_cr_tangents + + item = self.connecting_road_centerline_items.get(conn_road_id) + if not item: + return + cr = item.connecting_road + + menu = QMenu() + + # "Delete Point" only for polyline CRs with a clicked point + delete_action = None + if cr.geometry_type == "polyline" and point_index >= 0 and len(cr.inline_path) > 2: + delete_action = menu.addAction("Delete Point") + menu.addSeparator() + + smooth_action = menu.addAction("Smooth Curve") + smooth_action.setToolTip( + "Redistribute the curve's points along a smooth Bezier using " + "adjacent road tangents. Works for all geometry types." + ) + + action = menu.exec(self.mapToGlobal(view_pos)) + + if delete_action and action == delete_action: + cr.inline_path.pop(point_index) + if cr.inline_geo_path is not None: + transformer = self._get_geo_transformer() + if transformer: + cr.inline_geo_path = [ + transformer.pixel_to_geo(x, y) for x, y in cr.inline_path + ] + elif 0 <= point_index < len(cr.inline_geo_path): + cr.inline_geo_path.pop(point_index) + item.update_graphics() + if conn_road_id in self.connecting_road_lanes_items: + self.connecting_road_lanes_items[conn_road_id].update_graphics() + + elif action == smooth_action: + path = cr.inline_path + if not path or len(path) < 2: + return + # Prefer stored headings (set by _regenerate_parampoly3_cr, authoritative) + if cr.stored_start_heading is not None and cr.stored_end_heading is not None: + start_hdg, end_hdg = cr.stored_start_heading, cr.stored_end_heading + else: + tangents = get_smooth_cr_tangents(cr, self.project) + if tangents is None: + return + start_hdg, end_hdg = tangents + + n_pts, ok = QInputDialog.getInt( + self, "Smooth Curve", "Number of output points:", + value=50, min=5, max=500, step=5, + ) + if not ok: + return + + old_path = list(path) + n_out = max(n_pts, 2) + new_path = fit_smooth_curve_to_polyline(path, start_hdg, end_hdg, num_output_points=n_out) + + # Apply first, then push for undo + cr.inline_path = new_path + self.update_connecting_road_graphics(conn_road_id) + + mw = self.parent() + if mw and hasattr(mw, 'undo_stack'): + cmd = SmoothCRCommand(self, cr, old_path, new_path) + mw.undo_stack.push(cmd) + def _show_boundary_point_menu(self, view_pos, polyline_id: str, point_index: int): """ Show context menu for boundary polyline point with Delete option. @@ -4092,26 +4324,13 @@ def _handle_right_click_context_menu(self, scene_pos, event: QMouseEvent): self._show_boundary_point_menu(event.pos(), polyline_id, point_index) return - # Connecting road point deletion (polyline geometry only) + # Connecting road right-click: point or segment for conn_road_id, item in self.connecting_road_centerline_items.items(): - if item.connecting_road.geometry_type != "polyline": - continue - point_index = item.get_point_at(scene_pos) - if point_index >= 0: - connecting_road = item.connecting_road - if len(connecting_road.inline_path) > 2: - connecting_road.inline_path.pop(point_index) - if connecting_road.inline_geo_path is not None: - transformer = self._get_geo_transformer() - if transformer: - connecting_road.inline_geo_path = [ - transformer.pixel_to_geo(x, y) for x, y in connecting_road.inline_path - ] - elif 0 <= point_index < len(connecting_road.inline_geo_path): - connecting_road.inline_geo_path.pop(point_index) - item.update_graphics() - if conn_road_id in self.connecting_road_lanes_items: - self.connecting_road_lanes_items[conn_road_id].update_graphics() + cr = item.connecting_road + point_index = item.get_point_at(scene_pos) if cr.geometry_type == "polyline" else -1 + segment_index = item.get_segment_at(scene_pos) if point_index < 0 else -1 + if point_index >= 0 or segment_index >= 0: + self._show_cr_centerline_menu(event.pos(), conn_road_id, point_index) return def mouseMoveEvent(self, event: QMouseEvent): diff --git a/orbit/gui/main_window.py b/orbit/gui/main_window.py index be3e1c6..9c367c2 100644 --- a/orbit/gui/main_window.py +++ b/orbit/gui/main_window.py @@ -24,6 +24,7 @@ from orbit.models import Project from orbit.utils.coordinate_transform import TransformAdjustment from orbit.utils.logging_config import get_logger +from orbit.utils.provenance import is_dataprov_available, record_export, record_project_save from .image_view import ImageView from .project_controller import ProjectController @@ -60,6 +61,10 @@ def __init__(self, image_path: Optional[Path] = None, verbose: bool = False, self._original_transformer = None # Saved original transformer self._aerial_transformer = None # Transformer for aerial tile image self._aerial_zoom = 18 # Default tile zoom level + self._original_cp_pixels: list = [] # Saved CP pixel positions for round-trip restore + self._original_cr_paths: dict = {} # Saved connecting road inline_paths for round-trip restore + self._original_junction_centers: dict = {} # Saved junction center_point pixel coords for round-trip restore + self._original_view_adjustment = None # Saved current_adjustment before aerial switch # Adjustment ghost overlay (shows unadjusted geometry positions) self._adjustment_ghost_overlay = None @@ -635,6 +640,46 @@ def _remember_directory(self, file_path: str) -> None: if file_path: self._last_file_directory = str(Path(file_path).parent) + def _provenance_setting_enabled(self) -> bool: + """Return True if provenance tracking is requested via settings.""" + return self.settings.value("provenance/enabled", False, type=bool) + + def _provenance_enabled(self) -> bool: + """Return True if provenance tracking is enabled and dataprov is available.""" + return is_dataprov_available() and self._provenance_setting_enabled() + + def _check_provenance_ready(self) -> bool: + """Return False (and show an error) if provenance is enabled but dataprov is missing.""" + if self._provenance_setting_enabled() and not is_dataprov_available(): + show_error( + self, + "Provenance tracking is enabled in Preferences, but the dataprov " + "package is not installed.\n\n" + "Install dataprov or disable provenance tracking in Preferences.", + "Provenance Unavailable", + ) + return False + return True + + def _provenance_template(self) -> str: + from orbit.utils.provenance import DEFAULT_TEMPLATE + return self.settings.value("provenance/name_template", DEFAULT_TEMPLATE, type=str) + + def _record_project_provenance(self, orbit_path: Path, start_time) -> None: + """Record a provenance step for a project save, if enabled.""" + if not self._provenance_enabled(): + return + record_project_save(self.project, orbit_path, start_time, self._provenance_template()) + + def _record_export_provenance(self, output_path: Path, operation: str, output_format: str, start_time) -> None: + """Record a provenance step for an export, if enabled.""" + if not self._provenance_enabled(): + return + record_export( + output_path, self.current_project_file, operation, output_format, + start_time, self._provenance_template(), + ) + def open_project(self): """Open an existing project file.""" if not self.check_unsaved_changes(): @@ -721,17 +766,20 @@ def _ensure_original_view_for_save(self): def save_project(self) -> bool: """Save the current project. Returns False if the user cancels.""" + from datetime import datetime, timezone self._ensure_original_view_for_save() if not self._prompt_and_handle_unapplied_adjustment(): return False self._sync_adjustment_to_project() if self.current_project_file: try: + start_time = datetime.now(timezone.utc) self.project.save(self.current_project_file) self.undo_stack.setClean() self.modified = False self.update_window_title() self.statusBar().showMessage(f"Project saved: {self.current_project_file}") + self._record_project_provenance(self.current_project_file, start_time) return True except Exception as e: show_error(self, f"Failed to save project:\n{str(e)}", "Error") @@ -741,26 +789,34 @@ def save_project(self) -> bool: def save_project_as(self) -> bool: """Save the project with a new name. Returns False if the user cancels.""" + from datetime import datetime, timezone self._ensure_original_view_for_save() if not self._prompt_and_handle_unapplied_adjustment(): return False self._sync_adjustment_to_project() + suggested = self._last_file_directory + if (not self.current_project_file + and self.project.image_path + and Path(self.project.image_path).stem): + suggested = str(Path(self._last_file_directory) / (Path(self.project.image_path).stem + ".orbit")) file_path, _ = QFileDialog.getSaveFileName( self, "Save Project As", - self._last_file_directory, + suggested, "ORBIT Projects (*.orbit);;JSON Files (*.json);;All Files (*)" ) if file_path: self._remember_directory(file_path) try: + start_time = datetime.now(timezone.utc) self.current_project_file = Path(file_path) self.project.save(self.current_project_file) self.undo_stack.setClean() self.modified = False self.update_window_title() self.statusBar().showMessage(f"Project saved: {file_path}") + self._record_project_provenance(self.current_project_file, start_time) return True except Exception as e: show_error(self, f"Failed to save project:\n{str(e)}", "Error") @@ -849,6 +905,9 @@ def export_to_opendrive(self): if not self._prompt_and_handle_unapplied_adjustment(): return + if not self._check_provenance_ready(): + return + # Check if we have any roads if not self.project.roads: show_warning(self, "Cannot export: No roads defined in the project.\n" @@ -856,14 +915,17 @@ def export_to_opendrive(self): return # Show export dialog with optional schema path for validation - adjustment = self.image_view.current_adjustment if hasattr(self.image_view, 'current_adjustment') else None + from datetime import datetime, timezone + start_time = datetime.now(timezone.utc) dialog = ExportDialog( self.project, self, xodr_schema_path=self.xodr_schema_path, - adjustment=adjustment, + transformer_factory=self._make_transformer_factory(), ) if dialog.exec() == QDialog.DialogCode.Accepted: self.statusBar().showMessage("Export completed successfully") + if dialog.output_path: + self._record_export_provenance(dialog.output_path, "road network OpenDRIVE export", "XODR", start_time) else: self.statusBar().showMessage("Export cancelled") @@ -876,6 +938,9 @@ def export_to_osm(self): if not self._prompt_and_handle_unapplied_adjustment(): return + if not self._check_provenance_ready(): + return + # Check if any element has geo coordinates has_geo = any( project_polyline.has_geo_coords() @@ -905,6 +970,8 @@ def export_to_osm(self): self._remember_directory(file_path) try: + from datetime import datetime, timezone + start_time = datetime.now(timezone.utc) # Create transformer for pixel→geo conversion (needed for connecting # roads that only have pixel coordinates, e.g. roundabout entries/exits) transformer = self._create_transformer(use_validation=True) @@ -916,6 +983,7 @@ def export_to_osm(self): if success: show_info(self, message, "OSM Export") self.statusBar().showMessage("OSM export completed") + self._record_export_provenance(_Path(file_path), "road network OSM export", "OSM", start_time) else: show_warning(self, message, "OSM Export") except Exception as e: @@ -953,6 +1021,9 @@ def export_georeferencing(self): """Export georeferencing parameters to JSON file.""" from orbit.export import export_georeferencing + if not self._check_provenance_ready(): + return + # Check if we have enough control points if len(self.project.control_points) < 3: show_warning( @@ -979,6 +1050,8 @@ def export_georeferencing(self): # Resolve any unapplied adjustment before exporting — downstream tools do not # support the adjustment field, so the exported matrices must be fully committed. + # (For drone-assisted mode _has_unapplied_adjustment always returns False; + # the adjustment is stored in project.transform_adjustment instead.) if not self._prompt_and_handle_unapplied_adjustment(): return @@ -999,10 +1072,9 @@ def export_georeferencing(self): else: proj_string = base_transformer.get_utm_projection_string() - from orbit.utils.coordinate_transform import create_transformer as _create_transformer - transformer = _create_transformer( - self.project.control_points, - self.project.transform_method, + # Use self._create_transformer so drone_metadata and image dimensions + # are included (the module-level create_transformer lacks those kwargs). + transformer = self._create_transformer( use_validation=True, export_proj_string=proj_string, ) @@ -1011,6 +1083,9 @@ def export_georeferencing(self): "Please check your control points.", "Transformation Error") return + # Apply any stored adjustment so the exported matrices are fully baked. + self._apply_active_adjustment(transformer) + # Get image size if self.image_view.image_item: image_size = ( @@ -1044,8 +1119,11 @@ def export_georeferencing(self): self._remember_directory(file_path) # Export + from datetime import datetime, timezone + start_time = datetime.now(timezone.utc) if export_georeferencing(self.project, Path(file_path), transformer, image_size, self.current_project_file): self.statusBar().showMessage(f"Georeferencing exported to {file_path}") + self._record_export_provenance(Path(file_path), "georeferencing parameter export", "JSON", start_time) else: show_error(self, "Failed to export georeferencing parameters.", "Export Error") @@ -1219,9 +1297,13 @@ def import_osm_data(self): # Check if custom radius was requested (georef mode only) custom_radius = dialog.get_custom_radius() if custom_radius is not None: - center_lon, center_lat = transformer.pixel_to_geo( - image_width / 2.0, image_height / 2.0 - ) + # Use the geographic centroid of the control points as the center. + # Using transformer.pixel_to_geo(image_width/2, image_height/2) is + # unreliable when control points only cover a small portion of the + # image — the transformer extrapolates badly far from its training data. + all_cps = transformer.all_control_points + center_lon = sum(cp.longitude for cp in all_cps) / len(all_cps) + center_lat = sum(cp.latitude for cp in all_cps) / len(all_cps) bbox = calculate_bbox_from_center(center_lat, center_lon, custom_radius) # Build ImportOptions @@ -1278,6 +1360,7 @@ def _setup_osm_import(self, OSMImportDialog, calculate_bbox_from_image, show_error(self, "Failed to create coordinate transformer.\n" "Please check your control points.", "Transformation Error") return None + self._apply_active_adjustment(transformer) try: bbox = calculate_bbox_from_image(image_width, image_height, transformer) except Exception as e: @@ -1348,6 +1431,15 @@ def _process_osm_import_result(self, result, source_type, file_path): show_info(self, msg, "Import Successful") self.project.openstreetmap_used = True self.modified = True + # Record import source for provenance tracking + from datetime import datetime, timezone + src_entry = { + "type": "osm_file" if source_type == "file" else "osm_api", + "path": str(file_path) if file_path else "https://overpass-api.de/api/interpreter", + "timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), + } + if src_entry not in self.project.source_files: + self.project.source_files.append(src_entry) self.image_view.load_project(self.project) self.elements_tree.refresh_tree() self.road_tree.refresh_tree() @@ -1407,6 +1499,7 @@ def import_opendrive_file(self): scale = dialog.get_scale() auto_georeference = dialog.get_auto_georeference() verbose = dialog.get_verbose() + import_signals, import_parking, import_object_types = dialog.get_import_filter() # Override transformer if forcing synthetic mode if force_synthetic: @@ -1447,7 +1540,10 @@ def import_opendrive_file(self): import_mode=import_mode, scale_pixels_per_meter=scale, auto_create_control_points=auto_georeference, - verbose=verbose + verbose=verbose, + import_signals=import_signals, + import_parking=import_parking, + import_object_types=import_object_types, ) # Show progress dialog @@ -1479,6 +1575,16 @@ def import_opendrive_file(self): show_opendrive_import_report(result, self) if result.success: + # Record import source for provenance tracking + from datetime import datetime, timezone + src_entry = { + "type": "xodr", + "path": str(file_path), + "timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), + } + if src_entry not in self.project.source_files: + self.project.source_files.append(src_entry) + # Align connecting road paths to lane centers before rendering scale_factors = self.get_current_scale() self._align_all_junction_connecting_roads(scale_factors) @@ -1855,6 +1961,12 @@ def open_georeferencing(self): # Connect control points changed signal for real-time visualization updates dialog.control_points_changed.connect(self.on_control_points_changed) + # Connect drone metadata changed signal + dialog.drone_metadata_changed.connect(self.on_control_points_changed) + + # Connect control point drag signal for live matrix updates without full refresh + self.image_view.control_point_moved.connect(self.on_control_point_dragged) + # Connect dialog finished signal dialog.finished.connect(lambda result: self.on_georef_dialog_closed(result)) @@ -1881,8 +1993,23 @@ def on_georef_dialog_closed(self, result): # Update lane graphics with new scale self.update_affected_road_lanes() - # Clean up reference + # Clean up reference and disconnect drag signal self.georef_dialog = None + try: + self.image_view.control_point_moved.disconnect(self.on_control_point_dragged) + except RuntimeError: + pass # Already disconnected + + def on_control_point_dragged(self, control_point): + """Handle a control point being dragged on the image canvas. + + Updates the transformer and scale display without recreating graphics items. + """ + self._invalidate_cached_transformer() + self.update_scale_display() + # Keep georef dialog validation display up to date if open + if self.georef_dialog and hasattr(self.georef_dialog, 'update_validation'): + self.georef_dialog.update_validation() def on_control_points_changed(self): """Handle control points being added/removed in georeferencing dialog.""" @@ -1963,6 +2090,7 @@ def _create_transformer(self, **kwargs): self.project.transform_method, image_width=image_width, image_height=image_height, + drone_metadata=self.project.drone_metadata, **kwargs, ) @@ -1970,28 +2098,104 @@ def _invalidate_cached_transformer(self): """Invalidate the cached transformer while preserving the active adjustment.""" self._cached_transformer = None + def _make_transformer_factory(self): + """Return a factory callable for creating correctly configured export transformers. + + The returned callable accepts the same kwargs as ``create_transformer`` + (e.g. ``use_validation``, ``export_proj_string``) and automatically + includes drone metadata, image dimensions, and any active adjustment. + """ + def factory(**kwargs): + t = self._create_transformer(**kwargs) + if t: + self._apply_active_adjustment(t) + return t + return factory + + def _compose_with_drone_base( + self, new_adj: 'TransformAdjustment' + ) -> 'TransformAdjustment': + """Compose new_adj on top of any existing stored drone adjustment. + + For drone-assisted transformers the stored project adjustment is the + accumulated base; a new UI delta must be composed on top of it so that + successive adjustments build on each other rather than starting fresh. + Returns new_adj unchanged for non-drone or when no base is stored. + """ + if (self.project.transform_method != 'drone_assisted' + or not self.project.transform_adjustment): + return new_adj + from orbit.utils.adjustment_fitter import decompose_to_adjustment + base = TransformAdjustment.from_dict(self.project.transform_adjustment) + M = new_adj.get_adjustment_matrix() @ base.get_adjustment_matrix() + return decompose_to_adjustment(M, new_adj.pivot_x, new_adj.pivot_y) + def _apply_active_adjustment(self, transformer): - """Apply the project's persisted adjustment to a transformer, if any.""" + """Apply the project's persisted adjustment to a transformer, if any. + + For drone-assisted transformers the 'applied' adjustment is stored in + project.transform_adjustment rather than current_adjustment (which is + kept at identity after baking). Both sources are checked. + """ if transformer is None: return adj = self.image_view.current_adjustment if adj and not adj.is_identity(): - transformer.set_adjustment(adj) + # For drone-assisted, compose the live UI delta on top of the stored + # base so that the transformer sees the total (accumulated) adjustment. + transformer.set_adjustment(self._compose_with_drone_base(adj)) + return + # Drone-assisted: fall back to permanently stored project adjustment + if (self.project.transform_method == 'drone_assisted' + and self.project.transform_adjustment): + stored = TransformAdjustment.from_dict(self.project.transform_adjustment) + if not stored.is_identity(): + transformer.set_adjustment(stored) def _sync_adjustment_to_project(self): - """Clear any stored adjustment from the project (adjustments are resolved before save).""" + """Clear or preserve adjustment for saving. + + For homography/affine: clear (already baked into CP positions). + For drone-assisted: keep project.transform_adjustment as-is; it IS + the permanent correction for the physics-based transformer. + """ + if self.project.transform_method == 'drone_assisted': + return self.project.transform_adjustment = None def _has_unapplied_adjustment(self) -> bool: - """Return True if there is an active non-identity adjustment that has not been baked.""" + """Return True if there is an active non-identity adjustment that has not been baked. + + For drone-assisted transformers the adjustment cannot be baked into CP + positions; it is instead stored in project.transform_adjustment. When + using drone-assisted mode the adjustment is always considered 'applied' + (either it's stored in the project or hasn't been computed yet) so we + never block saves with a "unapplied adjustment" prompt. + """ + if self.project.transform_method == 'drone_assisted': + return False adj = self.image_view.current_adjustment return adj is not None and not adj.is_identity() def _bake_adjustment_into_control_points(self): - """Bake the current adjustment into CP pixel positions and clear it (no dialog).""" + """Bake the current adjustment into CP pixel positions and clear it (no dialog). + + For drone-assisted transformers the adjustment is stored in the project + instead of CP pixels (see apply_adjustment_to_control_points). + """ adj = self.image_view.current_adjustment if adj is None or adj.is_identity(): return + if self.project.transform_method == 'drone_assisted': + # Can't bake into CPs — persist in project and re-apply to transformer + self.project.transform_adjustment = adj.to_dict() + self.image_view.reset_adjustment() + self._remove_adjustment_ghost() + self._invalidate_cached_transformer() + self._cached_transformer = self._create_transformer(use_validation=True) + self._apply_active_adjustment(self._cached_transformer) + self.refresh_imported_geometry() + return for cp in self.project.control_points: cp.pixel_x, cp.pixel_y = adj.apply_to_point(cp.pixel_x, cp.pixel_y) self.image_view.reset_adjustment() @@ -2601,8 +2805,13 @@ def apply_adjustment_to_control_points(self): """ Apply current adjustment to control points. - This "bakes" the adjustment into the control point positions, - then recomputes the transformation with the new positions. + For homography/affine transformers this "bakes" the adjustment into + the control point pixel positions and recomputes the transformation. + + For drone-assisted transformers the core physics matrix is independent + of CP pixel positions, so baking into CPs has no effect. Instead the + adjustment is stored permanently in the project and kept active on the + transformer so it persists across saves/reloads. """ adjustment = self.image_view.get_adjustment() if adjustment is None or adjustment.is_identity(): @@ -2613,7 +2822,45 @@ def apply_adjustment_to_control_points(self): self.statusBar().showMessage("No control points to adjust") return - # Confirm with user + is_drone = self.project.transform_method == 'drone_assisted' + + if is_drone: + # For drone-assisted: can't bake into CPs because the physics matrix + # ignores CP pixel positions. Persist the adjustment in the project so + # it survives save/reload, and keep it applied to the transformer. + if not ask_yes_no( + self, + "The drone-assisted transformer is physics-based and cannot absorb " + "the adjustment via control point positions.\n\n" + "The adjustment will be stored permanently with the project and " + "re-applied automatically on every reload.\n\n" + "Continue?", + "Apply Adjustment" + ): + return + + self.project.transform_adjustment = self._compose_with_drone_base( + adjustment).to_dict() + # Reset current_adjustment so the panel and save-prompt see no pending + # adjustment (the stored project value is the authoritative source now). + self.image_view.reset_adjustment() + self._remove_adjustment_ghost() + + # Rebuild transformer; _apply_active_adjustment will re-apply the + # stored project adjustment via the drone-assisted fallback path. + self._invalidate_cached_transformer() + self._cached_transformer = self._create_transformer(use_validation=True) + self._apply_active_adjustment(self._cached_transformer) + + self.refresh_imported_geometry() + self.modified = True + self.update_window_title() + self.statusBar().showMessage( + "Adjustment stored permanently for drone-assisted transformer" + ) + return + + # --- homography / affine: bake adjustment into CP pixel positions --- if not ask_yes_no( self, "This will modify the pixel positions of all control points " @@ -2790,8 +3037,33 @@ def _switch_to_aerial(self): show_warning(self, "Cannot create coordinate transformer.") self.toggle_aerial_action.setChecked(False) return + # Save the current_adjustment so we can restore it precisely on return. + # Without this, current_adjustment keeps the aerial value and causes + # _apply_active_adjustment to corrupt the original transformer on the + # next aerial switch. + self._original_view_adjustment = self.image_view.current_adjustment self._apply_active_adjustment(self._original_transformer) self._original_image_np = self.image_view.image_np.copy() if self.image_view.image_np is not None else None + # Save exact user-placed pixel positions so the round-trip can restore them + # precisely (geo_to_pixel on the least-squares transformer does not reproduce + # training point positions exactly when there are more than the minimum points). + self._original_cp_pixels = [(cp.pixel_x, cp.pixel_y) for cp in self.project.control_points] + + # Save connecting road pixel paths so geo_to_pixel round-trip errors can't corrupt them. + # geo_to_pixel may return out-of-bounds coords for physics-based transformers when a + # road is outside the camera's field of view. + self._original_cr_paths = {} + self._original_junction_centers = {} + for junction in self.project.junctions: + if junction.center_point: + self._original_junction_centers[junction.id] = { + 'center': junction.center_point, + 'roundabout': junction.roundabout_center, + } + for cr_id in junction.connecting_road_ids: + cr = self.project.get_road(cr_id) + if cr and cr.inline_path: + self._original_cr_paths[cr_id] = list(cr.inline_path) # Remove ghost overlay before switching (will be rebuilt on return) self._remove_adjustment_ghost() @@ -2818,6 +3090,10 @@ def _switch_to_aerial(self): self.toggle_aerial_action.setChecked(False) self._original_image_np = None self._original_transformer = None + self._original_cp_pixels = [] + self._original_cr_paths = {} + self._original_view_adjustment = None + self._original_junction_centers = {} return # Build initial affine transformer from raw tile image bounds @@ -2831,6 +3107,10 @@ def _switch_to_aerial(self): self.toggle_aerial_action.setChecked(False) self._original_image_np = None self._original_transformer = None + self._original_cp_pixels = [] + self._original_cr_paths = {} + self._original_view_adjustment = None + self._original_junction_centers = {} return # Resize aerial image so its pixels/meter matches the original image. @@ -2854,8 +3134,18 @@ def _switch_to_aerial(self): self.toggle_aerial_action.setChecked(False) self._original_image_np = None self._original_transformer = None + self._original_cp_pixels = [] + self._original_cr_paths = {} + self._original_view_adjustment = None + self._original_junction_centers = {} return + # Ensure CR/junction geo coords are consistent with the original + # transformer before reprojecting. Stale geo_coords (from a previous + # transformer or from the junction analyzer's endpoint snapping) would + # cause reproject_project_geometry to place CRs at wrong aerial pixels. + self._resync_junction_geo_coords(self._original_transformer) + # Re-project all geometry into the aerial pixel space count = reproject_project_geometry( self.project, self._original_transformer, self._aerial_transformer, @@ -2897,6 +3187,14 @@ def _switch_to_original(self): self.project, self._aerial_transformer, self._original_transformer, ) + # Restore exact user-placed CP pixel positions. reproject_project_geometry + # computes them via geo_to_pixel, which doesn't reproduce the original positions + # exactly when the least-squares transform has non-zero residuals (>min points). + if self._original_cp_pixels: + for cp, (px, py) in zip(self.project.control_points, self._original_cp_pixels): + cp.pixel_x = px + cp.pixel_y = py + # Restore adjustment and reposition geo-derived entities if saved_adjustment is not None: self._original_transformer.set_adjustment(saved_adjustment) @@ -2904,12 +3202,61 @@ def _switch_to_original(self): self.image_view.swap_background(self._original_image_np) self._cached_transformer = self._original_transformer + # Restore the original view's adjustment into current_adjustment. + # While in aerial view, current_adjustment held the aerial adjustment; leaving + # it there would cause _apply_active_adjustment to corrupt the original transformer + # on the next aerial switch, and incorrectly trigger update_all_from_geo_coords + # with the wrong (aerial) adjustment value. + self.image_view.current_adjustment = self._original_view_adjustment + # Recompute pixel positions from geo coords using the adjusted transformer - # so entities land in the correct adjusted positions. Control points are - # NOT updated here — they keep their unadjusted positions from reprojection. - adj = self.image_view.current_adjustment - if adj and not adj.is_identity(): + # so entities land in the correct adjusted positions. + adj = self._original_view_adjustment + drone_adj = ( + self.project.transform_method == 'drone_assisted' + and self.project.transform_adjustment + and not TransformAdjustment.from_dict( + self.project.transform_adjustment).is_identity() + ) + if (adj and not adj.is_identity()) or drone_adj: self.image_view.update_all_from_geo_coords(self._cached_transformer) + + # Restore connecting road pixel paths AFTER update_all_from_geo_coords, because + # that call re-runs geo_to_pixel via the drone transformer which returns large + # out-of-bounds coordinates for CRs whose geo positions are outside the camera's + # field of view, undoing the reproject_project_geometry result. We restore the + # exact pre-aerial pixel positions as the authoritative source for all CRs + # (in-FOV and out-of-FOV alike) so that subsequent load_project picks them up. + if self._original_cr_paths: + for junction in self.project.junctions: + for cr_id in junction.connecting_road_ids: + if cr_id in self._original_cr_paths: + cr = self.project.get_road(cr_id) + if cr: + cr.inline_path = list(self._original_cr_paths[cr_id]) + + # Restore junction center_point pixel positions for the same reason: the drone + # camera model's geo_to_pixel gives out-of-bounds values for junctions that are + # outside its field of view. + if self._original_junction_centers: + for junction in self.project.junctions: + saved = self._original_junction_centers.get(junction.id) + if saved: + junction.center_point = saved['center'] + junction.roundabout_center = saved['roundabout'] + + # Regenerate CR paths from current road positions. After restoring + # _original_cr_paths the saved positions may no longer match roads that + # were moved in aerial view, leaving gaps at junction endpoints. + # Parampoly3 CRs need a full spline rebuild; lane-aligned CRs need + # their lane-offset endpoints recomputed; non-aligned polyline CRs are + # handled by the endpoint snap inside _resync_junction_geo_coords. + self._regenerate_all_junction_crs() + + # Resync geo coords to match the restored pixel positions so the + # next reproject (or save) sees consistent geo_coords. + self._resync_junction_geo_coords(self._original_transformer) + self._aerial_view_active = False # Refresh scene @@ -2925,6 +3272,10 @@ def _switch_to_original(self): self._original_image_np = None self._original_transformer = None self._aerial_transformer = None + self._original_cp_pixels = [] + self._original_cr_paths = {} + self._original_junction_centers = {} + self._original_view_adjustment = None self.toggle_aerial_action.setText("&Aerial Map View") self.toggle_aerial_action.setEnabled(True) @@ -2953,6 +3304,14 @@ def _initialize_and_refresh_geo_coords(self): except Exception: return + # For drone-assisted projects with a saved adjustment, apply it before + # syncing geo coords from pixel positions. Saved pixel positions are in + # "adjusted" space (the user placed them while the adjustment was active), + # so pixel→geo conversion must use the same adjusted transformer that was + # active at save time. Without this, _resync_junction_geo_coords produces + # wrong geo coords and _restore_adjustment_from_project double-shifts them. + self._apply_active_adjustment(transformer) + # Initialize geo_path for connecting roads that don't have it (legacy support) for junction in self.project.junctions: for cr_id in junction.connecting_road_ids: @@ -2964,10 +3323,125 @@ def _initialize_and_refresh_geo_coords(self): # (only for CRs without lane connections; lane-aligned CRs are skipped) self._snap_connecting_road_endpoints() + # Resync CR/junction geo coords from pixel positions to ensure + # consistency with the active transformer. Without this, a transformer + # change (e.g. switching to drone_assisted) leaves stale geo_coords + # that cause reproject_project_geometry to produce wrong pixel positions. + self._resync_junction_geo_coords(transformer) + def _snap_connecting_road_endpoints(self): """Snap CR pixel endpoints to match connected road endpoints.""" self.controller.snap_connecting_road_endpoints() + def _resync_junction_geo_coords(self, transformer): + """Recompute CR and junction geo coords from pixel positions. + + Connecting road pixel paths are authoritative (generated by curve + fitting); their inline_geo_path may become stale when the transformer + changes (e.g. switching to drone_assisted). Resyncing ensures + geo_to_pixel(geo) == pixel for every CR, which is required for + reproject_project_geometry to produce correct results. + + After resyncing, CR endpoints are snapped to the connected polyline + endpoints (both pixel and geo) so that no gap appears after reprojection. + """ + if transformer is None: + return + for junction in self.project.junctions: + if junction.center_point: + lon, lat = transformer.pixel_to_geo( + junction.center_point[0], junction.center_point[1], + ) + junction.geo_center_point = (lon, lat) + if junction.roundabout_center: + rlon, rlat = transformer.pixel_to_geo( + junction.roundabout_center[0], + junction.roundabout_center[1], + ) + junction.geo_roundabout_center = (rlon, rlat) + + # Build set of CRs with lane connections (those are lane-aligned) + aligned_cr_ids = { + c.connecting_road_id + for c in (junction.lane_connections or []) + if c.connecting_road_id + } + + for cr_id in junction.connecting_road_ids: + cr = self.project.get_road(cr_id) + if not cr or not cr.inline_path: + continue + + # Snap pixel endpoints to connected polyline endpoints + if cr_id not in aligned_cr_ids: + self._snap_cr_endpoints_pixel(cr) + + # Recompute full geo path from (snapped) pixel + cr.inline_geo_path = [ + transformer.pixel_to_geo(x, y) + for x, y in cr.inline_path + ] + + # Snap geo endpoints to connected polyline geo endpoints + # so that reproject produces identical pixel positions for + # the CR endpoint and its connecting polyline endpoint. + if cr_id not in aligned_cr_ids: + self._snap_cr_endpoints_geo(cr) + + def _snap_cr_endpoints_pixel(self, cr): + """Snap a CR's pixel endpoints to the connected polyline endpoints.""" + pred = self.project.get_road(cr.predecessor_id) + succ = self.project.get_road(cr.successor_id) + pred_pl = self.project.get_polyline(pred.centerline_id) if pred else None + succ_pl = self.project.get_polyline(succ.centerline_id) if succ else None + if pred_pl and pred_pl.points: + cr.inline_path[0] = ( + pred_pl.points[-1] if cr.predecessor_contact == 'end' + else pred_pl.points[0] + ) + if succ_pl and succ_pl.points: + cr.inline_path[-1] = ( + succ_pl.points[-1] if cr.successor_contact == 'end' + else succ_pl.points[0] + ) + + def _snap_cr_endpoints_geo(self, cr): + """Snap a CR's geo endpoints to the connected polyline geo endpoints.""" + if not cr.inline_geo_path: + return + pred = self.project.get_road(cr.predecessor_id) + succ = self.project.get_road(cr.successor_id) + pred_pl = self.project.get_polyline(pred.centerline_id) if pred else None + succ_pl = self.project.get_polyline(succ.centerline_id) if succ else None + if pred_pl and pred_pl.geo_points: + cr.inline_geo_path[0] = ( + pred_pl.geo_points[-1] if cr.predecessor_contact == 'end' + else pred_pl.geo_points[0] + ) + if succ_pl and succ_pl.geo_points: + cr.inline_geo_path[-1] = ( + succ_pl.geo_points[-1] if cr.successor_contact == 'end' + else succ_pl.geo_points[0] + ) + + def _regenerate_all_junction_crs(self): + """Regenerate all CR paths from current road positions after a view switch. + + Called after _original_cr_paths is restored so that CRs which connect + to roads that were moved in aerial view are rebuilt from the new road + positions rather than the stale saved positions. + """ + # Rebuild parampoly3 CRs from connected road endpoints + for junction in self.project.junctions: + for cr_id in junction.connecting_road_ids: + cr = self.project.get_road(cr_id) + if cr and cr.geometry_type == "parampoly3": + self.controller._regenerate_parampoly3_cr(cr) + + # Re-apply lane-alignment offsets for all junctions that have them + scale_factors = self.get_current_scale() + self.controller.align_all_junction_crs(scale_factors) + def update_affected_road_lanes(self): """Update lane graphics for all roads with centerlines.""" # Get current scale factors if available diff --git a/orbit/gui/undo_commands.py b/orbit/gui/undo_commands.py index 44b3b6c..ffdfbbd 100644 --- a/orbit/gui/undo_commands.py +++ b/orbit/gui/undo_commands.py @@ -13,6 +13,7 @@ from orbit.models import Junction, LaneConnection, ParkingSpace, Polyline, Road, RoadObject, Signal if TYPE_CHECKING: + from .image_view import ImageView from .main_window import MainWindow @@ -1349,3 +1350,35 @@ def _apply_data(self, data: dict): scale_factor = self.main_window.get_current_scale() self.main_window.image_view.add_parking_graphics(parking, scale_factor) self.main_window._refresh_trees() + + +class SmoothCRCommand(QUndoCommand): + """Command for smoothing a connecting road's inline_path.""" + + def __init__( + self, + image_view: 'ImageView', + cr_road: 'Road', + old_inline_path: list, + new_inline_path: list, + description: str = "Smooth Connecting Road Curve", + ): + super().__init__(description) + self.image_view = image_view + self.cr_road = cr_road + self.old_inline_path = [tuple(p) for p in old_inline_path] + self.new_inline_path = [tuple(p) for p in new_inline_path] + self._first_redo = True + + def redo(self): + if self._first_redo: + self._first_redo = False + return + self._apply(self.new_inline_path) + + def undo(self): + self._apply(self.old_inline_path) + + def _apply(self, path: list): + self.cr_road.inline_path = list(path) + self.image_view.update_connecting_road_graphics(self.cr_road.id) diff --git a/orbit/gui/widgets/elements_tree.py b/orbit/gui/widgets/elements_tree.py index dee33f6..ad70c29 100644 --- a/orbit/gui/widgets/elements_tree.py +++ b/orbit/gui/widgets/elements_tree.py @@ -356,14 +356,14 @@ def create_object_item(self, obj) -> QTreeWidgetItem: if obj.road_id and self.project: road = self.project.get_road(obj.road_id) if road: - road_name = road.name or f"Road {road.id[:8]}" + road_name = f"Road {road.id}" + (f" – {road.name}" if road.name else "") road_info = f" → {road_name}" else: cr = self.project.get_road(obj.road_id) if cr and cr.is_connecting_road: road_info = f" → CR {cr.id[:8]}" - text = f"{display_name} ({category}){road_info}" + text = f"#{obj.id} {display_name} ({category}){road_info}" item = QTreeWidgetItem([text]) item.setData(0, Qt.ItemDataRole.UserRole, {"type": "object", "id": obj.id}) diff --git a/orbit/import/opendrive_importer.py b/orbit/import/opendrive_importer.py index f2dcab0..d376172 100644 --- a/orbit/import/opendrive_importer.py +++ b/orbit/import/opendrive_importer.py @@ -38,6 +38,80 @@ _DEFAULT_LANE_WIDTH = 3.5 # Fallback when ODR has no lane width +def classify_xodr_object_type(odr_type: str, odr_subtype: str = "") -> Optional[ObjectType]: + """Map an OpenDRIVE object type+subtype to an ORBIT ObjectType. + + Module-level so it can be used both by the importer and the file scanner. + """ + t = odr_type.lower() + s = odr_subtype.lower() + + if 'lamp' in t or 'pole' in t: + return ObjectType.LAMPPOST + elif 'guard' in t or 'rail' in t or 'barrier' in t: + return ObjectType.GUARDRAIL + elif 'building' in t or 'house' in t: + return ObjectType.BUILDING + elif 'vegetation' in t or 'tree' in t or 'bush' in t or 'shrub' in t or 'forest' in t: + if 'forest' in s or 'forest' in t: + return ObjectType.LANDUSE_FOREST + elif 'meadow' in s: + return ObjectType.LANDUSE_MEADOW + elif 'scrub' in s: + return ObjectType.LANDUSE_SCRUB + elif 'conifer' in s or 'pine' in s or 'conifer' in t or 'pine' in t: + return ObjectType.TREE_CONIFER + elif 'tree' in s or 'tree' in t: + return ObjectType.TREE_BROADLEAF + elif 'bush' in s or 'shrub' in s or 'bush' in t or 'shrub' in t: + return ObjectType.BUSH + return ObjectType.TREE_BROADLEAF + elif 'land' in t: + if 'farmland' in s or 'farm' in s: + return ObjectType.LANDUSE_FARMLAND + elif 'meadow' in s: + return ObjectType.LANDUSE_MEADOW + elif 'water' in t: + if 'wetland' in s: + return ObjectType.NATURAL_WETLAND + return ObjectType.NATURAL_WATER + + return None + + +def scan_xodr_feature_categories(file_path: str) -> Dict[str, int]: + """Scan an xodr file and return counts of feature categories present. + + Uses lxml iterparse for efficiency — does not build a full parse tree. + + Returns a dict with keys: + "signals" — count of elements + "parking" — count of elements with child + ObjectType — count per ORBIT ObjectType (for elements) + """ + from lxml import etree + + counts: Dict[str, int] = {} + # We need to see children of to detect parkingSpace, so collect + # objects at 'end' event when all children are parsed. + for _event, elem in etree.iterparse(file_path, events=('end',), recover=True): + tag = elem.tag + if tag == 'signal': + counts['signals'] = counts.get('signals', 0) + 1 + elif tag == 'object': + if elem.find('parkingSpace') is not None: + counts['parking'] = counts.get('parking', 0) + 1 + else: + obj_type = classify_xodr_object_type( + elem.get('type', ''), elem.get('subtype', '') + ) + if obj_type is not None: + counts[obj_type] = counts.get(obj_type, 0) + 1 + elem.clear() + + return counts + + class ImportMode(Enum): """Import mode: add to existing or replace.""" ADD = "add" @@ -51,6 +125,10 @@ class ImportOptions: scale_pixels_per_meter: float = 10.0 # For synthetic mode auto_create_control_points: bool = False # Auto-create georeferencing control points verbose: bool = False # Print debug information + # Feature category filters (None = import all) + import_signals: bool = True + import_parking: bool = True + import_object_types: Optional[Set[ObjectType]] = None # None = all object types @dataclass @@ -304,20 +382,30 @@ def _import_signals_and_objects(self, options: ImportOptions, result: ImportResu if not road_id: continue - for odr_signal in odr_road.signals: - try: - if self._import_signal(odr_signal, road_id, odr_road, options): - result.signals_imported += 1 - except Exception as e: - result.warnings.append(f"Failed to import signal {odr_signal.id}: {e}") + if options.import_signals: + for odr_signal in odr_road.signals: + try: + if self._import_signal(odr_signal, road_id, odr_road, options): + result.signals_imported += 1 + except Exception as e: + result.warnings.append(f"Failed to import signal {odr_signal.id}: {e}") for odr_object in odr_road.objects: try: if odr_object.is_parking: - if self._import_parking(odr_object, road_id, odr_road, options): - result.parking_imported += 1 - elif self._import_object(odr_object, road_id, odr_road, options): - result.objects_imported += 1 + if options.import_parking: + if self._import_parking(odr_object, road_id, odr_road, options): + result.parking_imported += 1 + else: + # Check object type filter before importing + if options.import_object_types is not None: + obj_type = classify_xodr_object_type( + odr_object.type, odr_object.subtype + ) + if obj_type not in options.import_object_types: + continue + if self._import_object(odr_object, road_id, odr_road, options): + result.objects_imported += 1 except Exception as e: result.warnings.append(f"Failed to import object {odr_object.id}: {e}") @@ -689,6 +777,7 @@ def _import_connecting_road( # Add to project and junction self.project.add_road(connecting_road) junction.add_connecting_road(connecting_road.id) + self.odr_road_to_orbit[odr_road.id] = connecting_road.id # Add predecessor/successor roads to junction's connected_road_ids if predecessor_orbit_id and predecessor_orbit_id not in junction.connected_road_ids: @@ -699,7 +788,7 @@ def _import_connecting_road( # Create LaneConnection objects from junction connection data self._import_cr_lane_connections( odr_road, junction_odr_id, junction, connecting_road.id, - successor_orbit_id, options + successor_orbit_id, predecessor_orbit_id, options ) if options.verbose: @@ -767,7 +856,8 @@ def _build_cr_geometry_kwargs(param_poly3, points_pixel): return geom_kwargs, stored_start_heading, stored_end_heading def _import_cr_lane_connections(self, odr_road, junction_odr_id, junction, - connecting_road_id, successor_orbit_id, options): + connecting_road_id, successor_orbit_id, + predecessor_orbit_id, options): """Create LaneConnection objects from ODR junction connection data.""" odr_junction = None for j in self.odr_data.junctions: @@ -783,13 +873,20 @@ def _import_cr_lane_connections(self, odr_road, junction_odr_id, junction, continue from_road_id = self.odr_road_to_orbit.get(odr_conn.incoming_road, "") - to_road_id = successor_orbit_id + + # When contactPoint="start", traffic enters the CR at its start and exits at + # its end (successor road). When contactPoint="end", traffic enters at the end + # and exits at the start (predecessor road). + if odr_conn.contact_point == "start": + to_road_id = successor_orbit_id + else: + to_road_id = predecessor_orbit_id if not from_road_id or not to_road_id: continue for lane_link in odr_conn.lane_links: - to_lane = lane_link.from_lane # Fallback + to_lane = lane_link.to_lane # Fallback: connecting lane ID if odr_road.lane_sections: section = odr_road.lane_sections[0] for odr_lane in section.right_lanes + section.left_lanes: @@ -1259,9 +1356,6 @@ def _import_signal(self, odr_signal: ODRSignal, road_id: str, odr_road: ODRRoad, # Convert signal type signal_type, value = self._convert_signal_type(odr_signal) - if signal_type is None: - return False # Unsupported signal type - # Calculate pixel position from s,t coordinates position_pixel = self._calculate_position_from_st( odr_signal.s, @@ -1299,6 +1393,10 @@ def _import_signal(self, odr_signal: ODRSignal, road_id: str, odr_road: ODRRoad, signal.country = odr_signal.country # Lane validity (which lanes this signal applies to) signal.validity_lanes = odr_signal.validity_lanes + # Preserve custom type/subtype for CUSTOM signals (enables round-trip) + if signal_type == SignalType.CUSTOM: + signal.custom_type = odr_signal.type + signal.custom_subtype = odr_signal.subtype self.project.add_signal(signal) return True @@ -1306,20 +1404,20 @@ def _import_signal(self, odr_signal: ODRSignal, road_id: str, odr_road: ODRRoad, def _import_object(self, odr_object: ODRObject, road_id: str, odr_road: ODRRoad, options: ImportOptions) -> bool: """Import object from OpenDrive.""" # Convert object type - object_type = self._convert_object_type(odr_object.type) + object_type = self._convert_object_type(odr_object.type, odr_object.subtype) if object_type is None: return False # Unsupported object type - # Calculate pixel position from s,t coordinates - position_pixel = self._calculate_position_from_st( - odr_object.s, - odr_object.t, - odr_road + # Get metric anchor position and road heading at s,t + metric_result = self._get_metric_position_and_heading( + odr_object.s, odr_object.t, odr_road ) - - if not position_pixel: + if not metric_result: return False + anchor_x, anchor_y, road_hdg = metric_result + + position_pixel = self.coord_transform.metric_to_pixel(anchor_x, anchor_y) # Convert pixel position to geo coords for storage as source of truth geo_position = None @@ -1339,6 +1437,7 @@ def _import_object(self, odr_object: ODRObject, road_id: str, odr_road: ODRRoad, obj.name = odr_object.name obj.z_offset = odr_object.z_offset obj.orientation = math.degrees(odr_object.hdg) # Convert radians to degrees + obj.odr_orientation = odr_object.orientation_str # Preserve directional marker # Set dimensions if odr_object.radius > 0: @@ -1354,20 +1453,29 @@ def _import_object(self, odr_object: ODRObject, road_id: str, odr_road: ODRRoad, obj.pitch = odr_object.pitch obj.roll = odr_object.roll + # Reconstruct polyline/polygon points from cornerLocal outline + if odr_object.corner_points: + pixel_pts, geo_pts = self._corners_to_pixel_points( + odr_object.corner_points, anchor_x, anchor_y, road_hdg + ) + if pixel_pts: + obj.points = pixel_pts + obj.geo_points = geo_pts + self.project.add_object(obj) return True def _import_parking(self, odr_object: ODRObject, road_id: str, odr_road: ODRRoad, options: ImportOptions) -> bool: """Import parking space from OpenDrive object with parkingSpace child.""" - # Calculate pixel position from s,t coordinates - position_pixel = self._calculate_position_from_st( - odr_object.s, - odr_object.t, - odr_road + # Get metric anchor position and road heading at s,t + metric_result = self._get_metric_position_and_heading( + odr_object.s, odr_object.t, odr_road ) - - if not position_pixel: + if not metric_result: return False + anchor_x, anchor_y, road_hdg = metric_result + + position_pixel = self.coord_transform.metric_to_pixel(anchor_x, anchor_y) # Convert pixel position to geo coords for storage as source of truth geo_position = None @@ -1411,6 +1519,15 @@ def _import_parking(self, odr_object: ODRObject, road_id: str, odr_road: ODRRoad if odr_object.length > 0: parking.length = odr_object.length + # Reconstruct polygon points from cornerLocal outline + if odr_object.corner_points: + pixel_pts, geo_pts = self._corners_to_pixel_points( + odr_object.corner_points, anchor_x, anchor_y, road_hdg + ) + if pixel_pts: + parking.points = pixel_pts + parking.geo_points = geo_pts + self.project.add_parking(parking) return True @@ -1431,7 +1548,24 @@ def _calculate_position_from_st( Returns: Tuple of (pixel_x, pixel_y) or None if calculation fails """ - # Find geometry element containing s + result = self._get_metric_position_and_heading(s, t, odr_road) + if result is None: + return None + x_metric, y_metric, _ = result + return self.coord_transform.metric_to_pixel(x_metric, y_metric) + + def _get_metric_position_and_heading( + self, + s: float, + t: float, + odr_road: ODRRoad + ) -> Optional[Tuple[float, float, float]]: + """ + Calculate metric position and road heading at s,t on the road. + + Returns: + (x_metric, y_metric, road_hdg) or None if calculation fails + """ geom_element = None for geom in odr_road.geometry: if s >= geom.s and s < geom.s + geom.length: @@ -1439,30 +1573,59 @@ def _calculate_position_from_st( break if not geom_element: - # Use last geometry element if s is beyond end if odr_road.geometry: geom_element = odr_road.geometry[-1] else: return None - # Calculate position along this geometry segment ds = s - geom_element.s - - # Get position and heading at ds - # For simplicity, use linear interpolation along segment - # TODO: Use proper geometry evaluation for arcs/spirals cos_hdg = math.cos(geom_element.hdg) sin_hdg = math.sin(geom_element.hdg) x_center = geom_element.x + ds * cos_hdg y_center = geom_element.y + ds * sin_hdg - - # Apply lateral offset (perpendicular to heading) x_metric = x_center - t * sin_hdg y_metric = y_center + t * cos_hdg - # Transform to pixels - return self.coord_transform.metric_to_pixel(x_metric, y_metric) + return x_metric, y_metric, geom_element.hdg + + def _corners_to_pixel_points( + self, + corner_points: List[Tuple[float, float]], + anchor_x: float, + anchor_y: float, + road_hdg: float, + ) -> Tuple[List[Tuple[float, float]], Optional[List[Tuple[float, float]]]]: + """ + Convert cornerLocal (u,v) points to pixel and geo coordinates. + + cornerLocal u,v are in road-local frame (u=s-direction, v=t-direction) + relative to the object anchor. The inverse rotation by road_hdg gives + metric offsets. + + Returns: + (pixel_points, geo_points) where geo_points may be None if no + transformer is available for geo conversion. + """ + cos_h = math.cos(road_hdg) + sin_h = math.sin(road_hdg) + pixel_pts: List[Tuple[float, float]] = [] + geo_pts: List[Tuple[float, float]] = [] + has_geo = self.orbit_transformer is not None + + for u, v in corner_points: + # Inverse rotation: local (u,v) → global metric (dx,dy) + dx = cos_h * u - sin_h * v + dy = sin_h * u + cos_h * v + x_m = anchor_x + dx + y_m = anchor_y + dy + px, py = self.coord_transform.metric_to_pixel(x_m, y_m) + pixel_pts.append((px, py)) + if has_geo: + lon, lat = self.orbit_transformer.pixel_to_geo(px, py) + geo_pts.append((lon, lat)) + + return pixel_pts, (geo_pts if has_geo else None) def _convert_road_type(self, odr_road_type: str) -> RoadType: """Convert OpenDrive road type to ORBIT RoadType.""" @@ -1589,27 +1752,11 @@ def _convert_signal_type(self, odr_signal: ODRSignal) -> Tuple[Optional[SignalTy if 'priority' in type_lower: return (SignalType.PRIORITY_ROAD, None) - return (None, None) - - def _convert_object_type(self, odr_object_type: str) -> Optional[ObjectType]: - """Convert OpenDrive object type to ORBIT ObjectType.""" - type_lower = odr_object_type.lower() - - if 'lamp' in type_lower or 'pole' in type_lower: - return ObjectType.LAMPPOST - elif 'guard' in type_lower or 'rail' in type_lower or 'barrier' in type_lower: - return ObjectType.GUARDRAIL - elif 'building' in type_lower or 'house' in type_lower: - return ObjectType.BUILDING - elif 'tree' in type_lower: - if 'conifer' in type_lower or 'pine' in type_lower: - return ObjectType.TREE_CONIFER - else: - return ObjectType.TREE_BROADLEAF - elif 'bush' in type_lower or 'shrub' in type_lower: - return ObjectType.BUSH + return (SignalType.CUSTOM, None) - return None + def _convert_object_type(self, odr_object_type: str, odr_subtype: str = "") -> Optional[ObjectType]: + """Convert OpenDrive object type+subtype to ORBIT ObjectType.""" + return classify_xodr_object_type(odr_object_type, odr_subtype) def _track_unsupported_features(self, result: ImportResult): """Track unsupported OpenDrive features that were skipped.""" diff --git a/orbit/import/opendrive_parser.py b/orbit/import/opendrive_parser.py index 5c24cec..758657e 100644 --- a/orbit/import/opendrive_parser.py +++ b/orbit/import/opendrive_parser.py @@ -280,6 +280,7 @@ class ODRObject: t: Lateral offset from reference line (meters) z_offset: Height above road surface (meters) type: Object type + subtype: Object subtype (e.g. 'guardrail', 'tree', 'forest') name: Object name orientation: Orientation angle (radians) length: Object length (meters) @@ -290,6 +291,7 @@ class ODRObject: pitch: Pitch angle (radians) roll: Roll angle (radians) validity_length: Validity length along road (meters, for objects spanning distance) + corner_points: List of (u, v) cornerLocal coordinates from outline element is_parking: True if this object is a parking space parking_access: Parking access type (if is_parking is True) parking_restrictions: Parking restrictions text (if is_parking is True) @@ -299,8 +301,10 @@ class ODRObject: t: float z_offset: float = 0.0 type: str = "" + subtype: str = "" name: str = "" orientation: float = 0.0 + orientation_str: str = "none" # Raw OpenDRIVE orientation: "+", "-", "none", or numeric string length: float = 0.0 width: float = 0.0 height: float = 0.0 @@ -309,6 +313,7 @@ class ODRObject: pitch: float = 0.0 roll: float = 0.0 validity_length: Optional[float] = None + corner_points: List[Tuple[float, float]] = field(default_factory=list) # Parking-specific attributes is_parking: bool = False parking_access: str = "standard" @@ -481,6 +486,13 @@ class OpenDriveData: junction_groups: List[ODRJunctionGroup] = field(default_factory=list) +def _parse_object_orientation(value: str) -> float: + """Parse object orientation, handling direction specifiers '+', '-', 'none'.""" + if value in ('+', '-', 'none'): + return 0.0 + return float(value) + + class OpenDriveParser: """Parser for ASAM OpenDrive XML files.""" @@ -926,14 +938,28 @@ def _parse_object(self, object_elem: etree.Element) -> Optional[ODRObject]: parking_access = parking_space_elem.get('access', 'standard') parking_restrictions = parking_space_elem.get('restrictions', '') + # Parse outline/cornerLocal points + corner_points: List[Tuple[float, float]] = [] + outline_elem = object_elem.find('outline') + if outline_elem is not None: + for corner in outline_elem.findall('cornerLocal'): + try: + u = float(corner.get('u', '0')) + v = float(corner.get('v', '0')) + corner_points.append((u, v)) + except (ValueError, TypeError): + pass + return ODRObject( id=object_id, s=float(object_elem.get('s', '0')), t=float(object_elem.get('t', '0')), z_offset=float(object_elem.get('zOffset', '0')), type=object_elem.get('type', ''), + subtype=object_elem.get('subtype', ''), name=object_elem.get('name', ''), - orientation=float(object_elem.get('orientation', '0')), + orientation=_parse_object_orientation(object_elem.get('orientation', '0')), + orientation_str=object_elem.get('orientation', 'none'), length=float(object_elem.get('length', '0')), width=float(object_elem.get('width', '0')), height=float(object_elem.get('height', '0')), @@ -942,6 +968,7 @@ def _parse_object(self, object_elem: etree.Element) -> Optional[ODRObject]: pitch=float(object_elem.get('pitch', '0')), roll=float(object_elem.get('roll', '0')), validity_length=float(object_elem.get('validLength')) if object_elem.get('validLength') else None, + corner_points=corner_points, is_parking=is_parking, parking_access=parking_access, parking_restrictions=parking_restrictions diff --git a/orbit/import/osm_importer.py b/orbit/import/osm_importer.py index b7ea5fb..f859c92 100644 --- a/orbit/import/osm_importer.py +++ b/orbit/import/osm_importer.py @@ -1092,9 +1092,13 @@ def _attach_signals_to_roads(self, osm_data: OSMData, options: ImportOptions) -> if not road_ids: continue - # Attach to the first road found (usually there's only one for traffic signals) - # If multiple roads share this node (junction), we pick the first one - road_id = road_ids[0] + # Attach to the road containing this node. If multiple roads share the node + # (e.g. at a junction), prefer regular roads over connecting roads. + candidate_ids = [ + rid for rid in road_ids + if not (self.project.get_road(rid) or Road()).junction_id + ] + road_id = candidate_ids[0] if candidate_ids else road_ids[0] road = self.project.get_road(road_id) if not road or not road.centerline_id: continue diff --git a/orbit/import/osm_to_orbit.py b/orbit/import/osm_to_orbit.py index 1491fe6..79b0146 100644 --- a/orbit/import/osm_to_orbit.py +++ b/orbit/import/osm_to_orbit.py @@ -104,9 +104,11 @@ def calculate_bbox_from_image(image_width: int, image_height: int, """ Calculate bounding box for OSM query from image dimensions. - Uses control points to define the area, with optional image corner inclusion - for better coverage. This prevents issues with homography extrapolation - far from control points. + Uses the geographic extent of the control points as a reliable anchor, + then extends it to cover the full image using the affine sub-transformer + (linear extrapolation) and a scale-based estimate. The affine is used + because it extrapolates predictably outside the CP convex hull, unlike + homography which can be non-monotonic far from training data. Args: image_width: Width of image in pixels @@ -117,51 +119,46 @@ def calculate_bbox_from_image(image_width: int, image_height: int, Returns: Tuple of (min_lat, min_lon, max_lat, max_lon) """ - # Get control point locations (these are known to be accurate) + # Get control point locations (known to be accurate) control_points = transformer.all_control_points control_lons = [cp.longitude for cp in control_points] control_lats = [cp.latitude for cp in control_points] - # Start with control point bounds + # Start with control point geographic bounds min_lon, max_lon = min(control_lons), max(control_lons) min_lat, max_lat = min(control_lats), max(control_lats) - # Try to include image corners if they're reasonable - # (i.e., not too far from control point area due to extrapolation) - corners_pixel = [ - (0, 0), - (image_width, 0), - (image_width, image_height), - (0, image_height) + # Use the affine sub-transformer for extrapolation so that homography's + # non-monotonic behaviour outside the CP convex hull does not produce + # unreliable corner coordinates. + extrap_transformer = getattr(transformer, '_affine', transformer) + + # Sample corners and edge midpoints for better image boundary coverage + sample_pixels = [ + (0, 0), (image_width, 0), + (image_width, image_height), (0, image_height), + (image_width // 2, 0), (image_width // 2, image_height), + (0, image_height // 2), (image_width, image_height // 2), ] - # Calculate control point extent in pixels - control_pixels = [(cp.pixel_x, cp.pixel_y) for cp in control_points] - cp_min_x = min(x for x, y in control_pixels) - cp_max_x = max(x for x, y in control_pixels) - cp_min_y = min(y for x, y in control_pixels) - cp_max_y = max(y for x, y in control_pixels) - cp_extent = max(cp_max_x - cp_min_x, cp_max_y - cp_min_y) - - # Only include corners that are within reasonable distance of control points - # (within 2x the control point extent to avoid bad extrapolation) - max_distance = cp_extent * 2.0 - cp_center_x = (cp_min_x + cp_max_x) / 2 - cp_center_y = (cp_min_y + cp_max_y) / 2 - - for corner_x, corner_y in corners_pixel: - distance = ((corner_x - cp_center_x)**2 + (corner_y - cp_center_y)**2)**0.5 - if distance <= max_distance: - try: - lon, lat = transformer.pixel_to_geo(corner_x, corner_y) - # Expand bounds to include this corner + # Geographic centroid of CPs — used to reject wildly extrapolated values + cp_center_lon = sum(control_lons) / len(control_lons) + cp_center_lat = sum(control_lats) / len(control_lats) + cp_geo_span = max(max_lon - min_lon, max_lat - min_lat) + # Allow extrapolated points up to 5× the CP geographic span from the centroid + max_geo_dist = cp_geo_span * 5.0 + + for px, py in sample_pixels: + try: + lon, lat = extrap_transformer.pixel_to_geo(px, py) + dist = math.sqrt((lon - cp_center_lon) ** 2 + (lat - cp_center_lat) ** 2) + if dist <= max_geo_dist: min_lon = min(min_lon, lon) max_lon = max(max_lon, lon) min_lat = min(min_lat, lat) max_lat = max(max_lat, lat) - except Exception: - # Skip corners that fail to transform - pass + except Exception: + pass # Add buffer lon_buffer = (max_lon - min_lon) * (buffer_percent / 100.0) diff --git a/orbit/models/object.py b/orbit/models/object.py index 178f591..937d3c6 100644 --- a/orbit/models/object.py +++ b/orbit/models/object.py @@ -212,6 +212,8 @@ def __init__( # OpenDRIVE orientation angles for round-trip preservation self.pitch: float = 0.0 # Pitch angle in radians self.roll: float = 0.0 # Roll angle in radians + # OpenDRIVE directional orientation marker for round-trip ("+", "-", "none") + self.odr_orientation: str = "none" # Original OSM tags for round-trip export self.osm_tags: Optional[Dict[str, str]] = None @@ -298,6 +300,8 @@ def to_dict(self) -> dict: data['pitch'] = self.pitch if self.roll != 0.0: data['roll'] = self.roll + if self.odr_orientation != "none": + data['odr_orientation'] = self.odr_orientation if self.osm_tags: data['osm_tags'] = self.osm_tags return data @@ -339,6 +343,8 @@ def from_dict(cls, data: dict) -> 'RoadObject': # OpenDRIVE orientation angles obj.pitch = data.get('pitch', 0.0) obj.roll = data.get('roll', 0.0) + # OpenDRIVE directional orientation marker + obj.odr_orientation = data.get('odr_orientation', 'none') # OSM tags obj.osm_tags = data.get('osm_tags') diff --git a/orbit/models/project.py b/orbit/models/project.py index 592f02e..2117483 100644 --- a/orbit/models/project.py +++ b/orbit/models/project.py @@ -77,6 +77,94 @@ def from_dict(cls, data: Dict[str, Any]) -> 'ControlPoint': ) +@dataclass +class DroneMetadata: + """Flight-log statistics for the video sequence used to create this project. + + Extracted from a drone log by FieldDataLab's extract_flightlog_params tool. + All angular values in degrees; position in WGS84; altitude in metres AGL. + hfov_deg is the horizontal field of view at native sensor width; present only + when the drone/lens combination is known to the log-extraction tool. + """ + latitude: float + longitude: float + alt_agl: float + gimbal_yaw: float # world-frame compass bearing, 0=N, CW positive + gimbal_pitch: float # below horizontal: 0=horizontal, -90=nadir + gimbal_roll: float = 0.0 + drone_type: Optional[str] = None + lens_type: str = "standard" + hfov_deg: Optional[float] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + d: Dict[str, Any] = { + 'latitude': self.latitude, + 'longitude': self.longitude, + 'alt_agl': self.alt_agl, + 'gimbal_yaw': self.gimbal_yaw, + 'gimbal_pitch': self.gimbal_pitch, + 'gimbal_roll': self.gimbal_roll, + 'drone_type': self.drone_type, + 'lens_type': self.lens_type, + } + if self.hfov_deg is not None: + d['hfov_deg'] = self.hfov_deg + return d + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'DroneMetadata': + """Create from dictionary.""" + return cls( + latitude=data['latitude'], + longitude=data['longitude'], + alt_agl=data['alt_agl'], + gimbal_yaw=data['gimbal_yaw'], + gimbal_pitch=data['gimbal_pitch'], + gimbal_roll=data.get('gimbal_roll', 0.0), + drone_type=data.get('drone_type'), + lens_type=data.get('lens_type', 'standard'), + hfov_deg=data.get('hfov_deg'), + ) + + @classmethod + def from_video_stats(cls, stats: Dict[str, Any], sequence_id: int = 0) -> 'DroneMetadata': + """Parse a video_stats.json dict (from extract_flightlog_params). + + Uses mean values from the specified sequence. Raises ValueError if + the required OSD/gimbal fields are absent. + """ + sequences = stats.get('sequences', []) + if not sequences: + raise ValueError("video_stats JSON contains no sequences") + seq = next((s for s in sequences if s['sequence_id'] == sequence_id), sequences[0]) + osd = seq['stats']['osd'] + gimbal = seq['stats']['gimbal'] + + def mean(section: Dict[str, Any], key: str) -> float: + try: + return float(section[key]['mean']) + except (KeyError, TypeError, ValueError) as exc: + raise ValueError(f"Missing field {key!r} in video_stats stats") from exc + + hfov_deg: Optional[float] = None + camera = stats.get('camera') + if isinstance(camera, dict) and 'hfov_deg' in camera: + hfov_deg = float(camera['hfov_deg']) + + return cls( + latitude=mean(osd, 'latitude'), + longitude=mean(osd, 'longitude'), + alt_agl=mean(osd, 'height_agl'), + gimbal_yaw=mean(gimbal, 'yaw'), + gimbal_pitch=mean(gimbal, 'pitch'), + gimbal_roll=mean(gimbal, 'roll'), + drone_type=stats.get('drone_type'), + lens_type=stats.get('lens_type', 'standard'), + hfov_deg=hfov_deg, + ) + + @dataclass class Project: """ @@ -108,7 +196,7 @@ class Project: parking_spaces: List[ParkingSpace] = field(default_factory=list) control_points: List[ControlPoint] = field(default_factory=list) right_hand_traffic: bool = True # Default to right-hand traffic - transform_method: str = 'homography' # Default to homography for drone images + transform_method: str = 'homography' # 'affine', 'homography', or 'drone_assisted' country_code: str = 'se' # Default to Sweden map_name: str = '' # Name for OpenDrive export (defaults to image filename when loaded) openstreetmap_used: bool = False # Flag for OpenStreetMap attribution @@ -135,6 +223,8 @@ class Project: synthetic_canvas_width: Optional[int] = None # Synthetic canvas width in pixels (no real image) synthetic_canvas_height: Optional[int] = None # Synthetic canvas height in pixels (no real image) transform_adjustment: Optional[Dict[str, float]] = None # Persisted geo-alignment adjustment + drone_metadata: Optional['DroneMetadata'] = None # Flight log data for drone-assisted georef + source_files: List[Dict[str, str]] = field(default_factory=list) # Imported input files for provenance tracking metadata: Dict[str, Any] = field(default_factory=dict) # ID counters for sequential ID generation (one per entity type) @@ -482,7 +572,17 @@ def link_lane_connections_to_connecting_roads(self) -> int: if lc.connecting_lane_id is None and lc.connecting_road_id: cr = self.get_road(lc.connecting_road_id) if cr and cr.is_connecting_road: - if lc.from_lane_id < 0: + if cr.cr_lane_count_left > 0 and cr.cr_lane_count_right > 0: + # Bidirectional CR: use predecessor/successor orientation to + # assign direction, not from_lane_id sign. Traffic flowing + # pred→succ uses right lanes (-1); succ→pred uses left lanes (1). + if lc.from_road_id == cr.predecessor_id: + lc.connecting_lane_id = -1 + elif lc.from_road_id == cr.successor_id: + lc.connecting_lane_id = 1 + else: + lc.connecting_lane_id = -1 if lc.from_lane_id < 0 else 1 + elif lc.from_lane_id < 0: lc.connecting_lane_id = -1 if cr.cr_lane_count_right > 0 else 1 else: lc.connecting_lane_id = 1 if cr.cr_lane_count_left > 0 else -1 @@ -1636,6 +1736,8 @@ def to_dict(self) -> Dict[str, Any]: 'synthetic_canvas_width': self.synthetic_canvas_width, 'synthetic_canvas_height': self.synthetic_canvas_height, 'transform_adjustment': self.transform_adjustment, + 'drone_metadata': self.drone_metadata.to_dict() if self.drone_metadata else None, + 'source_files': self.source_files, 'id_counters': { 'polyline': self._next_polyline_id, 'road': self._next_road_id, @@ -1696,6 +1798,11 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Project': synthetic_canvas_width=data.get('synthetic_canvas_width'), synthetic_canvas_height=data.get('synthetic_canvas_height'), transform_adjustment=data.get('transform_adjustment'), + drone_metadata=( + DroneMetadata.from_dict(data['drone_metadata']) + if data.get('drone_metadata') else None + ), + source_files=data.get('source_files', []), metadata=data.get('metadata', {}) ) @@ -1753,6 +1860,7 @@ def clear(self) -> None: self.parking_spaces.clear() self.control_points.clear() self.image_path = None + self.source_files = [] self.metadata = { 'version': _get_version(), 'created': datetime.now().isoformat(), diff --git a/orbit/utils/__init__.py b/orbit/utils/__init__.py index a25d220..512594a 100644 --- a/orbit/utils/__init__.py +++ b/orbit/utils/__init__.py @@ -3,6 +3,7 @@ from .coordinate_transform import ( AffineTransformer, CoordinateTransformer, + DroneAssistedTransformer, HomographyTransformer, HybridTransformer, TransformMethod, @@ -21,6 +22,7 @@ __all__ = [ 'CoordinateTransformer', 'AffineTransformer', + 'DroneAssistedTransformer', 'HomographyTransformer', 'HybridTransformer', 'TransformMethod', diff --git a/orbit/utils/camera_model.py b/orbit/utils/camera_model.py new file mode 100644 index 0000000..5ac8ebd --- /dev/null +++ b/orbit/utils/camera_model.py @@ -0,0 +1,290 @@ +"""Drone camera model for physically-derived georeferencing. + +Computes a homography matrix from drone flight parameters (position, altitude, +gimbal orientation) instead of fitting it to ground control points alone. Works +even when GCPs are nearly collinear (e.g., all along a single road). + +Coordinate conventions: +- World frame: ENU (East, North, Up), origin = drone ground nadir +- Gimbal yaw: 0=North, positive=CW (East), negative=CCW (West), world-frame absolute + (DJI reports magnetic heading; this module corrects to true north automatically) +- Gimbal pitch: 0=horizontal, -90=nadir (DJI convention) +- Result: transform_matrix maps pixel [u,v,1] → ENU [E,N,w] relative to nadir +""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np +from scipy.optimize import minimize_scalar + + +@dataclass +class DroneMetadata: + """Flight statistics for a single video sequence, extracted from a drone log. + + All angular values in degrees. Position in WGS84. Altitude in meters AGL. + hfov_deg is the horizontal field of view at native sensor width — read from + the drone log tool's camera lookup table when available. + """ + latitude: float + longitude: float + alt_agl: float + gimbal_yaw: float # world-frame compass bearing, 0=N, CW positive + gimbal_pitch: float # below horizontal: 0=horizontal, -90=nadir + gimbal_roll: float = 0.0 + drone_type: Optional[str] = None + lens_type: str = "standard" + hfov_deg: Optional[float] = None # horizontal FOV in degrees; None = solve from GCPs + + +def get_magnetic_declination(latitude: float, longitude: float, altitude_m: float = 0.0) -> float: + """Return magnetic declination in degrees at the given WGS84 position. + + Positive = east declination (magnetic north is east of true north). + Uses the geomag WMM implementation; falls back to 0.0 if unavailable. + """ + try: + import geomag # type: ignore[import-untyped] + return float(geomag.declination(latitude, longitude, altitude_m / 1000.0)) + except Exception: + return 0.0 + + +def _rotation_matrix(yaw_deg: float, pitch_deg: float, roll_deg: float) -> np.ndarray: + """Build a 3×3 rotation matrix mapping world ENU → camera frame. + + Rows are the camera X (right), Y (down), Z (forward) axes in world ENU. + Gimbal convention: yaw=compass bearing (0=N, CW+), pitch=below horizontal. + """ + theta = math.radians(yaw_deg) # compass bearing of camera + phi = math.radians(-pitch_deg) # tilt below horizontal (phi>0 = looking down) + rho = math.radians(roll_deg) # roll + + # Camera axes in world ENU [E, N, U]: + # Right (X): perpendicular to viewing direction, 90° CW from bearing in horizontal plane + x_cam = np.array([math.cos(theta), -math.sin(theta), 0.0]) + # Forward (Z): bearing θ, tilted φ below horizontal + z_cam = np.array([math.sin(theta) * math.cos(phi), + math.cos(theta) * math.cos(phi), + -math.sin(phi)]) + # Down (Y): complete right-handed frame (Z × X = Y) + y_cam = np.cross(z_cam, x_cam) + y_cam /= np.linalg.norm(y_cam) # normalise for numerical safety + + # Apply roll around Z-axis (camera forward) + if abs(roll_deg) > 1e-6: + cr, sr = math.cos(rho), math.sin(rho) + x_rolled = cr * x_cam - sr * y_cam + y_rolled = sr * x_cam + cr * y_cam + x_cam, y_cam = x_rolled, y_rolled + + R = np.stack([x_cam, y_cam, z_cam], axis=0) # shape (3, 3) + return R + + +def _build_projection_matrix( + yaw_deg: float, + pitch_deg: float, + roll_deg: float, + alt_agl: float, + focal_length_px: float, + cx: float, + cy: float, +) -> np.ndarray: + """Build 3×3 matrix M that projects ENU ground point [E, N, 1] → pixel [u*w, v*w, w]. + + M = K × [r1, r2, t] where r1, r2 are columns of R and t = -h × R[:,2]. + inverse(M) maps pixel → ENU ground = the transform_matrix for CoordinateTransformer. + """ + R = _rotation_matrix(yaw_deg, pitch_deg, roll_deg) + K = np.array([[focal_length_px, 0, cx], + [0, focal_length_px, cy], + [0, 0, 1]], dtype=np.float64) + + # Translation: camera at [0, 0, alt_agl], so t = R × (-C) = -alt_agl × R[:,2] + t = -alt_agl * R[:, 2] + + # 3×3 matrix [r1, r2, t] = first two cols of R + translation + M_world = np.column_stack([R[:, 0], R[:, 1], t]) # shape (3, 3) + return K @ M_world + + +class DroneCameraModel: + """Compute pixel↔geographic transform from drone flight parameters. + + The model assumes a flat ground plane at the drone's ground level. + Reference origin (local ENU) is the drone's ground nadir point. + + Heading refinement is applied in two stages: + 1. Magnetic declination: auto-computed from lat/lon (DJI logs magnetic heading). + 2. GCP yaw refinement: 1-D optimisation over residual yaw offset if GCPs provided. + + Focal length is resolved in priority order: + 1. Explicit `focal_length_px` argument + 2. Computed from `metadata.hfov_deg` and image dimensions + 3. Solved from GCPs via 1-D least-squares (requires ≥ 2 GCPs) + Raises ValueError if none of the above is available. + """ + + def __init__( + self, + metadata: DroneMetadata, + image_width: int, + image_height: int, + control_points: Optional[list] = None, + focal_length_px: Optional[float] = None, + ): + self.metadata = metadata + self.image_width = image_width + self.image_height = image_height + self.cx = image_width / 2.0 + self.cy = image_height / 2.0 + + # Stage 1: resolve focal length + f = focal_length_px + if f is None and metadata.hfov_deg is not None: + f = (image_width / 2.0) / math.tan(math.radians(metadata.hfov_deg / 2.0)) + if f is None and control_points: + f = self._solve_focal_length(control_points) + if f is None: + raise ValueError( + "Cannot determine focal length: provide hfov_deg in drone log or at " + "least 2 control points." + ) + self.focal_length_px: float = f + + # Stage 2: correct magnetic → true heading via declination + self.declination_deg: float = get_magnetic_declination( + metadata.latitude, metadata.longitude, metadata.alt_agl + ) + corrected_yaw = metadata.gimbal_yaw + self.declination_deg + + # Re-solve focal length with declination-corrected yaw if it was GCP-derived + if focal_length_px is None and metadata.hfov_deg is None and control_points: + f = self._solve_focal_length(control_points, yaw_override=corrected_yaw) + self.focal_length_px = f + + # Stage 3: GCP yaw refinement — find residual offset after declination correction + self.yaw_refinement_deg: float = 0.0 + if control_points and len(control_points) >= 2: + self.yaw_refinement_deg = self._refine_yaw(corrected_yaw, control_points) + + self.effective_yaw: float = corrected_yaw + self.yaw_refinement_deg + + self._M = _build_projection_matrix( + self.effective_yaw, metadata.gimbal_pitch, metadata.gimbal_roll, + metadata.alt_agl, f, self.cx, self.cy, + ) + self.transform_matrix = np.linalg.inv(self._M) # pixel → ENU + self.projection_matrix = self._M # ENU → pixel + + def _refine_yaw(self, base_yaw: float, control_points: list) -> float: + """Find residual yaw offset (on top of base_yaw) that minimises GCP reprojection. + + Returns the offset in degrees. Searches ±30° around base_yaw. + """ + cos_lat = math.cos(math.radians(self.metadata.latitude)) + R_earth = 6371000.0 + + def _rms(offset: float) -> float: + M = _build_projection_matrix( + base_yaw + offset, self.metadata.gimbal_pitch, self.metadata.gimbal_roll, + self.metadata.alt_agl, self.focal_length_px, self.cx, self.cy, + ) + total = 0.0 + for cp in control_points: + east = (cp.longitude - self.metadata.longitude) * R_earth * cos_lat * math.pi / 180.0 + north = (cp.latitude - self.metadata.latitude) * R_earth * math.pi / 180.0 + ph = M @ np.array([east, north, 1.0]) + if abs(ph[2]) < 1e-10: + return 1e12 + u, v = ph[0] / ph[2], ph[1] / ph[2] + total += (u - cp.pixel_x) ** 2 + (v - cp.pixel_y) ** 2 + return total + + result = minimize_scalar(_rms, bounds=(-30.0, 30.0), method='bounded') + return float(result.x) + + def _solve_focal_length(self, control_points: list, yaw_override: Optional[float] = None) -> float: + """Find focal length that minimises GCP reprojection error (1-D optimisation). + + Uses the ENU reference at drone nadir. Uses yaw_override if provided (for + declination-corrected heading), else falls back to metadata.gimbal_yaw. + """ + if len(control_points) < 2: + raise ValueError("Need at least 2 GCPs to solve focal length.") + + yaw = yaw_override if yaw_override is not None else self.metadata.gimbal_yaw + cos_lat = math.cos(math.radians(self.metadata.latitude)) + R_earth = 6371000.0 + + def reprojection_error(f_val: float) -> float: + M = _build_projection_matrix( + yaw, self.metadata.gimbal_pitch, + self.metadata.gimbal_roll, self.metadata.alt_agl, + f_val, self.cx, self.cy, + ) + total = 0.0 + for cp in control_points: + east = (cp.longitude - self.metadata.longitude) * R_earth * cos_lat * math.pi / 180.0 + north = (cp.latitude - self.metadata.latitude) * R_earth * math.pi / 180.0 + g = np.array([east, north, 1.0]) + ph = M @ g + if abs(ph[2]) < 1e-10: + return 1e12 + u = ph[0] / ph[2] + v = ph[1] / ph[2] + total += (u - cp.pixel_x) ** 2 + (v - cp.pixel_y) ** 2 + return total + + result = minimize_scalar(reprojection_error, bounds=(500, 8000), method='bounded') + return float(result.x) + + def pixel_to_enu(self, pixel_x: float, pixel_y: float) -> Tuple[float, float]: + """Convert image pixel to ENU offset (meters) from drone nadir.""" + p = np.array([pixel_x, pixel_y, 1.0]) + g = self.transform_matrix @ p + return g[0] / g[2], g[1] / g[2] + + def enu_to_pixel(self, east: float, north: float) -> Tuple[float, float]: + """Convert ENU offset (meters) from drone nadir to image pixel.""" + g = np.array([east, north, 1.0]) + p = self.projection_matrix @ g + return p[0] / p[2], p[1] / p[2] + + def estimate_heading_from_gcps( + self, control_points: list + ) -> Tuple[float, float]: + """Estimate gimbal heading that best fits the GCPs. + + Returns (best_heading_deg, rmse_pixels). Useful as a cross-check + against the value recorded in the drone log. + """ + if len(control_points) < 2: + raise ValueError("Need at least 2 GCPs to estimate heading.") + + cos_lat = math.cos(math.radians(self.metadata.latitude)) + R_earth = 6371000.0 + + def _rms_for_heading(yaw_deg: float) -> float: + M = _build_projection_matrix( + yaw_deg, self.metadata.gimbal_pitch, self.metadata.gimbal_roll, + self.metadata.alt_agl, self.focal_length_px, self.cx, self.cy, + ) + total = 0.0 + for cp in control_points: + east = (cp.longitude - self.metadata.longitude) * R_earth * cos_lat * math.pi / 180.0 + north = (cp.latitude - self.metadata.latitude) * R_earth * math.pi / 180.0 + ph = M @ np.array([east, north, 1.0]) + if abs(ph[2]) < 1e-10: + return 1e12 + u, v = ph[0] / ph[2], ph[1] / ph[2] + total += (u - cp.pixel_x) ** 2 + (v - cp.pixel_y) ** 2 + return math.sqrt(total / len(control_points)) + + result = minimize_scalar(_rms_for_heading, bounds=(-180, 180), method='bounded') + best_yaw = float(result.x) + rmse = float(result.fun) + return best_yaw, rmse diff --git a/orbit/utils/coordinate_transform.py b/orbit/utils/coordinate_transform.py index 6539804..3f89559 100644 --- a/orbit/utils/coordinate_transform.py +++ b/orbit/utils/coordinate_transform.py @@ -28,7 +28,7 @@ from .logging_config import get_logger if TYPE_CHECKING: - from orbit.models.project import ControlPoint + from orbit.models.project import ControlPoint, DroneMetadata logger = get_logger(__name__) @@ -37,6 +37,7 @@ class TransformMethod(Enum): """Transformation method for georeferencing.""" AFFINE = auto() HOMOGRAPHY = auto() + DRONE_ASSISTED = auto() @dataclass @@ -1025,6 +1026,8 @@ def __init__(self, control_points: List['ControlPoint'], self.all_control_points = control_points self._export_proj_string = export_proj_string self._export_proj = None + if export_proj_string: + self._export_proj = Proj(export_proj_string) # Create both sub-transformers from the same control points self._homography = HomographyTransformer( @@ -1224,6 +1227,107 @@ def set_adjustment(self, adjustment): self.adjustment = adjustment +class DroneAssistedTransformer(CoordinateTransformer): + """Physically-derived homography from drone flight parameters + optional GCP refinement. + + Uses a pinhole camera model (position, altitude, gimbal angles, focal length) + to compute a homography for the ground plane. GCPs refine focal length when + hfov_deg is not available from the drone log. + + Reference origin is the drone's ground nadir (lat_drone, lon_drone). + The transform_matrix maps pixel [u,v,1] → local ENU [E,N,w] in metres. + """ + + def __init__( + self, + metadata: 'DroneMetadata', + control_points: List['ControlPoint'], + image_width: int, + image_height: int, + use_validation: bool = True, + export_proj_string: Optional[str] = None, + ): + from .camera_model import DroneCameraModel + + super().__init__(control_points, use_validation, export_proj_string=export_proj_string) + + # Override reference point to drone nadir + self.reference_lat = metadata.latitude + self.reference_lon = metadata.longitude + + self._model = DroneCameraModel( + metadata=metadata, + image_width=image_width, + image_height=image_height, + control_points=self.training_points if self.training_points else None, + ) + + # transform_matrix: pixel → ENU (local metres relative to drone nadir) + # inverse_matrix: ENU → pixel + self.transform_matrix = self._model.transform_matrix + self.inverse_matrix = self._model.projection_matrix + + self.compute_reprojection_error() + if self.validation_points: + self.compute_validation_error() + + def compute_transformation(self): + """Delegated to DroneCameraModel in __init__.""" + + def pixel_to_geo(self, pixel_x: float, pixel_y: float) -> Tuple[float, float]: + """Convert pixel to geographic coordinates (longitude, latitude).""" + if self.transform_matrix is None: + raise RuntimeError("Transformation not initialized") + + if self.adjustment is not None: + pixel_x, pixel_y = self.adjustment.apply_inverse_to_point(pixel_x, pixel_y) + + p = np.array([pixel_x, pixel_y, 1.0]) + g = self.transform_matrix @ p + east = g[0] / g[2] + north = g[1] / g[2] + lat, lon = self.meters_to_latlon(east, north) + return lon, lat + + def geo_to_pixel(self, longitude: float, latitude: float) -> Tuple[float, float]: + """Convert geographic coordinates to pixel coordinates.""" + if self.inverse_matrix is None: + raise RuntimeError("Transformation not initialized") + + east, north = self.latlon_to_meters(latitude, longitude) + g = np.array([east, north, 1.0]) + p = self.inverse_matrix @ g + pixel_x = p[0] / p[2] + pixel_y = p[1] / p[2] + + if self.adjustment is not None: + pixel_x, pixel_y = self.adjustment.apply_to_point(pixel_x, pixel_y) + + return pixel_x, pixel_y + + def geo_to_pixel_unadjusted(self, longitude: float, latitude: float) -> Tuple[float, float]: + """Convert geographic coordinates to pixel coordinates without adjustment.""" + east, north = self.latlon_to_meters(latitude, longitude) + g = np.array([east, north, 1.0]) + p = self.inverse_matrix @ g + return p[0] / p[2], p[1] / p[2] + + def get_scale_factor(self) -> Tuple[float, float]: + """Approximate metres-per-pixel scale at image centre.""" + cx, cy = self._model.cx, self._model.cy + offset = 10.0 + mx1, my1 = self._model.pixel_to_enu(cx - offset, cy) + mx2, my2 = self._model.pixel_to_enu(cx + offset, cy) + scale_x = math.sqrt((mx2 - mx1) ** 2 + (my2 - my1) ** 2) / (2 * offset) + mx1, my1 = self._model.pixel_to_enu(cx, cy - offset) + mx2, my2 = self._model.pixel_to_enu(cx, cy + offset) + scale_y = math.sqrt((mx2 - mx1) ** 2 + (my2 - my1) ** 2) / (2 * offset) + return scale_x, scale_y + + def _set_reference_point(self): + """Overridden: reference point is drone nadir, set in __init__.""" + + def create_transformer( control_points: List['ControlPoint'], method: Union[str, TransformMethod] = TransformMethod.HOMOGRAPHY, @@ -1231,6 +1335,7 @@ def create_transformer( export_proj_string: Optional[str] = None, image_width: int = 0, image_height: int = 0, + drone_metadata: Optional['DroneMetadata'] = None, ) -> Optional[CoordinateTransformer]: """ Create a coordinate transformer from control points. @@ -1238,20 +1343,46 @@ def create_transformer( Args: control_points: List of control points method: Transformation method - either TransformMethod enum or string - ('affine' or 'homography') + ('affine', 'homography', or 'drone_assisted') use_validation: If True, separate validation points from training export_proj_string: If set, latlon_to_meters/meters_to_latlon use this pyproj projection instead of equirectangular approximation. + drone_metadata: Required when method is 'drone_assisted'. Returns: CoordinateTransformer if successful, None if insufficient points """ - if not control_points: - return None - # Convert string to enum if needed if isinstance(method, str): - method = TransformMethod.HOMOGRAPHY if method == 'homography' else TransformMethod.AFFINE + if method == 'drone_assisted': + method = TransformMethod.DRONE_ASSISTED + elif method == 'homography': + method = TransformMethod.HOMOGRAPHY + else: + method = TransformMethod.AFFINE + + if method == TransformMethod.DRONE_ASSISTED: + if drone_metadata is None: + logger.error("drone_assisted method requires drone_metadata") + return None + if image_width <= 0 or image_height <= 0: + logger.error("drone_assisted method requires image_width and image_height") + return None + try: + return DroneAssistedTransformer( + metadata=drone_metadata, + control_points=control_points, + image_width=image_width, + image_height=image_height, + use_validation=use_validation, + export_proj_string=export_proj_string, + ) + except (ValueError, np.linalg.LinAlgError) as e: + logger.error(f"Error creating DroneAssistedTransformer: {e}") + return None + + if not control_points: + return None # Separate training and validation points if use_validation: diff --git a/orbit/utils/geometry.py b/orbit/utils/geometry.py index 51420d2..fe20a9a 100644 --- a/orbit/utils/geometry.py +++ b/orbit/utils/geometry.py @@ -9,6 +9,8 @@ from typing import TYPE_CHECKING, List, Optional, Tuple if TYPE_CHECKING: + from orbit.models.project import Project + from orbit.models.road import Road from orbit.utils.coordinate_transform import CoordinateTransformer @@ -2120,3 +2122,136 @@ def shorten_geo_points( result = result[:-(segment_idx + 1)] + [new_end] return result + + +def fit_smooth_curve_to_polyline( + points: List[Tuple[float, float]], + start_tangent_rad: float, + end_tangent_rad: float, + tangent_scale: float = 1.0, + num_output_points: Optional[int] = None, +) -> List[Tuple[float, float]]: + """Redistribute polyline points along a smooth cubic Bezier curve. + + Fixes first and last points; replaces all intermediate points with + arc-length-uniform samples so the curve is G1-continuous at both ends. + + Args: + points: Current pixel-space polyline (≥ 2 points). + start_tangent_rad: Desired heading at the start (radians, screen coords). + end_tangent_rad: Desired heading at the end (radians, screen coords). + tangent_scale: Controls how strongly the tangents pull. 1.0 = chord/3. + num_output_points: Number of output points. Defaults to len(points). + Use a larger value to reduce miter spike artifacts in offset polylines. + + Returns: + New list of num_output_points (x, y) tuples. + """ + if len(points) < 2: + return list(points) + if len(points) == 2: + return [points[0], points[-1]] + + p0 = points[0] + p3 = points[-1] + n = num_output_points if num_output_points is not None else len(points) + + # Bezier handle length ≈ 1/3 of the chord, scaled by tangent_scale + chord = math.sqrt((p3[0] - p0[0]) ** 2 + (p3[1] - p0[1]) ** 2) + handle = chord / 3.0 * tangent_scale + + p1 = ( + p0[0] + math.cos(start_tangent_rad) * handle, + p0[1] + math.sin(start_tangent_rad) * handle, + ) + p2 = ( + p3[0] - math.cos(end_tangent_rad) * handle, + p3[1] - math.sin(end_tangent_rad) * handle, + ) + ctrl = [p0, p1, p2, p3] + + # Over-sample the Bezier then re-parameterize by arc length + oversample = max(200, n * 20) + dense = sample_bezier(ctrl, oversample) + + # Build cumulative arc-length table + arc = [0.0] + for i in range(1, len(dense)): + dx = dense[i][0] - dense[i - 1][0] + dy = dense[i][1] - dense[i - 1][1] + arc.append(arc[-1] + math.sqrt(dx * dx + dy * dy)) + total = arc[-1] + if total < 1e-9: + return list(points) + + # Sample at n arc-length-uniform positions + result = [p0] + for k in range(1, n - 1): + target = total * k / (n - 1) + # Binary search for the right dense-sample interval + lo, hi = 0, len(arc) - 1 + while lo < hi - 1: + mid = (lo + hi) // 2 + if arc[mid] < target: + lo = mid + else: + hi = mid + t = (target - arc[lo]) / (arc[hi] - arc[lo]) if arc[hi] > arc[lo] else 0.0 + x = dense[lo][0] + t * (dense[hi][0] - dense[lo][0]) + y = dense[lo][1] + t * (dense[hi][1] - dense[lo][1]) + result.append((x, y)) + result.append(p3) + + # Snap pts[1] and pts[-2] to exact tangent direction so that + # calculate_offset_polyline sees the correct perpendicular at both ends. + # Distance from the endpoint is preserved; only direction is corrected. + if n >= 3: + d1 = math.sqrt((result[1][0] - p0[0]) ** 2 + (result[1][1] - p0[1]) ** 2) + if d1 > 1e-9: + result[1] = ( + p0[0] + math.cos(start_tangent_rad) * d1, + p0[1] + math.sin(start_tangent_rad) * d1, + ) + d2 = math.sqrt((result[-2][0] - p3[0]) ** 2 + (result[-2][1] - p3[1]) ** 2) + if d2 > 1e-9: + result[-2] = ( + p3[0] - math.cos(end_tangent_rad) * d2, + p3[1] - math.sin(end_tangent_rad) * d2, + ) + + return result + + +def get_smooth_cr_tangents( + cr_road: 'Road', + project: 'Project', +) -> Optional[Tuple[float, float]]: + """Derive (start_heading, end_heading) for a connecting road from its neighbours. + + Uses the same logic as ProjectController._regenerate_parampoly3_cr. + + Returns: + (start_heading_rad, end_heading_rad) in screen/pixel space, or None if + the adjacent roads or polylines cannot be found. + """ + from orbit.gui.project_controller import get_contact_pos_heading + + pred_road = project.get_road(cr_road.predecessor_id) + succ_road = project.get_road(cr_road.successor_id) + if not pred_road or not succ_road: + return None + + pred_pl = project.get_polyline(pred_road.centerline_id) + succ_pl = project.get_polyline(succ_road.centerline_id) + if not pred_pl or not succ_pl: + return None + + _, start_hdg = get_contact_pos_heading(pred_pl, cr_road.predecessor_contact) + if cr_road.predecessor_contact == "start": + start_hdg += math.pi + + _, end_hdg = get_contact_pos_heading(succ_pl, cr_road.successor_contact) + if cr_road.successor_contact == "end": + end_hdg += math.pi + + return start_hdg, end_hdg diff --git a/orbit/utils/provenance.py b/orbit/utils/provenance.py new file mode 100644 index 0000000..58727ce --- /dev/null +++ b/orbit/utils/provenance.py @@ -0,0 +1,218 @@ +"""Optional data provenance tracking via the dataprov library. + +Provenance sidecar files are created alongside .orbit project files and exports +when the user enables the feature in Preferences and the dataprov package is installed. + +File names are resolved from a configurable template stored in QSettings +(key ``provenance/name_template``, default ``{stem}{ext}.prov.json``). + +Template variables (resolved against the target output file's path): + {dir} — parent directory of the output file + {stem} — filename stem (without extension) + {ext} — file extension including the leading dot + {name} — full filename (stem + ext) +""" + +from __future__ import annotations + +import importlib.util +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from orbit.models.project import Project + +DEFAULT_TEMPLATE = "{stem}{ext}.prov.json" +_ORBIT_SOURCE = "RISE Research Institutes of Sweden" + + +def is_dataprov_available() -> bool: + """Return True if the dataprov package is importable.""" + return importlib.util.find_spec("dataprov") is not None + + +def prov_path_for(file_path: Path | str, template: str = DEFAULT_TEMPLATE) -> Path: + """Resolve the provenance sidecar path from *file_path* and *template*. + + Template variables ``{dir}``, ``{stem}``, ``{ext}``, ``{name}`` are substituted + from *file_path*. If *template* contains ``{dir}``, the result is used as-is; + otherwise the resolved path is placed in the same directory as *file_path*. + """ + p = Path(file_path) + resolved = template.format( + dir=str(p.parent), + stem=p.stem, + ext=p.suffix, + name=p.name, + ) + result = Path(resolved) + # If the template didn't include {dir}, put it alongside the original file. + if "{dir}" not in template and not result.is_absolute(): + result = p.parent / result + return result + + +def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + + +def record_project_save( + project: "Project", + orbit_path: Path, + start_time: datetime, + template: str = DEFAULT_TEMPLATE, +) -> bool: + """Create or update the provenance sidecar for an .orbit project file. + + Records the source image and any imported files as inputs. + Returns True on success, False if dataprov is unavailable or an error occurs. + """ + if not is_dataprov_available(): + return False + + try: + from dataprov import ProvenanceChain + + prov_file = prov_path_for(orbit_path, template) + ended_at = _now_iso() + started_at = start_time.isoformat().replace("+00:00", "Z") + + initial_source = str(project.image_path) if project.image_path else str(orbit_path) + + chain = ProvenanceChain.load_or_create( + str(prov_file), + entity_id="orbit_project", + initial_source=initial_source, + description=f"ORBIT project: {project.map_name or orbit_path.stem}", + tags=["orbit", "road-annotation", "opendrive"], + ) + + inputs: list[str] = [] + input_formats: list[str] = [] + + if project.image_path: + inputs.append(str(project.image_path)) + input_formats.append(_format_for(project.image_path)) + + for src in getattr(project, "source_files", []): + path = src.get("path", "") + src_type = src.get("type", "") + if path and path != "api": + inputs.append(path) + input_formats.append(_format_for(Path(path))) + elif src_type == "osm_api": + # API import – use URL placeholder + inputs.append(src.get("path", "https://overpass-api.de/api/interpreter")) + input_formats.append("XML") + + if not inputs: + # Nothing useful to record + inputs = [str(orbit_path)] + input_formats = ["ORBIT"] + + chain.add( + started_at=started_at, + ended_at=ended_at, + tool_name="orbit", + tool_version=_orbit_version(), + operation="road network annotation", + inputs=inputs, + input_formats=input_formats, + outputs=[str(orbit_path)], + output_formats=["ORBIT"], + source=_ORBIT_SOURCE, + capture_agent=True, + capture_environment=True, + ) + + chain.save(str(prov_file)) + return True + + except Exception as exc: # noqa: BLE001 + import sys + print(f"Warning: provenance recording failed: {exc}", file=sys.stderr) + return False + + +def record_export( + output_path: Path, + orbit_path: Path | None, + operation: str, + output_format: str, + start_time: datetime, + template: str = DEFAULT_TEMPLATE, +) -> bool: + """Create a provenance sidecar for an exported file. + + Links to the .orbit project's own provenance chain when available. + Returns True on success, False if dataprov is unavailable or an error occurs. + """ + if not is_dataprov_available(): + return False + + try: + from dataprov import ProvenanceChain + + prov_file = prov_path_for(output_path, template) + ended_at = _now_iso() + started_at = start_time.isoformat().replace("+00:00", "Z") + + inputs: list[str] = [] + input_formats: list[str] = [] + input_provenance_files: list[str | None] = [] + + if orbit_path is not None: + inputs.append(str(orbit_path)) + input_formats.append("ORBIT") + orbit_prov = prov_path_for(orbit_path, template) + input_provenance_files.append(str(orbit_prov) if orbit_prov.exists() else None) + + chain = ProvenanceChain.create( + entity_id=f"orbit_{output_format.lower()}_export", + initial_source=str(orbit_path) if orbit_path else str(output_path), + description=f"ORBIT export: {output_path.name}", + tags=["orbit", "export", output_format.lower()], + ) + + chain.add( + started_at=started_at, + ended_at=ended_at, + tool_name="orbit", + tool_version=_orbit_version(), + operation=operation, + inputs=inputs, + input_formats=input_formats, + outputs=[str(output_path)], + output_formats=[output_format], + input_provenance_files=input_provenance_files if any(input_provenance_files) else None, + source=_ORBIT_SOURCE, + capture_agent=True, + capture_environment=True, + ) + + chain.save(str(prov_file)) + return True + + except Exception as exc: # noqa: BLE001 + import sys + print(f"Warning: provenance recording failed: {exc}", file=sys.stderr) + return False + + +def _format_for(path: Path) -> str: + """Return a short format label for a file extension.""" + return { + ".jpg": "JPEG", ".jpeg": "JPEG", + ".png": "PNG", ".tif": "TIFF", ".tiff": "TIFF", + ".bmp": "BMP", ".orbit": "ORBIT", ".xodr": "XODR", + ".osm": "OSM", ".json": "JSON", + }.get(path.suffix.lower(), path.suffix.lstrip(".").upper() or "UNKNOWN") + + +def _orbit_version() -> str: + try: + from importlib.metadata import version + return version("orbit") + except Exception: + return "unknown" diff --git a/pyproject.toml b/pyproject.toml index 62b1ac3..4f18375 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "orbit" -version = "0.10.1" +version = "0.11.0" description = "OpenDrive Road Builder from Imagery Tool" readme = "README.md" requires-python = ">=3.10" @@ -28,6 +28,7 @@ dependencies = [ "lxml>=4.9.0", "pyproj>=3.6.0", "xmlschema>=4.2.0", + "geomag>=0.9.2015", ] [project.urls] @@ -38,6 +39,9 @@ Issues = "https://github.com/RI-SE/ORBIT/issues" Changelog = "https://github.com/RI-SE/ORBIT/blob/main/CHANGELOG.md" [project.optional-dependencies] +provenance = [ + "dataprov>=3.0", +] dev = [ "pytest>=7.4.0", "pytest-cov>=4.1.0", diff --git a/tests/unit/test_import/test_opendrive_importer.py b/tests/unit/test_import/test_opendrive_importer.py index 61de8ac..b083ce9 100644 --- a/tests/unit/test_import/test_opendrive_importer.py +++ b/tests/unit/test_import/test_opendrive_importer.py @@ -411,13 +411,14 @@ def test_traffic_signal(self, importer, mock_signal): assert signal_type == SignalType.TRAFFIC_SIGNALS def test_unsupported_type(self, importer, mock_signal): - """Unsupported type returns None.""" + """Unrecognized type falls back to CUSTOM for round-trip preservation.""" + from orbit.models.signal import SignalType mock_signal.type = "unknown_sign_type" mock_signal.country = "" signal_type, value = importer._convert_signal_type(mock_signal) - assert signal_type is None + assert signal_type == SignalType.CUSTOM assert value is None diff --git a/tests/unit/test_utils/test_adjustment_roundtrip.py b/tests/unit/test_utils/test_adjustment_roundtrip.py index 89f12b0..dc44c44 100644 --- a/tests/unit/test_utils/test_adjustment_roundtrip.py +++ b/tests/unit/test_utils/test_adjustment_roundtrip.py @@ -2,8 +2,10 @@ import pytest +from orbit.models.junction import Junction from orbit.models.polyline import LineType, Polyline, RoadMarkType from orbit.models.project import ControlPoint, Project +from orbit.models.road import Road from orbit.utils.coordinate_transform import ( HybridTransformer, TransformAdjustment, @@ -336,3 +338,92 @@ def test_project_transform_adjustment_cleared_after_bake(self): # Verify it's serialized as None data = project.to_dict() assert data.get('transform_adjustment') is None + + +class TestCRGeoResyncRoundTrip: + """CR geo_coords must be consistent with the active transformer. + + When CR inline_geo_path is stale (e.g. computed with a different + transformer), reproject_project_geometry uses geo_to_pixel on those + stale values and produces wrong pixel positions. Resyncing geo from + pixel before the reproject eliminates this. + """ + + @staticmethod + def _make_project_with_inconsistent_cr(): + """Create a project where CR geo_coords are intentionally inconsistent.""" + cps = [ + ControlPoint(pixel_x=100, pixel_y=100, + longitude=12.940, latitude=57.720, name="A"), + ControlPoint(pixel_x=500, pixel_y=100, + longitude=12.945, latitude=57.720, name="B"), + ControlPoint(pixel_x=300, pixel_y=400, + longitude=12.9425, latitude=57.718, name="C"), + ControlPoint(pixel_x=500, pixel_y=400, + longitude=12.945, latitude=57.718, name="D"), + ] + t = create_transformer(cps, "affine") + + project = Project(control_points=cps) + + # Two polylines meeting at a junction + pl1 = Polyline(id="pl1", line_type=LineType.CENTERLINE) + pl1.points = [(200, 200), (300, 300)] + pl1.geo_points = [t.pixel_to_geo(x, y) for x, y in pl1.points] + pl2 = Polyline(id="pl2", line_type=LineType.CENTERLINE) + pl2.points = [(300, 300), (400, 200)] + pl2.geo_points = [t.pixel_to_geo(x, y) for x, y in pl2.points] + project.polylines.extend([pl1, pl2]) + + # A connecting road between them (pixel path is correct) + cr = Road(name="CR1", junction_id="j1", + inline_path=[(300, 300), (320, 280), (340, 260)]) + # Set WRONG geo_coords (simulating stale data from a different transformer) + cr.inline_geo_path = [(12.940, 57.718), (12.941, 57.719), (12.942, 57.720)] + project.add_road(cr) + + j = Junction(center_point=(300, 300)) + j.geo_center_point = t.pixel_to_geo(300, 300) + j.add_connecting_road(cr.id) + project.junctions.append(j) + + return project, t, cr + + def test_stale_geo_causes_wrong_reproject(self): + """Without resync, stale CR geo_coords produce wrong pixel positions.""" + project, orig_t, cr = self._make_project_with_inconsistent_cr() + original_px = list(cr.inline_path) + + aerial_t = create_transformer_from_bounds( + 800, 600, 12.939, 57.717, 12.946, 57.721) + + reproject_project_geometry(project, orig_t, aerial_t) + # Store aerial positions + reproject_project_geometry(project, aerial_t, orig_t) + + # Without resync, pixel positions should NOT match original + # because the stale geo_coords map to different pixel positions + for i, (ox, oy) in enumerate(original_px): + ax, ay = cr.inline_path[i] + if abs(ax - ox) > 1.0 or abs(ay - oy) > 1.0: + return # Found expected mismatch — test passes + pytest.fail("Expected mismatched positions but got near-identical ones") + + def test_resync_then_reproject_preserves_positions(self): + """After resyncing geo from pixel, round-trip preserves positions exactly.""" + project, orig_t, cr = self._make_project_with_inconsistent_cr() + original_px = list(cr.inline_path) + + # Resync geo from pixel + cr.inline_geo_path = [orig_t.pixel_to_geo(x, y) for x, y in cr.inline_path] + + aerial_t = create_transformer_from_bounds( + 800, 600, 12.939, 57.717, 12.946, 57.721) + + reproject_project_geometry(project, orig_t, aerial_t) + reproject_project_geometry(project, aerial_t, orig_t) + + for i, (ox, oy) in enumerate(original_px): + ax, ay = cr.inline_path[i] + assert abs(ax - ox) < 0.1, f"Point {i} x: {ax} != {ox}" + assert abs(ay - oy) < 0.1, f"Point {i} y: {ay} != {oy}" diff --git a/tests/unit/test_utils/test_drone_camera.py b/tests/unit/test_utils/test_drone_camera.py new file mode 100644 index 0000000..f6057ec --- /dev/null +++ b/tests/unit/test_utils/test_drone_camera.py @@ -0,0 +1,320 @@ +"""Tests for DroneCameraModel and DroneAssistedTransformer.""" + +import math + +import numpy as np +import pytest + +from orbit.models.project import ControlPoint, DroneMetadata +from orbit.utils.camera_model import ( + DroneCameraModel, + _build_projection_matrix, + _rotation_matrix, +) +from orbit.utils.coordinate_transform import DroneAssistedTransformer + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def nadir_metadata(): + """Pure nadir camera (pitch=-90), facing North (yaw=0), 100 m altitude.""" + return DroneMetadata( + latitude=57.0, + longitude=12.0, + alt_agl=100.0, + gimbal_yaw=0.0, + gimbal_pitch=-90.0, + gimbal_roll=0.0, + hfov_deg=90.0, # 45° from centre → f = W/2 + ) + + +@pytest.fixture +def tilted_metadata(): + """South-facing camera, 10° off-nadir, 120 m altitude.""" + return DroneMetadata( + latitude=57.73975, + longitude=12.89942, + alt_agl=120.0, + gimbal_yaw=180.0, # facing South + gimbal_pitch=-80.0, # 10° off nadir + gimbal_roll=0.0, + hfov_deg=71.5, + ) + + +# --------------------------------------------------------------------------- +# _rotation_matrix +# --------------------------------------------------------------------------- + +class TestRotationMatrix: + def test_nadir_identity_columns(self): + """At pure nadir / north-facing, Z-axis (forward) should point straight down [0,0,-1].""" + R = _rotation_matrix(yaw_deg=0.0, pitch_deg=-90.0, roll_deg=0.0) + np.testing.assert_allclose(R[2], [0.0, 0.0, -1.0], atol=1e-10) + + def test_orthonormal(self): + """Rotation matrix must be orthonormal (R @ R.T == I).""" + R = _rotation_matrix(yaw_deg=45.0, pitch_deg=-77.7, roll_deg=2.0) + np.testing.assert_allclose(R @ R.T, np.eye(3), atol=1e-10) + + def test_determinant_one(self): + """det(R) must be +1 (proper rotation, not reflection).""" + R = _rotation_matrix(yaw_deg=-133.6, pitch_deg=-77.7, roll_deg=0.0) + assert abs(np.linalg.det(R) - 1.0) < 1e-10 + + +# --------------------------------------------------------------------------- +# _build_projection_matrix +# --------------------------------------------------------------------------- + +class TestBuildProjectionMatrix: + def test_nadir_centre_projects_to_image_centre(self): + """Ground nadir [E=0, N=0] must project to image centre for pure nadir camera.""" + W, H = 4000, 3000 + cx, cy = W / 2.0, H / 2.0 + f = cx # for hfov=90° tan(45°)=1 → f = cx + M = _build_projection_matrix(0.0, -90.0, 0.0, 100.0, f, cx, cy) + g = np.array([0.0, 0.0, 1.0]) # ENU ground point at nadir + p = M @ g + u, v = p[0] / p[2], p[1] / p[2] + np.testing.assert_allclose([u, v], [cx, cy], atol=0.5) + + def test_invertibility(self): + """M must be invertible.""" + M = _build_projection_matrix(226.4, -77.7, 0.0, 120.0, 2664.0, 1920.0, 1080.0) + det = np.linalg.det(M) + assert abs(det) > 1e-6 + + +# --------------------------------------------------------------------------- +# DroneCameraModel +# --------------------------------------------------------------------------- + +class TestDroneCameraModel: + def test_focal_from_hfov(self, nadir_metadata): + """Focal length resolves from hfov_deg without GCPs.""" + model = DroneCameraModel(nadir_metadata, image_width=4000, image_height=3000) + expected_f = 2000.0 # (4000/2) / tan(45°) + assert abs(model.focal_length_px - expected_f) < 1.0 + + def test_raises_without_hfov_or_gcps(self): + md = DroneMetadata( + latitude=57.0, longitude=12.0, alt_agl=100.0, + gimbal_yaw=0.0, gimbal_pitch=-90.0, + ) + with pytest.raises(ValueError, match="focal length"): + DroneCameraModel(md, image_width=1920, image_height=1080) + + def test_pixel_to_enu_roundtrip(self, nadir_metadata): + """pixel_to_enu → enu_to_pixel should reconstruct original pixel within 0.5 px.""" + model = DroneCameraModel(nadir_metadata, image_width=4000, image_height=3000) + test_pixels = [(1000, 800), (3500, 2500), (200, 200)] + for u, v in test_pixels: + east, north = model.pixel_to_enu(u, v) + u2, v2 = model.enu_to_pixel(east, north) + np.testing.assert_allclose([u2, v2], [u, v], atol=0.5, + err_msg=f"Roundtrip failed for ({u},{v})") + + def test_nadir_ground_track_at_image_centre(self, nadir_metadata): + """For pure-nadir camera, image centre must map to ENU origin (0, 0).""" + model = DroneCameraModel(nadir_metadata, image_width=4000, image_height=3000) + east, north = model.pixel_to_enu(2000.0, 1500.0) + np.testing.assert_allclose([east, north], [0.0, 0.0], atol=1.0) + + def test_gcp_reprojection(self, tilted_metadata): + """With hfov_deg given, nadir (ENU origin) projects to a finite pixel location.""" + model = DroneCameraModel(tilted_metadata, image_width=3840, image_height=2160) + u, v = model.enu_to_pixel(0.0, 0.0) + assert math.isfinite(u) and math.isfinite(v) + + def test_focal_solved_from_gcps(self): + """When hfov_deg is absent, focal length should be solved from GCPs.""" + md = DroneMetadata( + latitude=57.0, longitude=12.0, alt_agl=100.0, + gimbal_yaw=0.0, gimbal_pitch=-90.0, + ) + # Build a real model with known f to generate GCPs + f_true = 2000.0 + W, H = 4000, 3000 + from orbit.utils.camera_model import _build_projection_matrix + M = _build_projection_matrix(0.0, -90.0, 0.0, 100.0, f_true, W/2, H/2) + M_inv = np.linalg.inv(M) + + R_earth = 6_371_000.0 + cos_lat = math.cos(math.radians(57.0)) + + def enu_to_latlon(east, north): + lat = md.latitude + math.degrees(north / R_earth) + lon = md.longitude + math.degrees(east / (R_earth * cos_lat)) + return lat, lon + + gcps = [] + for (pu, pv) in [(500, 500), (3500, 500), (500, 2500), (3500, 2500)]: + g = M_inv @ np.array([pu, pv, 1.0]) + e, n = g[0] / g[2], g[1] / g[2] + lat, lon = enu_to_latlon(e, n) + gcps.append(ControlPoint(name="g", pixel_x=pu, pixel_y=pv, latitude=lat, longitude=lon)) + + model = DroneCameraModel(md, image_width=W, image_height=H, control_points=gcps) + assert abs(model.focal_length_px - f_true) / f_true < 0.05 # within 5% + + def test_estimate_heading_from_gcps_returns_float(self, nadir_metadata): + """estimate_heading_from_gcps returns (yaw_float, rmse_float) without error.""" + model = DroneCameraModel(nadir_metadata, image_width=4000, image_height=3000) + # Need at least 2 GCPs + cp1 = ControlPoint(name="a", pixel_x=1000.0, pixel_y=1000.0, + latitude=57.001, longitude=12.0) + cp2 = ControlPoint(name="b", pixel_x=3000.0, pixel_y=2000.0, + latitude=57.0, longitude=12.001) + yaw, rmse = model.estimate_heading_from_gcps([cp1, cp2]) + assert isinstance(yaw, float) and math.isfinite(yaw) + assert isinstance(rmse, float) and rmse >= 0.0 + + +# --------------------------------------------------------------------------- +# DroneAssistedTransformer +# --------------------------------------------------------------------------- + +class TestDroneAssistedTransformer: + def _make_transformer(self, metadata, gcps=None): + return DroneAssistedTransformer( + metadata=metadata, + control_points=gcps or [], + image_width=4000, + image_height=3000, + ) + + def test_reference_point_is_drone_nadir(self, nadir_metadata): + """Reference lat/lon should equal drone nadir, not CP centroid.""" + tr = self._make_transformer(nadir_metadata) + assert tr.reference_lat == nadir_metadata.latitude + assert tr.reference_lon == nadir_metadata.longitude + + def test_pixel_to_geo_roundtrip(self, nadir_metadata): + """pixel_to_geo → geo_to_pixel roundtrip within 0.5 px.""" + tr = self._make_transformer(nadir_metadata) + for u, v in [(500, 500), (2000, 1500), (3800, 2800)]: + lon, lat = tr.pixel_to_geo(u, v) + u2, v2 = tr.geo_to_pixel(lon, lat) + np.testing.assert_allclose([u2, v2], [u, v], atol=0.5, + err_msg=f"Roundtrip failed for ({u},{v})") + + def test_image_centre_near_drone_nadir(self, nadir_metadata): + """For pure-nadir camera, image centre should be close to drone lat/lon.""" + tr = self._make_transformer(nadir_metadata) + lon, lat = tr.pixel_to_geo(2000.0, 1500.0) + assert abs(lat - nadir_metadata.latitude) < 0.001 + assert abs(lon - nadir_metadata.longitude) < 0.001 + + def test_get_scale_factor_positive(self, nadir_metadata): + """get_scale_factor should return positive x/y metres-per-pixel values.""" + tr = self._make_transformer(nadir_metadata) + sx, sy = tr.get_scale_factor() + assert sx > 0 + assert sy > 0 + + def test_transform_matrix_shape(self, nadir_metadata): + """transform_matrix and inverse_matrix must be 3×3.""" + tr = self._make_transformer(nadir_metadata) + assert tr.transform_matrix.shape == (3, 3) + assert tr.inverse_matrix.shape == (3, 3) + + def test_with_gcps_reprojection_error(self): + """With consistent synthetic GCPs, reprojection error should be near zero.""" + md = DroneMetadata( + latitude=57.0, longitude=12.0, alt_agl=100.0, + gimbal_yaw=0.0, gimbal_pitch=-90.0, hfov_deg=90.0, + ) + # Generate GCPs consistent with the camera model + f = 2000.0 + W, H = 4000, 3000 + M = _build_projection_matrix(0.0, -90.0, 0.0, 100.0, f, W/2, H/2) + M_inv = np.linalg.inv(M) + R_earth = 6_371_000.0 + cos_lat = math.cos(math.radians(57.0)) + + def enu_to_latlon(e, n): + return ( + md.latitude + math.degrees(n / R_earth), + md.longitude + math.degrees(e / (R_earth * cos_lat)), + ) + + gcps = [] + for pu, pv in [(500, 400), (3500, 400), (500, 2600), (3500, 2600), (2000, 1500)]: + g = M_inv @ np.array([pu, pv, 1.0]) + e, n = g[0] / g[2], g[1] / g[2] + lat, lon = enu_to_latlon(e, n) + gcps.append(ControlPoint(name="g", pixel_x=pu, pixel_y=pv, + latitude=lat, longitude=lon)) + + tr = DroneAssistedTransformer(metadata=md, control_points=gcps, + image_width=W, image_height=H) + assert tr.reprojection_error is not None + assert tr.reprojection_error['rmse_pixels'] < 1.0 + + +# --------------------------------------------------------------------------- +# DroneMetadata.from_video_stats +# --------------------------------------------------------------------------- + +class TestDroneMetadataFromVideoStats: + def _sample_stats(self, **overrides): + stats = { + "drone_type": "Mavic3Pro", + "lens_type": "standard", + "sequences": [{ + "sequence_id": 0, + "stats": { + "osd": { + "latitude": {"mean": 57.73975, "std": 0.0}, + "longitude": {"mean": 12.89942, "std": 0.0}, + "height_agl": {"mean": 119.92, "std": 0.5}, + }, + "gimbal": { + "yaw": {"mean": -133.6, "std": 0.0}, + "pitch": {"mean": -77.7, "std": 0.0}, + "roll": {"mean": 0.0, "std": 0.0}, + }, + }, + }], + } + stats.update(overrides) + return stats + + def test_basic_parsing(self): + md = DroneMetadata.from_video_stats(self._sample_stats()) + assert abs(md.latitude - 57.73975) < 1e-5 + assert abs(md.longitude - 12.89942) < 1e-5 + assert abs(md.alt_agl - 119.92) < 0.01 + assert abs(md.gimbal_yaw - (-133.6)) < 0.01 + assert abs(md.gimbal_pitch - (-77.7)) < 0.01 + assert md.drone_type == "Mavic3Pro" + assert md.lens_type == "standard" + + def test_hfov_from_camera_section(self): + stats = self._sample_stats() + stats["camera"] = {"hfov_deg": 71.5} + md = DroneMetadata.from_video_stats(stats) + assert md.hfov_deg == 71.5 + + def test_no_hfov_without_camera_section(self): + md = DroneMetadata.from_video_stats(self._sample_stats()) + assert md.hfov_deg is None + + def test_raises_on_empty_sequences(self): + stats = self._sample_stats() + stats["sequences"] = [] + with pytest.raises(ValueError, match="no sequences"): + DroneMetadata.from_video_stats(stats) + + def test_roundtrip_to_from_dict(self): + md = DroneMetadata.from_video_stats(self._sample_stats()) + md2 = DroneMetadata.from_dict(md.to_dict()) + assert md.latitude == md2.latitude + assert md.longitude == md2.longitude + assert md.alt_agl == md2.alt_agl + assert md.gimbal_yaw == md2.gimbal_yaw + assert md.drone_type == md2.drone_type diff --git a/uv.lock b/uv.lock index a9c7dcc..d526eac 100644 --- a/uv.lock +++ b/uv.lock @@ -144,6 +144,15 @@ toml = [ { name = "tomli", marker = "python_full_version <= '3.11'" }, ] +[[package]] +name = "dataprov" +version = "3.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/a5/a873ee8fa3ca9f3ae02bc1584119e4515cf656e203152796d564ad67d9eb/dataprov-3.2.0.tar.gz", hash = "sha256:e2c76cc3198c671e5953d20e2c39a0264ef4c91f1e04e9ff28343386995d9b93", size = 86968, upload-time = "2026-04-14T12:22:18.831Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/c5/520700bad30e9428bea54d54925656a001a99977ae54a79bae1bb7c0f8f9/dataprov-3.2.0-py3-none-any.whl", hash = "sha256:098891f4496e25d18185d53644b6724f7863181ae83b74ef5cbce83083513ac3", size = 48009, upload-time = "2026-04-14T12:22:17.642Z" }, +] + [[package]] name = "distlib" version = "0.4.0" @@ -183,6 +192,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/a5/842ae8f0c08b61d6484b52f99a03510a3a72d23141942d216ebe81fefbce/filelock-3.25.2-py3-none-any.whl", hash = "sha256:ca8afb0da15f229774c9ad1b455ed96e85a81373065fb10446672f64444ddf70", size = 26759, upload-time = "2026-03-11T20:45:37.437Z" }, ] +[[package]] +name = "geomag" +version = "0.9.2015" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/1f/95b5d9db29b89735af88713e7fb49b20052d1ce5a05e8e4ea0f2f96da4b3/geomag-0.9.2015.zip", hash = "sha256:fa3f7544beca2dd706c3078d9d58cb35734d00de2591270a335216bb8ba310d5", size = 8071, upload-time = "2014-12-18T01:15:53.201Z" } + [[package]] name = "identify" version = "2.6.18" @@ -512,6 +527,7 @@ name = "orbit" version = "0.10.1" source = { editable = "." } dependencies = [ + { name = "geomag" }, { name = "lxml" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -532,6 +548,9 @@ dev = [ { name = "pytest-mock" }, { name = "ruff" }, ] +provenance = [ + { name = "dataprov" }, +] [package.dev-dependencies] dev = [ @@ -540,6 +559,8 @@ dev = [ [package.metadata] requires-dist = [ + { name = "dataprov", marker = "extra == 'provenance'", specifier = ">=3.0" }, + { name = "geomag", specifier = ">=0.9.2015" }, { name = "lxml", specifier = ">=4.9.0" }, { name = "numpy", specifier = ">=1.24.0" }, { name = "opencv-python", specifier = ">=4.8.0" }, @@ -553,7 +574,7 @@ requires-dist = [ { name = "scipy", specifier = ">=1.11.0" }, { name = "xmlschema", specifier = ">=4.2.0" }, ] -provides-extras = ["dev"] +provides-extras = ["provenance", "dev"] [package.metadata.requires-dev] dev = [{ name = "pre-commit", specifier = ">=4.5.1" }]