Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- uses: actions/checkout@v6

- name: Install uv
uses: astral-sh/setup-uv@v8
uses: astral-sh/setup-uv@v7
with:
version: "latest"

Expand All @@ -40,7 +40,7 @@ jobs:
- uses: actions/checkout@v6

- name: Install uv
uses: astral-sh/setup-uv@v8
uses: astral-sh/setup-uv@v7
with:
version: "latest"

Expand All @@ -65,7 +65,7 @@ jobs:
- uses: actions/checkout@v6

- name: Install uv
uses: astral-sh/setup-uv@v8
uses: astral-sh/setup-uv@v7
with:
version: "latest"

Expand All @@ -87,7 +87,7 @@ jobs:
- uses: actions/checkout@v6

- name: Install uv
uses: astral-sh/setup-uv@v8
uses: astral-sh/setup-uv@v7
with:
version: "latest"

Expand Down Expand Up @@ -121,7 +121,7 @@ jobs:
- uses: actions/checkout@v6

- name: Install uv
uses: astral-sh/setup-uv@v8
uses: astral-sh/setup-uv@v7
with:
version: "latest"

Expand All @@ -148,7 +148,7 @@ jobs:
- uses: actions/checkout@v6

- name: Install uv
uses: astral-sh/setup-uv@v8
uses: astral-sh/setup-uv@v7
with:
version: "latest"

Expand Down Expand Up @@ -180,7 +180,7 @@ jobs:
fetch-depth: 0

- name: Install uv
uses: astral-sh/setup-uv@v8
uses: astral-sh/setup-uv@v7
with:
version: "latest"

Expand Down
42 changes: 41 additions & 1 deletion src/ml4t/data/core/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,51 @@

from __future__ import annotations

from typing import ClassVar
from typing import Any, ClassVar

import polars as pl


def _build_fill_expression(name: str, value: Any, dtype: pl.DataType | None) -> pl.Expr:
if dtype is not None:
return pl.lit(value, dtype=dtype).alias(name)
return pl.lit(value).alias(name)


def _align_frame_columns(
df: pl.DataFrame,
columns: list[str],
schema: dict[str, pl.DataType],
fill_values: dict[str, Any] | None = None,
) -> pl.DataFrame:
fill_values = fill_values or {}
missing = [col for col in columns if col not in df.columns]

if missing:
df = df.with_columns(
[_build_fill_expression(col, fill_values.get(col), schema.get(col)) for col in missing]
)

return df.select(columns)


def align_frames_for_concat(
left: pl.DataFrame,
right: pl.DataFrame,
*,
left_fill_values: dict[str, Any] | None = None,
right_fill_values: dict[str, Any] | None = None,
) -> tuple[pl.DataFrame, pl.DataFrame]:
"""Align two DataFrames to a shared column set for safe vertical concatenation."""
columns = [*left.columns, *[col for col in right.columns if col not in left.columns]]
schema = {**right.schema, **left.schema}

return (
_align_frame_columns(left, columns, schema, left_fill_values),
_align_frame_columns(right, columns, schema, right_fill_values),
)


class MultiAssetSchema:
"""Schema definition and validation for multi-asset DataFrames.

Expand Down
7 changes: 6 additions & 1 deletion src/ml4t/data/etfs/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import structlog
import yaml

from ml4t.data.core.schemas import align_frames_for_concat
from ml4t.data.storage.data_profile import (
ProfileMixin,
generate_profile,
Expand Down Expand Up @@ -259,6 +260,7 @@ def update(self) -> dict[str, int]:
# Ensure datetime precision matches (existing may be ns, new data is μs)
existing = existing.with_columns(pl.col("timestamp").cast(pl.Datetime("us")))
new_data = new_data.with_columns(pl.col("timestamp").cast(pl.Datetime("us")))
existing, new_data = align_frames_for_concat(existing, new_data)
combined = pl.concat([existing, new_data])
combined = combined.unique(subset=["timestamp"], maintain_order=True)
combined = combined.sort("timestamp")
Expand Down Expand Up @@ -400,7 +402,10 @@ def load_ohlcv(self, symbol: str) -> pl.DataFrame:
if "symbol" not in df.columns:
df = df.with_columns(pl.lit(symbol).alias("symbol"))

return df.sort("timestamp")
columns = [col for col in OHLCV_SCHEMA if col in df.columns]
columns.extend(col for col in df.columns if col not in OHLCV_SCHEMA)

return df.sort("timestamp").select(columns)

def load_symbols(self, symbols: list[str]) -> pl.DataFrame:
"""Load OHLCV data for multiple symbols.
Expand Down
2 changes: 2 additions & 0 deletions src/ml4t/data/futures/book_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import structlog
import yaml

from ml4t.data.core.schemas import align_frames_for_concat
from ml4t.data.storage.data_profile import (
DatasetProfile,
generate_profile,
Expand Down Expand Up @@ -267,6 +268,7 @@ def download_product_ohlcv(
# Check if file exists and merge
if output_path.exists():
existing = pl.read_parquet(output_path)
existing, year_df = align_frames_for_concat(existing, year_df)
year_df = pl.concat([existing, year_df])
year_df = year_df.unique(subset=["ts_event", "symbol"], keep="last")

Expand Down
8 changes: 3 additions & 5 deletions src/ml4t/data/futures/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from databento import Historical
from databento.common.error import BentoClientError, BentoServerError

from ml4t.data.core.schemas import align_frames_for_concat

from .config import (
DEFAULT_PRODUCTS,
DefinitionsConfig,
Expand Down Expand Up @@ -802,11 +804,7 @@ def _update_product(self, product: str) -> bool:

# Merge and deduplicate
if existing_data.height > 0:
# Ensure same columns
common_cols = set(existing_data.columns) & set(new_data.columns)
existing_data = existing_data.select(sorted(common_cols))
new_data = new_data.select(sorted(common_cols))

existing_data, new_data = align_frames_for_concat(existing_data, new_data)
merged = pl.concat([existing_data, new_data])

# Deduplicate based on key columns
Expand Down
19 changes: 7 additions & 12 deletions src/ml4t/data/managers/storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import structlog
from tenacity import RetryError

from ml4t.data.core.schemas import align_frames_for_concat

if TYPE_CHECKING:
from ml4t.data.managers.fetch_manager import FetchManager

Expand Down Expand Up @@ -151,18 +153,11 @@ def _merge_data(self, existing: pl.DataFrame, new: pl.DataFrame) -> pl.DataFrame
Returns:
Merged DataFrame with duplicates removed
"""
# Ensure new_df has same columns as existing_df for concatenation
for col in existing.columns:
if col not in new.columns:
if col == "dividends":
new = new.with_columns(pl.lit(0.0).alias(col))
elif col == "splits":
new = new.with_columns(pl.lit(1.0).alias(col))
else:
new = new.with_columns(pl.lit(None).alias(col))

# Ensure column order matches
new = new.select(existing.columns)
existing, new = align_frames_for_concat(
existing,
new,
right_fill_values={"dividends": 0.0, "splits": 1.0},
Comment on lines +156 to +159

Copilot AI Apr 29, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_merge_data now preserves columns that exist only in new, but the current call only supplies right_fill_values. If new introduces optional equity columns like dividends/splits that were absent from existing (e.g., older parquet written without them), align_frames_for_concat will add those columns to existing filled with nulls. Downstream logic typically treats missing dividends as 0.0 and splits as 1.0 (see prior behavior in this method), so consider also passing the same defaults via left_fill_values to avoid nulls in historical rows when these columns are introduced later.

Suggested change
existing, new = align_frames_for_concat(
existing,
new,
right_fill_values={"dividends": 0.0, "splits": 1.0},
fill_values = {"dividends": 0.0, "splits": 1.0}
existing, new = align_frames_for_concat(
existing,
new,
left_fill_values=fill_values,
right_fill_values=fill_values,

Copilot uses AI. Check for mistakes.
)

# Merge data: concatenate and remove duplicates
merged_df = pl.concat([existing, new])
Expand Down
2 changes: 2 additions & 0 deletions src/ml4t/data/storage/chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import structlog

from ml4t.data.core.models import DataObject, Metadata
from ml4t.data.core.schemas import align_frames_for_concat
from ml4t.data.utils.locking import file_lock

logger = structlog.get_logger()
Expand Down Expand Up @@ -448,6 +449,7 @@ def write(self, data_object: DataObject) -> str:
existing_df = pl.read_parquet(chunk_path)

# Merge data
existing_df, chunk_df = align_frames_for_concat(existing_df, chunk_df)
merged_df = (
pl.concat([existing_df, chunk_df])
.unique(
Expand Down
3 changes: 3 additions & 0 deletions src/ml4t/data/storage/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import polars as pl

from ml4t.data.core.schemas import align_frames_for_concat

from .backend import StorageBackend, StorageConfig

if TYPE_CHECKING:
Expand Down Expand Up @@ -510,6 +512,7 @@ def update_combined_file(
# Read existing data
if self.exists(key):
existing_df = self.read(key).collect()
existing_df, data = align_frames_for_concat(existing_df, data)
combined = pl.concat([existing_df, data])
else:
combined = data
Expand Down
5 changes: 5 additions & 0 deletions src/ml4t/data/update_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import polars as pl
import structlog

from ml4t.data.core.schemas import align_frames_for_concat
from ml4t.data.providers.base import BaseProvider
from ml4t.data.storage.hive import HiveStorage
from ml4t.data.storage.metadata_tracker import MetadataTracker, UpdateRecord
Expand Down Expand Up @@ -458,6 +459,7 @@ def apply_strategy(

if not new_rows.is_empty():
# Append new rows
existing_df, new_rows = align_frames_for_concat(existing_df, new_rows)
combined = pl.concat([existing_df, new_rows])
storage.delete(key)
storage.write(combined, key)
Expand Down Expand Up @@ -511,6 +513,7 @@ def apply_strategy(

if not gap_data.is_empty():
# Merge with existing data
existing_df, gap_data = align_frames_for_concat(existing_df, gap_data)
combined = pl.concat([existing_df, gap_data]).sort("timestamp")
storage.delete(key)
storage.write(combined, key)
Expand Down Expand Up @@ -564,6 +567,7 @@ def apply_strategy(
)

# Combine all data
existing_filtered, new_data = align_frames_for_concat(existing_filtered, new_data)
combined = pl.concat([existing_filtered, new_data]).sort("timestamp")

storage.delete(key)
Expand All @@ -578,6 +582,7 @@ def apply_strategy(
rows_after=len(combined),
)
# No overlap, just append
existing_df, new_data = align_frames_for_concat(existing_df, new_data)
combined = pl.concat([existing_df, new_data]).sort("timestamp")
storage.delete(key)
storage.write(combined, key)
Expand Down
68 changes: 67 additions & 1 deletion tests/core/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import polars as pl
import pytest

from ml4t.data.core.schemas import MultiAssetSchema
from ml4t.data.core.schemas import MultiAssetSchema, align_frames_for_concat


class TestMultiAssetSchemaConstants:
Expand Down Expand Up @@ -616,6 +616,72 @@ def test_cast_to_schema_no_casting_needed(self):
assert df_cast.equals(df)


class TestAlignFramesForConcat:
"""Test frame alignment for concat-safe updates."""

def test_aligns_reordered_columns(self):
left = pl.DataFrame(
{
"timestamp": [datetime(2024, 1, 1, tzinfo=UTC)],
"open": [100.0],
"close": [101.0],
}
)
right = pl.DataFrame(
{
"close": [102.0],
"timestamp": [datetime(2024, 1, 2, tzinfo=UTC)],
"open": [101.0],
}
)

aligned_left, aligned_right = align_frames_for_concat(left, right)

assert aligned_left.columns == ["timestamp", "open", "close"]
assert aligned_right.columns == ["timestamp", "open", "close"]

def test_preserves_new_columns_and_fills_existing_rows(self):
left = pl.DataFrame(
{
"timestamp": [datetime(2024, 1, 1, tzinfo=UTC)],
"close": [101.0],
}
)
right = pl.DataFrame(
{
"close": [102.0],
"timestamp": [datetime(2024, 1, 2, tzinfo=UTC)],
"volume": [1000.0],
}
)

aligned_left, aligned_right = align_frames_for_concat(left, right)

assert aligned_left.columns == ["timestamp", "close", "volume"]
assert aligned_right.columns == ["timestamp", "close", "volume"]
assert aligned_left["volume"].to_list() == [None]
assert aligned_right["volume"].to_list() == [1000.0]

def test_applies_fill_values_for_missing_columns(self):
left = pl.DataFrame(
{
"timestamp": [datetime(2024, 1, 1, tzinfo=UTC)],
"dividends": [0.25],
"splits": [2.0],
}
)
right = pl.DataFrame({"timestamp": [datetime(2024, 1, 2, tzinfo=UTC)]})

_, aligned_right = align_frames_for_concat(
left,
right,
right_fill_values={"dividends": 0.0, "splits": 1.0},
)

assert aligned_right["dividends"].to_list() == [0.0]
assert aligned_right["splits"].to_list() == [1.0]


class TestIntegration:
"""Integration tests combining multiple operations."""

Expand Down
Loading
Loading