Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

All notable changes to `ml4t-models` will be documented in this file.

## 0.1.0a4

- Aligned conditional autoencoder training with shuffled mini-batches, BatchNorm hidden layers,
validation-best checkpoints, and per-member extraction.
- Aligned stochastic discount factor checkpoints with phase-local training, validation-best
tracking, and legacy checkpoint-label compatibility.
- Kept IPCA factor extraction consistent with the final normalized loading matrix.

## 0.1.0a1

- Declared Python 3.14 support in package metadata.
Expand Down
2 changes: 1 addition & 1 deletion src/ml4t/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from importlib import import_module

__version__ = "0.1.0a3"
__version__ = "0.1.0a4"

from ml4t.models.api import (
AssetMapper,
Expand Down
1 change: 1 addition & 0 deletions src/ml4t/models/_internal/cae_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
in_features = n_characteristics
for units in hidden_units:
layers.append(nn.Linear(in_features, units))
layers.append(nn.BatchNorm1d(units))
layers.append(nn.ReLU())
in_features = units
layers.append(nn.Linear(in_features, n_factors))
Expand Down
3 changes: 2 additions & 1 deletion src/ml4t/models/configs/latent_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class CAEConfig(LatentFactorConfig):
default_checkpoint: int | None = None
lr: float = 1e-3
lambda_l1: float = 1e-4
batch_size: int = 10_000


@dataclass(frozen=True, slots=True)
Expand All @@ -80,7 +81,7 @@ class StochasticDiscountFactorConfig(BaseModelConfig):
n_epochs_cond: int = 1024
checkpoint_interval: int | None = None
checkpoint_epochs: tuple[int, ...] = ()
default_checkpoint: int | None = None
default_checkpoint: int | tuple[str, int] | None = None
expected_return_mapper: str = "linear"
beta_state_dim: int = 4
beta_hidden_dim: int = 64
Expand Down
257 changes: 241 additions & 16 deletions src/ml4t/models/latent_factors/cae.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,33 @@ def __init__(self, config: CAEConfig) -> None:
self._n_characteristics: int | None = None
self._n_instruments: int | None = None
self._history: tuple[dict[str, float | str], ...] = ()
self._fit_default_checkpoint: int | None = None

@property
def available_checkpoints(self) -> tuple[int, ...]:
return tuple(sorted(self._checkpoint_states))

def fit(self, batch: PanelBatch) -> FitSummary:
def fit(
self,
batch: PanelBatch,
*,
validation_batch: PanelBatch | None = None,
patience: int = 50,
) -> FitSummary:
cross_section = _require_cross_section(batch)
if cross_section.returns is None:
raise ValueError("CAE requires returns in the training batch")
validation_cross_section = (
_require_cross_section(validation_batch) if validation_batch is not None else None
)
if validation_cross_section is not None and validation_cross_section.returns is None:
raise ValueError("validation_batch requires returns for CAE validation loss")
if self.config.n_factors < 1:
raise ValueError(f"n_factors must be positive; got {self.config.n_factors}")
if self.config.n_ensemble < 1:
raise ValueError(f"n_ensemble must be positive; got {self.config.n_ensemble}")
if self.config.batch_size < 1:
raise ValueError(f"batch_size must be positive; got {self.config.batch_size}")
if self.config.task_type == "classification" and cross_section.factor_returns is None:
raise ValueError(
"Classification CAE requires factor_returns for managed-portfolio construction"
Expand Down Expand Up @@ -72,9 +86,34 @@ def fit(self, batch: PanelBatch) -> FitSummary:
device=device,
)
mask_train = _resolve_mask(cross_section)
flat_chars, flat_portfolios, flat_returns = _flatten_training_panel(
torch=torch,
characteristics=chars_train,
managed_portfolios=portfolios_train,
returns=returns_train,
mask=mask_train,
device=device,
)
if int(flat_returns.shape[0]) == 0:
raise ValueError("CAE received no valid training observations")

validation_tensors = None
if validation_cross_section is not None:
validation_portfolio_returns = _portfolio_returns(validation_cross_section)
assert validation_portfolio_returns is not None
validation_tensors = _prepare_validation_tensors(
torch=torch,
cross_section=validation_cross_section,
managed_portfolios=_compute_managed_portfolios(
characteristics=validation_cross_section.characteristics,
returns=validation_portfolio_returns,
),
device=device,
)

self._checkpoint_states = defaultdict(list)
loss_sums = dict.fromkeys(checkpoint_epochs, 0.0)
loss_sums: dict[int, float] = dict.fromkeys(checkpoint_epochs, 0.0)
val_best_losses: list[float] = []

for ensemble_idx in range(self.config.n_ensemble):
seed = self.config.seed + ensemble_idx
Expand All @@ -89,22 +128,22 @@ def fit(self, batch: PanelBatch) -> FitSummary:
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=self.config.lr)

best_val_loss = float("inf")
best_state: dict[str, Any] | None = None
patience_counter = 0
for epoch in range(1, self.config.n_epochs + 1):
model.train()
epoch_loss = 0.0
n_batches = 0
order = torch.randperm(flat_returns.shape[0], device=device)

for date_idx in range(cross_section.n_periods):
valid = (
mask_train[date_idx] & torch.isfinite(returns_train[date_idx]).cpu().numpy()
)
if not valid.any():
for start in range(0, int(flat_returns.shape[0]), self.config.batch_size):
batch_idx = order[start : start + self.config.batch_size]
if batch_idx.numel() == 1 and self.config.hidden_units:
continue

features_t = chars_train[date_idx, valid]
target_t = returns_train[date_idx, valid]
portfolios_t = portfolios_train[date_idx, valid]
scores_t = model(features_t, portfolios_t)
scores_t = model(flat_chars[batch_idx], flat_portfolios[batch_idx])
target_t = flat_returns[batch_idx]

if self.config.task_type == "classification":
main_loss = torch.nn.functional.binary_cross_entropy_with_logits(
Expand All @@ -127,28 +166,63 @@ def fit(self, batch: PanelBatch) -> FitSummary:
self._checkpoint_states[epoch].append(_cpu_state_dict(model))
loss_sums[epoch] += mean_loss

if validation_tensors is not None:
val_loss = _validation_loss(
torch=torch,
model=model,
validation_tensors=validation_tensors,
task_type=self.config.task_type,
)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_state = _cpu_state_dict(model)
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
break

if best_state is not None:
self._checkpoint_states[0].append(best_state)
val_best_losses.append(best_val_loss)

history: list[dict[str, float | str]] = []
for epoch in checkpoint_epochs:
if epoch in self._checkpoint_states:
history.append(
{
"epoch": float(epoch),
"train_loss": loss_sums[epoch] / self.config.n_ensemble,
}
)
if val_best_losses:
history.append(
{
"epoch": float(epoch),
"train_loss": loss_sums[epoch] / self.config.n_ensemble,
"epoch": 0.0,
"checkpoint": "validation_best",
"val_loss": float(np.mean(val_best_losses)),
}
)
self._history = tuple(history)
self._asset_ids = cross_section.asset_ids
self._n_characteristics = cross_section.characteristics.shape[2]
self._n_instruments = managed_portfolios.shape[2]
self._fit_default_checkpoint = (
0 if val_best_losses else _default_checkpoint(self.config, self.available_checkpoints)
)
self._mark_fitted()

return FitSummary(
converged=True,
train_metrics={
"n_train_periods": float(cross_section.n_periods),
"n_checkpoints": float(len(checkpoint_epochs)),
"n_checkpoints": float(len(self.available_checkpoints)),
"n_ensemble": float(self.config.n_ensemble),
},
best_epoch=_default_checkpoint(self.config, checkpoint_epochs),
val_metrics=(
{"best_val_loss": float(np.mean(val_best_losses))} if val_best_losses else {}
),
best_epoch=self._fit_default_checkpoint,
history=self._history,
notes=("Neural betas and linear factors stored at configurable checkpoints.",),
)
Expand Down Expand Up @@ -177,7 +251,7 @@ def extract(
nn = _import_cae_nn()
selected_checkpoint = _select_checkpoint(
checkpoint=checkpoint,
configured_default=self.config.default_checkpoint,
configured_default=self._fit_default_checkpoint,
available=self.available_checkpoints,
)
device = resolve_device(torch, self.config.device)
Expand Down Expand Up @@ -237,6 +311,80 @@ def extract(
},
)

def extract_per_member(
self,
batch: PanelBatch,
*,
checkpoint: int | None = None,
) -> list[LatentFactorState]:
cross_section = _require_cross_section(batch)
if (
not self.is_fitted
or self._n_characteristics is None
or self._n_instruments is None
or not self._checkpoint_states
):
raise RuntimeError("CAE model must be fitted before extract_per_member()")
if cross_section.characteristics.shape[2] != self._n_characteristics:
raise ValueError(
"characteristics feature dimension does not match fitted CAE model; "
f"expected {self._n_characteristics}, got {cross_section.characteristics.shape[2]}"
)

torch = import_torch()
nn = _import_cae_nn()
selected_checkpoint = _select_checkpoint(
checkpoint=checkpoint,
configured_default=self._fit_default_checkpoint,
available=self.available_checkpoints,
)
device = resolve_device(torch, self.config.device)
mask = _resolve_mask(cross_section)
portfolio_returns = _portfolio_returns(cross_section)
managed_portfolios = None
if portfolio_returns is not None:
managed_portfolios = _compute_managed_portfolios(
characteristics=cross_section.characteristics,
returns=portfolio_returns,
)

states: list[LatentFactorState] = []
for member_idx, state_dict in enumerate(self._checkpoint_states[selected_checkpoint]):
model = nn.ConditionalAutoencoder(
n_characteristics=self._n_characteristics,
n_instruments=self._n_instruments,
n_factors=self.config.n_factors,
hidden_units=self.config.hidden_units,
).to(device)
model.load_state_dict(deepcopy(state_dict))
model.eval()
asset_betas, factor_returns = _extract_cae_state(
torch=torch,
model=model,
characteristics=cross_section.characteristics,
managed_portfolios=managed_portfolios,
mask=mask,
n_factors=self.config.n_factors,
device=device,
)
states.append(
LatentFactorState(
asset_betas=asset_betas,
factor_returns=factor_returns,
checkpoint_epoch=selected_checkpoint,
timestamps=cross_section.timestamps,
asset_ids=cross_section.asset_ids or self._asset_ids,
metadata={
"model_name": self.config.model_name,
"persistent_entities": False,
"task_type": self.config.task_type,
"ensemble_member": member_idx,
"n_ensemble": len(self._checkpoint_states[selected_checkpoint]),
},
)
)
return states


def _require_cross_section(batch: PanelBatch) -> CrossSectionBatch:
if not isinstance(batch, CrossSectionBatch):
Expand Down Expand Up @@ -306,6 +454,83 @@ def _compute_managed_portfolios(
)


def _flatten_training_panel(
*,
torch: Any,
characteristics: Any,
managed_portfolios: Any,
returns: Any,
mask: np.ndarray,
device: Any,
) -> tuple[Any, Any, Any]:
mask_t = torch.as_tensor(mask, dtype=torch.bool, device=device)
valid = mask_t & torch.isfinite(returns)
return characteristics[valid], managed_portfolios[valid], returns[valid]


def _prepare_validation_tensors(
*,
torch: Any,
cross_section: CrossSectionBatch,
managed_portfolios: np.ndarray,
device: Any,
) -> tuple[Any, Any, Any, Any]:
assert cross_section.returns is not None
returns_np = np.asarray(cross_section.returns, dtype=np.float32)
mask_np = _resolve_mask(cross_section)
if not np.any(mask_np & np.isfinite(returns_np)):
raise ValueError("validation_batch contains no valid CAE validation observations")
return (
torch.as_tensor(
np.asarray(cross_section.characteristics, dtype=np.float32),
dtype=torch.float32,
device=device,
),
torch.as_tensor(
np.asarray(managed_portfolios, dtype=np.float32),
dtype=torch.float32,
device=device,
),
torch.as_tensor(
returns_np,
dtype=torch.float32,
device=device,
),
torch.as_tensor(mask_np, dtype=torch.bool, device=device),
)


def _validation_loss(
*,
torch: Any,
model: Any,
validation_tensors: tuple[Any, Any, Any, Any],
task_type: str,
) -> float:
characteristics, managed_portfolios, returns, mask = validation_tensors
was_training = model.training
model.eval()
losses: list[float] = []
with torch.no_grad():
for date_idx in range(characteristics.shape[0]):
valid = mask[date_idx] & torch.isfinite(returns[date_idx])
if not bool(valid.any()):
continue
scores_t = model(characteristics[date_idx, valid], managed_portfolios[date_idx, valid])
target_t = returns[date_idx, valid]
if task_type == "classification":
loss = torch.nn.functional.binary_cross_entropy_with_logits(
scores_t,
target_t.clamp(0.0, 1.0),
)
else:
loss = torch.nn.functional.mse_loss(scores_t, target_t)
losses.append(float(loss.item()))
if was_training:
model.train()
return float(np.mean(losses)) if losses else float("inf")


def _cpu_state_dict(model: Any) -> dict[str, Any]:
return {key: value.detach().cpu().clone() for key, value in model.state_dict().items()}

Expand Down
Loading
Loading