From e070e5a763594a41aa219455d0b6c2ee09fc0cad Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 19 Aug 2024 12:10:22 +0200 Subject: [PATCH 01/12] Extend MapillaryDownloader to list images --- maploc/data/mapillary/download.py | 167 +++++++++++++++++++----------- 1 file changed, 105 insertions(+), 62 deletions(-) diff --git a/maploc/data/mapillary/download.py b/maploc/data/mapillary/download.py index 816aa9f..410bfc4 100644 --- a/maploc/data/mapillary/download.py +++ b/maploc/data/mapillary/download.py @@ -2,7 +2,9 @@ import asyncio import json +from functools import partial from pathlib import Path +from typing import List, Optional, Sequence, Tuple import httpx import numpy as np @@ -12,7 +14,7 @@ from opensfm.pymap import Shot from ... import logger -from ...utils.geo import Projection +from ...utils.geo import BoundaryBox, Projection semaphore = asyncio.Semaphore(100) # number of parallel threads. image_filename = "{image_id}.jpg" @@ -55,11 +57,13 @@ class MapillaryDownloader: "sequence", "sfm_cluster", ) - image_info_url = ( - "https://graph.mapillary.com/{image_id}?access_token={token}&fields={fields}" + base_url = "https://graph.mapillary.com" + image_info_url = "{base}/{image_id}?access_token={token}&fields={fields}" + image_list_url = ( + "{base}/images?access_token={token}&bbox={bbox}&limit={limit}&fields=id" ) - seq_info_url = "https://graph.mapillary.com/image_ids?access_token={token}&sequence_id={seq_id}" # noqa E501 max_requests_per_minute = 50_000 + max_num_results = 5_000 # maximum allowed by mapillary.com def __init__(self, token: str): self.token = token @@ -70,27 +74,59 @@ def __init__(self, token: str): @retry(times=5, exceptions=(httpx.RemoteProtocolError, httpx.ReadError)) async def call_api(self, url: str): - async with self.limiter: - r = await self.client.get(url) + async with semaphore: + async with self.limiter: + r = await self.client.get(url) if not r.is_success: - logger.error("Error in API call: %s", r.text) + logger.error("Error in API call: %s, retrying...", r.text) + raise httpx.ReadError(r.text) return r - async def get_image_info(self, image_id: int): + async def get_image_list( + self, bbox: BoundaryBox, is_pano: Optional[bool] = None, **filters + ): + # API bbox format: left, bottom, right, top (or minLon, minLat, maxLon, maxLat) + bbox = ",".join(map(str, (*bbox.min_[::-1], *bbox.max_[::-1]))) + url = self.image_list_url.format( + base=self.base_url, token=self.token, bbox=bbox, limit=self.max_num_results + ) + if is_pano is not None: + url += "&is_pano=" + ("true" if is_pano else "false") + for name, val in filters.items(): + url += f"&{name}={val}" + r = await self.call_api(url) + if r.is_success: + info = json.loads(r.text) + image_ids = [int(d["id"]) for d in info["data"]] + return image_ids + # return json.loads(r.text) + + async def get_image_info( + self, image_id: int, fields: Optional[Sequence[str]] = None + ): url = self.image_info_url.format( + base=self.base_url, image_id=image_id, token=self.token, - fields=",".join(self.image_fields), + fields=",".join(fields or self.image_fields), ) r = await self.call_api(url) if r.is_success: return json.loads(r.text) - async def get_sequence_info(self, seq_id: str): - url = self.seq_info_url.format(seq_id=seq_id, token=self.token) - r = await self.call_api(url) - if r.is_success: - return json.loads(r.text) + async def get_image_info_cached( + self, image_id: int, dir_: Optional[Path] = None, **kwargs + ): + if dir_ is None: + return await self.get_image_info(image_id, **kwargs) + path = dir_ / info_filename.format(image_id=image_id) + if path.exists(): + info = json.loads(path.read_text()) + else: + info = await self.get_image_info(image_id, **kwargs) + if info is not None: + path.write_text(json.dumps(info)) + return info async def download_image_pixels(self, url: str, path: Path): r = await self.call_api(url) @@ -99,15 +135,6 @@ async def download_image_pixels(self, url: str, path: Path): fid.write(r.content) return r.is_success - async def get_image_info_cached(self, image_id: int, path: Path): - if path.exists(): - info = json.loads(path.read_text()) - else: - info = await self.get_image_info(image_id) - if info is not None: - path.write_text(json.dumps(info)) - return info - async def download_image_pixels_cached(self, url: str, path: Path): if path.exists(): return True @@ -115,64 +142,80 @@ async def download_image_pixels_cached(self, url: str, path: Path): return await self.download_image_pixels(url, path) -async def fetch_images_in_sequence(i, downloader): - async with semaphore: - info = await downloader.get_sequence_info(i) - if info is None: - image_ids = None - else: - image_ids = [int(d["id"]) for d in info["data"]] - return i, image_ids +async def _return_with_arg(item, fn): + ret = await fn(item) + return item, ret -async def fetch_images_in_sequences(sequence_ids, downloader): - seq_to_images_ids = {} - tasks = [fetch_images_in_sequence(i, downloader) for i in sequence_ids] +async def fetch_many(items, fn): + results = [] + tasks = [_return_with_arg(item, fn) for item in items] for task in tqdm.asyncio.tqdm.as_completed(tasks): - i, image_ids = await task - if image_ids is not None: - seq_to_images_ids[i] = image_ids - return seq_to_images_ids + results.append(await task) + return results -async def fetch_image_info(i, downloader, dir_): - async with semaphore: - path = dir_ / info_filename.format(image_id=i) - info = await downloader.get_image_info_cached(i, path) - return i, info - - -async def fetch_image_infos(image_ids, downloader, dir_): - infos = {} +async def fetch_image_infos(image_ids, downloader, **kwargs): + infos = await fetch_many( + image_ids, partial(downloader.get_image_info_cached, **kwargs) + ) + infos = dict(infos) num_fail = 0 - tasks = [fetch_image_info(i, downloader, dir_) for i in image_ids] - for task in tqdm.asyncio.tqdm.as_completed(tasks): - i, info = await task - if info is None: + for i in image_ids: + if infos[i] is None: + del infos[i] num_fail += 1 - else: - infos[i] = info return infos, num_fail -async def fetch_image_pixels(i, url, downloader, dir_, overwrite=False): - async with semaphore: +async def fetch_images_pixels(image_urls, downloader, dir_, overwrite=False): + tasks = [] + for i, url in image_urls: path = dir_ / image_filename.format(image_id=i) if overwrite: path.unlink(missing_ok=True) - success = await downloader.download_image_pixels_cached(url, path) - return i, success - - -async def fetch_images_pixels(image_urls, downloader, dir_): + tasks.append(downloader.download_image_pixels_cached(url, path)) num_fail = 0 - tasks = [fetch_image_pixels(*id_url, downloader, dir_) for id_url in image_urls] for task in tqdm.asyncio.tqdm.as_completed(tasks): - i, success = await task + success = await task num_fail += not success return num_fail +def split_bbox(bbox: BoundaryBox) -> tuple[BoundaryBox]: + midpoint = bbox.center + return ( + BoundaryBox(bbox.min_, midpoint), + BoundaryBox((bbox.min_[0], midpoint[1]), (midpoint[0], bbox.max_[1])), + BoundaryBox((midpoint[0], bbox.min_[1]), (bbox.max_[0], midpoint[1])), + BoundaryBox(midpoint, bbox.max_), + ) + + +async def fetch_image_list( + query_bbox: BoundaryBox, + downloader: MapillaryDownloader, + **filters, +) -> Tuple[List[int], List[BoundaryBox]]: + """Because of the limit in number of returned results, we recursively split + the query area until each query is below this limit. + """ + pool = [query_bbox] + finished = [] + all_ids = [] + while len(pool): + rets = await fetch_many(pool, partial(downloader.get_image_list, **filters)) + pool = [] + for bbox, ids in rets: + assert ids is not None + if len(ids) == downloader.max_num_results: + pool.extend(split_bbox(bbox)) + else: + finished.append(bbox) + all_ids.extend(ids) + return all_ids, finished + + def opensfm_camera_from_info(info: dict) -> Camera: cam_type = info["camera_type"] if cam_type == "perspective": From 0c0018d9de4e2b411b42946ede7a8c4e68bbef12 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 19 Aug 2024 12:11:10 +0200 Subject: [PATCH 02/12] Add code to list Mapillary images and split them --- maploc/data/mapillary/config.py | 146 ++++++++++++++++++++++++++ maploc/data/mapillary/prepare.py | 102 +------------------ maploc/data/mapillary/split.py | 170 +++++++++++++++++++++++++++++++ 3 files changed, 318 insertions(+), 100 deletions(-) create mode 100644 maploc/data/mapillary/config.py create mode 100644 maploc/data/mapillary/split.py diff --git a/maploc/data/mapillary/config.py b/maploc/data/mapillary/config.py new file mode 100644 index 0000000..2bcba8f --- /dev/null +++ b/maploc/data/mapillary/config.py @@ -0,0 +1,146 @@ +from omegaconf import OmegaConf + +from ...utils.geo import BoundaryBox + +location_to_params = { + "sanfrancisco_soma": { + "bbox": BoundaryBox((37.770364, -122.410307), (37.795545, -122.388772)), + "bbox_val": BoundaryBox( + (37.788123419925945, -122.40053535863909), + (37.78897443253716, -122.3994618718349), + ), + "filters": {"model": "GoPro Max"}, + "osm_file": "sanfrancisco.osm", + }, + "sanfrancisco_hayes": { + "bbox": BoundaryBox((37.768634, -122.438415), (37.783894, -122.410605)), + "bbox_val": BoundaryBox( + (37.77682908567614, -122.42439593370665), + (37.7776996640339, -122.42329849537967), + ), + "filters": {"model": "GoPro Max"}, + "osm_file": "sanfrancisco.osm", + }, + "montrouge": { + "bbox": BoundaryBox((48.80874, 2.298958), (48.825276, 2.332989)), + "bbox_val": BoundaryBox( + (48.81554465300679, 2.315590378986898), + (48.816228935240346, 2.3166087395920103), + ), + "filters": {"model": "LG-R105"}, + "osm_file": "paris.osm", + }, + "amsterdam": { + "bbox": BoundaryBox((52.340679, 4.845284), (52.386299, 4.926147)), + "bbox_val": BoundaryBox( + (52.358275965541495, 4.876867175817335), + (52.35920971624303, 4.878370977965195), + ), + "filters": {"model": "GoPro Max"}, + "osm_file": "amsterdam.osm", + }, + "lemans": { + "bbox": BoundaryBox((47.995125, 0.185752), (48.014209, 0.224088)), + "bbox_val": BoundaryBox( + (48.00468200256593, 0.20130905922712253), + (48.00555356009431, 0.20251886369476968), + ), + "filters": {"creator_username": "sogefi"}, + "osm_file": "lemans.osm", + }, + "berlin": { + "bbox": BoundaryBox((52.459656, 13.416271), (52.499195, 13.469829)), + "bbox_val": BoundaryBox( + (52.47478263625299, 13.436060761632277), + (52.47610554128314, 13.438407628895831), + ), + "filters": {"is_pano": True}, + "osm_file": "berlin.osm", + }, + "nantes": { + "bbox": BoundaryBox((47.198289, -1.585839), (47.236161, -1.51318)), + "bbox_val": BoundaryBox( + (47.212224982547106, -1.555772859366718), + (47.213374064189956, -1.554270622470525), + ), + "filters": {"is_pano": True}, + "osm_file": "nantes.osm", + }, + "toulouse": { + "bbox": BoundaryBox((43.591434, 1.429457), (43.61343, 1.456653)), + "bbox_val": BoundaryBox( + (43.60314813839066, 1.4431497839062253), + (43.604433961018984, 1.4448508228862122), + ), + "filters": {"is_pano": True}, + "osm_file": "toulouse.osm", + }, + "vilnius": { + "bbox": BoundaryBox((54.672956, 25.258633), (54.696755, 25.296094)), + "bbox_val": BoundaryBox( + (54.68292611300143, 25.276979025529165), + (54.68349008447563, 25.27798847871685), + ), + "filters": {"is_pano": True}, + "osm_file": "vilnius.osm", + }, + "helsinki": { + "bbox": BoundaryBox( + (60.1449128318, 24.8975480117), (60.1770977471, 24.9816543235) + ), + "bbox_val": BoundaryBox( + (60.163825618884275, 24.930182541064955), + (60.16518598734065, 24.93274647451007), + ), + "filters": {"is_pano": True}, + "osm_file": "helsinki.osm", + }, + "milan": { + "bbox": BoundaryBox( + (45.4810977947, 9.1732723899), (45.5284238563, 9.2255987917) + ), + "bbox_val": BoundaryBox( + (45.502686834500466, 9.189078329923374), + (45.50329294217317, 9.189881944589828), + ), + "filters": {"is_pano": True}, + "osm_file": "milan.osm", + }, + "avignon": { + "bbox": BoundaryBox( + (43.9416178156, 4.7887045302), (43.9584848909, 4.8227015622) + ), + "bbox_val": BoundaryBox( + (43.94768786305171, 4.809099008430249), + (43.94827840894793, 4.809954737764413), + ), + "filters": {"is_pano": True}, + "osm_file": "avignon.osm", + }, + "paris": { + "bbox": BoundaryBox((48.833827, 2.306823), (48.889335, 2.39067)), + "bbox_val": BoundaryBox( + (48.85558288211851, 2.3427920762801526), + (48.85703370256603, 2.3449544861818654), + ), + "filters": {"is_pano": True}, + "osm_file": "paris.osm", + }, +} + +default_cfg = OmegaConf.create( + { + "downsampling_resolution_meters": 2, + "target_num_val_images": 100, + "val_train_margin_meters": 25, + "max_num_train_images": 50_000, + "max_image_size": 512, + "do_legacy_pano_offset": True, + "min_dist_between_keyframes": 4, + "tiling": { + "tile_size": 128, + "margin": 128, + "ppm": 2, + }, + } +) diff --git a/maploc/data/mapillary/prepare.py b/maploc/data/mapillary/prepare.py index af8d96e..bbaed39 100644 --- a/maploc/data/mapillary/prepare.py +++ b/maploc/data/mapillary/prepare.py @@ -26,6 +26,7 @@ from ...utils.geo import BoundaryBox, Projection from ...utils.io import DATA_URL, download_file, write_json from ..utils import decompose_rotmat +from .config import default_cfg, location_to_params from .dataset import MapillaryDataModule from .download import ( MapillaryDownloader, @@ -43,106 +44,7 @@ undistort_shot, ) -location_to_params = { - "sanfrancisco_soma": { - "bbox": BoundaryBox( - [-122.410307, 37.770364][::-1], [-122.388772, 37.795545][::-1] - ), - "camera_models": ["GoPro Max"], - "osm_file": "sanfrancisco.osm", - }, - "sanfrancisco_hayes": { - "bbox": BoundaryBox( - [-122.438415, 37.768634][::-1], [-122.410605, 37.783894][::-1] - ), - "camera_models": ["GoPro Max"], - "osm_file": "sanfrancisco.osm", - }, - "amsterdam": { - "bbox": BoundaryBox([4.845284, 52.340679][::-1], [4.926147, 52.386299][::-1]), - "camera_models": ["GoPro Max"], - "osm_file": "amsterdam.osm", - }, - "lemans": { - "bbox": BoundaryBox([0.185752, 47.995125][::-1], [0.224088, 48.014209][::-1]), - "owners": ["xXOocM1jUB4jaaeukKkmgw"], # sogefi - "osm_file": "lemans.osm", - }, - "berlin": { - "bbox": BoundaryBox([13.416271, 52.459656][::-1], [13.469829, 52.499195][::-1]), - "owners": ["LT3ajUxH6qsosamrOHIrFw"], # supaplex030 - "osm_file": "berlin.osm", - }, - "montrouge": { - "bbox": BoundaryBox([2.298958, 48.80874][::-1], [2.332989, 48.825276][::-1]), - "owners": [ - "XtzGKZX2_VIJRoiJ8IWRNQ", - "C4ENdWpJdFNf8CvnQd7NrQ", - "e_ZBE6mFd7CYNjRSpLl-Lg", - ], # overflorian, phyks, francois2 - "camera_models": ["LG-R105"], - "osm_file": "paris.osm", - }, - "nantes": { - "bbox": BoundaryBox([-1.585839, 47.198289][::-1], [-1.51318, 47.236161][::-1]), - "owners": [ - "jGdq3CL-9N-Esvj3mtCWew", - "s-j5BH9JRIzsgORgaJF3aA", - ], # c_mobilite, cartocite - "osm_file": "nantes.osm", - }, - "toulouse": { - "bbox": BoundaryBox([1.429457, 43.591434][::-1], [1.456653, 43.61343][::-1]), - "owners": ["MNkhq6MCoPsdQNGTMh3qsQ"], # tyndare - "osm_file": "toulouse.osm", - }, - "vilnius": { - "bbox": BoundaryBox([25.258633, 54.672956][::-1], [25.296094, 54.696755][::-1]), - "owners": ["bClduFF6Gq16cfwCdhWivw", "u5ukBseATUS8jUbtE43fcO"], # kedas, vms - "osm_file": "vilnius.osm", - }, - "helsinki": { - "bbox": BoundaryBox( - [24.8975480117, 60.1449128318][::-1], [24.9816543235, 60.1770977471][::-1] - ), - "camera_types": ["spherical", "equirectangular"], - "osm_file": "helsinki.osm", - }, - "milan": { - "bbox": BoundaryBox( - [9.1732723899, 45.4810977947][::-1], - [9.2255987917, 45.5284238563][::-1], - ), - "camera_types": ["spherical", "equirectangular"], - "osm_file": "milan.osm", - }, - "avignon": { - "bbox": BoundaryBox( - [4.7887045302, 43.9416178156][::-1], [4.8227015622, 43.9584848909][::-1] - ), - "camera_types": ["spherical", "equirectangular"], - "osm_file": "avignon.osm", - }, - "paris": { - "bbox": BoundaryBox([2.306823, 48.833827][::-1], [2.39067, 48.889335][::-1]), - "camera_types": ["spherical", "equirectangular"], - "osm_file": "paris.osm", - }, -} - - -default_cfg = OmegaConf.create( - { - "max_image_size": 512, - "do_legacy_pano_offset": True, - "min_dist_between_keyframes": 4, - "tiling": { - "tile_size": 128, - "margin": 128, - "ppm": 2, - }, - } -) +DATA_FILENAME = "dump.json" def get_pano_offset(image_info: dict, do_legacy: bool = False) -> float: diff --git a/maploc/data/mapillary/split.py b/maploc/data/mapillary/split.py new file mode 100644 index 0000000..d6bc575 --- /dev/null +++ b/maploc/data/mapillary/split.py @@ -0,0 +1,170 @@ +import argparse +import asyncio +from pathlib import Path +from typing import Any, Dict, Optional + +import numpy as np +from omegaconf import DictConfig, OmegaConf + +from ... import logger +from ...osm.viz import GeoPlotter +from ...utils.geo import BoundaryBox, Projection +from ...utils.io import write_json +from .config import default_cfg, location_to_params +from .dataset import MapillaryDataModule +from .download import MapillaryDownloader, fetch_image_infos, fetch_image_list + + +def grid_downsample(xy: np.ndarray, bbox: BoundaryBox, resolution: float) -> np.ndarray: + assert bbox.contains(xy).all() + extent = bbox.max_ - bbox.min_ + size = np.ceil(extent / resolution).astype(int) + grid = np.full(size, -1) + idx = np.floor((xy - bbox.min_) / resolution).astype(int) + grid[tuple(idx.T)] = np.arange(len(xy)) + indices = grid[grid >= 0] + return indices + + +def find_validation_bbox( + xy: np.ndarray, target_num, size_upper_bound=500 +) -> BoundaryBox: + # Find the centroid of all points + center = np.median(xy, 0) + dist = np.linalg.norm(xy - center[None], axis=1) + center = xy[np.argmin(dist)] + + bbox = BoundaryBox(center - size_upper_bound, center + size_upper_bound) + mask = bbox.contains(xy) + dist = np.abs(xy[mask] - center).max(-1) + dist.sort() + thresh = dist[target_num] + bbox_val = BoundaryBox(center - thresh, center + thresh) + return bbox_val + + +def process_location( + output_path: Path, + token: str, + cfg: DictConfig, + query_bbox: BoundaryBox, + filters: Dict[str, Any], + bbox_val: Optional[BoundaryBox] = None, +): + output_path.parent.mkdir(parents=True, exist_ok=True) + downloader = MapillaryDownloader(token) + loop = asyncio.get_event_loop() + projection = Projection(*query_bbox.center) + bbox_local = projection.project(query_bbox) + + logger.info("Fetching the list of image with filter: %s", filters) + image_ids, bboxes = loop.run_until_complete( + fetch_image_list(query_bbox, downloader, **filters) + ) + if not image_ids: + raise ValueError("Could not find any image!") + logger.info("Found %d images.", len(image_ids)) + + logger.info("Fetching the image coordinates.") + infos, num_fail = loop.run_until_complete( + fetch_image_infos(image_ids, downloader, fields=["computed_geometry"]) + ) + logger.info("%d failures (%.1f%%).", num_fail, 100 * num_fail / len(image_ids)) + + # discard images that don't have coordinates available + image_ids = np.array([i for i in infos if "computed_geometry" in infos[i]]) + image_ids.sort() + + # discard images outside of the query bbox + latlon = np.array( + [infos[i]["computed_geometry"]["coordinates"][::-1] for i in image_ids] + ) + xy = projection.project(latlon) + valid = bbox_local.contains(xy) + image_ids = image_ids[valid] + latlon = latlon[valid] + xy = xy[valid] + + # downsample the images with a grid + indices = grid_downsample(xy, bbox_local, cfg.downsampling_resolution_meters) + image_ids = image_ids[indices] + latlon = latlon[indices] + xy = xy[indices] + logger.info("Filtered down to %d images.", len(image_ids)) + + if bbox_val is None: + bbox_val_local = find_validation_bbox(xy, cfg.target_num_val_images) + bbox_val = projection.unproject(bbox_val_local) + else: + bbox_val_local = projection.project(bbox_val) + logger.info("Using validation bounding box: %s.", bbox_val) + indices_val = np.where(bbox_val_local.contains(xy))[0] + bbox_not_train = bbox_val_local + cfg.val_train_margin_meters + indices_train = np.where(~bbox_not_train.contains(xy))[0] + if len(indices_train) > cfg.max_num_train_images: + indices_subsample = np.random.RandomState(0).choice( + len(indices_train), cfg.max_num_train_images + ) + indices_train = indices_train[indices_subsample] + logger.info( + "Resulting split: %d val and %d train images.", + len(indices_val), + len(indices_train), + ) + + splits = { + "val": image_ids[indices_val].tolist(), + "train": image_ids[indices_train].tolist(), + } + write_json(output_path, splits) + + # Visualize the data split + plotter = GeoPlotter() + plotter.points(latlon[indices_train], "red", image_ids[indices_train], "train") + plotter.points(latlon[indices_val], "green", image_ids[indices_val], "val") + plotter.bbox(query_bbox, "blue", "query bounding box") + plotter.bbox(bbox_val, "green", "validation bounding box") + plotter.bbox(projection.unproject(bbox_not_train), "black", "margin bounding box") + geo_viz_path = f"{output_path}_viz.html" + plotter.fig.write_html(geo_viz_path) + logger.info("Wrote split visualization to %s.", geo_viz_path) + + +def main(args: argparse.Namespace): + args.data_dir.mkdir(exist_ok=True, parents=True) + cfg = OmegaConf.merge(default_cfg, OmegaConf.from_cli(args.dotlist)) + for location in args.locations: + output_path = args.data_dir / args.output_filename.format(scene=location) + if output_path.exists() and not args.overwrite: + logger.info("Skipping processing for location %s.", location) + continue + logger.info("Starting processing for location %s.", location) + params = location_to_params[location] + process_location( + output_path, + args.token, + cfg, + params["bbox"], + params["filters"] | {"start_captured_at": args.min_capture_date}, + None if args.force_auto_val_bbox else params.get("bbox_val"), + ) + logger.info("Done processing for location %s.", location) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--locations", type=str, nargs="+", default=list(location_to_params) + ) + parser.add_argument("--token", type=str, required=True) + parser.add_argument("--min_capture_date", type=str, default="2015-01-01T00:00:00Z") + parser.add_argument( + "--output_filename", type=str, default="splits_MGL_v2_{scene}.json" + ) + parser.add_argument( + "--data_dir", type=Path, default=MapillaryDataModule.default_cfg["data_dir"] + ) + parser.add_argument("--overwrite", action="store_true") + parser.add_argument("--force_auto_val_bbox", action="store_true") + parser.add_argument("dotlist", nargs="*") + main(parser.parse_args()) From dded9aeaa71e8d9b68656bf468cc227e95a1b6c7 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 19 Aug 2024 12:42:06 +0200 Subject: [PATCH 03/12] Add utils to download OSM data for large areas --- maploc/osm/download.py | 46 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/maploc/osm/download.py b/maploc/osm/download.py index 2c1f7a2..3f7e0cc 100644 --- a/maploc/osm/download.py +++ b/maploc/osm/download.py @@ -1,10 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import json +import subprocess from http.client import responses from pathlib import Path from typing import Any, Dict, Optional +import shapely import urllib3 from .. import logger @@ -18,6 +20,7 @@ def get_osm( cache_path: Optional[Path] = None, overwrite: bool = False, ) -> Dict[str, Any]: + """Fetch OSM data using the OSM API. Suitable only for small areas.""" if not overwrite and cache_path is not None and cache_path.is_file(): return json.loads(cache_path.read_text()) @@ -33,3 +36,46 @@ def get_osm( if cache_path is not None: cache_path.write_bytes(result.data) return result.json() + + +def get_geofabrik_index() -> Dict[str, Any]: + """Fetch the index of all regions served by Geofabrik.""" + with urllib3.request.urlopen("https://download.geofabrik.de/index-v1.json") as url: + return json.load(url) + + +def get_geofabrik_url(bbox: BoundaryBox) -> str: + """Find the smallest Geofabrik region file that includes a given area.""" + gf = get_geofabrik_index() + best_region = None + best_area = float("inf") + query_poly = shapely.box(*bbox.min_[::-1], *bbox.max_[::-1]) + for i, region in enumerate(gf["features"]): + coords = region["geometry"]["coordinates"] + # fix the polygon format + coords = [c if len(c[0]) < 2 else (c[0], c[1:]) for c in coords] + poly = shapely.MultiPolygon([shapely.Polygon(*c) for c in coords]) + if poly.contains(query_poly): + area = poly.area + if area < best_area: + best_area = area + best_region = region + return best_region["properties"]["urls"]["pbf"] + + +def convert_osm_file(bbox: BoundaryBox, input_path: Path, output_path: Path): + """Convert and crop a binary .pbf file to an .osm XML file.""" + bbox_str = ",".join(map(str, (*bbox.min_[::-1], *bbox.max_[::-1]))) + cmd = [ + "osmium", + "extract", + "-s", + "smart", + "-b", + bbox_str, + input_path.as_posix(), + "-o", + output_path.as_posix(), + "--overwrite", + ] + subprocess.run(cmd, check=True) From 91b31fcb0cb843d066568e68837d9a719f299f18 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 19 Aug 2024 12:48:01 +0200 Subject: [PATCH 04/12] Download OSM data for any new region --- maploc/data/mapillary/prepare.py | 157 ++++++++++++++++++++----------- 1 file changed, 104 insertions(+), 53 deletions(-) diff --git a/maploc/data/mapillary/prepare.py b/maploc/data/mapillary/prepare.py index bbaed39..21e5a21 100644 --- a/maploc/data/mapillary/prepare.py +++ b/maploc/data/mapillary/prepare.py @@ -5,8 +5,9 @@ import json import shutil from collections import defaultdict +from enum import Enum, auto from pathlib import Path -from typing import List +from typing import Dict, List, Optional, Sequence import cv2 import numpy as np @@ -21,6 +22,7 @@ from tqdm.contrib.concurrent import thread_map from ... import logger +from ...osm.download import convert_osm_file, get_geofabrik_url from ...osm.tiling import TileManager from ...osm.viz import GeoPlotter from ...utils.geo import BoundaryBox, Projection @@ -174,24 +176,18 @@ def process_sequence( def process_location( - location: str, - data_dir: Path, - split_path: Path, + output_dir: Path, + bbox: BoundaryBox, + splits: Dict[str, Sequence[int]], token: str, cfg: DictConfig, - generate_tiles: bool = False, ): - params = location_to_params[location] - bbox = params["bbox"] projection = Projection(*bbox.center) + image_ids = [i for split in splits.values() for i in split] - splits = json.loads(split_path.read_text()) - image_ids = [i for split in splits.values() for i in split[location]] - - loc_dir = data_dir / location - infos_dir = loc_dir / "image_infos" - raw_image_dir = loc_dir / "images_raw" - out_image_dir = loc_dir / "images" + infos_dir = output_dir / "image_infos" + raw_image_dir = output_dir / "images_raw" + out_image_dir = output_dir / "images" for d in (infos_dir, raw_image_dir, out_image_dir): d.mkdir(parents=True, exist_ok=True) @@ -228,8 +224,28 @@ def process_location( out_image_dir, ) ) - write_json(loc_dir / "dump.json", dump) + write_json(output_dir / DATA_FILENAME, dump) + + shutil.rmtree(raw_image_dir) + + +class OSMDataSource(Enum): + PRECOMPUTED = auto() + CACHED = auto() + LATEST = auto() + +def prepare_osm( + location: str, + output_dir: Path, + bbox: BoundaryBox, + cfg: DictConfig, + osm_dir: Path, + osm_source: OSMDataSource, + osm_filename: Optional[str] = None, +): + projection = Projection(*bbox.center) + dump = json.reads((output_dir / DATA_FILENAME).read_text()) # Get the view locations view_ids = [] views_latlon = [] @@ -241,18 +257,31 @@ def process_location( view_ids = np.array(view_ids) views_xy = projection.project(views_latlon) - tiles_path = loc_dir / MapillaryDataModule.default_cfg["tiles_filename"] - if generate_tiles: + tiles_path = output_dir / MapillaryDataModule.default_cfg["tiles_filename"] + if osm_source == OSMDataSource.PRECOMPUtED: + logger.info("Downloading pre-computed map tiles.") + download_file(DATA_URL + f"/tiles/{location}.pkl", tiles_path) + tile_manager = TileManager.load(tiles_path) + else: logger.info("Creating the map tiles.") bbox_data = BoundaryBox(views_xy.min(0), views_xy.max(0)) bbox_tiling = bbox_data + cfg.tiling.margin - osm_dir = data_dir / "osm" - osm_path = osm_dir / params["osm_file"] - if not osm_path.exists(): - logger.info("Downloading OSM raw data.") - download_file(DATA_URL + f"/osm/{params['osm_file']}", osm_path) - if not osm_path.exists(): - raise FileNotFoundError(f"Cannot find OSM data file {osm_path}.") + osm_filename = osm_filename or f"{location}.osm" + osm_path = osm_dir / osm_filename + if osm_source == OSMDataSource.CACHED: + if not osm_path.exists(): + logger.info("Downloading OSM raw data.") + download_file(DATA_URL + f"/osm/{osm_filename}", osm_path) + if not osm_path.exists(): + raise FileNotFoundError(f"Cannot find OSM data file {osm_path}.") + elif osm_source == OSMDataSource.LATEST: + bbox_osm = projection.unproject(bbox_data + 2_000) # 2 km + url = get_geofabrik_url(bbox_osm) + tmp_path = osm_dir / Path(url).name + download_file(url, tmp_path) + convert_osm_file(bbox_osm, tmp_path, osm_path) + else: + raise NotImplementedError("Unknown source {osm_source}.") tile_manager = TileManager.from_bbox( projection, bbox_tiling, @@ -261,27 +290,60 @@ def process_location( path=osm_path, ) tile_manager.save(tiles_path) - else: - logger.info("Downloading pre-generated map tiles.") - download_file(DATA_URL + f"/tiles/{location}.pkl", tiles_path) - tile_manager = TileManager.load(tiles_path) # Visualize the data split plotter = GeoPlotter() - view_ids_val = set(splits["val"][location]) - is_val = np.array([int(i.rsplit("_", 1)[0]) in view_ids_val for i in view_ids]) - plotter.points(views_latlon[~is_val], "red", view_ids[~is_val], "train") - plotter.points(views_latlon[is_val], "green", view_ids[is_val], "val") + plotter.points(views_latlon, "red", view_ids, "images") plotter.bbox(bbox, "blue", "query bounding box") plotter.bbox( projection.unproject(tile_manager.bbox), "black", "tiling bounding box" ) - geo_viz_path = loc_dir / f"split_{location}.html" + geo_viz_path = output_dir / f"viz_data_{location}.html" plotter.fig.write_html(geo_viz_path) - logger.info("Wrote split visualization to %s.", geo_viz_path) + logger.info("Wrote the visualization to %s.", geo_viz_path) - shutil.rmtree(raw_image_dir) - logger.info("Done processing for location %s.", location) + +def main(args: argparse.Namespace): + args.data_dir.mkdir(exist_ok=True, parents=True) + + split_path = args.data_dir / args.split_filename + maybe_git_split = Path(__file__).parent / args.split_filename + if maybe_git_split.exists(): + logger.info("Using official split file at %s.", maybe_git_split) + shutil.copy(maybe_git_split, args.data_dir) + + cfg = OmegaConf.merge(default_cfg, OmegaConf.from_cli(args.dotlist)) + for location in args.locations: + logger.info("Starting processing for location %s.", location) + if split_path.exists(): + splits = json.loads(split_path.read_text()) + splits = {split_name: val[location] for split_name, val in splits.items()} + else: + split_path_ = Path(str(split_path.format(scene=location))) + if not split_path_.exists(): + raise ValueError(f"Cannot find any split file at path {split_path}.") + logger.info("Using per-location split file at %s.", split_path_) + splits = json.loads(split_path_.read_text()) + + process_location( + args.data_dir / location, + location_to_params[location]["bbox"], + splits, + args.token, + cfg, + ) + + logger.info("Preparing OSM data.") + prepare_osm( + location, + args.data_dir / location, + location_to_params[location]["bbox"], + cfg, + args.data_dir / "osm", + OSMDataSource[args.osm_source], + location_to_params[location].get("osm_file"), + ) + logger.info("Done processing for location %s.", location) if __name__ == "__main__": @@ -294,21 +356,10 @@ def process_location( parser.add_argument( "--data_dir", type=Path, default=MapillaryDataModule.default_cfg["data_dir"] ) - parser.add_argument("--generate_tiles", action="store_true") + parser.add_argument( + "--osm_source", + default=OSMDataSource.PRECOMPUTED.name, + choices=[e.name for e in OSMDataSource], + ) parser.add_argument("dotlist", nargs="*") - args = parser.parse_args() - - args.data_dir.mkdir(exist_ok=True, parents=True) - shutil.copy(Path(__file__).parent / args.split_filename, args.data_dir) - cfg_ = OmegaConf.merge(default_cfg, OmegaConf.from_cli(args.dotlist)) - - for location in args.locations: - logger.info("Starting processing for location %s.", location) - process_location( - location, - args.data_dir, - args.data_dir / args.split_filename, - args.token, - cfg_, - args.generate_tiles, - ) + main(parser.parse_args()) From d0e55aaf7cc87bd15e8606522fe2de82553f830e Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 19 Aug 2024 12:48:14 +0200 Subject: [PATCH 05/12] Support per-location split file --- maploc/data/mapillary/dataset.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/maploc/data/mapillary/dataset.py b/maploc/data/mapillary/dataset.py index 8202026..dfdce42 100644 --- a/maploc/data/mapillary/dataset.py +++ b/maploc/data/mapillary/dataset.py @@ -253,18 +253,28 @@ def parse_splits(self, split_arg, names): "val": [n for n in names if n[0] in scenes_val], } elif isinstance(split_arg, str): - with (self.root / split_arg).open("r") as fp: - splits = json.load(fp) + if (self.root / split_arg).exists(): + # Common split file. + with (self.root / split_arg).open("r") as fp: + splits = json.load(fp) + else: + # Per-scene split file. + splits = defaultdict(dict) + for scene in self.cfg.scenes: + with (self.root / split_arg.format(scene=scene)).open("r") as fp: + scene_splits = json.load(fp) + for split_name in scene_splits: + splits[split_name][scene] = scene_splits[split_name] splits = { - k: {loc: set(ids) for loc, ids in split.items()} - for k, split in splits.items() + split_name: {scene: set(ids) for scene, ids in split.items()} + for split_name, split in splits.items() } self.splits = {} - for k, split in splits.items(): - self.splits[k] = [ - n - for n in names - if n[0] in split and int(n[-1].rsplit("_", 1)[0]) in split[n[0]] + for split_name, split in splits.items(): + self.splits[split_name] = [ + (scene, *arg, name) + for scene, *arg, name in names + if scene in split and int(name.rsplit("_", 1)[0]) in split[scene] ] else: raise ValueError(split_arg) From 6721949d4083f58a05aa9575b21af82b1a41ddfc Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 19 Aug 2024 12:52:50 +0200 Subject: [PATCH 06/12] Minor improvements --- maploc/data/mapillary/config.py | 6 ++++++ maploc/data/mapillary/split.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/maploc/data/mapillary/config.py b/maploc/data/mapillary/config.py index 2bcba8f..df46e24 100644 --- a/maploc/data/mapillary/config.py +++ b/maploc/data/mapillary/config.py @@ -126,6 +126,12 @@ "filters": {"is_pano": True}, "osm_file": "paris.osm", }, + # Add any new region/city here: + # "location_name": { + # "bbox": BoundaryBox((lat_min, long_min), (lat_max, long_max)), + # "filters": {"is_pano": True}, # or other filers + # } + # Other fields (bbox_val, osm_file) will be deduced automatically. } default_cfg = OmegaConf.create( diff --git a/maploc/data/mapillary/split.py b/maploc/data/mapillary/split.py index d6bc575..281a6c9 100644 --- a/maploc/data/mapillary/split.py +++ b/maploc/data/mapillary/split.py @@ -145,7 +145,7 @@ def main(args: argparse.Namespace): args.token, cfg, params["bbox"], - params["filters"] | {"start_captured_at": args.min_capture_date}, + {"start_captured_at": args.min_capture_date} | params["filters"], None if args.force_auto_val_bbox else params.get("bbox_val"), ) logger.info("Done processing for location %s.", location) From 351bbf92ed49894fb936efba8c8cb8e78cd80ab8 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com> Date: Mon, 19 Aug 2024 17:13:17 +0200 Subject: [PATCH 07/12] Update README.md --- README.md | 74 +++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 56 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index a4fb7c6..09499ae 100644 --- a/README.md +++ b/README.md @@ -69,23 +69,72 @@ Try our minimal demo - take a picture with your phone in any city and find its e OrienterNet positions any image within a large area - try it with your own images!

-## Evaluation +## Mapillary Geo-Localization (MGL) dataset + +To train and evaluate OrienterNet, we introduce a large crowd-sourced dataset of images captured across multiple cities through the [Mapillary platform](https://www.mapillary.com/app/). To obtain the dataset: -#### Mapillary Geo-Localization dataset +1. Create a developper account at [mapillary.com](https://www.mapillary.com/dashboard/developers) and obtain a free access token. +2. Run the following script to download the data from Mapillary and prepare it: + +```bash +python -m maploc.data.mapillary.prepare --token $YOUR_ACCESS_TOKEN +``` + +By default the data is written to the directory `./datasets/MGL/` and requires about 80 GB of free disk space. + +#### Using different OpenStreetMap data
[Click to expand] -To obtain the dataset: +Multiple sources of OpenStreetMap (OSM) data can be selected for the dataset scripts `maploc.data.[mapillary,kitti].prepare` using the `--osm_source` option, which can take the following values: +- `PRECOMPUTED` (default): download pre-computed raster tiles that are hosted [here](https://cvg-data.inf.ethz.ch/OrienterNet_CVPR2023/tiles/). +- `CACHED`: compute the raster tiles from raw OSM data downloaded from [Geofabrik](https://download.geofabrik.de/) in November 2021 and hosted [here](https://cvg-data.inf.ethz.ch/OrienterNet_CVPR2023/osm/). This is useful if you wish to use different OSM classes but want to compare the results to the pre-computed tiles. +- `LATEST`: fetch the latest OSM data from [Geofabrik](https://download.geofabrik.de/). This requires that the [Osmium tool](https://osmcode.org/osmium-tool/) is available in your system, which can be downloaded via `apt-get install osmium-tool` on Ubuntu and `brew install osmium-tool` on macOS. -1. Create a developper account at [mapillary.com](https://www.mapillary.com/dashboard/developers) and obtain a free access token. -2. Run the following script to download the data from Mapillary and prepare it: +
+ +#### Extending the dataset + +
+[Click to expand] + +By default, the dataset script fetches data that was queried early 2022 from 12 cities. The dataset can be extended by including additional cities or querying images recently uploaded to Mapillary. To proceed, follow these steps: +1. For each new location, add an entry to `maploc.data.mapillary.config.location_to_params` following the format: +```python + "location_name": { + "bbox": BoundaryBox((lat_min, long_min), (lat_max, long_max)), + "filters": {"is_pano": True}, + # or other filters like creator_username, model, etc. + # all described at https://www.mapillary.com/developer/api-documentation#image + } +``` +The bounding box can easily be selected using [this tool](https://boundingbox.klokantech.com/). We recommend searching for cities with a high density of 360 panoramic images on the [Mapillary platform](https://www.mapillary.com/app/). +2. Query the corresponding images and split them into training and valiation subsets with: ```bash -python -m maploc.data.mapillary.prepare --token $YOUR_ACCESS_TOKEN +python -m maploc.data.mapillary.split --token $YOUR_ACCESS_TOKEN --output_filename splits_MGL_v2_{scene}.json --data_dir datasets/MGL_v2 +``` +3. Fetch and prepare the resulting data: +```bash +python -m maploc.data.mapillary.prepare --token $YOUR_ACCESS_TOKEN --split_filename splits_MGL_v2_{scene}.json --data_dir datasets/MGL_v2 +``` +4. To train or evaluate with this new version of the dataset, add the following CLI flags: +```bash +python -m maploc.[train,evaluation...] [...] data.data_dir=datasets/MGL_v2 data.split=splits_MGL_v2_{scene}.json ``` -By default the data is written to the directory `./datasets/MGL/`. Then run the evaluation with the pre-trained model: +
+ + +## Evaluation + +#### MGL dataset + +
+[Click to expand] + +Download the dataset [as described previously](#mapillary-geo-localization-mgl-dataset) and run the evaluation with the pre-trained model: ```bash python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL model.num_rotations=256 @@ -218,17 +267,6 @@ We provide several visualization notebooks: - [Visualize predictions on the KITTI dataset](./notebooks/visualize_predictions_kitti.ipynb) - [Visualize sequential predictions](./notebooks/visualize_predictions_sequences.ipynb) -## OpenStreetMap data - -
-[Click to expand] - -To make sure that the results are consistent over time, we used OSM data downloaded from [Geofabrik](https://download.geofabrik.de/) in November 2021. By default, the dataset scripts `maploc.data.[mapillary,kitti].prepare` download pre-generated raster tiles. If you wish to use different OSM classes, you can pass `--generate_tiles`, which will download and use our prepared raw `.osm` XML files. - -You may alternatively download more recent files from [Geofabrik](https://download.geofabrik.de/). Download either compressed XML files as `.osm.bz2` or binary files `.osm.pbf`, which need to be converted to XML files `.osm`, for example using Osmium: ` osmium cat xx.osm.pbf -o xx.osm`. - -
- ## License The MGL dataset is made available under the [CC-BY-SA](https://creativecommons.org/licenses/by-sa/4.0/) license following the data available on the Mapillary platform. The model implementation and the pre-trained weights follow a [CC-BY-NC](https://creativecommons.org/licenses/by-nc/2.0/) license. [OpenStreetMap data](https://www.openstreetmap.org/copyright) is licensed under the [Open Data Commons Open Database License](https://opendatacommons.org/licenses/odbl/). From 7faa8355283c7100bac1d11d4e23bcac89fe9989 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 19 Aug 2024 17:26:38 +0200 Subject: [PATCH 08/12] Bug fixes --- maploc/data/mapillary/prepare.py | 6 +++--- maploc/osm/download.py | 9 +++++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/maploc/data/mapillary/prepare.py b/maploc/data/mapillary/prepare.py index 21e5a21..e63b75c 100644 --- a/maploc/data/mapillary/prepare.py +++ b/maploc/data/mapillary/prepare.py @@ -196,7 +196,7 @@ def process_location( logger.info("Fetching metadata for all images.") image_infos, num_fail = loop.run_until_complete( - fetch_image_infos(image_ids, downloader, infos_dir) + fetch_image_infos(image_ids, downloader, dir_=infos_dir) ) logger.info("%d failures (%.1f%%).", num_fail, 100 * num_fail / len(image_ids)) @@ -245,7 +245,7 @@ def prepare_osm( osm_filename: Optional[str] = None, ): projection = Projection(*bbox.center) - dump = json.reads((output_dir / DATA_FILENAME).read_text()) + dump = json.loads((output_dir / DATA_FILENAME).read_text()) # Get the view locations view_ids = [] views_latlon = [] @@ -319,7 +319,7 @@ def main(args: argparse.Namespace): splits = json.loads(split_path.read_text()) splits = {split_name: val[location] for split_name, val in splits.items()} else: - split_path_ = Path(str(split_path.format(scene=location))) + split_path_ = Path(str(split_path).format(scene=location)) if not split_path_.exists(): raise ValueError(f"Cannot find any split file at path {split_path}.") logger.info("Using per-location split file at %s.", split_path_) diff --git a/maploc/osm/download.py b/maploc/osm/download.py index 3f7e0cc..47f053a 100644 --- a/maploc/osm/download.py +++ b/maploc/osm/download.py @@ -40,8 +40,13 @@ def get_osm( def get_geofabrik_index() -> Dict[str, Any]: """Fetch the index of all regions served by Geofabrik.""" - with urllib3.request.urlopen("https://download.geofabrik.de/index-v1.json") as url: - return json.load(url) + result = urllib3.request( + "GET", "https://download.geofabrik.de/index-v1.json", timeout=10 + ) + if result.status != 200: + error = result.info()["error"] + raise ValueError(f"{result.status} {responses[result.status]}: {error}") + return json.loads(result.data) def get_geofabrik_url(bbox: BoundaryBox) -> str: From f7d2b8142e64431e8ebf70944fa589b4a0789e02 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 19 Aug 2024 17:28:24 +0200 Subject: [PATCH 09/12] Extend multiple OSM sources to KITTI --- maploc/data/kitti/prepare.py | 61 ++++++++++++++---------------- maploc/data/mapillary/prepare.py | 57 ++++++++-------------------- maploc/osm/prepare.py | 64 ++++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 76 deletions(-) create mode 100644 maploc/osm/prepare.py diff --git a/maploc/data/kitti/prepare.py b/maploc/data/kitti/prepare.py index 8adef2b..278a761 100644 --- a/maploc/data/kitti/prepare.py +++ b/maploc/data/kitti/prepare.py @@ -9,10 +9,10 @@ from tqdm.auto import tqdm from ... import logger -from ...osm.tiling import TileManager +from ...osm.prepare import OSMDataSource, download_and_prepare_osm from ...osm.viz import GeoPlotter from ...utils.geo import BoundaryBox, Projection -from ...utils.io import DATA_URL, download_file +from ...utils.io import download_file from .dataset import KittiDataModule from .utils import parse_gps_file @@ -20,11 +20,10 @@ def prepare_osm( - data_dir, - osm_path, - output_path, - tile_margin=512, + data_dir: Path, + osm_source: OSMDataSource, ppm=2, + tile_margin=512, ): all_latlon = [] for gps_path in data_dir.glob("2011_*/*/oxts/data/*.txt"): @@ -34,21 +33,26 @@ def prepare_osm( all_latlon = np.stack(all_latlon) projection = Projection.from_points(all_latlon) all_xy = projection.project(all_latlon) - bbox_map = BoundaryBox(all_xy.min(0), all_xy.max(0)) + tile_margin + bbox_tiling = BoundaryBox(all_xy.min(0), all_xy.max(0)) + tile_margin + + tiles_path = data_dir / KittiDataModule.default_cfg["tiles_filename"] + osm_path = data_dir / "karlsruhe.osm" + tile_manager = download_and_prepare_osm( + osm_source, + "kitti", + tiles_path, + bbox_tiling, + projection, + osm_path, + ppm=ppm, + ) plotter = GeoPlotter() plotter.points(all_latlon, "red", name="GPS") - plotter.bbox(projection.unproject(bbox_map), "blue", "tiling bounding box") - plotter.fig.write_html(data_dir / "split_kitti.html") - - tile_manager = TileManager.from_bbox( - projection, - bbox_map, - ppm, - path=osm_path, + plotter.bbox( + projection.unproject(tile_manager.bbox), "black", "tiling bounding box" ) - tile_manager.save(output_path) - return tile_manager + plotter.fig.write_html(data_dir / "viz_kitti.html") def download(data_dir: Path): @@ -99,24 +103,13 @@ def download(data_dir: Path): "--data_dir", type=Path, default=Path(KittiDataModule.default_cfg["data_dir"]) ) parser.add_argument("--pixel_per_meter", type=int, default=2) - parser.add_argument("--generate_tiles", action="store_true") + parser.add_argument( + "--osm_source", + default=OSMDataSource.PRECOMPUTED.name, + choices=[e.name for e in OSMDataSource], + ) args = parser.parse_args() args.data_dir.mkdir(exist_ok=True, parents=True) download(args.data_dir) - - tiles_path = args.data_dir / KittiDataModule.default_cfg["tiles_filename"] - if args.generate_tiles: - logger.info("Generating the map tiles.") - osm_filename = "karlsruhe.osm" - osm_path = args.data_dir / osm_filename - if not osm_path.exists(): - logger.info("Downloading OSM raw data.") - download_file(DATA_URL + f"/osm/{osm_filename}", osm_path) - if not osm_path.exists(): - raise FileNotFoundError(f"No OSM data file at {osm_path}.") - prepare_osm(args.data_dir, osm_path, tiles_path, ppm=args.pixel_per_meter) - (args.data_dir / ".downloaded").touch() - else: - logger.info("Downloading pre-generated map tiles.") - download_file(DATA_URL + "/tiles/kitti.pkl", tiles_path) + prepare_osm(args.data_dir, OSMDataSource[args.osm_source], ppm=args.pixel_per_meter) diff --git a/maploc/data/mapillary/prepare.py b/maploc/data/mapillary/prepare.py index e63b75c..6c11178 100644 --- a/maploc/data/mapillary/prepare.py +++ b/maploc/data/mapillary/prepare.py @@ -5,7 +5,6 @@ import json import shutil from collections import defaultdict -from enum import Enum, auto from pathlib import Path from typing import Dict, List, Optional, Sequence @@ -22,11 +21,10 @@ from tqdm.contrib.concurrent import thread_map from ... import logger -from ...osm.download import convert_osm_file, get_geofabrik_url -from ...osm.tiling import TileManager +from ...osm.prepare import OSMDataSource, download_and_prepare_osm from ...osm.viz import GeoPlotter from ...utils.geo import BoundaryBox, Projection -from ...utils.io import DATA_URL, download_file, write_json +from ...utils.io import write_json from ..utils import decompose_rotmat from .config import default_cfg, location_to_params from .dataset import MapillaryDataModule @@ -229,12 +227,6 @@ def process_location( shutil.rmtree(raw_image_dir) -class OSMDataSource(Enum): - PRECOMPUTED = auto() - CACHED = auto() - LATEST = auto() - - def prepare_osm( location: str, output_dir: Path, @@ -258,38 +250,19 @@ def prepare_osm( views_xy = projection.project(views_latlon) tiles_path = output_dir / MapillaryDataModule.default_cfg["tiles_filename"] - if osm_source == OSMDataSource.PRECOMPUtED: - logger.info("Downloading pre-computed map tiles.") - download_file(DATA_URL + f"/tiles/{location}.pkl", tiles_path) - tile_manager = TileManager.load(tiles_path) - else: - logger.info("Creating the map tiles.") - bbox_data = BoundaryBox(views_xy.min(0), views_xy.max(0)) - bbox_tiling = bbox_data + cfg.tiling.margin - osm_filename = osm_filename or f"{location}.osm" - osm_path = osm_dir / osm_filename - if osm_source == OSMDataSource.CACHED: - if not osm_path.exists(): - logger.info("Downloading OSM raw data.") - download_file(DATA_URL + f"/osm/{osm_filename}", osm_path) - if not osm_path.exists(): - raise FileNotFoundError(f"Cannot find OSM data file {osm_path}.") - elif osm_source == OSMDataSource.LATEST: - bbox_osm = projection.unproject(bbox_data + 2_000) # 2 km - url = get_geofabrik_url(bbox_osm) - tmp_path = osm_dir / Path(url).name - download_file(url, tmp_path) - convert_osm_file(bbox_osm, tmp_path, osm_path) - else: - raise NotImplementedError("Unknown source {osm_source}.") - tile_manager = TileManager.from_bbox( - projection, - bbox_tiling, - cfg.tiling.ppm, - tile_size=cfg.tiling.tile_size, - path=osm_path, - ) - tile_manager.save(tiles_path) + bbox_data = BoundaryBox(views_xy.min(0), views_xy.max(0)) + bbox_tiling = bbox_data + cfg.tiling.margin + osm_path = osm_dir / (osm_filename or f"{location}.osm") + tile_manager = download_and_prepare_osm( + osm_source, + location, + tiles_path, + bbox_tiling, + projection, + osm_path, + ppm=cfg.tiling.ppm, + tile_size=cfg.tiling.tile_size, + ) # Visualize the data split plotter = GeoPlotter() diff --git a/maploc/osm/prepare.py b/maploc/osm/prepare.py new file mode 100644 index 0000000..5f488f4 --- /dev/null +++ b/maploc/osm/prepare.py @@ -0,0 +1,64 @@ +import logging +from enum import Enum, auto +from pathlib import Path + +from ..utils.geo import BoundaryBox, Projection +from ..utils.io import DATA_URL, download_file +from .download import convert_osm_file, get_geofabrik_url +from .tiling import TileManager + +logger = logging.getLogger(__name__) + + +class OSMDataSource(Enum): + # Pre-computed map tiles. + PRECOMPUTED = auto() + + # Re-compute the map tiles from cached OSM data. + CACHED = auto() + + # Fetch the latest OSM data and re-compute the map tiles from them. + LATEST = auto() + + +def download_and_prepare_osm( + source: OSMDataSource, + tiles_name: str, + tiles_path: Path, + bbox: BoundaryBox, + projection: Projection, + osm_path: Path, + **kwargs, +) -> TileManager: + if source == OSMDataSource.PRECOMPUTED: + logger.info("Downloading pre-computed map tiles.") + download_file(DATA_URL + f"/tiles/{tiles_name}.pkl", tiles_path) + tile_manager = TileManager.load(tiles_path) + assert tile_manager.ppm == kwargs["ppm"] + assert tile_manager.bbox.contains(bbox) + else: + logger.info("Creating the map tiles.") + if source == OSMDataSource.CACHED: + if not osm_path.exists(): + logger.info("Downloading cached OSM data.") + download_file(DATA_URL + f"/osm/{osm_path.name}", osm_path) + if not osm_path.exists(): + raise FileNotFoundError(f"Cannot find OSM data file {osm_path}.") + elif source == OSMDataSource.LATEST: + logger.info("Downloading the latest OSM data.") + bbox_osm = projection.unproject(bbox + 2_000) # 2 km + url = get_geofabrik_url(bbox_osm) + tmp_path = osm_path.parent / Path(url).name + download_file(url, tmp_path) + convert_osm_file(bbox_osm, tmp_path, osm_path) + tmp_path.unlink() + else: + raise NotImplementedError("Unknown source: {osm_source}.") + tile_manager = TileManager.from_bbox( + projection, + bbox, + path=osm_path, + **kwargs, + ) + tile_manager.save(tiles_path) + return tile_manager From 77c1eb4e07ebdd25d756b8f6e4df0ff29aa1c21e Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Mon, 19 Aug 2024 17:48:07 +0200 Subject: [PATCH 10/12] Minor fixes --- maploc/data/mapillary/config.py | 4 +++- maploc/data/mapillary/prepare.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/maploc/data/mapillary/config.py b/maploc/data/mapillary/config.py index df46e24..6053dfa 100644 --- a/maploc/data/mapillary/config.py +++ b/maploc/data/mapillary/config.py @@ -129,7 +129,9 @@ # Add any new region/city here: # "location_name": { # "bbox": BoundaryBox((lat_min, long_min), (lat_max, long_max)), - # "filters": {"is_pano": True}, # or other filers + # "filters": {"is_pano": True}, + # # or other filters like creator_username, model, etc. + # # all described at https://www.mapillary.com/developer/api-documentation#image # } # Other fields (bbox_val, osm_file) will be deduced automatically. } diff --git a/maploc/data/mapillary/prepare.py b/maploc/data/mapillary/prepare.py index 6c11178..1531381 100644 --- a/maploc/data/mapillary/prepare.py +++ b/maploc/data/mapillary/prepare.py @@ -48,7 +48,7 @@ def get_pano_offset(image_info: dict, do_legacy: bool = False) -> float: - if do_legacy: + if do_legacy and "sfm_cluster" in image_info: seed = int(image_info["sfm_cluster"]["id"]) else: seed = image_info["sequence"].__hash__() @@ -107,7 +107,7 @@ def pack_shot_dict(shot: Shot, info: dict) -> dict: capture_time=info["captured_at"], gps_position=np.r_[latlong_gps, info["altitude"]], compass_angle=info["compass_angle"], - chunk_id=int(info["sfm_cluster"]["id"]), + chunk_id=int(info["sfm_cluster"]["id"]) if "sfm_cluster" in info else -1, ) From 9c02bfa8c006e5c32ea96877b047a6017611bd2a Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com> Date: Mon, 19 Aug 2024 18:03:14 +0200 Subject: [PATCH 11/12] Update README.md --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 09499ae..b01b40f 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ Multiple sources of OpenStreetMap (OSM) data can be selected for the dataset scr
[Click to expand] -By default, the dataset script fetches data that was queried early 2022 from 12 cities. The dataset can be extended by including additional cities or querying images recently uploaded to Mapillary. To proceed, follow these steps: +By default, the dataset script fetches data that was queried early 2022 from 13 locations. The dataset can be extended by including additional cities or querying images recently uploaded to Mapillary. To proceed, follow these steps: 1. For each new location, add an entry to `maploc.data.mapillary.config.location_to_params` following the format: ```python "location_name": { @@ -115,6 +115,8 @@ The bounding box can easily be selected using [this tool](https://boundingbox.kl ```bash python -m maploc.data.mapillary.split --token $YOUR_ACCESS_TOKEN --output_filename splits_MGL_v2_{scene}.json --data_dir datasets/MGL_v2 ``` +Note that, for the 13 default locations, running this script will produce results slightly different from the default split file `splits_MGL_13loc.json` since new images have been uploaded since 2022 and some others have been taken down. + 3. Fetch and prepare the resulting data: ```bash python -m maploc.data.mapillary.prepare --token $YOUR_ACCESS_TOKEN --split_filename splits_MGL_v2_{scene}.json --data_dir datasets/MGL_v2 From 60cfc42341ee4cd8843de846b2c20e15b445fd45 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Tue, 20 Aug 2024 22:53:05 +0200 Subject: [PATCH 12/12] Filter imgaes that intersect with buildings --- maploc/data/mapillary/config.py | 2 +- maploc/data/mapillary/prepare.py | 7 +++- maploc/data/mapillary/split.py | 65 ++++++++++++++++++++++++++++---- maploc/osm/prepare.py | 5 ++- 4 files changed, 67 insertions(+), 12 deletions(-) diff --git a/maploc/data/mapillary/config.py b/maploc/data/mapillary/config.py index 6053dfa..1c84f5a 100644 --- a/maploc/data/mapillary/config.py +++ b/maploc/data/mapillary/config.py @@ -138,7 +138,7 @@ default_cfg = OmegaConf.create( { - "downsampling_resolution_meters": 2, + "downsampling_resolution_meters": 3, "target_num_val_images": 100, "val_train_margin_meters": 25, "max_num_train_images": 50_000, diff --git a/maploc/data/mapillary/prepare.py b/maploc/data/mapillary/prepare.py index 1531381..c9b2891 100644 --- a/maploc/data/mapillary/prepare.py +++ b/maploc/data/mapillary/prepare.py @@ -236,6 +236,11 @@ def prepare_osm( osm_source: OSMDataSource, osm_filename: Optional[str] = None, ): + tiles_path = output_dir / MapillaryDataModule.default_cfg["tiles_filename"] + if tiles_path.exists(): + return + + logger.info("Preparing OSM data.") projection = Projection(*bbox.center) dump = json.loads((output_dir / DATA_FILENAME).read_text()) # Get the view locations @@ -249,7 +254,6 @@ def prepare_osm( view_ids = np.array(view_ids) views_xy = projection.project(views_latlon) - tiles_path = output_dir / MapillaryDataModule.default_cfg["tiles_filename"] bbox_data = BoundaryBox(views_xy.min(0), views_xy.max(0)) bbox_tiling = bbox_data + cfg.tiling.margin osm_path = osm_dir / (osm_filename or f"{location}.osm") @@ -306,7 +310,6 @@ def main(args: argparse.Namespace): cfg, ) - logger.info("Preparing OSM data.") prepare_osm( location, args.data_dir / location, diff --git a/maploc/data/mapillary/split.py b/maploc/data/mapillary/split.py index 281a6c9..718a942 100644 --- a/maploc/data/mapillary/split.py +++ b/maploc/data/mapillary/split.py @@ -7,6 +7,8 @@ from omegaconf import DictConfig, OmegaConf from ... import logger +from ...osm.parser import Groups +from ...osm.prepare import OSMDataSource, download_and_prepare_osm from ...osm.viz import GeoPlotter from ...utils.geo import BoundaryBox, Projection from ...utils.io import write_json @@ -44,19 +46,40 @@ def find_validation_bbox( def process_location( - output_path: Path, + split_path: Path, + output_dir: Path, + osm_dir: Path, + osm_source: OSMDataSource, token: str, cfg: DictConfig, query_bbox: BoundaryBox, filters: Dict[str, Any], bbox_val: Optional[BoundaryBox] = None, + osm_filename: Optional[str] = None, ): - output_path.parent.mkdir(parents=True, exist_ok=True) - downloader = MapillaryDownloader(token) - loop = asyncio.get_event_loop() + split_path.parent.mkdir(parents=True, exist_ok=True) + output_dir.mkdir(parents=True, exist_ok=True) projection = Projection(*query_bbox.center) bbox_local = projection.project(query_bbox) + logger.info("Fetching OpenStreetMap data.") + tiles_path = output_dir / MapillaryDataModule.default_cfg["tiles_filename"] + bbox_tiling = query_bbox + cfg.tiling.margin + osm_path = osm_dir / (osm_filename or f"{output_dir.name}.osm") + tile_manager = download_and_prepare_osm( + osm_source, + output_dir.name, + tiles_path, + bbox_tiling, + projection, + osm_path, + ppm=cfg.tiling.ppm, + tile_size=cfg.tiling.tile_size, + ) + + downloader = MapillaryDownloader(token) + loop = asyncio.get_event_loop() + logger.info("Fetching the list of image with filter: %s", filters) image_ids, bboxes = loop.run_until_complete( fetch_image_list(query_bbox, downloader, **filters) @@ -85,6 +108,25 @@ def process_location( latlon = latlon[valid] xy = xy[valid] + # discard sequences that intersect with buildings + canvas = tile_manager.query(bbox_local) + building_mask = canvas.raster[0] == (Groups.areas.index("building") + 1) + is_building = building_mask[tuple(np.round(canvas.to_uv(xy)).astype(int).T[::-1])] + logger.info( + "%d images intersect with buildings (%.1f%%).", + is_building.sum(), + 100 * is_building.mean(), + ) + # plotter = GeoPlotter() + # plotter.points(latlon[is_building], "red", name="building") + # plotter.points(latlon[~is_building], "green", name="no building") + # plotter.bbox(query_bbox, "blue", "query bounding box") + # plotter.fig.write_html(f"{split_path}_building_viz.html") + # TODO: filter sequences that intersect with buildings by at least 30%. + image_ids = image_ids[~is_building] + latlon = latlon[~is_building] + xy = xy[~is_building] + # downsample the images with a grid indices = grid_downsample(xy, bbox_local, cfg.downsampling_resolution_meters) image_ids = image_ids[indices] @@ -116,7 +158,7 @@ def process_location( "val": image_ids[indices_val].tolist(), "train": image_ids[indices_train].tolist(), } - write_json(output_path, splits) + write_json(split_path, splits) # Visualize the data split plotter = GeoPlotter() @@ -125,7 +167,7 @@ def process_location( plotter.bbox(query_bbox, "blue", "query bounding box") plotter.bbox(bbox_val, "green", "validation bounding box") plotter.bbox(projection.unproject(bbox_not_train), "black", "margin bounding box") - geo_viz_path = f"{output_path}_viz.html" + geo_viz_path = f"{split_path}_viz.html" plotter.fig.write_html(geo_viz_path) logger.info("Wrote split visualization to %s.", geo_viz_path) @@ -142,11 +184,15 @@ def main(args: argparse.Namespace): params = location_to_params[location] process_location( output_path, + args.data_dir / location, + args.data_dir / "osm", + OSMDataSource[args.osm_source], args.token, cfg, params["bbox"], {"start_captured_at": args.min_capture_date} | params["filters"], - None if args.force_auto_val_bbox else params.get("bbox_val"), + bbox_val=None if args.force_auto_val_bbox else params.get("bbox_val"), + osm_filename=location_to_params[location].get("osm_file"), ) logger.info("Done processing for location %s.", location) @@ -164,6 +210,11 @@ def main(args: argparse.Namespace): parser.add_argument( "--data_dir", type=Path, default=MapillaryDataModule.default_cfg["data_dir"] ) + parser.add_argument( + "--osm_source", + default=OSMDataSource.PRECOMPUTED.name, + choices=[e.name for e in OSMDataSource], + ) parser.add_argument("--overwrite", action="store_true") parser.add_argument("--force_auto_val_bbox", action="store_true") parser.add_argument("dotlist", nargs="*") diff --git a/maploc/osm/prepare.py b/maploc/osm/prepare.py index 5f488f4..c6a3c08 100644 --- a/maploc/osm/prepare.py +++ b/maploc/osm/prepare.py @@ -31,8 +31,9 @@ def download_and_prepare_osm( **kwargs, ) -> TileManager: if source == OSMDataSource.PRECOMPUTED: - logger.info("Downloading pre-computed map tiles.") - download_file(DATA_URL + f"/tiles/{tiles_name}.pkl", tiles_path) + if not tiles_path.exists(): + logger.info("Downloading pre-computed map tiles.") + download_file(DATA_URL + f"/tiles/{tiles_name}.pkl", tiles_path) tile_manager = TileManager.load(tiles_path) assert tile_manager.ppm == kwargs["ppm"] assert tile_manager.bbox.contains(bbox)