diff --git a/README.md b/README.md
index a4fb7c6..b01b40f 100644
--- a/README.md
+++ b/README.md
@@ -69,23 +69,74 @@ Try our minimal demo - take a picture with your phone in any city and find its e
OrienterNet positions any image within a large area - try it with your own images!
-## Evaluation
+## Mapillary Geo-Localization (MGL) dataset
+
+To train and evaluate OrienterNet, we introduce a large crowd-sourced dataset of images captured across multiple cities through the [Mapillary platform](https://www.mapillary.com/app/). To obtain the dataset:
-#### Mapillary Geo-Localization dataset
+1. Create a developper account at [mapillary.com](https://www.mapillary.com/dashboard/developers) and obtain a free access token.
+2. Run the following script to download the data from Mapillary and prepare it:
+
+```bash
+python -m maploc.data.mapillary.prepare --token $YOUR_ACCESS_TOKEN
+```
+
+By default the data is written to the directory `./datasets/MGL/` and requires about 80 GB of free disk space.
+
+#### Using different OpenStreetMap data
[Click to expand]
-To obtain the dataset:
+Multiple sources of OpenStreetMap (OSM) data can be selected for the dataset scripts `maploc.data.[mapillary,kitti].prepare` using the `--osm_source` option, which can take the following values:
+- `PRECOMPUTED` (default): download pre-computed raster tiles that are hosted [here](https://cvg-data.inf.ethz.ch/OrienterNet_CVPR2023/tiles/).
+- `CACHED`: compute the raster tiles from raw OSM data downloaded from [Geofabrik](https://download.geofabrik.de/) in November 2021 and hosted [here](https://cvg-data.inf.ethz.ch/OrienterNet_CVPR2023/osm/). This is useful if you wish to use different OSM classes but want to compare the results to the pre-computed tiles.
+- `LATEST`: fetch the latest OSM data from [Geofabrik](https://download.geofabrik.de/). This requires that the [Osmium tool](https://osmcode.org/osmium-tool/) is available in your system, which can be downloaded via `apt-get install osmium-tool` on Ubuntu and `brew install osmium-tool` on macOS.
-1. Create a developper account at [mapillary.com](https://www.mapillary.com/dashboard/developers) and obtain a free access token.
-2. Run the following script to download the data from Mapillary and prepare it:
+
+
+#### Extending the dataset
+
+
+[Click to expand]
+By default, the dataset script fetches data that was queried early 2022 from 13 locations. The dataset can be extended by including additional cities or querying images recently uploaded to Mapillary. To proceed, follow these steps:
+1. For each new location, add an entry to `maploc.data.mapillary.config.location_to_params` following the format:
+```python
+ "location_name": {
+ "bbox": BoundaryBox((lat_min, long_min), (lat_max, long_max)),
+ "filters": {"is_pano": True},
+ # or other filters like creator_username, model, etc.
+ # all described at https://www.mapillary.com/developer/api-documentation#image
+ }
+```
+The bounding box can easily be selected using [this tool](https://boundingbox.klokantech.com/). We recommend searching for cities with a high density of 360 panoramic images on the [Mapillary platform](https://www.mapillary.com/app/).
+
+2. Query the corresponding images and split them into training and valiation subsets with:
```bash
-python -m maploc.data.mapillary.prepare --token $YOUR_ACCESS_TOKEN
+python -m maploc.data.mapillary.split --token $YOUR_ACCESS_TOKEN --output_filename splits_MGL_v2_{scene}.json --data_dir datasets/MGL_v2
```
+Note that, for the 13 default locations, running this script will produce results slightly different from the default split file `splits_MGL_13loc.json` since new images have been uploaded since 2022 and some others have been taken down.
+
+3. Fetch and prepare the resulting data:
+```bash
+python -m maploc.data.mapillary.prepare --token $YOUR_ACCESS_TOKEN --split_filename splits_MGL_v2_{scene}.json --data_dir datasets/MGL_v2
+```
+4. To train or evaluate with this new version of the dataset, add the following CLI flags:
+```bash
+python -m maploc.[train,evaluation...] [...] data.data_dir=datasets/MGL_v2 data.split=splits_MGL_v2_{scene}.json
+```
+
+
-By default the data is written to the directory `./datasets/MGL/`. Then run the evaluation with the pre-trained model:
+
+## Evaluation
+
+#### MGL dataset
+
+
+[Click to expand]
+
+Download the dataset [as described previously](#mapillary-geo-localization-mgl-dataset) and run the evaluation with the pre-trained model:
```bash
python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL model.num_rotations=256
@@ -218,17 +269,6 @@ We provide several visualization notebooks:
- [Visualize predictions on the KITTI dataset](./notebooks/visualize_predictions_kitti.ipynb)
- [Visualize sequential predictions](./notebooks/visualize_predictions_sequences.ipynb)
-## OpenStreetMap data
-
-
-[Click to expand]
-
-To make sure that the results are consistent over time, we used OSM data downloaded from [Geofabrik](https://download.geofabrik.de/) in November 2021. By default, the dataset scripts `maploc.data.[mapillary,kitti].prepare` download pre-generated raster tiles. If you wish to use different OSM classes, you can pass `--generate_tiles`, which will download and use our prepared raw `.osm` XML files.
-
-You may alternatively download more recent files from [Geofabrik](https://download.geofabrik.de/). Download either compressed XML files as `.osm.bz2` or binary files `.osm.pbf`, which need to be converted to XML files `.osm`, for example using Osmium: ` osmium cat xx.osm.pbf -o xx.osm`.
-
-
-
## License
The MGL dataset is made available under the [CC-BY-SA](https://creativecommons.org/licenses/by-sa/4.0/) license following the data available on the Mapillary platform. The model implementation and the pre-trained weights follow a [CC-BY-NC](https://creativecommons.org/licenses/by-nc/2.0/) license. [OpenStreetMap data](https://www.openstreetmap.org/copyright) is licensed under the [Open Data Commons Open Database License](https://opendatacommons.org/licenses/odbl/).
diff --git a/maploc/data/kitti/prepare.py b/maploc/data/kitti/prepare.py
index 8adef2b..278a761 100644
--- a/maploc/data/kitti/prepare.py
+++ b/maploc/data/kitti/prepare.py
@@ -9,10 +9,10 @@
from tqdm.auto import tqdm
from ... import logger
-from ...osm.tiling import TileManager
+from ...osm.prepare import OSMDataSource, download_and_prepare_osm
from ...osm.viz import GeoPlotter
from ...utils.geo import BoundaryBox, Projection
-from ...utils.io import DATA_URL, download_file
+from ...utils.io import download_file
from .dataset import KittiDataModule
from .utils import parse_gps_file
@@ -20,11 +20,10 @@
def prepare_osm(
- data_dir,
- osm_path,
- output_path,
- tile_margin=512,
+ data_dir: Path,
+ osm_source: OSMDataSource,
ppm=2,
+ tile_margin=512,
):
all_latlon = []
for gps_path in data_dir.glob("2011_*/*/oxts/data/*.txt"):
@@ -34,21 +33,26 @@ def prepare_osm(
all_latlon = np.stack(all_latlon)
projection = Projection.from_points(all_latlon)
all_xy = projection.project(all_latlon)
- bbox_map = BoundaryBox(all_xy.min(0), all_xy.max(0)) + tile_margin
+ bbox_tiling = BoundaryBox(all_xy.min(0), all_xy.max(0)) + tile_margin
+
+ tiles_path = data_dir / KittiDataModule.default_cfg["tiles_filename"]
+ osm_path = data_dir / "karlsruhe.osm"
+ tile_manager = download_and_prepare_osm(
+ osm_source,
+ "kitti",
+ tiles_path,
+ bbox_tiling,
+ projection,
+ osm_path,
+ ppm=ppm,
+ )
plotter = GeoPlotter()
plotter.points(all_latlon, "red", name="GPS")
- plotter.bbox(projection.unproject(bbox_map), "blue", "tiling bounding box")
- plotter.fig.write_html(data_dir / "split_kitti.html")
-
- tile_manager = TileManager.from_bbox(
- projection,
- bbox_map,
- ppm,
- path=osm_path,
+ plotter.bbox(
+ projection.unproject(tile_manager.bbox), "black", "tiling bounding box"
)
- tile_manager.save(output_path)
- return tile_manager
+ plotter.fig.write_html(data_dir / "viz_kitti.html")
def download(data_dir: Path):
@@ -99,24 +103,13 @@ def download(data_dir: Path):
"--data_dir", type=Path, default=Path(KittiDataModule.default_cfg["data_dir"])
)
parser.add_argument("--pixel_per_meter", type=int, default=2)
- parser.add_argument("--generate_tiles", action="store_true")
+ parser.add_argument(
+ "--osm_source",
+ default=OSMDataSource.PRECOMPUTED.name,
+ choices=[e.name for e in OSMDataSource],
+ )
args = parser.parse_args()
args.data_dir.mkdir(exist_ok=True, parents=True)
download(args.data_dir)
-
- tiles_path = args.data_dir / KittiDataModule.default_cfg["tiles_filename"]
- if args.generate_tiles:
- logger.info("Generating the map tiles.")
- osm_filename = "karlsruhe.osm"
- osm_path = args.data_dir / osm_filename
- if not osm_path.exists():
- logger.info("Downloading OSM raw data.")
- download_file(DATA_URL + f"/osm/{osm_filename}", osm_path)
- if not osm_path.exists():
- raise FileNotFoundError(f"No OSM data file at {osm_path}.")
- prepare_osm(args.data_dir, osm_path, tiles_path, ppm=args.pixel_per_meter)
- (args.data_dir / ".downloaded").touch()
- else:
- logger.info("Downloading pre-generated map tiles.")
- download_file(DATA_URL + "/tiles/kitti.pkl", tiles_path)
+ prepare_osm(args.data_dir, OSMDataSource[args.osm_source], ppm=args.pixel_per_meter)
diff --git a/maploc/data/mapillary/config.py b/maploc/data/mapillary/config.py
new file mode 100644
index 0000000..1c84f5a
--- /dev/null
+++ b/maploc/data/mapillary/config.py
@@ -0,0 +1,154 @@
+from omegaconf import OmegaConf
+
+from ...utils.geo import BoundaryBox
+
+location_to_params = {
+ "sanfrancisco_soma": {
+ "bbox": BoundaryBox((37.770364, -122.410307), (37.795545, -122.388772)),
+ "bbox_val": BoundaryBox(
+ (37.788123419925945, -122.40053535863909),
+ (37.78897443253716, -122.3994618718349),
+ ),
+ "filters": {"model": "GoPro Max"},
+ "osm_file": "sanfrancisco.osm",
+ },
+ "sanfrancisco_hayes": {
+ "bbox": BoundaryBox((37.768634, -122.438415), (37.783894, -122.410605)),
+ "bbox_val": BoundaryBox(
+ (37.77682908567614, -122.42439593370665),
+ (37.7776996640339, -122.42329849537967),
+ ),
+ "filters": {"model": "GoPro Max"},
+ "osm_file": "sanfrancisco.osm",
+ },
+ "montrouge": {
+ "bbox": BoundaryBox((48.80874, 2.298958), (48.825276, 2.332989)),
+ "bbox_val": BoundaryBox(
+ (48.81554465300679, 2.315590378986898),
+ (48.816228935240346, 2.3166087395920103),
+ ),
+ "filters": {"model": "LG-R105"},
+ "osm_file": "paris.osm",
+ },
+ "amsterdam": {
+ "bbox": BoundaryBox((52.340679, 4.845284), (52.386299, 4.926147)),
+ "bbox_val": BoundaryBox(
+ (52.358275965541495, 4.876867175817335),
+ (52.35920971624303, 4.878370977965195),
+ ),
+ "filters": {"model": "GoPro Max"},
+ "osm_file": "amsterdam.osm",
+ },
+ "lemans": {
+ "bbox": BoundaryBox((47.995125, 0.185752), (48.014209, 0.224088)),
+ "bbox_val": BoundaryBox(
+ (48.00468200256593, 0.20130905922712253),
+ (48.00555356009431, 0.20251886369476968),
+ ),
+ "filters": {"creator_username": "sogefi"},
+ "osm_file": "lemans.osm",
+ },
+ "berlin": {
+ "bbox": BoundaryBox((52.459656, 13.416271), (52.499195, 13.469829)),
+ "bbox_val": BoundaryBox(
+ (52.47478263625299, 13.436060761632277),
+ (52.47610554128314, 13.438407628895831),
+ ),
+ "filters": {"is_pano": True},
+ "osm_file": "berlin.osm",
+ },
+ "nantes": {
+ "bbox": BoundaryBox((47.198289, -1.585839), (47.236161, -1.51318)),
+ "bbox_val": BoundaryBox(
+ (47.212224982547106, -1.555772859366718),
+ (47.213374064189956, -1.554270622470525),
+ ),
+ "filters": {"is_pano": True},
+ "osm_file": "nantes.osm",
+ },
+ "toulouse": {
+ "bbox": BoundaryBox((43.591434, 1.429457), (43.61343, 1.456653)),
+ "bbox_val": BoundaryBox(
+ (43.60314813839066, 1.4431497839062253),
+ (43.604433961018984, 1.4448508228862122),
+ ),
+ "filters": {"is_pano": True},
+ "osm_file": "toulouse.osm",
+ },
+ "vilnius": {
+ "bbox": BoundaryBox((54.672956, 25.258633), (54.696755, 25.296094)),
+ "bbox_val": BoundaryBox(
+ (54.68292611300143, 25.276979025529165),
+ (54.68349008447563, 25.27798847871685),
+ ),
+ "filters": {"is_pano": True},
+ "osm_file": "vilnius.osm",
+ },
+ "helsinki": {
+ "bbox": BoundaryBox(
+ (60.1449128318, 24.8975480117), (60.1770977471, 24.9816543235)
+ ),
+ "bbox_val": BoundaryBox(
+ (60.163825618884275, 24.930182541064955),
+ (60.16518598734065, 24.93274647451007),
+ ),
+ "filters": {"is_pano": True},
+ "osm_file": "helsinki.osm",
+ },
+ "milan": {
+ "bbox": BoundaryBox(
+ (45.4810977947, 9.1732723899), (45.5284238563, 9.2255987917)
+ ),
+ "bbox_val": BoundaryBox(
+ (45.502686834500466, 9.189078329923374),
+ (45.50329294217317, 9.189881944589828),
+ ),
+ "filters": {"is_pano": True},
+ "osm_file": "milan.osm",
+ },
+ "avignon": {
+ "bbox": BoundaryBox(
+ (43.9416178156, 4.7887045302), (43.9584848909, 4.8227015622)
+ ),
+ "bbox_val": BoundaryBox(
+ (43.94768786305171, 4.809099008430249),
+ (43.94827840894793, 4.809954737764413),
+ ),
+ "filters": {"is_pano": True},
+ "osm_file": "avignon.osm",
+ },
+ "paris": {
+ "bbox": BoundaryBox((48.833827, 2.306823), (48.889335, 2.39067)),
+ "bbox_val": BoundaryBox(
+ (48.85558288211851, 2.3427920762801526),
+ (48.85703370256603, 2.3449544861818654),
+ ),
+ "filters": {"is_pano": True},
+ "osm_file": "paris.osm",
+ },
+ # Add any new region/city here:
+ # "location_name": {
+ # "bbox": BoundaryBox((lat_min, long_min), (lat_max, long_max)),
+ # "filters": {"is_pano": True},
+ # # or other filters like creator_username, model, etc.
+ # # all described at https://www.mapillary.com/developer/api-documentation#image
+ # }
+ # Other fields (bbox_val, osm_file) will be deduced automatically.
+}
+
+default_cfg = OmegaConf.create(
+ {
+ "downsampling_resolution_meters": 3,
+ "target_num_val_images": 100,
+ "val_train_margin_meters": 25,
+ "max_num_train_images": 50_000,
+ "max_image_size": 512,
+ "do_legacy_pano_offset": True,
+ "min_dist_between_keyframes": 4,
+ "tiling": {
+ "tile_size": 128,
+ "margin": 128,
+ "ppm": 2,
+ },
+ }
+)
diff --git a/maploc/data/mapillary/dataset.py b/maploc/data/mapillary/dataset.py
index 8202026..dfdce42 100644
--- a/maploc/data/mapillary/dataset.py
+++ b/maploc/data/mapillary/dataset.py
@@ -253,18 +253,28 @@ def parse_splits(self, split_arg, names):
"val": [n for n in names if n[0] in scenes_val],
}
elif isinstance(split_arg, str):
- with (self.root / split_arg).open("r") as fp:
- splits = json.load(fp)
+ if (self.root / split_arg).exists():
+ # Common split file.
+ with (self.root / split_arg).open("r") as fp:
+ splits = json.load(fp)
+ else:
+ # Per-scene split file.
+ splits = defaultdict(dict)
+ for scene in self.cfg.scenes:
+ with (self.root / split_arg.format(scene=scene)).open("r") as fp:
+ scene_splits = json.load(fp)
+ for split_name in scene_splits:
+ splits[split_name][scene] = scene_splits[split_name]
splits = {
- k: {loc: set(ids) for loc, ids in split.items()}
- for k, split in splits.items()
+ split_name: {scene: set(ids) for scene, ids in split.items()}
+ for split_name, split in splits.items()
}
self.splits = {}
- for k, split in splits.items():
- self.splits[k] = [
- n
- for n in names
- if n[0] in split and int(n[-1].rsplit("_", 1)[0]) in split[n[0]]
+ for split_name, split in splits.items():
+ self.splits[split_name] = [
+ (scene, *arg, name)
+ for scene, *arg, name in names
+ if scene in split and int(name.rsplit("_", 1)[0]) in split[scene]
]
else:
raise ValueError(split_arg)
diff --git a/maploc/data/mapillary/download.py b/maploc/data/mapillary/download.py
index 816aa9f..410bfc4 100644
--- a/maploc/data/mapillary/download.py
+++ b/maploc/data/mapillary/download.py
@@ -2,7 +2,9 @@
import asyncio
import json
+from functools import partial
from pathlib import Path
+from typing import List, Optional, Sequence, Tuple
import httpx
import numpy as np
@@ -12,7 +14,7 @@
from opensfm.pymap import Shot
from ... import logger
-from ...utils.geo import Projection
+from ...utils.geo import BoundaryBox, Projection
semaphore = asyncio.Semaphore(100) # number of parallel threads.
image_filename = "{image_id}.jpg"
@@ -55,11 +57,13 @@ class MapillaryDownloader:
"sequence",
"sfm_cluster",
)
- image_info_url = (
- "https://graph.mapillary.com/{image_id}?access_token={token}&fields={fields}"
+ base_url = "https://graph.mapillary.com"
+ image_info_url = "{base}/{image_id}?access_token={token}&fields={fields}"
+ image_list_url = (
+ "{base}/images?access_token={token}&bbox={bbox}&limit={limit}&fields=id"
)
- seq_info_url = "https://graph.mapillary.com/image_ids?access_token={token}&sequence_id={seq_id}" # noqa E501
max_requests_per_minute = 50_000
+ max_num_results = 5_000 # maximum allowed by mapillary.com
def __init__(self, token: str):
self.token = token
@@ -70,27 +74,59 @@ def __init__(self, token: str):
@retry(times=5, exceptions=(httpx.RemoteProtocolError, httpx.ReadError))
async def call_api(self, url: str):
- async with self.limiter:
- r = await self.client.get(url)
+ async with semaphore:
+ async with self.limiter:
+ r = await self.client.get(url)
if not r.is_success:
- logger.error("Error in API call: %s", r.text)
+ logger.error("Error in API call: %s, retrying...", r.text)
+ raise httpx.ReadError(r.text)
return r
- async def get_image_info(self, image_id: int):
+ async def get_image_list(
+ self, bbox: BoundaryBox, is_pano: Optional[bool] = None, **filters
+ ):
+ # API bbox format: left, bottom, right, top (or minLon, minLat, maxLon, maxLat)
+ bbox = ",".join(map(str, (*bbox.min_[::-1], *bbox.max_[::-1])))
+ url = self.image_list_url.format(
+ base=self.base_url, token=self.token, bbox=bbox, limit=self.max_num_results
+ )
+ if is_pano is not None:
+ url += "&is_pano=" + ("true" if is_pano else "false")
+ for name, val in filters.items():
+ url += f"&{name}={val}"
+ r = await self.call_api(url)
+ if r.is_success:
+ info = json.loads(r.text)
+ image_ids = [int(d["id"]) for d in info["data"]]
+ return image_ids
+ # return json.loads(r.text)
+
+ async def get_image_info(
+ self, image_id: int, fields: Optional[Sequence[str]] = None
+ ):
url = self.image_info_url.format(
+ base=self.base_url,
image_id=image_id,
token=self.token,
- fields=",".join(self.image_fields),
+ fields=",".join(fields or self.image_fields),
)
r = await self.call_api(url)
if r.is_success:
return json.loads(r.text)
- async def get_sequence_info(self, seq_id: str):
- url = self.seq_info_url.format(seq_id=seq_id, token=self.token)
- r = await self.call_api(url)
- if r.is_success:
- return json.loads(r.text)
+ async def get_image_info_cached(
+ self, image_id: int, dir_: Optional[Path] = None, **kwargs
+ ):
+ if dir_ is None:
+ return await self.get_image_info(image_id, **kwargs)
+ path = dir_ / info_filename.format(image_id=image_id)
+ if path.exists():
+ info = json.loads(path.read_text())
+ else:
+ info = await self.get_image_info(image_id, **kwargs)
+ if info is not None:
+ path.write_text(json.dumps(info))
+ return info
async def download_image_pixels(self, url: str, path: Path):
r = await self.call_api(url)
@@ -99,15 +135,6 @@ async def download_image_pixels(self, url: str, path: Path):
fid.write(r.content)
return r.is_success
- async def get_image_info_cached(self, image_id: int, path: Path):
- if path.exists():
- info = json.loads(path.read_text())
- else:
- info = await self.get_image_info(image_id)
- if info is not None:
- path.write_text(json.dumps(info))
- return info
-
async def download_image_pixels_cached(self, url: str, path: Path):
if path.exists():
return True
@@ -115,64 +142,80 @@ async def download_image_pixels_cached(self, url: str, path: Path):
return await self.download_image_pixels(url, path)
-async def fetch_images_in_sequence(i, downloader):
- async with semaphore:
- info = await downloader.get_sequence_info(i)
- if info is None:
- image_ids = None
- else:
- image_ids = [int(d["id"]) for d in info["data"]]
- return i, image_ids
+async def _return_with_arg(item, fn):
+ ret = await fn(item)
+ return item, ret
-async def fetch_images_in_sequences(sequence_ids, downloader):
- seq_to_images_ids = {}
- tasks = [fetch_images_in_sequence(i, downloader) for i in sequence_ids]
+async def fetch_many(items, fn):
+ results = []
+ tasks = [_return_with_arg(item, fn) for item in items]
for task in tqdm.asyncio.tqdm.as_completed(tasks):
- i, image_ids = await task
- if image_ids is not None:
- seq_to_images_ids[i] = image_ids
- return seq_to_images_ids
+ results.append(await task)
+ return results
-async def fetch_image_info(i, downloader, dir_):
- async with semaphore:
- path = dir_ / info_filename.format(image_id=i)
- info = await downloader.get_image_info_cached(i, path)
- return i, info
-
-
-async def fetch_image_infos(image_ids, downloader, dir_):
- infos = {}
+async def fetch_image_infos(image_ids, downloader, **kwargs):
+ infos = await fetch_many(
+ image_ids, partial(downloader.get_image_info_cached, **kwargs)
+ )
+ infos = dict(infos)
num_fail = 0
- tasks = [fetch_image_info(i, downloader, dir_) for i in image_ids]
- for task in tqdm.asyncio.tqdm.as_completed(tasks):
- i, info = await task
- if info is None:
+ for i in image_ids:
+ if infos[i] is None:
+ del infos[i]
num_fail += 1
- else:
- infos[i] = info
return infos, num_fail
-async def fetch_image_pixels(i, url, downloader, dir_, overwrite=False):
- async with semaphore:
+async def fetch_images_pixels(image_urls, downloader, dir_, overwrite=False):
+ tasks = []
+ for i, url in image_urls:
path = dir_ / image_filename.format(image_id=i)
if overwrite:
path.unlink(missing_ok=True)
- success = await downloader.download_image_pixels_cached(url, path)
- return i, success
-
-
-async def fetch_images_pixels(image_urls, downloader, dir_):
+ tasks.append(downloader.download_image_pixels_cached(url, path))
num_fail = 0
- tasks = [fetch_image_pixels(*id_url, downloader, dir_) for id_url in image_urls]
for task in tqdm.asyncio.tqdm.as_completed(tasks):
- i, success = await task
+ success = await task
num_fail += not success
return num_fail
+def split_bbox(bbox: BoundaryBox) -> tuple[BoundaryBox]:
+ midpoint = bbox.center
+ return (
+ BoundaryBox(bbox.min_, midpoint),
+ BoundaryBox((bbox.min_[0], midpoint[1]), (midpoint[0], bbox.max_[1])),
+ BoundaryBox((midpoint[0], bbox.min_[1]), (bbox.max_[0], midpoint[1])),
+ BoundaryBox(midpoint, bbox.max_),
+ )
+
+
+async def fetch_image_list(
+ query_bbox: BoundaryBox,
+ downloader: MapillaryDownloader,
+ **filters,
+) -> Tuple[List[int], List[BoundaryBox]]:
+ """Because of the limit in number of returned results, we recursively split
+ the query area until each query is below this limit.
+ """
+ pool = [query_bbox]
+ finished = []
+ all_ids = []
+ while len(pool):
+ rets = await fetch_many(pool, partial(downloader.get_image_list, **filters))
+ pool = []
+ for bbox, ids in rets:
+ assert ids is not None
+ if len(ids) == downloader.max_num_results:
+ pool.extend(split_bbox(bbox))
+ else:
+ finished.append(bbox)
+ all_ids.extend(ids)
+ return all_ids, finished
+
+
def opensfm_camera_from_info(info: dict) -> Camera:
cam_type = info["camera_type"]
if cam_type == "perspective":
diff --git a/maploc/data/mapillary/prepare.py b/maploc/data/mapillary/prepare.py
index af8d96e..c9b2891 100644
--- a/maploc/data/mapillary/prepare.py
+++ b/maploc/data/mapillary/prepare.py
@@ -6,7 +6,7 @@
import shutil
from collections import defaultdict
from pathlib import Path
-from typing import List
+from typing import Dict, List, Optional, Sequence
import cv2
import numpy as np
@@ -21,11 +21,12 @@
from tqdm.contrib.concurrent import thread_map
from ... import logger
-from ...osm.tiling import TileManager
+from ...osm.prepare import OSMDataSource, download_and_prepare_osm
from ...osm.viz import GeoPlotter
from ...utils.geo import BoundaryBox, Projection
-from ...utils.io import DATA_URL, download_file, write_json
+from ...utils.io import write_json
from ..utils import decompose_rotmat
+from .config import default_cfg, location_to_params
from .dataset import MapillaryDataModule
from .download import (
MapillaryDownloader,
@@ -43,110 +44,11 @@
undistort_shot,
)
-location_to_params = {
- "sanfrancisco_soma": {
- "bbox": BoundaryBox(
- [-122.410307, 37.770364][::-1], [-122.388772, 37.795545][::-1]
- ),
- "camera_models": ["GoPro Max"],
- "osm_file": "sanfrancisco.osm",
- },
- "sanfrancisco_hayes": {
- "bbox": BoundaryBox(
- [-122.438415, 37.768634][::-1], [-122.410605, 37.783894][::-1]
- ),
- "camera_models": ["GoPro Max"],
- "osm_file": "sanfrancisco.osm",
- },
- "amsterdam": {
- "bbox": BoundaryBox([4.845284, 52.340679][::-1], [4.926147, 52.386299][::-1]),
- "camera_models": ["GoPro Max"],
- "osm_file": "amsterdam.osm",
- },
- "lemans": {
- "bbox": BoundaryBox([0.185752, 47.995125][::-1], [0.224088, 48.014209][::-1]),
- "owners": ["xXOocM1jUB4jaaeukKkmgw"], # sogefi
- "osm_file": "lemans.osm",
- },
- "berlin": {
- "bbox": BoundaryBox([13.416271, 52.459656][::-1], [13.469829, 52.499195][::-1]),
- "owners": ["LT3ajUxH6qsosamrOHIrFw"], # supaplex030
- "osm_file": "berlin.osm",
- },
- "montrouge": {
- "bbox": BoundaryBox([2.298958, 48.80874][::-1], [2.332989, 48.825276][::-1]),
- "owners": [
- "XtzGKZX2_VIJRoiJ8IWRNQ",
- "C4ENdWpJdFNf8CvnQd7NrQ",
- "e_ZBE6mFd7CYNjRSpLl-Lg",
- ], # overflorian, phyks, francois2
- "camera_models": ["LG-R105"],
- "osm_file": "paris.osm",
- },
- "nantes": {
- "bbox": BoundaryBox([-1.585839, 47.198289][::-1], [-1.51318, 47.236161][::-1]),
- "owners": [
- "jGdq3CL-9N-Esvj3mtCWew",
- "s-j5BH9JRIzsgORgaJF3aA",
- ], # c_mobilite, cartocite
- "osm_file": "nantes.osm",
- },
- "toulouse": {
- "bbox": BoundaryBox([1.429457, 43.591434][::-1], [1.456653, 43.61343][::-1]),
- "owners": ["MNkhq6MCoPsdQNGTMh3qsQ"], # tyndare
- "osm_file": "toulouse.osm",
- },
- "vilnius": {
- "bbox": BoundaryBox([25.258633, 54.672956][::-1], [25.296094, 54.696755][::-1]),
- "owners": ["bClduFF6Gq16cfwCdhWivw", "u5ukBseATUS8jUbtE43fcO"], # kedas, vms
- "osm_file": "vilnius.osm",
- },
- "helsinki": {
- "bbox": BoundaryBox(
- [24.8975480117, 60.1449128318][::-1], [24.9816543235, 60.1770977471][::-1]
- ),
- "camera_types": ["spherical", "equirectangular"],
- "osm_file": "helsinki.osm",
- },
- "milan": {
- "bbox": BoundaryBox(
- [9.1732723899, 45.4810977947][::-1],
- [9.2255987917, 45.5284238563][::-1],
- ),
- "camera_types": ["spherical", "equirectangular"],
- "osm_file": "milan.osm",
- },
- "avignon": {
- "bbox": BoundaryBox(
- [4.7887045302, 43.9416178156][::-1], [4.8227015622, 43.9584848909][::-1]
- ),
- "camera_types": ["spherical", "equirectangular"],
- "osm_file": "avignon.osm",
- },
- "paris": {
- "bbox": BoundaryBox([2.306823, 48.833827][::-1], [2.39067, 48.889335][::-1]),
- "camera_types": ["spherical", "equirectangular"],
- "osm_file": "paris.osm",
- },
-}
-
-
-default_cfg = OmegaConf.create(
- {
- "max_image_size": 512,
- "do_legacy_pano_offset": True,
- "min_dist_between_keyframes": 4,
- "tiling": {
- "tile_size": 128,
- "margin": 128,
- "ppm": 2,
- },
- }
-)
+DATA_FILENAME = "dump.json"
def get_pano_offset(image_info: dict, do_legacy: bool = False) -> float:
- if do_legacy:
+ if do_legacy and "sfm_cluster" in image_info:
seed = int(image_info["sfm_cluster"]["id"])
else:
seed = image_info["sequence"].__hash__()
@@ -205,7 +107,7 @@ def pack_shot_dict(shot: Shot, info: dict) -> dict:
capture_time=info["captured_at"],
gps_position=np.r_[latlong_gps, info["altitude"]],
compass_angle=info["compass_angle"],
- chunk_id=int(info["sfm_cluster"]["id"]),
+ chunk_id=int(info["sfm_cluster"]["id"]) if "sfm_cluster" in info else -1,
)
@@ -272,24 +174,18 @@ def process_sequence(
def process_location(
- location: str,
- data_dir: Path,
- split_path: Path,
+ output_dir: Path,
+ bbox: BoundaryBox,
+ splits: Dict[str, Sequence[int]],
token: str,
cfg: DictConfig,
- generate_tiles: bool = False,
):
- params = location_to_params[location]
- bbox = params["bbox"]
projection = Projection(*bbox.center)
+ image_ids = [i for split in splits.values() for i in split]
- splits = json.loads(split_path.read_text())
- image_ids = [i for split in splits.values() for i in split[location]]
-
- loc_dir = data_dir / location
- infos_dir = loc_dir / "image_infos"
- raw_image_dir = loc_dir / "images_raw"
- out_image_dir = loc_dir / "images"
+ infos_dir = output_dir / "image_infos"
+ raw_image_dir = output_dir / "images_raw"
+ out_image_dir = output_dir / "images"
for d in (infos_dir, raw_image_dir, out_image_dir):
d.mkdir(parents=True, exist_ok=True)
@@ -298,7 +194,7 @@ def process_location(
logger.info("Fetching metadata for all images.")
image_infos, num_fail = loop.run_until_complete(
- fetch_image_infos(image_ids, downloader, infos_dir)
+ fetch_image_infos(image_ids, downloader, dir_=infos_dir)
)
logger.info("%d failures (%.1f%%).", num_fail, 100 * num_fail / len(image_ids))
@@ -326,8 +222,27 @@ def process_location(
out_image_dir,
)
)
- write_json(loc_dir / "dump.json", dump)
+ write_json(output_dir / DATA_FILENAME, dump)
+
+ shutil.rmtree(raw_image_dir)
+
+
+def prepare_osm(
+ location: str,
+ output_dir: Path,
+ bbox: BoundaryBox,
+ cfg: DictConfig,
+ osm_dir: Path,
+ osm_source: OSMDataSource,
+ osm_filename: Optional[str] = None,
+):
+ tiles_path = output_dir / MapillaryDataModule.default_cfg["tiles_filename"]
+ if tiles_path.exists():
+ return
+ logger.info("Preparing OSM data.")
+ projection = Projection(*bbox.center)
+ dump = json.loads((output_dir / DATA_FILENAME).read_text())
# Get the view locations
view_ids = []
views_latlon = []
@@ -339,47 +254,72 @@ def process_location(
view_ids = np.array(view_ids)
views_xy = projection.project(views_latlon)
- tiles_path = loc_dir / MapillaryDataModule.default_cfg["tiles_filename"]
- if generate_tiles:
- logger.info("Creating the map tiles.")
- bbox_data = BoundaryBox(views_xy.min(0), views_xy.max(0))
- bbox_tiling = bbox_data + cfg.tiling.margin
- osm_dir = data_dir / "osm"
- osm_path = osm_dir / params["osm_file"]
- if not osm_path.exists():
- logger.info("Downloading OSM raw data.")
- download_file(DATA_URL + f"/osm/{params['osm_file']}", osm_path)
- if not osm_path.exists():
- raise FileNotFoundError(f"Cannot find OSM data file {osm_path}.")
- tile_manager = TileManager.from_bbox(
- projection,
- bbox_tiling,
- cfg.tiling.ppm,
- tile_size=cfg.tiling.tile_size,
- path=osm_path,
- )
- tile_manager.save(tiles_path)
- else:
- logger.info("Downloading pre-generated map tiles.")
- download_file(DATA_URL + f"/tiles/{location}.pkl", tiles_path)
- tile_manager = TileManager.load(tiles_path)
+ bbox_data = BoundaryBox(views_xy.min(0), views_xy.max(0))
+ bbox_tiling = bbox_data + cfg.tiling.margin
+ osm_path = osm_dir / (osm_filename or f"{location}.osm")
+ tile_manager = download_and_prepare_osm(
+ osm_source,
+ location,
+ tiles_path,
+ bbox_tiling,
+ projection,
+ osm_path,
+ ppm=cfg.tiling.ppm,
+ tile_size=cfg.tiling.tile_size,
+ )
# Visualize the data split
plotter = GeoPlotter()
- view_ids_val = set(splits["val"][location])
- is_val = np.array([int(i.rsplit("_", 1)[0]) in view_ids_val for i in view_ids])
- plotter.points(views_latlon[~is_val], "red", view_ids[~is_val], "train")
- plotter.points(views_latlon[is_val], "green", view_ids[is_val], "val")
+ plotter.points(views_latlon, "red", view_ids, "images")
plotter.bbox(bbox, "blue", "query bounding box")
plotter.bbox(
projection.unproject(tile_manager.bbox), "black", "tiling bounding box"
)
- geo_viz_path = loc_dir / f"split_{location}.html"
+ geo_viz_path = output_dir / f"viz_data_{location}.html"
plotter.fig.write_html(geo_viz_path)
- logger.info("Wrote split visualization to %s.", geo_viz_path)
+ logger.info("Wrote the visualization to %s.", geo_viz_path)
- shutil.rmtree(raw_image_dir)
- logger.info("Done processing for location %s.", location)
+
+def main(args: argparse.Namespace):
+ args.data_dir.mkdir(exist_ok=True, parents=True)
+
+ split_path = args.data_dir / args.split_filename
+ maybe_git_split = Path(__file__).parent / args.split_filename
+ if maybe_git_split.exists():
+ logger.info("Using official split file at %s.", maybe_git_split)
+ shutil.copy(maybe_git_split, args.data_dir)
+
+ cfg = OmegaConf.merge(default_cfg, OmegaConf.from_cli(args.dotlist))
+ for location in args.locations:
+ logger.info("Starting processing for location %s.", location)
+ if split_path.exists():
+ splits = json.loads(split_path.read_text())
+ splits = {split_name: val[location] for split_name, val in splits.items()}
+ else:
+ split_path_ = Path(str(split_path).format(scene=location))
+ if not split_path_.exists():
+ raise ValueError(f"Cannot find any split file at path {split_path}.")
+ logger.info("Using per-location split file at %s.", split_path_)
+ splits = json.loads(split_path_.read_text())
+
+ process_location(
+ args.data_dir / location,
+ location_to_params[location]["bbox"],
+ splits,
+ args.token,
+ cfg,
+ )
+
+ prepare_osm(
+ location,
+ args.data_dir / location,
+ location_to_params[location]["bbox"],
+ cfg,
+ args.data_dir / "osm",
+ OSMDataSource[args.osm_source],
+ location_to_params[location].get("osm_file"),
+ )
+ logger.info("Done processing for location %s.", location)
if __name__ == "__main__":
@@ -392,21 +332,10 @@ def process_location(
parser.add_argument(
"--data_dir", type=Path, default=MapillaryDataModule.default_cfg["data_dir"]
)
- parser.add_argument("--generate_tiles", action="store_true")
+ parser.add_argument(
+ "--osm_source",
+ default=OSMDataSource.PRECOMPUTED.name,
+ choices=[e.name for e in OSMDataSource],
+ )
parser.add_argument("dotlist", nargs="*")
- args = parser.parse_args()
-
- args.data_dir.mkdir(exist_ok=True, parents=True)
- shutil.copy(Path(__file__).parent / args.split_filename, args.data_dir)
- cfg_ = OmegaConf.merge(default_cfg, OmegaConf.from_cli(args.dotlist))
-
- for location in args.locations:
- logger.info("Starting processing for location %s.", location)
- process_location(
- location,
- args.data_dir,
- args.data_dir / args.split_filename,
- args.token,
- cfg_,
- args.generate_tiles,
- )
+ main(parser.parse_args())
diff --git a/maploc/data/mapillary/split.py b/maploc/data/mapillary/split.py
new file mode 100644
index 0000000..718a942
--- /dev/null
+++ b/maploc/data/mapillary/split.py
@@ -0,0 +1,221 @@
+import argparse
+import asyncio
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import numpy as np
+from omegaconf import DictConfig, OmegaConf
+
+from ... import logger
+from ...osm.parser import Groups
+from ...osm.prepare import OSMDataSource, download_and_prepare_osm
+from ...osm.viz import GeoPlotter
+from ...utils.geo import BoundaryBox, Projection
+from ...utils.io import write_json
+from .config import default_cfg, location_to_params
+from .dataset import MapillaryDataModule
+from .download import MapillaryDownloader, fetch_image_infos, fetch_image_list
+
+
+def grid_downsample(xy: np.ndarray, bbox: BoundaryBox, resolution: float) -> np.ndarray:
+ assert bbox.contains(xy).all()
+ extent = bbox.max_ - bbox.min_
+ size = np.ceil(extent / resolution).astype(int)
+ grid = np.full(size, -1)
+ idx = np.floor((xy - bbox.min_) / resolution).astype(int)
+ grid[tuple(idx.T)] = np.arange(len(xy))
+ indices = grid[grid >= 0]
+ return indices
+
+
+def find_validation_bbox(
+ xy: np.ndarray, target_num, size_upper_bound=500
+) -> BoundaryBox:
+ # Find the centroid of all points
+ center = np.median(xy, 0)
+ dist = np.linalg.norm(xy - center[None], axis=1)
+ center = xy[np.argmin(dist)]
+
+ bbox = BoundaryBox(center - size_upper_bound, center + size_upper_bound)
+ mask = bbox.contains(xy)
+ dist = np.abs(xy[mask] - center).max(-1)
+ dist.sort()
+ thresh = dist[target_num]
+ bbox_val = BoundaryBox(center - thresh, center + thresh)
+ return bbox_val
+
+
+def process_location(
+ split_path: Path,
+ output_dir: Path,
+ osm_dir: Path,
+ osm_source: OSMDataSource,
+ token: str,
+ cfg: DictConfig,
+ query_bbox: BoundaryBox,
+ filters: Dict[str, Any],
+ bbox_val: Optional[BoundaryBox] = None,
+ osm_filename: Optional[str] = None,
+):
+ split_path.parent.mkdir(parents=True, exist_ok=True)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ projection = Projection(*query_bbox.center)
+ bbox_local = projection.project(query_bbox)
+
+ logger.info("Fetching OpenStreetMap data.")
+ tiles_path = output_dir / MapillaryDataModule.default_cfg["tiles_filename"]
+ bbox_tiling = query_bbox + cfg.tiling.margin
+ osm_path = osm_dir / (osm_filename or f"{output_dir.name}.osm")
+ tile_manager = download_and_prepare_osm(
+ osm_source,
+ output_dir.name,
+ tiles_path,
+ bbox_tiling,
+ projection,
+ osm_path,
+ ppm=cfg.tiling.ppm,
+ tile_size=cfg.tiling.tile_size,
+ )
+
+ downloader = MapillaryDownloader(token)
+ loop = asyncio.get_event_loop()
+
+ logger.info("Fetching the list of image with filter: %s", filters)
+ image_ids, bboxes = loop.run_until_complete(
+ fetch_image_list(query_bbox, downloader, **filters)
+ )
+ if not image_ids:
+ raise ValueError("Could not find any image!")
+ logger.info("Found %d images.", len(image_ids))
+
+ logger.info("Fetching the image coordinates.")
+ infos, num_fail = loop.run_until_complete(
+ fetch_image_infos(image_ids, downloader, fields=["computed_geometry"])
+ )
+ logger.info("%d failures (%.1f%%).", num_fail, 100 * num_fail / len(image_ids))
+
+ # discard images that don't have coordinates available
+ image_ids = np.array([i for i in infos if "computed_geometry" in infos[i]])
+ image_ids.sort()
+
+ # discard images outside of the query bbox
+ latlon = np.array(
+ [infos[i]["computed_geometry"]["coordinates"][::-1] for i in image_ids]
+ )
+ xy = projection.project(latlon)
+ valid = bbox_local.contains(xy)
+ image_ids = image_ids[valid]
+ latlon = latlon[valid]
+ xy = xy[valid]
+
+ # discard sequences that intersect with buildings
+ canvas = tile_manager.query(bbox_local)
+ building_mask = canvas.raster[0] == (Groups.areas.index("building") + 1)
+ is_building = building_mask[tuple(np.round(canvas.to_uv(xy)).astype(int).T[::-1])]
+ logger.info(
+ "%d images intersect with buildings (%.1f%%).",
+ is_building.sum(),
+ 100 * is_building.mean(),
+ )
+ # plotter = GeoPlotter()
+ # plotter.points(latlon[is_building], "red", name="building")
+ # plotter.points(latlon[~is_building], "green", name="no building")
+ # plotter.bbox(query_bbox, "blue", "query bounding box")
+ # plotter.fig.write_html(f"{split_path}_building_viz.html")
+ # TODO: filter sequences that intersect with buildings by at least 30%.
+ image_ids = image_ids[~is_building]
+ latlon = latlon[~is_building]
+ xy = xy[~is_building]
+
+ # downsample the images with a grid
+ indices = grid_downsample(xy, bbox_local, cfg.downsampling_resolution_meters)
+ image_ids = image_ids[indices]
+ latlon = latlon[indices]
+ xy = xy[indices]
+ logger.info("Filtered down to %d images.", len(image_ids))
+
+ if bbox_val is None:
+ bbox_val_local = find_validation_bbox(xy, cfg.target_num_val_images)
+ bbox_val = projection.unproject(bbox_val_local)
+ else:
+ bbox_val_local = projection.project(bbox_val)
+ logger.info("Using validation bounding box: %s.", bbox_val)
+ indices_val = np.where(bbox_val_local.contains(xy))[0]
+ bbox_not_train = bbox_val_local + cfg.val_train_margin_meters
+ indices_train = np.where(~bbox_not_train.contains(xy))[0]
+ if len(indices_train) > cfg.max_num_train_images:
+ indices_subsample = np.random.RandomState(0).choice(
+ len(indices_train), cfg.max_num_train_images
+ )
+ indices_train = indices_train[indices_subsample]
+ logger.info(
+ "Resulting split: %d val and %d train images.",
+ len(indices_val),
+ len(indices_train),
+ )
+
+ splits = {
+ "val": image_ids[indices_val].tolist(),
+ "train": image_ids[indices_train].tolist(),
+ }
+ write_json(split_path, splits)
+
+ # Visualize the data split
+ plotter = GeoPlotter()
+ plotter.points(latlon[indices_train], "red", image_ids[indices_train], "train")
+ plotter.points(latlon[indices_val], "green", image_ids[indices_val], "val")
+ plotter.bbox(query_bbox, "blue", "query bounding box")
+ plotter.bbox(bbox_val, "green", "validation bounding box")
+ plotter.bbox(projection.unproject(bbox_not_train), "black", "margin bounding box")
+ geo_viz_path = f"{split_path}_viz.html"
+ plotter.fig.write_html(geo_viz_path)
+ logger.info("Wrote split visualization to %s.", geo_viz_path)
+
+
+def main(args: argparse.Namespace):
+ args.data_dir.mkdir(exist_ok=True, parents=True)
+ cfg = OmegaConf.merge(default_cfg, OmegaConf.from_cli(args.dotlist))
+ for location in args.locations:
+ output_path = args.data_dir / args.output_filename.format(scene=location)
+ if output_path.exists() and not args.overwrite:
+ logger.info("Skipping processing for location %s.", location)
+ continue
+ logger.info("Starting processing for location %s.", location)
+ params = location_to_params[location]
+ process_location(
+ output_path,
+ args.data_dir / location,
+ args.data_dir / "osm",
+ OSMDataSource[args.osm_source],
+ args.token,
+ cfg,
+ params["bbox"],
+ {"start_captured_at": args.min_capture_date} | params["filters"],
+ bbox_val=None if args.force_auto_val_bbox else params.get("bbox_val"),
+ osm_filename=location_to_params[location].get("osm_file"),
+ )
+ logger.info("Done processing for location %s.", location)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--locations", type=str, nargs="+", default=list(location_to_params)
+ )
+ parser.add_argument("--token", type=str, required=True)
+ parser.add_argument("--min_capture_date", type=str, default="2015-01-01T00:00:00Z")
+ parser.add_argument(
+ "--output_filename", type=str, default="splits_MGL_v2_{scene}.json"
+ )
+ parser.add_argument(
+ "--data_dir", type=Path, default=MapillaryDataModule.default_cfg["data_dir"]
+ )
+ parser.add_argument(
+ "--osm_source",
+ default=OSMDataSource.PRECOMPUTED.name,
+ choices=[e.name for e in OSMDataSource],
+ )
+ parser.add_argument("--overwrite", action="store_true")
+ parser.add_argument("--force_auto_val_bbox", action="store_true")
+ parser.add_argument("dotlist", nargs="*")
+ main(parser.parse_args())
diff --git a/maploc/osm/download.py b/maploc/osm/download.py
index 2c1f7a2..47f053a 100644
--- a/maploc/osm/download.py
+++ b/maploc/osm/download.py
@@ -1,10 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
+import subprocess
from http.client import responses
from pathlib import Path
from typing import Any, Dict, Optional
+import shapely
import urllib3
from .. import logger
@@ -18,6 +20,7 @@ def get_osm(
cache_path: Optional[Path] = None,
overwrite: bool = False,
) -> Dict[str, Any]:
+ """Fetch OSM data using the OSM API. Suitable only for small areas."""
if not overwrite and cache_path is not None and cache_path.is_file():
return json.loads(cache_path.read_text())
@@ -33,3 +36,51 @@ def get_osm(
if cache_path is not None:
cache_path.write_bytes(result.data)
return result.json()
+
+
+def get_geofabrik_index() -> Dict[str, Any]:
+ """Fetch the index of all regions served by Geofabrik."""
+ result = urllib3.request(
+ "GET", "https://download.geofabrik.de/index-v1.json", timeout=10
+ )
+ if result.status != 200:
+ error = result.info()["error"]
+ raise ValueError(f"{result.status} {responses[result.status]}: {error}")
+ return json.loads(result.data)
+
+
+def get_geofabrik_url(bbox: BoundaryBox) -> str:
+ """Find the smallest Geofabrik region file that includes a given area."""
+ gf = get_geofabrik_index()
+ best_region = None
+ best_area = float("inf")
+ query_poly = shapely.box(*bbox.min_[::-1], *bbox.max_[::-1])
+ for i, region in enumerate(gf["features"]):
+ coords = region["geometry"]["coordinates"]
+ # fix the polygon format
+ coords = [c if len(c[0]) < 2 else (c[0], c[1:]) for c in coords]
+ poly = shapely.MultiPolygon([shapely.Polygon(*c) for c in coords])
+ if poly.contains(query_poly):
+ area = poly.area
+ if area < best_area:
+ best_area = area
+ best_region = region
+ return best_region["properties"]["urls"]["pbf"]
+
+
+def convert_osm_file(bbox: BoundaryBox, input_path: Path, output_path: Path):
+ """Convert and crop a binary .pbf file to an .osm XML file."""
+ bbox_str = ",".join(map(str, (*bbox.min_[::-1], *bbox.max_[::-1])))
+ cmd = [
+ "osmium",
+ "extract",
+ "-s",
+ "smart",
+ "-b",
+ bbox_str,
+ input_path.as_posix(),
+ "-o",
+ output_path.as_posix(),
+ "--overwrite",
+ ]
+ subprocess.run(cmd, check=True)
diff --git a/maploc/osm/prepare.py b/maploc/osm/prepare.py
new file mode 100644
index 0000000..c6a3c08
--- /dev/null
+++ b/maploc/osm/prepare.py
@@ -0,0 +1,65 @@
+import logging
+from enum import Enum, auto
+from pathlib import Path
+
+from ..utils.geo import BoundaryBox, Projection
+from ..utils.io import DATA_URL, download_file
+from .download import convert_osm_file, get_geofabrik_url
+from .tiling import TileManager
+
+logger = logging.getLogger(__name__)
+
+
+class OSMDataSource(Enum):
+ # Pre-computed map tiles.
+ PRECOMPUTED = auto()
+
+ # Re-compute the map tiles from cached OSM data.
+ CACHED = auto()
+
+ # Fetch the latest OSM data and re-compute the map tiles from them.
+ LATEST = auto()
+
+
+def download_and_prepare_osm(
+ source: OSMDataSource,
+ tiles_name: str,
+ tiles_path: Path,
+ bbox: BoundaryBox,
+ projection: Projection,
+ osm_path: Path,
+ **kwargs,
+) -> TileManager:
+ if source == OSMDataSource.PRECOMPUTED:
+ if not tiles_path.exists():
+ logger.info("Downloading pre-computed map tiles.")
+ download_file(DATA_URL + f"/tiles/{tiles_name}.pkl", tiles_path)
+ tile_manager = TileManager.load(tiles_path)
+ assert tile_manager.ppm == kwargs["ppm"]
+ assert tile_manager.bbox.contains(bbox)
+ else:
+ logger.info("Creating the map tiles.")
+ if source == OSMDataSource.CACHED:
+ if not osm_path.exists():
+ logger.info("Downloading cached OSM data.")
+ download_file(DATA_URL + f"/osm/{osm_path.name}", osm_path)
+ if not osm_path.exists():
+ raise FileNotFoundError(f"Cannot find OSM data file {osm_path}.")
+ elif source == OSMDataSource.LATEST:
+ logger.info("Downloading the latest OSM data.")
+ bbox_osm = projection.unproject(bbox + 2_000) # 2 km
+ url = get_geofabrik_url(bbox_osm)
+ tmp_path = osm_path.parent / Path(url).name
+ download_file(url, tmp_path)
+ convert_osm_file(bbox_osm, tmp_path, osm_path)
+ tmp_path.unlink()
+ else:
+ raise NotImplementedError("Unknown source: {osm_source}.")
+ tile_manager = TileManager.from_bbox(
+ projection,
+ bbox,
+ path=osm_path,
+ **kwargs,
+ )
+ tile_manager.save(tiles_path)
+ return tile_manager