From e9b6c75f104ad81982e0b3a8f3530f8370ce7782 Mon Sep 17 00:00:00 2001 From: Stefan Jansen Date: Sat, 16 May 2026 15:18:20 -0400 Subject: [PATCH] feat: add per-symbol percentage margin schedule --- README.md | 2 +- src/ml4t/backtest/accounting/policy.py | 32 ++++++++- src/ml4t/backtest/broker.py | 9 ++- src/ml4t/backtest/config.py | 17 ++++- src/ml4t/backtest/datafeed.py | 1 + src/ml4t/backtest/execution/rebalancer.py | 4 +- src/ml4t/backtest/execution/schedule.py | 1 + src/ml4t/backtest/result.py | 1 + src/ml4t/backtest/types.py | 9 +++ .../accounting/test_margin_account_policy.py | 68 +++++++++++++++++++ tests/contracts/test_execution_contracts.py | 2 +- tests/execution/test_rebalancer.py | 2 +- tests/execution/test_rebalancer_futures.py | 60 ++++++++++++++++ tests/execution/test_schedule.py | 2 +- tests/test_artifact_spec.py | 7 +- tests/test_broker.py | 2 +- tests/test_config_wiring.py | 40 +++++++---- tests/test_core.py | 2 +- tests/test_datafeed_memory.py | 2 +- tests/test_equity_curve.py | 2 +- tests/test_result.py | 2 +- tests/test_strategy_templates.py | 2 +- 22 files changed, 236 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 3971b013..022b20f0 100644 --- a/README.md +++ b/README.md @@ -337,7 +337,7 @@ Benchmark on 250 assets x 20 years daily data (1.26M bars): ## Development ```bash -git clone https://github.com/ml4t/ml4t-backtest.git +git clone https://github.com/ml4t/backtest.git cd ml4t-backtest uv sync uv run pytest tests/ -q diff --git a/src/ml4t/backtest/accounting/policy.py b/src/ml4t/backtest/accounting/policy.py index 722b7dd2..77fc0131 100644 --- a/src/ml4t/backtest/accounting/policy.py +++ b/src/ml4t/backtest/accounting/policy.py @@ -229,6 +229,7 @@ def __init__( long_maintenance_margin: float = 0.25, short_maintenance_margin: float = 0.30, fixed_margin_schedule: dict[str, tuple[float, float]] | None = None, + margin_pct_schedule: dict[str, tuple[float, float]] | None = None, short_cash_policy: str = "credit", ) -> None: """Initialize unified account policy. @@ -246,6 +247,10 @@ def __init__( fixed_margin_schedule: Per-asset fixed dollar margin for futures. - Dict mapping asset symbol to (initial, maintenance) tuple - Example: {"ES": (12000, 6000)} + margin_pct_schedule: Per-asset percentage-of-notional margin schedule. + - Dict mapping asset symbol to (initial, maintenance) tuple + - Percentages are fractions of notional, not whole percents + - Example: {"ES": (0.05, 0.035)} short_cash_policy: How short proceeds affect spendable cash in non-levered accounts. One of {"credit", "lock_notional"}. @@ -258,6 +263,7 @@ def __init__( self.long_maintenance_margin = long_maintenance_margin self.short_maintenance_margin = short_maintenance_margin self.fixed_margin_schedule = fixed_margin_schedule or {} + self.margin_pct_schedule = margin_pct_schedule or {} if short_cash_policy not in {"credit", "credit_proceeds", "lock_notional"}: raise ValueError( "short_cash_policy must be 'credit', 'credit_proceeds', or " @@ -265,6 +271,15 @@ def __init__( ) self.short_cash_policy = short_cash_policy + overlapping_margin_assets = sorted( + set(self.fixed_margin_schedule) & set(self.margin_pct_schedule) + ) + if overlapping_margin_assets: + raise ValueError( + "fixed_margin_schedule and margin_pct_schedule cannot both define: " + f"{overlapping_margin_assets}" + ) + # Validate margin parameters if leverage is enabled if allow_leverage: if not 0.0 < initial_margin <= 1.0: @@ -306,6 +321,7 @@ def from_config(cls, config: BacktestConfig) -> UnifiedAccountPolicy: long_maintenance_margin=config.long_maintenance_margin, short_maintenance_margin=config.short_maintenance_margin, fixed_margin_schedule=config.fixed_margin_schedule, + margin_pct_schedule=config.margin_pct_schedule, short_cash_policy=config.short_cash_policy.value, ) @@ -384,7 +400,15 @@ def get_margin_requirement( ) -> float: """Calculate margin requirement for a position. - Uses fixed dollar margin for futures, percentage for equities. + Supports three margin models: + + - ``margin_pct_schedule``: per-asset percentage-of-notional margin. + This is the preferred price-aware approximation for futures when only + a stable scan ratio is known. + - ``fixed_margin_schedule``: per-asset fixed dollar margin per contract. + This models a single historical SPAN snapshot. + - account-wide percentage margin: fallback for assets not covered by a + per-asset schedule. Args: asset: Asset symbol @@ -395,6 +419,12 @@ def get_margin_requirement( Returns: Margin required in dollars """ + # Price-aware per-asset percentage margin (preferred for futures) + if asset in self.margin_pct_schedule: + initial, maintenance = self.margin_pct_schedule[asset] + margin_rate = initial if for_initial else maintenance + return abs(quantity * price) * margin_rate + # Check for fixed margin (futures) if asset in self.fixed_margin_schedule: initial, maintenance = self.fixed_margin_schedule[asset] diff --git a/src/ml4t/backtest/broker.py b/src/ml4t/backtest/broker.py index d34eabaf..e90ae006 100644 --- a/src/ml4t/backtest/broker.py +++ b/src/ml4t/backtest/broker.py @@ -73,6 +73,7 @@ def __init__( long_maintenance_margin: float = 0.25, short_maintenance_margin: float = 0.30, fixed_margin_schedule: dict[str, tuple[float, float]] | None = None, + margin_pct_schedule: dict[str, tuple[float, float]] | None = None, short_cash_policy: ShortCashPolicy = ShortCashPolicy.CREDIT, execution_limits: ExecutionLimits | None = None, market_impact_model: MarketImpactModel | None = None, @@ -144,11 +145,14 @@ def __init__( # This lets users specify margin once on ContractSpec rather than duplicating # it in both ContractSpec and BacktestConfig.fixed_margin_schedule. effective_margin_schedule = dict(fixed_margin_schedule or {}) + effective_margin_pct_schedule = dict(margin_pct_schedule or {}) if contract_specs: for symbol, spec in contract_specs.items(): if spec.margin is not None and symbol not in effective_margin_schedule: # Use spec.margin as initial margin, 50% as maintenance (industry standard) effective_margin_schedule[symbol] = (spec.margin, spec.margin * 0.5) + if spec.margin_pct is not None and symbol not in effective_margin_pct_schedule: + effective_margin_pct_schedule[symbol] = spec.margin_pct # Create AccountState with UnifiedAccountPolicy policy: AccountPolicy = UnifiedAccountPolicy( @@ -158,6 +162,7 @@ def __init__( long_maintenance_margin=long_maintenance_margin, short_maintenance_margin=short_maintenance_margin, fixed_margin_schedule=effective_margin_schedule or None, + margin_pct_schedule=effective_margin_pct_schedule or None, short_cash_policy=short_cash_policy.value, ) @@ -174,7 +179,8 @@ def __init__( self.initial_margin = initial_margin self.long_maintenance_margin = long_maintenance_margin self.short_maintenance_margin = short_maintenance_margin - self.fixed_margin_schedule = fixed_margin_schedule or {} + self.fixed_margin_schedule = effective_margin_schedule + self.margin_pct_schedule = effective_margin_pct_schedule self.short_cash_policy = short_cash_policy # Create Gatekeeper for order validation @@ -360,6 +366,7 @@ def from_config( long_maintenance_margin=config.long_maintenance_margin, short_maintenance_margin=config.short_maintenance_margin, fixed_margin_schedule=config.fixed_margin_schedule, + margin_pct_schedule=config.margin_pct_schedule, short_cash_policy=config.short_cash_policy, execution_limits=execution_limits, market_impact_model=market_impact_model, diff --git a/src/ml4t/backtest/config.py b/src/ml4t/backtest/config.py index 51174d90..efb0123c 100644 --- a/src/ml4t/backtest/config.py +++ b/src/ml4t/backtest/config.py @@ -29,6 +29,7 @@ from typing import Any import yaml + from ml4t.specs.base import serialize_artifact_value from ml4t.specs.market_data import FeedSpec, TimestampSemantics @@ -402,6 +403,7 @@ class BacktestConfig: long_maintenance_margin: float = 0.25 # Reg T standard for longs short_maintenance_margin: float = 0.30 # Reg T standard for shorts (higher!) fixed_margin_schedule: dict[str, tuple[float, float]] | None = None # For futures + margin_pct_schedule: dict[str, tuple[float, float]] | None = None # Price-aware futures margin short_cash_policy: ShortCashPolicy = ShortCashPolicy.CREDIT # === Execution Timing === @@ -511,6 +513,15 @@ def validate(self, warn: bool = True) -> list[str]: f"initial_margin ({self.initial_margin})" ) + fixed_assets = set(self.fixed_margin_schedule or {}) + pct_assets = set(self.margin_pct_schedule or {}) + overlapping_margin_assets = sorted(fixed_assets & pct_assets) + if overlapping_margin_assets: + issues.append( + "fixed_margin_schedule and margin_pct_schedule cannot both define: " + f"{overlapping_margin_assets}" + ) + if self.settlement_delay < 0 or self.settlement_delay > 5: issues.append( f"settlement_delay ({self.settlement_delay}) should be 0-5. " @@ -717,6 +728,7 @@ def to_dict(self) -> dict: "long_maintenance_margin": self.long_maintenance_margin, "short_maintenance_margin": self.short_maintenance_margin, "fixed_margin_schedule": self.fixed_margin_schedule, + "margin_pct_schedule": self.margin_pct_schedule, "short_cash_policy": self.short_cash_policy.value, }, "execution": { @@ -826,6 +838,7 @@ def from_dict( "long_maintenance_margin", "short_maintenance_margin", "fixed_margin_schedule", + "margin_pct_schedule", "short_cash_policy", }, "execution": {"execution_price", "mark_price", "execution_mode"}, @@ -946,6 +959,7 @@ def from_dict( long_maintenance_margin=acct_cfg.get("long_maintenance_margin", 0.25), short_maintenance_margin=acct_cfg.get("short_maintenance_margin", 0.30), fixed_margin_schedule=acct_cfg.get("fixed_margin_schedule"), + margin_pct_schedule=acct_cfg.get("margin_pct_schedule"), short_cash_policy=ShortCashPolicy(acct_cfg.get("short_cash_policy", "credit")), # Execution execution_price=ExecutionPrice(exec_cfg.get("execution_price", "open")), @@ -1171,8 +1185,7 @@ def from_user_config( if node: if not isinstance(node, dict): raise TypeError( - "Expected broker-specific override to be a mapping in " - f"{assumptions_path}" + f"Expected broker-specific override to be a mapping in {assumptions_path}" ) merged_data = _deep_merge_dicts(merged_data, node) else: diff --git a/src/ml4t/backtest/datafeed.py b/src/ml4t/backtest/datafeed.py index 396252ba..e290981b 100644 --- a/src/ml4t/backtest/datafeed.py +++ b/src/ml4t/backtest/datafeed.py @@ -9,6 +9,7 @@ from typing import Any import polars as pl + from ml4t.specs.market_data import FeedSpec diff --git a/src/ml4t/backtest/execution/rebalancer.py b/src/ml4t/backtest/execution/rebalancer.py index 2c77ff5b..89f24a9a 100644 --- a/src/ml4t/backtest/execution/rebalancer.py +++ b/src/ml4t/backtest/execution/rebalancer.py @@ -362,7 +362,9 @@ def _quantize_shares(self, shares: float, broker: Broker) -> float: """ use_fractional = self.config.allow_fractional if use_fractional is None: - use_fractional = getattr(broker, "share_type", ShareType.INTEGER) == ShareType.FRACTIONAL + use_fractional = ( + getattr(broker, "share_type", ShareType.INTEGER) == ShareType.FRACTIONAL + ) if self.config.round_lots: rounded_lots = round(shares / self.config.lot_size) * self.config.lot_size diff --git a/src/ml4t/backtest/execution/schedule.py b/src/ml4t/backtest/execution/schedule.py index a9b4f463..9f713537 100644 --- a/src/ml4t/backtest/execution/schedule.py +++ b/src/ml4t/backtest/execution/schedule.py @@ -9,6 +9,7 @@ from typing import Any import polars as pl + from ml4t.specs.market_data import FeedSpec, TimestampSemantics from ..calendar import get_schedule diff --git a/src/ml4t/backtest/result.py b/src/ml4t/backtest/result.py index f44ff56c..4454f122 100644 --- a/src/ml4t/backtest/result.py +++ b/src/ml4t/backtest/result.py @@ -28,6 +28,7 @@ from typing import TYPE_CHECKING, Any, Literal import polars as pl + from ml4t.specs.market_data import FeedSpec try: diff --git a/src/ml4t/backtest/types.py b/src/ml4t/backtest/types.py index 16b42d27..99722b4c 100644 --- a/src/ml4t/backtest/types.py +++ b/src/ml4t/backtest/types.py @@ -84,6 +84,14 @@ class ContractSpec: margin=15000.0, # Initial margin per contract ) + # Price-aware margin approximation + nq_spec = ContractSpec( + symbol="NQ", + asset_class=AssetClass.FUTURE, + multiplier=20.0, + margin_pct=(0.05, 0.035), # 5.0% initial, 3.5% maintenance + ) + # Apple stock aapl_spec = ContractSpec( symbol="AAPL", @@ -98,6 +106,7 @@ class ContractSpec: multiplier: float = 1.0 # Point value ($ per point move) tick_size: float = 0.01 # Minimum price increment margin: float | None = None # Initial margin per contract (overrides account default) + margin_pct: tuple[float, float] | None = None # (initial, maintenance) fractions of notional currency: str = "USD" diff --git a/tests/accounting/test_margin_account_policy.py b/tests/accounting/test_margin_account_policy.py index 091ae0a6..ba01495c 100644 --- a/tests/accounting/test_margin_account_policy.py +++ b/tests/accounting/test_margin_account_policy.py @@ -688,6 +688,74 @@ def test_mixed_portfolio_buying_power(self): assert bp > 0 +class TestMarginAccountPolicyFuturesMarginPct: + """Tests for price-aware percentage-of-notional futures margin.""" + + def test_futures_margin_pct_initial(self): + """Initial margin should scale with notional.""" + policy = UnifiedAccountPolicy( + allow_short_selling=True, + allow_leverage=True, + margin_pct_schedule={"ES": (0.05, 0.035)}, + ) + margin = policy.get_margin_requirement("ES", 2, 5000.0, for_initial=True) + assert margin == 500.0 + + def test_futures_margin_pct_maintenance(self): + """Maintenance margin should use maintenance schedule rate.""" + policy = UnifiedAccountPolicy( + allow_short_selling=True, + allow_leverage=True, + margin_pct_schedule={"ES": (0.05, 0.035)}, + ) + margin = policy.get_margin_requirement("ES", 2, 5000.0, for_initial=False) + assert margin == pytest.approx(350.0) + + def test_futures_margin_pct_tracks_price(self): + """Percentage margin should move with price.""" + policy = UnifiedAccountPolicy( + allow_short_selling=True, + allow_leverage=True, + margin_pct_schedule={"ES": (0.05, 0.035)}, + ) + margin_low = policy.get_margin_requirement("ES", 1, 4000.0, for_initial=True) + margin_high = policy.get_margin_requirement("ES", 1, 6000.0, for_initial=True) + assert margin_low == 200.0 + assert margin_high == 300.0 + + def test_margin_pct_schedule_short_same_as_long(self): + """Percentage-based futures margin should be direction-agnostic.""" + policy = UnifiedAccountPolicy( + allow_short_selling=True, + allow_leverage=True, + margin_pct_schedule={"ES": (0.05, 0.035)}, + ) + long_margin = policy.get_margin_requirement("ES", 2, 5000.0, for_initial=True) + short_margin = policy.get_margin_requirement("ES", -2, 5000.0, for_initial=True) + assert long_margin == short_margin == 500.0 + + def test_margin_pct_schedule_takes_precedence_over_global_margin(self): + """Per-asset percentage schedule should override account-wide margin.""" + policy = UnifiedAccountPolicy( + allow_short_selling=True, + allow_leverage=True, + initial_margin=0.5, + margin_pct_schedule={"ES": (0.05, 0.035)}, + ) + margin = policy.get_margin_requirement("ES", 1, 5000.0, for_initial=True) + assert margin == 250.0 + + def test_reject_overlapping_fixed_and_percentage_margin(self): + """A symbol must not define both fixed and percentage margin models.""" + with pytest.raises(ValueError, match="cannot both define"): + UnifiedAccountPolicy( + allow_short_selling=True, + allow_leverage=True, + fixed_margin_schedule={"ES": (12_000.0, 6_000.0)}, + margin_pct_schedule={"ES": (0.05, 0.035)}, + ) + + class TestMarginAccountPolicyMarginCall: """Tests for margin call detection.""" diff --git a/tests/contracts/test_execution_contracts.py b/tests/contracts/test_execution_contracts.py index e6e4dc3a..4819131c 100644 --- a/tests/contracts/test_execution_contracts.py +++ b/tests/contracts/test_execution_contracts.py @@ -4,7 +4,6 @@ import polars as pl import pytest -from ml4t.specs.market_data import FeedSpec from ml4t.backtest.config import ( BacktestConfig, @@ -16,6 +15,7 @@ from ml4t.backtest.engine import run_backtest from ml4t.backtest.strategy import Strategy from ml4t.backtest.types import ExecutionMode +from ml4t.specs.market_data import FeedSpec def _prices() -> pl.DataFrame: diff --git a/tests/execution/test_rebalancer.py b/tests/execution/test_rebalancer.py index e58ef131..0359dda5 100644 --- a/tests/execution/test_rebalancer.py +++ b/tests/execution/test_rebalancer.py @@ -3,7 +3,6 @@ from datetime import datetime import pytest -from ml4t.specs.market_data import FeedSpec from ml4t.backtest import ( Broker, @@ -13,6 +12,7 @@ from ml4t.backtest.execution.rebalancer import RebalanceConfig, TargetWeightExecutor from ml4t.backtest.execution.schedule import RebalanceSchedule from ml4t.backtest.models import NoCommission, NoSlippage +from ml4t.specs.market_data import FeedSpec class TestRebalanceConfig: diff --git a/tests/execution/test_rebalancer_futures.py b/tests/execution/test_rebalancer_futures.py index f092e2a3..dacfc90d 100644 --- a/tests/execution/test_rebalancer_futures.py +++ b/tests/execution/test_rebalancer_futures.py @@ -7,6 +7,8 @@ from datetime import datetime +import pytest + from ml4t.backtest import Broker, OrderSide from ml4t.backtest.execution.rebalancer import RebalanceConfig, TargetWeightExecutor from ml4t.backtest.models import NoCommission, NoSlippage @@ -438,6 +440,64 @@ def test_margin_enables_leveraged_futures(self): # 2.0 * $100K / ($5000 * 50) = 0.8 contracts assert abs(orders[0].quantity - 0.8) < 0.01 + def test_margin_pct_from_contract_spec(self): + """ContractSpec.margin_pct should be wired into the broker's pct schedule.""" + es_spec = ContractSpec( + symbol="ES", + asset_class=AssetClass.FUTURE, + multiplier=50.0, + margin_pct=(0.05, 0.035), + ) + broker = Broker( + initial_cash=100_000, + commission_model=NoCommission(), + slippage_model=NoSlippage(), + contract_specs={"ES": es_spec}, + allow_leverage=True, + ) + + policy = broker.account.policy + assert policy.margin_pct_schedule is not None + assert policy.margin_pct_schedule["ES"] == (0.05, 0.035) + + def test_explicit_margin_pct_schedule_takes_precedence(self): + """Explicit margin_pct_schedule should override ContractSpec.margin_pct.""" + es_spec = ContractSpec( + symbol="ES", + asset_class=AssetClass.FUTURE, + multiplier=50.0, + margin_pct=(0.05, 0.035), + ) + broker = Broker( + initial_cash=100_000, + commission_model=NoCommission(), + slippage_model=NoSlippage(), + contract_specs={"ES": es_spec}, + margin_pct_schedule={"ES": (0.06, 0.04)}, + allow_leverage=True, + ) + + policy = broker.account.policy + assert policy.margin_pct_schedule["ES"] == (0.06, 0.04) + + def test_rejects_contract_spec_with_both_margin_models(self): + """A single asset must not activate both fixed and percentage margin paths.""" + es_spec = ContractSpec( + symbol="ES", + asset_class=AssetClass.FUTURE, + multiplier=50.0, + margin=15_000.0, + margin_pct=(0.05, 0.035), + ) + with pytest.raises(ValueError, match="cannot both define"): + Broker( + initial_cash=100_000, + commission_model=NoCommission(), + slippage_model=NoSlippage(), + contract_specs={"ES": es_spec}, + allow_leverage=True, + ) + class TestMaxGrossLeverage: """Test the max_gross_leverage safety guardrail.""" diff --git a/tests/execution/test_schedule.py b/tests/execution/test_schedule.py index 00bca0aa..9a578169 100644 --- a/tests/execution/test_schedule.py +++ b/tests/execution/test_schedule.py @@ -5,13 +5,13 @@ from datetime import UTC, datetime import polars as pl -from ml4t.specs.market_data import FeedSpec from ml4t.backtest.execution import ( RebalanceCadence, RebalanceSchedule, resolve_rebalance_timestamps, ) +from ml4t.specs.market_data import FeedSpec def _make_weekday_series(start: str, end: str) -> pl.Series: diff --git a/tests/test_artifact_spec.py b/tests/test_artifact_spec.py index a37f7d25..c6a3f093 100644 --- a/tests/test_artifact_spec.py +++ b/tests/test_artifact_spec.py @@ -2,14 +2,13 @@ from pathlib import Path -from ml4t.diagnostic.artifacts import dump_spec, load_market_data_spec, load_spec -from ml4t.engineer.artifacts import FeatureSpec, LabelSpec, PredictionSpec -from ml4t.specs import ArtifactKind, FeedSpec, MarketDataSpec, TimestampSemantics - from ml4t.backtest.spec_bridge import ( market_data_spec_to_feed_spec, market_data_spec_to_runtime_metadata, ) +from ml4t.diagnostic.artifacts import dump_spec, load_market_data_spec, load_spec +from ml4t.engineer.artifacts import FeatureSpec, LabelSpec, PredictionSpec +from ml4t.specs import ArtifactKind, FeedSpec, MarketDataSpec, TimestampSemantics def test_market_data_spec_from_mapping_normalizes_timestamp_semantics() -> None: diff --git a/tests/test_broker.py b/tests/test_broker.py index 3531246c..627efbf7 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -3,7 +3,6 @@ from datetime import datetime import pytest -from ml4t.specs.market_data import FeedSpec from ml4t.backtest.broker import Broker from ml4t.backtest.config import ShareType @@ -16,6 +15,7 @@ OrderType, Position, ) +from ml4t.specs.market_data import FeedSpec @pytest.fixture diff --git a/tests/test_config_wiring.py b/tests/test_config_wiring.py index bbee8745..fe94e508 100644 --- a/tests/test_config_wiring.py +++ b/tests/test_config_wiring.py @@ -12,7 +12,6 @@ from datetime import datetime import pytest -from ml4t.specs.market_data import FeedSpec from ml4t.backtest import ( BacktestConfig, @@ -42,6 +41,7 @@ VolumeShareSlippage, ) from ml4t.backtest.types import OrderSide, Position +from ml4t.specs.market_data import FeedSpec # --------------------------------------------------------------------------- # Helpers @@ -708,7 +708,15 @@ class TestPresetRoundTrip: @pytest.mark.parametrize( "preset_name", - ["default", "fast", "backtrader", "vectorbt", "zipline", "realistic", "ibkr_us_stocks_fixed"], + [ + "default", + "fast", + "backtrader", + "vectorbt", + "zipline", + "realistic", + "ibkr_us_stocks_fixed", + ], ) def test_preset_creates_valid_config(self, preset_name): config = BacktestConfig.from_preset(preset_name) @@ -946,6 +954,19 @@ def test_to_dict_from_dict_roundtrip(self): assert restored.immediate_fill is True assert restored.mark_price == ExecutionPrice.QUOTE_MID + def test_margin_pct_schedule_roundtrip(self): + config = BacktestConfig(margin_pct_schedule={"ES": (0.05, 0.035)}) + restored = BacktestConfig.from_dict(config.to_dict()) + assert restored.margin_pct_schedule == {"ES": (0.05, 0.035)} + + def test_validate_rejects_overlapping_margin_models(self): + config = BacktestConfig( + fixed_margin_schedule={"ES": (12_000.0, 6_000.0)}, + margin_pct_schedule={"ES": (0.05, 0.035)}, + ) + issues = config.validate(warn=False) + assert any("cannot both define" in issue for issue in issues) + class TestFromDictDefaultParity: """from_dict({}) must produce the same defaults as BacktestConfig().""" @@ -1012,7 +1033,6 @@ def test_constructor_canonicalizes_feed_spec_metadata(self): assert config.timezone == "America/New_York" assert config.data_frequency == DataFrequency.MINUTE_1 - def test_resolved_feed_spec_preserves_explicit_runtime_over_feed_metadata(self): config = BacktestConfig( timezone="UTC", @@ -1146,10 +1166,7 @@ def test_from_user_config_uses_default_assumptions_and_global_defaults(self, tmp config_dir = tmp_path / "ml4t" config_dir.mkdir() (config_dir / "defaults.yaml").write_text( - "cash:\n" - " initial: 250000\n" - "calendar:\n" - " timezone: America/New_York\n" + "cash:\n initial: 250000\ncalendar:\n timezone: America/New_York\n" ) (config_dir / "assumptions.yaml").write_text( "default_assumptions:\n" @@ -1172,10 +1189,7 @@ def test_from_user_config_broker_override_beats_global_defaults(self, tmp_path): config_dir = tmp_path / "ml4t" config_dir.mkdir() (config_dir / "defaults.yaml").write_text( - "commission:\n" - " model: none\n" - "slippage:\n" - " model: none\n" + "commission:\n model: none\nslippage:\n model: none\n" ) (config_dir / "assumptions.yaml").write_text( "default_assumptions:\n" @@ -1202,9 +1216,7 @@ def test_from_user_config_requires_complete_assumptions_tuple(self, tmp_path): config_dir = tmp_path / "ml4t" config_dir.mkdir() (config_dir / "assumptions.yaml").write_text( - "default_assumptions:\n" - " broker: ibkr\n" - " region: us\n" + "default_assumptions:\n broker: ibkr\n region: us\n" ) with pytest.raises(ValueError, match="broker, region, asset_class, and plan"): diff --git a/tests/test_core.py b/tests/test_core.py index 80fdac61..85cdabdb 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -4,7 +4,6 @@ import polars as pl import pytest -from ml4t.specs.market_data import FeedSpec from ml4t.backtest import ( Broker, @@ -24,6 +23,7 @@ SlippageType, ) from ml4t.backtest.models import PercentageCommission, VolumeShareSlippage +from ml4t.specs.market_data import FeedSpec # === Test Data Generators === diff --git a/tests/test_datafeed_memory.py b/tests/test_datafeed_memory.py index 4fcd7086..48273b94 100644 --- a/tests/test_datafeed_memory.py +++ b/tests/test_datafeed_memory.py @@ -9,10 +9,10 @@ import polars as pl import pytest -from ml4t.specs.market_data import FeedSpec from ml4t.backtest import BacktestConfig, DataFeed from ml4t.backtest.config import DataFrequency +from ml4t.specs.market_data import FeedSpec class TestDataFeedMemoryEfficiency: diff --git a/tests/test_equity_curve.py b/tests/test_equity_curve.py index 72bae584..65223c8b 100644 --- a/tests/test_equity_curve.py +++ b/tests/test_equity_curve.py @@ -3,11 +3,11 @@ from datetime import datetime, timedelta import polars as pl -from ml4t.specs.market_data import FeedSpec from ml4t.backtest import BacktestConfig, DataFeed, Engine, Strategy from ml4t.backtest.analytics.equity import EquityCurve from ml4t.backtest.config import DataFrequency +from ml4t.specs.market_data import FeedSpec class TestEquityCurveAnnualization: diff --git a/tests/test_result.py b/tests/test_result.py index d1347538..ece09e32 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -10,7 +10,6 @@ import polars as pl import pytest -from ml4t.specs.market_data import FeedSpec from ml4t.backtest.config import BacktestConfig from ml4t.backtest.result import ( @@ -18,6 +17,7 @@ enrich_trades_with_signals, ) from ml4t.backtest.types import Fill, OrderSide, Trade +from ml4t.specs.market_data import FeedSpec @pytest.fixture diff --git a/tests/test_strategy_templates.py b/tests/test_strategy_templates.py index 966cc7b2..fdf7b27a 100644 --- a/tests/test_strategy_templates.py +++ b/tests/test_strategy_templates.py @@ -4,7 +4,6 @@ import numpy as np import polars as pl -from ml4t.specs.market_data import FeedSpec from ml4t.backtest import BacktestConfig, DataFeed, Engine from ml4t.backtest.execution.schedule import RebalanceSchedule @@ -14,6 +13,7 @@ MomentumStrategy, SignalFollowingStrategy, ) +from ml4t.specs.market_data import FeedSpec def make_price_data(