diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0a81b1c..9dcbd98 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: - uses: actions/checkout@v6 - name: Install uv - uses: astral-sh/setup-uv@v8 + uses: astral-sh/setup-uv@v7 with: version: "latest" @@ -40,7 +40,7 @@ jobs: - uses: actions/checkout@v6 - name: Install uv - uses: astral-sh/setup-uv@v8 + uses: astral-sh/setup-uv@v7 with: version: "latest" @@ -65,7 +65,7 @@ jobs: - uses: actions/checkout@v6 - name: Install uv - uses: astral-sh/setup-uv@v8 + uses: astral-sh/setup-uv@v7 with: version: "latest" @@ -87,7 +87,7 @@ jobs: - uses: actions/checkout@v6 - name: Install uv - uses: astral-sh/setup-uv@v8 + uses: astral-sh/setup-uv@v7 with: version: "latest" @@ -121,7 +121,7 @@ jobs: - uses: actions/checkout@v6 - name: Install uv - uses: astral-sh/setup-uv@v8 + uses: astral-sh/setup-uv@v7 with: version: "latest" @@ -148,7 +148,7 @@ jobs: - uses: actions/checkout@v6 - name: Install uv - uses: astral-sh/setup-uv@v8 + uses: astral-sh/setup-uv@v7 with: version: "latest" @@ -180,7 +180,7 @@ jobs: fetch-depth: 0 - name: Install uv - uses: astral-sh/setup-uv@v8 + uses: astral-sh/setup-uv@v7 with: version: "latest" diff --git a/src/ml4t/data/core/schemas.py b/src/ml4t/data/core/schemas.py index fd0d2f1..ddba0ac 100644 --- a/src/ml4t/data/core/schemas.py +++ b/src/ml4t/data/core/schemas.py @@ -19,11 +19,51 @@ from __future__ import annotations -from typing import ClassVar +from typing import Any, ClassVar import polars as pl +def _build_fill_expression(name: str, value: Any, dtype: pl.DataType | None) -> pl.Expr: + if dtype is not None: + return pl.lit(value, dtype=dtype).alias(name) + return pl.lit(value).alias(name) + + +def _align_frame_columns( + df: pl.DataFrame, + columns: list[str], + schema: dict[str, pl.DataType], + fill_values: dict[str, Any] | None = None, +) -> pl.DataFrame: + fill_values = fill_values or {} + missing = [col for col in columns if col not in df.columns] + + if missing: + df = df.with_columns( + [_build_fill_expression(col, fill_values.get(col), schema.get(col)) for col in missing] + ) + + return df.select(columns) + + +def align_frames_for_concat( + left: pl.DataFrame, + right: pl.DataFrame, + *, + left_fill_values: dict[str, Any] | None = None, + right_fill_values: dict[str, Any] | None = None, +) -> tuple[pl.DataFrame, pl.DataFrame]: + """Align two DataFrames to a shared column set for safe vertical concatenation.""" + columns = [*left.columns, *[col for col in right.columns if col not in left.columns]] + schema = {**right.schema, **left.schema} + + return ( + _align_frame_columns(left, columns, schema, left_fill_values), + _align_frame_columns(right, columns, schema, right_fill_values), + ) + + class MultiAssetSchema: """Schema definition and validation for multi-asset DataFrames. diff --git a/src/ml4t/data/etfs/downloader.py b/src/ml4t/data/etfs/downloader.py index 863d6a0..dd7d221 100644 --- a/src/ml4t/data/etfs/downloader.py +++ b/src/ml4t/data/etfs/downloader.py @@ -39,6 +39,7 @@ import structlog import yaml +from ml4t.data.core.schemas import align_frames_for_concat from ml4t.data.storage.data_profile import ( ProfileMixin, generate_profile, @@ -259,6 +260,7 @@ def update(self) -> dict[str, int]: # Ensure datetime precision matches (existing may be ns, new data is μs) existing = existing.with_columns(pl.col("timestamp").cast(pl.Datetime("us"))) new_data = new_data.with_columns(pl.col("timestamp").cast(pl.Datetime("us"))) + existing, new_data = align_frames_for_concat(existing, new_data) combined = pl.concat([existing, new_data]) combined = combined.unique(subset=["timestamp"], maintain_order=True) combined = combined.sort("timestamp") @@ -400,7 +402,10 @@ def load_ohlcv(self, symbol: str) -> pl.DataFrame: if "symbol" not in df.columns: df = df.with_columns(pl.lit(symbol).alias("symbol")) - return df.sort("timestamp") + columns = [col for col in OHLCV_SCHEMA if col in df.columns] + columns.extend(col for col in df.columns if col not in OHLCV_SCHEMA) + + return df.sort("timestamp").select(columns) def load_symbols(self, symbols: list[str]) -> pl.DataFrame: """Load OHLCV data for multiple symbols. diff --git a/src/ml4t/data/futures/book_downloader.py b/src/ml4t/data/futures/book_downloader.py index 625d5b6..a7b901a 100644 --- a/src/ml4t/data/futures/book_downloader.py +++ b/src/ml4t/data/futures/book_downloader.py @@ -40,6 +40,7 @@ import structlog import yaml +from ml4t.data.core.schemas import align_frames_for_concat from ml4t.data.storage.data_profile import ( DatasetProfile, generate_profile, @@ -267,6 +268,7 @@ def download_product_ohlcv( # Check if file exists and merge if output_path.exists(): existing = pl.read_parquet(output_path) + existing, year_df = align_frames_for_concat(existing, year_df) year_df = pl.concat([existing, year_df]) year_df = year_df.unique(subset=["ts_event", "symbol"], keep="last") diff --git a/src/ml4t/data/futures/downloader.py b/src/ml4t/data/futures/downloader.py index 9f22bcb..fa17f79 100644 --- a/src/ml4t/data/futures/downloader.py +++ b/src/ml4t/data/futures/downloader.py @@ -38,6 +38,8 @@ from databento import Historical from databento.common.error import BentoClientError, BentoServerError +from ml4t.data.core.schemas import align_frames_for_concat + from .config import ( DEFAULT_PRODUCTS, DefinitionsConfig, @@ -802,11 +804,7 @@ def _update_product(self, product: str) -> bool: # Merge and deduplicate if existing_data.height > 0: - # Ensure same columns - common_cols = set(existing_data.columns) & set(new_data.columns) - existing_data = existing_data.select(sorted(common_cols)) - new_data = new_data.select(sorted(common_cols)) - + existing_data, new_data = align_frames_for_concat(existing_data, new_data) merged = pl.concat([existing_data, new_data]) # Deduplicate based on key columns diff --git a/src/ml4t/data/managers/storage_manager.py b/src/ml4t/data/managers/storage_manager.py index 579e674..a995c25 100644 --- a/src/ml4t/data/managers/storage_manager.py +++ b/src/ml4t/data/managers/storage_manager.py @@ -17,6 +17,8 @@ import structlog from tenacity import RetryError +from ml4t.data.core.schemas import align_frames_for_concat + if TYPE_CHECKING: from ml4t.data.managers.fetch_manager import FetchManager @@ -151,18 +153,11 @@ def _merge_data(self, existing: pl.DataFrame, new: pl.DataFrame) -> pl.DataFrame Returns: Merged DataFrame with duplicates removed """ - # Ensure new_df has same columns as existing_df for concatenation - for col in existing.columns: - if col not in new.columns: - if col == "dividends": - new = new.with_columns(pl.lit(0.0).alias(col)) - elif col == "splits": - new = new.with_columns(pl.lit(1.0).alias(col)) - else: - new = new.with_columns(pl.lit(None).alias(col)) - - # Ensure column order matches - new = new.select(existing.columns) + existing, new = align_frames_for_concat( + existing, + new, + right_fill_values={"dividends": 0.0, "splits": 1.0}, + ) # Merge data: concatenate and remove duplicates merged_df = pl.concat([existing, new]) diff --git a/src/ml4t/data/storage/chunked.py b/src/ml4t/data/storage/chunked.py index 912a825..94bef11 100644 --- a/src/ml4t/data/storage/chunked.py +++ b/src/ml4t/data/storage/chunked.py @@ -10,6 +10,7 @@ import structlog from ml4t.data.core.models import DataObject, Metadata +from ml4t.data.core.schemas import align_frames_for_concat from ml4t.data.utils.locking import file_lock logger = structlog.get_logger() @@ -448,6 +449,7 @@ def write(self, data_object: DataObject) -> str: existing_df = pl.read_parquet(chunk_path) # Merge data + existing_df, chunk_df = align_frames_for_concat(existing_df, chunk_df) merged_df = ( pl.concat([existing_df, chunk_df]) .unique( diff --git a/src/ml4t/data/storage/hive.py b/src/ml4t/data/storage/hive.py index ff3f1b2..1612761 100644 --- a/src/ml4t/data/storage/hive.py +++ b/src/ml4t/data/storage/hive.py @@ -14,6 +14,8 @@ import polars as pl +from ml4t.data.core.schemas import align_frames_for_concat + from .backend import StorageBackend, StorageConfig if TYPE_CHECKING: @@ -510,6 +512,7 @@ def update_combined_file( # Read existing data if self.exists(key): existing_df = self.read(key).collect() + existing_df, data = align_frames_for_concat(existing_df, data) combined = pl.concat([existing_df, data]) else: combined = data diff --git a/src/ml4t/data/update_manager.py b/src/ml4t/data/update_manager.py index 0e31c93..e6b6396 100644 --- a/src/ml4t/data/update_manager.py +++ b/src/ml4t/data/update_manager.py @@ -11,6 +11,7 @@ import polars as pl import structlog +from ml4t.data.core.schemas import align_frames_for_concat from ml4t.data.providers.base import BaseProvider from ml4t.data.storage.hive import HiveStorage from ml4t.data.storage.metadata_tracker import MetadataTracker, UpdateRecord @@ -458,6 +459,7 @@ def apply_strategy( if not new_rows.is_empty(): # Append new rows + existing_df, new_rows = align_frames_for_concat(existing_df, new_rows) combined = pl.concat([existing_df, new_rows]) storage.delete(key) storage.write(combined, key) @@ -511,6 +513,7 @@ def apply_strategy( if not gap_data.is_empty(): # Merge with existing data + existing_df, gap_data = align_frames_for_concat(existing_df, gap_data) combined = pl.concat([existing_df, gap_data]).sort("timestamp") storage.delete(key) storage.write(combined, key) @@ -564,6 +567,7 @@ def apply_strategy( ) # Combine all data + existing_filtered, new_data = align_frames_for_concat(existing_filtered, new_data) combined = pl.concat([existing_filtered, new_data]).sort("timestamp") storage.delete(key) @@ -578,6 +582,7 @@ def apply_strategy( rows_after=len(combined), ) # No overlap, just append + existing_df, new_data = align_frames_for_concat(existing_df, new_data) combined = pl.concat([existing_df, new_data]).sort("timestamp") storage.delete(key) storage.write(combined, key) diff --git a/tests/core/test_schemas.py b/tests/core/test_schemas.py index 9f54f62..9a4eab3 100644 --- a/tests/core/test_schemas.py +++ b/tests/core/test_schemas.py @@ -5,7 +5,7 @@ import polars as pl import pytest -from ml4t.data.core.schemas import MultiAssetSchema +from ml4t.data.core.schemas import MultiAssetSchema, align_frames_for_concat class TestMultiAssetSchemaConstants: @@ -616,6 +616,72 @@ def test_cast_to_schema_no_casting_needed(self): assert df_cast.equals(df) +class TestAlignFramesForConcat: + """Test frame alignment for concat-safe updates.""" + + def test_aligns_reordered_columns(self): + left = pl.DataFrame( + { + "timestamp": [datetime(2024, 1, 1, tzinfo=UTC)], + "open": [100.0], + "close": [101.0], + } + ) + right = pl.DataFrame( + { + "close": [102.0], + "timestamp": [datetime(2024, 1, 2, tzinfo=UTC)], + "open": [101.0], + } + ) + + aligned_left, aligned_right = align_frames_for_concat(left, right) + + assert aligned_left.columns == ["timestamp", "open", "close"] + assert aligned_right.columns == ["timestamp", "open", "close"] + + def test_preserves_new_columns_and_fills_existing_rows(self): + left = pl.DataFrame( + { + "timestamp": [datetime(2024, 1, 1, tzinfo=UTC)], + "close": [101.0], + } + ) + right = pl.DataFrame( + { + "close": [102.0], + "timestamp": [datetime(2024, 1, 2, tzinfo=UTC)], + "volume": [1000.0], + } + ) + + aligned_left, aligned_right = align_frames_for_concat(left, right) + + assert aligned_left.columns == ["timestamp", "close", "volume"] + assert aligned_right.columns == ["timestamp", "close", "volume"] + assert aligned_left["volume"].to_list() == [None] + assert aligned_right["volume"].to_list() == [1000.0] + + def test_applies_fill_values_for_missing_columns(self): + left = pl.DataFrame( + { + "timestamp": [datetime(2024, 1, 1, tzinfo=UTC)], + "dividends": [0.25], + "splits": [2.0], + } + ) + right = pl.DataFrame({"timestamp": [datetime(2024, 1, 2, tzinfo=UTC)]}) + + _, aligned_right = align_frames_for_concat( + left, + right, + right_fill_values={"dividends": 0.0, "splits": 1.0}, + ) + + assert aligned_right["dividends"].to_list() == [0.0] + assert aligned_right["splits"].to_list() == [1.0] + + class TestIntegration: """Integration tests combining multiple operations.""" diff --git a/tests/test_etf_downloader.py b/tests/test_etf_downloader.py index 868d3d2..f4cd022 100644 --- a/tests/test_etf_downloader.py +++ b/tests/test_etf_downloader.py @@ -380,6 +380,42 @@ def test_update_no_existing_data(self, manager, temp_storage): assert "AAPL" in stats or "MSFT" in stats + def test_update_handles_reordered_provider_columns(self, manager, temp_storage): + """Test update succeeds when provider returns columns in a different order.""" + existing_dir = temp_storage / "ohlcv_1d" / "ticker=AAPL" + existing_dir.mkdir(parents=True) + pl.DataFrame( + { + "timestamp": [datetime(2024, 1, 1)], + "open": [100.0], + "high": [102.0], + "low": [99.0], + "close": [101.0], + "volume": [1000.0], + } + ).write_parquet(existing_dir / "data.parquet") + + mock_provider = MagicMock() + mock_provider.fetch_ohlcv.return_value = pl.DataFrame( + { + "close": [102.0], + "timestamp": [datetime(2024, 1, 2)], + "volume": [1100.0], + "open": [101.0], + "low": [100.0], + "high": [103.0], + } + ) + + with patch("ml4t.data.providers.yahoo.YahooFinanceProvider") as mock_yf: + mock_yf.return_value = mock_provider + stats = manager.update() + + saved = pl.read_parquet(existing_dir / "data.parquet") + assert stats["AAPL"] == 1 + assert saved.columns == ["timestamp", "open", "high", "low", "close", "volume"] + assert saved["timestamp"].to_list() == [datetime(2024, 1, 1), datetime(2024, 1, 2)] + def test_save_metadata(self, manager, temp_storage): """Test _save_metadata creates JSON file.""" manager._save_metadata() diff --git a/tests/test_update_manager.py b/tests/test_update_manager.py index e2d170e..e315bbd 100644 --- a/tests/test_update_manager.py +++ b/tests/test_update_manager.py @@ -498,6 +498,22 @@ def test_append_only_strategy( # Should only add data after Jan 15 assert result.rows_updated == 0 + def test_append_only_strategy_with_reordered_columns( + self, storage: HiveStorage, existing_data: pl.DataFrame, new_data: pl.DataFrame + ): + """APPEND_ONLY tolerates provider column reordering.""" + updater = IncrementalUpdater() + storage.write(existing_data, "test") + + reordered = new_data.select(["close", "timestamp", "volume", "open", "low", "high"]) + + result = updater.apply_strategy(storage, "test", reordered, UpdateStrategy.APPEND_ONLY) + + assert result.success is True + stored = storage.read("test").collect().sort("timestamp") + assert stored.columns == existing_data.columns + assert stored["timestamp"].max() == reordered["timestamp"].max() + def test_incremental_strategy( self, storage: HiveStorage, existing_data: pl.DataFrame, new_data: pl.DataFrame ):