Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions pinecone_datasets/catalog.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import warnings
import os
import posixpath
import json
from typing import List, Optional, Union, TYPE_CHECKING

import logging
import platform
from pydantic import BaseModel, ValidationError, Field

from .cfg import Storage
from .fs import get_cloud_fs
from .dataset import Dataset
Expand All @@ -17,7 +17,6 @@
else:
pd = None


logger = logging.getLogger(__name__)


Expand All @@ -34,12 +33,28 @@ def __init__(self, base_path: Optional[str] = None, **kwargs):
base_path: str = Field(default=None)
datasets: List[DatasetMetadata] = Field(default_factory=list)

def _join_cloud_path(self, *paths: str) -> str:
"""
Join path components for cloud storage paths using forward slashes.
This ensures consistent behavior across platforms (Windows, Linux, Mac).

Args:
*paths: Path components to join

Returns:
str: Joined path with forward slashes
"""
return posixpath.join(*paths)

def load(self, **kwargs) -> "Catalog":
"""Loads metadata about all datasets from the catalog."""
fs = get_cloud_fs(self.base_path, **kwargs)
collected_datasets = []

metadata_files_glob_path = os.path.join(self.base_path, "*", "metadata.json")
metadata_files_glob_path = self._join_cloud_path(self.base_path, "*", "metadata.json")

logger.debug(f"Searching for datasets with glob pattern: {metadata_files_glob_path}")

for metadata_path in fs.glob(metadata_files_glob_path):
with fs.open(metadata_path) as f:
try:
Expand All @@ -49,7 +64,6 @@ def load(self, **kwargs) -> "Catalog":
f"Not a JSON: Invalid metadata.json for {metadata_path}, skipping"
)
continue

try:
this_dataset = DatasetMetadata(**this_dataset_json)
collected_datasets.append(this_dataset)
Expand All @@ -63,31 +77,30 @@ def load(self, **kwargs) -> "Catalog":
logger.info(f"Loaded {len(self.datasets)} datasets from {self.base_path}")
return self

def list_datasets(self, as_df: bool) -> Union[List[str], "pd.DataFrame"]:
def list_datasets(self, as_df: bool = False) -> Union[List[str], "pd.DataFrame"]:
"""Lists all datasets in the catalog."""
if self.datasets is None or len(self.datasets) == 0:
self.load()

import pandas as pd

if as_df:
return pd.DataFrame([ds.model_dump() for ds in self.datasets])
else:
return [dataset.name for dataset in self.datasets]

def load_dataset(self, dataset_id: str, **kwargs) -> "Dataset":
"""Loads the dataset from the catalog."""
ds_path = os.path.join(str(self.base_path), dataset_id)
ds_path = self._join_cloud_path(str(self.base_path), dataset_id)
return Dataset.from_path(dataset_path=ds_path, **kwargs)

def save_dataset(
self,
dataset: "Dataset",
**kwargs,
self,
dataset: "Dataset",
**kwargs,
):
"""
Save a dataset to the catalog.
"""
ds_path = os.path.join(self.base_path, dataset.metadata.name)
ds_path = self._join_cloud_path(self.base_path, dataset.metadata.name)
DatasetFSWriter.write_dataset(dataset_path=ds_path, dataset=dataset, **kwargs)
logger.info(f"Saved dataset {dataset.metadata.name} to {ds_path}")
logger.info(f"Saved dataset {dataset.metadata.name} to {ds_path}")
6 changes: 4 additions & 2 deletions pinecone_datasets/dataset_fsreader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import posixpath
import json
import logging
import warnings
Expand Down Expand Up @@ -63,15 +64,16 @@ def _does_datatype_exist(
dataset_path: str,
data_type: Literal["documents", "queries"],
) -> bool:
return fs.exists(os.path.join(dataset_path, data_type))
path = posixpath.join(dataset_path, data_type)
return fs.exists(path)

@staticmethod
def _safe_read_from_path(
fs: CloudOrLocalFS,
dataset_path: str,
data_type: Literal["documents", "queries"],
) -> pd.DataFrame:
read_path_str = os.path.join(dataset_path, data_type, "*.parquet")
read_path_str = posixpath.join(dataset_path, data_type, "*.parquet")
read_path = fs.glob(read_path_str)
if DatasetFSReader._does_datatype_exist(fs, dataset_path, data_type):
# First, collect all the dataframes
Expand Down