Skip to content
36 changes: 18 additions & 18 deletions src/pypgstac/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,38 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
dependencies = [
"cachetools==5.3.*",
"fire==0.4.*",
"hydraters==0.1.*",
"orjson>=3.6.2",
"plpygis==0.2.*",
"pydantic>=1.7",
"python-dateutil==2.8.*",
"smart-open>=4.2",
"tenacity==8.1.*",
"version-parser>= 1.0.1",
"cachetools>=5.0.0",
"fire>=0.4.0",
"hydraters>=0.1.0",
"orjson>=3.6.0",
"plpygis>=0.2.0",
"pydantic-settings>=2.0.0",
"python-dateutil>=2.8.0",
"smart-open>=4.2.0",
"tenacity>=8.0.0",
"version-parser>=1.0.0",
]

[project.optional-dependencies]
test = [
"pytest",
"pytest-cov",
"pystac[validation]==1.*",
"pystac[validation]>=1.0.0",
"types-cachetools",
]
dev = [
"flake8==7.1.1",
"black>=24.10.0",
"mypy>=1.13.0",
"types-setuptools",
"ruff==0.8.2",
"black",
"mypy",
"ruff",
"pre-commit",
]
psycopg = [
"psycopg[binary]==3.1.*",
"psycopg-pool==3.1.*",
"psycopg[binary]>=3.1.9",
"psycopg-pool>=3.1",
]
migrations = [
"psycopg2-binary",
Expand Down
16 changes: 3 additions & 13 deletions src/pypgstac/src/pypgstac/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,13 @@
from psycopg import Connection, sql
from psycopg.types.json import set_json_dumps, set_json_loads
from psycopg_pool import ConnectionPool

try:
from pydantic.v1 import BaseSettings # type:ignore
except ImportError:
from pydantic import BaseSettings # type:ignore

from pydantic_settings import BaseSettings
from tenacity import retry, retry_if_exception_type, stop_after_attempt

logger = logging.getLogger(__name__)


def dumps(data: dict) -> str:
"""Dump dictionary as string."""
return orjson.dumps(data).decode()


set_json_dumps(dumps)
set_json_dumps(orjson.dumps)
set_json_loads(orjson.loads)


Expand Down Expand Up @@ -304,4 +294,4 @@ def func(self, function_name: str, *args: Any) -> Generator:

def search(self, query: Union[dict, str, psycopg.types.json.Jsonb] = "{}") -> str:
"""Search PgSTAC."""
return dumps(next(self.func("search", query))[0])
return orjson.dumps(next(self.func("search", query))[0]).decode()
57 changes: 36 additions & 21 deletions src/pypgstac/src/pypgstac/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import (
Any,
BinaryIO,
Dict,
Generator,
Iterable,
Iterator,
Optional,
TextIO,
Tuple,
TypeVar,
Union,
)

Expand Down Expand Up @@ -55,7 +54,13 @@ class Partition:
requires_update: bool


def chunked_iterable(iterable: Iterable, size: Optional[int] = 10000) -> Iterable:
_T = TypeVar("_T")


def chunked_iterable(
iterable: Iterable[_T],
size: Optional[int] = 10000,
) -> Generator[Tuple[_T, ...], None, None]:
"""Chunk an iterable."""
it = iter(iterable)
while True:
Expand Down Expand Up @@ -84,19 +89,19 @@ class Methods(str, Enum):

@contextlib.contextmanager
def open_std(
filename: str,
filename: Optional[str],
mode: str = "r",
*args: Any,
**kwargs: Any,
) -> Generator[Any, None, None]:
"""Open files and i/o streams transparently."""
fh: Union[TextIO, BinaryIO]
if (
filename is None
or filename == "-"
or filename == "stdin"
or filename == "stdout"
):
if filename in {
None,
"-",
"stdin",
"stdout",
}:
stream = sys.stdin if "r" in mode else sys.stdout
fh = stream.buffer if "b" in mode else stream
close = False
Expand All @@ -114,13 +119,15 @@ def open_std(
pass


def read_json(file: Union[Path, str, Iterator[Any]] = "stdin") -> Iterable:
_ReadJsonFileType = Union[str, Iterable[Union[Dict, bytes, bytearray, memoryview, str]]]


def read_json(file: _ReadJsonFileType = "stdin") -> Generator[Any, None, None]:
"""Load data from an ndjson or json file."""
if file is None:
file = "stdin"
if isinstance(file, str):
open_file: Any = open_std(file, "r")
with open_file as f:
with open_std(file, "r") as f:
# Try reading line by line as ndjson
try:
for line in f:
Expand All @@ -146,6 +153,8 @@ def read_json(file: Union[Path, str, Iterator[Any]] = "stdin") -> Iterable:
yield line
else:
yield orjson.loads(line)
else:
raise TypeError(f"Unsupported read json from file of type {type(file)}")


class Loader:
Expand Down Expand Up @@ -197,7 +206,7 @@ def collection_json(self, collection_id: str) -> Tuple[Dict[str, Any], int, str]

def load_collections(
self,
file: Union[Path, str, Iterator[Any]] = "stdin",
file: _ReadJsonFileType = "stdin",
insert_mode: Optional[Methods] = Methods.insert,
) -> None:
"""Load a collections json or ndjson file."""
Expand Down Expand Up @@ -548,12 +557,14 @@ def _partition_update(self, item: Dict[str, Any]) -> str:

return partition_name

def read_dehydrated(self, file: Union[Path, str] = "stdin") -> Generator:
def read_dehydrated(
self,
file: str = "stdin",
) -> Generator[Dict[str, Any], None, None]:
if file is None:
file = "stdin"
if isinstance(file, str):
open_file: Any = open_std(file, "r")
with open_file as f:
with open_std(file, "r") as f:
# Note: if 'content' is changed to be anything
# but the last field, the logic below will break.
fields = [
Expand Down Expand Up @@ -581,19 +592,23 @@ def read_dehydrated(self, file: Union[Path, str] = "stdin") -> Generator:
item[field] = tab_split[i]
item["partition"] = self._partition_update(item)
yield item
else:
raise TypeError(
f"Unsupported read dehydrated from file of type {type(file)}",
)

def read_hydrated(
self,
file: Union[Path, str, Iterator[Any]] = "stdin",
) -> Generator:
file: _ReadJsonFileType = "stdin",
) -> Generator[Dict[str, Any], None, None]:
for line in read_json(file):
item = self.format_item(line)
item["partition"] = self._partition_update(item)
yield item

def load_items(
self,
file: Union[Path, str, Iterator[Any]] = "stdin",
file: _ReadJsonFileType = "stdin",
insert_mode: Optional[Methods] = Methods.insert,
dehydrated: Optional[bool] = False,
chunksize: Optional[int] = 10000,
Expand All @@ -619,7 +634,7 @@ def load_items(

logger.debug(f"Adding data to database took {time.perf_counter() - t} seconds.")

def format_item(self, _item: Union[Path, str, Dict[str, Any]]) -> Dict[str, Any]:
def format_item(self, _item: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
"""Format an item to insert into a record."""
out: Dict[str, Any] = {}
item: Dict[str, Any]
Expand Down