From 6a784c49f99f78671f5074bda84bff1945b6a418 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 6 Jan 2026 11:42:13 -0800 Subject: [PATCH 01/18] initial commit --- .../airtable/example_config_with_airtable.yml | 66 +++ .../airtable/filter_n_create_dataset_tag.py | 143 +++++ examples/airtable/sklearn_wrapper.py | 174 ++++++ examples/airtable/test_airtable_connection.py | 114 ++++ pyproject.toml | 8 +- viscy/representation/airtable_callback.py | 120 +++++ viscy/representation/airtable_fov_registry.py | 494 ++++++++++++++++++ 7 files changed, 1116 insertions(+), 3 deletions(-) create mode 100644 examples/airtable/example_config_with_airtable.yml create mode 100644 examples/airtable/filter_n_create_dataset_tag.py create mode 100644 examples/airtable/sklearn_wrapper.py create mode 100644 examples/airtable/test_airtable_connection.py create mode 100644 viscy/representation/airtable_callback.py create mode 100644 viscy/representation/airtable_fov_registry.py diff --git a/examples/airtable/example_config_with_airtable.yml b/examples/airtable/example_config_with_airtable.yml new file mode 100644 index 000000000..60dbe558c --- /dev/null +++ b/examples/airtable/example_config_with_airtable.yml @@ -0,0 +1,66 @@ +# Example config showing Airtable integration with Lightning training +# Usage: viscy fit -c examples/airtable/example_config_with_airtable.yml + +seed_everything: true + +trainer: + accelerator: gpu + devices: 1 + max_epochs: 100 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + + # Add Airtable logging callback + callbacks: + - class_path: viscy.representation.airtable_callback.AirtableLoggingCallback + init_args: + base_id: "appXXXXXXXXXXXXXX" # Replace with your Airtable base ID + dataset_id: "recYYYYYYYYYYYYYY" # Replace with your dataset record ID + model_name: null # Auto-generate from model class and timestamp + log_metrics: false # Set to true to store metrics in Airtable (otherwise use TensorBoard) + +model: + class_path: viscy.representation.contrastive.ContrastiveModule + init_args: + # Your model config here + backbone: resnet50 + embedding_len: 256 + +data: + class_path: viscy.data.triplet.TripletDataModule + init_args: + data_path: /hpc/data/your_plate.zarr + tracks_path: /hpc/tracks/your_tracks/ + source_channel: [Phase] + z_range: [0, 5] + initial_yx_patch_size: [512, 512] + final_yx_patch_size: [224, 224] + split_ratio: 0.8 + batch_size: 16 + num_workers: 8 + + # FOV selection from Airtable dataset definition + fit_include_wells: ["B3", "B4", "C3"] + fit_exclude_fovs: [] + + # Data augmentation + augmentations: + - class_path: viscy.transforms.RandAffined + init_args: + keys: [Phase] + prob: 0.8 + rotate_range: [3.14, 0.0, 0.0] + scale_range: [0.1, 0.1, 0.1] + - class_path: viscy.transforms.RandGaussianNoised + init_args: + keys: [Phase] + prob: 0.5 + mean: 0.0 + std: 0.1 + +# Dataset metadata (for reference, not used by training) +dataset_metadata: + airtable_id: "recYYYYYYYYYYYYYY" + name: "RPE1_infection_v2" + version: "v2" + description: "RPE1 cells, infection experiment, wells B3-C3" diff --git a/examples/airtable/filter_n_create_dataset_tag.py b/examples/airtable/filter_n_create_dataset_tag.py new file mode 100644 index 000000000..d44ec7e01 --- /dev/null +++ b/examples/airtable/filter_n_create_dataset_tag.py @@ -0,0 +1,143 @@ +"""Filter FOVs using pandas and create dataset tags.""" + +# %% +import os + +from viscy.representation.airtable_fov_registry import AirtableFOVRegistry + +BASE_ID = os.getenv("AIRTABLE_BASE_ID") +registry = AirtableFOVRegistry(base_id=BASE_ID) + +# %% +# EXAMPLE 1: Get all FOVs as DataFrame and explore +print("=" * 70) +print("Getting all FOVs as DataFrame") +print("=" * 70) + +df_fovs = registry.list_fovs() +print(f"\nTotal FOVs: {len(df_fovs)}") +print("\nDataFrame columns:") +print(df_fovs.columns.tolist()) +print("\nFirst few rows:") +print(df_fovs.head()) + +# %% +# EXAMPLE 2: Filter by plate and rows B and C using pandas +print("\n" + "=" * 70) +print("Filter: Plate RPE1_plate1, Rows B and C, Good quality") +print("=" * 70) + +# Get all FOVs as DataFrame +df = registry.list_fovs() + +# Filter with pandas - simple and powerful! +filtered = df[ + (df["plate_name"] == "RPE1_plate1") + & (df["quality"] == "Good") + & (df["row"].isin(["B", "C"])) +] + +print(f"\nTotal FOVs after filtering: {len(filtered)}") +print("\nBreakdown by well:") +print(filtered.groupby("well_id").size()) + +# Create dataset from filtered FOVs +fov_ids = filtered["fov_id"].tolist() + +try: + dataset_id = registry.create_dataset_from_fovs( + dataset_name="RPE1_rows_BC_good", + fov_ids=fov_ids, + version="0.0.1", # Semantic versioning + purpose="training", + description="Good quality FOVs from rows B and C", + ) + print(f"\n✓ Created dataset: {dataset_id}") + print(f" Contains {len(fov_ids)} FOVs") +except ValueError as e: + print(f"\n⚠ {e}") + +# %% +# EXAMPLE 3: Group by plate and show summary +print("\n" + "=" * 70) +print("Group by plate and show summary") +print("=" * 70) + +df_all = registry.list_fovs() + +# Filter for good quality only +df_all = df_all[df_all["quality"] == "Good"] + +grouped = df_all.groupby("plate_name") + +for plate_name, group in grouped: + print(f"\n{plate_name}:") + print(f" Total FOVs: {len(group)}") + print(f" Wells: {group['well_id'].unique()}") + print(f" Rows: {group['row'].unique()}") + +# %% +# EXAMPLE 4: Complex filtering - specific rows and columns +print("\n" + "=" * 70) +print("Complex Filter: Rows B/C AND Columns 3/4") +print("=" * 70) + +df = registry.list_fovs() + +# Complex pandas filter: plate, quality, rows B or C, AND columns 3 or 4 +filtered = df[ + (df["plate_name"] == "RPE1_plate1") + & (df["quality"] == "Good") + & (df["row"].isin(["B", "C"])) + & (df["column"].isin(["3", "4"])) +] + +print(f"\nFOVs matching criteria: {len(filtered)}") +print("\nBy well:") +print(filtered.groupby(["row", "column"]).size()) + +print("\nFOV IDs:") +for fov_id in filtered["fov_id"]: + print(f" {fov_id}") + +# %% +# EXAMPLE 5: Exclude specific FOVs +print("\n" + "=" * 70) +print("Exclude specific FOVs from dataset") +print("=" * 70) + +df = registry.list_fovs() + +# Start with good quality FOVs from specific plate +filtered = df[(df["plate_name"] == "RPE1_plate1") & (df["quality"] == "Good")] + +print(f"\nBefore exclusion: {len(filtered)} FOVs") + +# List of FOVs to exclude (e.g., known contamination) +exclude_list = ["RPE1_plate1_B_3_2", "RPE1_plate1_C_4_1"] + +# Filter out excluded FOVs +filtered = filtered[~filtered["fov_id"].isin(exclude_list)] + +print(f"Excluded: {len(exclude_list)} FOVs") +print(f"After exclusion: {len(filtered)} FOVs") + +# %% +# EXAMPLE 6: Summary statistics +print("\n" + "=" * 70) +print("Summary Statistics") +print("=" * 70) + +df = registry.list_fovs() + +print("\nFOVs per plate:") +print(df.groupby("plate_name").size()) + +print("\nFOVs per quality:") +print(df.groupby("quality").size()) + +print("\nFOVs per row (across all plates):") +print(df.groupby("row").size().sort_index()) + +print("\nWells with most FOVs:") +print(df.groupby("well_id").size().sort_values(ascending=False).head(10)) diff --git a/examples/airtable/sklearn_wrapper.py b/examples/airtable/sklearn_wrapper.py new file mode 100644 index 000000000..5e4ba4993 --- /dev/null +++ b/examples/airtable/sklearn_wrapper.py @@ -0,0 +1,174 @@ +import matplotlib.pyplot as plt +import torch +from lightning.pytorch import LightningModule +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import ( + ConfusionMatrixDisplay, + accuracy_score, + classification_report, + confusion_matrix, + f1_score, +) + + +class SklearnLogisticRegressionModule(LightningModule): + """ + Wrap sklearn LogisticRegression in Lightning for experiment tracking. + + This module collects features/labels during training_step, then trains + the sklearn model at the end of each epoch. This pattern allows us to + use Lightning's logging infrastructure while using sklearn's optimized + solvers. + + Parameters + ---------- + input_dim : int + Feature dimension + lr : float + Inverse regularization strength (C parameter for LogisticRegression) + solver : str + Solver algorithm + max_iter : int + Maximum iterations for solver + class_weight : str | None + Class weighting strategy + """ + + def __init__( + self, + input_dim: int = 768, + lr: float = 1.0, + solver: str = "lbfgs", + max_iter: int = 1000, + class_weight: str | None = "balanced", + ): + super().__init__() + self.save_hyperparameters() + + # Sklearn model (trained incrementally per epoch) + self.model = LogisticRegression( + C=lr, + solver=solver, + max_iter=max_iter, + class_weight=class_weight, + random_state=42, + ) + + # Storage for batch features/labels + self.train_features = [] + self.train_labels = [] + self.val_features = [] + self.val_labels = [] + + # Store example input for Lightning compatibility + self.example_input_array = torch.rand(2, input_dim) + + def forward(self, x): + """Sklearn doesn't have gradients, so forward just returns input.""" + return x + + def training_step(self, batch, batch_idx): + """Collect features and labels for end-of-epoch training.""" + features, labels = batch + self.train_features.append(features.cpu()) + self.train_labels.append(labels.cpu()) + return None # No loss to backprop + + def on_train_epoch_end(self): + """Train sklearn model on collected features.""" + X_train = torch.cat(self.train_features, dim=0).numpy() + y_train = torch.cat(self.train_labels, dim=0).numpy() + + # Train sklearn model + self.model.fit(X_train, y_train) + + # Compute training metrics + y_pred = self.model.predict(X_train) + train_acc = accuracy_score(y_train, y_pred) + train_f1 = f1_score(y_train, y_pred, average="binary") + + # Log metrics (viscy convention: metric/category/stage) + self.log_dict( + { + "metric/accuracy/train": train_acc, + "metric/f1_score/train": train_f1, + }, + on_step=False, + on_epoch=True, + ) + + print(f"\n Training - Accuracy: {train_acc:.3f}, F1: {train_f1:.3f}") + + # Clear storage + self.train_features.clear() + self.train_labels.clear() + + def validation_step(self, batch, batch_idx): + """Collect validation features and labels.""" + features, labels = batch + self.val_features.append(features.cpu()) + self.val_labels.append(labels.cpu()) + return None + + def on_validation_epoch_end(self): + """Evaluate sklearn model on validation set.""" + # Skip if model not fitted yet (happens during sanity check) + if not hasattr(self.model, "classes_"): + self.val_features.clear() + self.val_labels.clear() + return + + X_val = torch.cat(self.val_features, dim=0).numpy() + y_val = torch.cat(self.val_labels, dim=0).numpy() + + # Predict + y_pred = self.model.predict(X_val) + + # Compute metrics + val_acc = accuracy_score(y_val, y_pred) + val_f1 = f1_score(y_val, y_pred, average="binary") + + # Log metrics + self.log_dict( + { + "metric/accuracy/val": val_acc, + "metric/f1_score/val": val_f1, + }, + on_step=False, + on_epoch=True, + ) + + print(f" Validation - Accuracy: {val_acc:.3f}, F1: {val_f1:.3f}") + + # Log confusion matrix + cm = confusion_matrix(y_val, y_pred) + if hasattr(self.logger, "experiment"): + fig, ax = plt.subplots(figsize=(6, 6)) + ConfusionMatrixDisplay(cm, display_labels=["Uninfected", "Infected"]).plot( + ax=ax, cmap="Blues" + ) + ax.set_title(f"Confusion Matrix - Epoch {self.current_epoch}") + + # Save and log figure + fig_path = f"/tmp/confusion_matrix_epoch_{self.current_epoch}.png" + fig.savefig(fig_path, dpi=100, bbox_inches="tight") + self.logger.experiment.log_artifact( + self.logger.run_id, fig_path, artifact_path="plots" + ) + plt.close(fig) + + # Print classification report + print( + "\n" + + classification_report( + y_val, y_pred, target_names=["Uninfected", "Infected"] + ) + ) + + # Clear storage + self.val_features.clear() + self.val_labels.clear() + + def configure_optimizers(self): + """No optimizer needed for sklearn.""" + return None diff --git a/examples/airtable/test_airtable_connection.py b/examples/airtable/test_airtable_connection.py new file mode 100644 index 000000000..c3ee19761 --- /dev/null +++ b/examples/airtable/test_airtable_connection.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +"""Test Airtable connection and setup.""" + +# %% +import os + +from pyairtable import Api + +print("=" * 70) +print("Testing Airtable Connection") +print("=" * 70) + +# Check environment variables +# TODO: Add these ENVIRONMENT VARIABLES TO BASHRC or export them in the node +api_key = os.getenv("AIRTABLE_API_KEY") +base_id = os.getenv("AIRTABLE_BASE_ID") + +print("\n1. Environment Variables:") +print( + f" AIRTABLE_API_KEY: {'✓ Set' if api_key else '✗ Not set'} ({api_key[:10] if api_key else 'N/A'}...)" +) +print( + f" AIRTABLE_BASE_ID: {'✓ Set' if base_id else '✗ Not set'} ({base_id if base_id else 'N/A'})" +) + +if not api_key or not base_id: + print("\n❌ ERROR: Environment variables not set!") + print("\nRun these commands in your shell:") + print(' export AIRTABLE_API_KEY="patXXXXXXXXXXXXXX"') + print(' export AIRTABLE_BASE_ID="appXXXXXXXXXXXXXX"') + print("\nOr add them to your ~/.bashrc") + exit(1) + +# Test API connection +print("\n2. Testing API Connection...") +try: + api = Api(api_key) + print(" ✓ API initialized") +except Exception as e: + print(f" ✗ Failed to initialize API: {e}") + exit(1) + +# Test Datasets table +print("\n3. Testing Datasets Table...") +try: + datasets_table = api.table(base_id, "Datasets") + records = datasets_table.all() + print(" ✓ Connected to Datasets table") + print(f" ✓ Found {len(records)} record(s)") + + if records: + print("\n Existing datasets:") + for record in records: + fields = record["fields"] + name = fields.get("name", "N/A") + version = fields.get("version", "N/A") + print(f" - {name} (v{version})") + else: + print(" (Table is empty - this is OK for first run)") + +except Exception as e: + print(f" ✗ Failed to access Datasets table: {e}") + print("\n Make sure you created a table named 'Datasets' (case-sensitive)") + exit(1) + +# Test Models table +print("\n4. Testing Models Table...") +try: + models_table = api.table(base_id, "Models") + records = models_table.all() + print(" ✓ Connected to Models table") + print(f" ✓ Found {len(records)} record(s)") + + if records: + print("\n Existing models:") + for record in records: + fields = record["fields"] + name = fields.get("model_name", "N/A") + acc = fields.get("test_accuracy", "N/A") + print(f" - {name} (accuracy: {acc})") + else: + print(" (Table is empty - this is OK for first run)") + +except Exception as e: + print(f" ✗ Failed to access Models table: {e}") + print("\n Make sure you created a table named 'Models' (case-sensitive)") + exit(1) + +# Test creating a dummy record +print("\n5. Testing Write Permissions...") +try: + test_record = datasets_table.create( + { + "name": "connection_test", + "version": "v0", + "hpc_path": "/tmp/test", + "sha256": "test_hash", + "created_date": "2024-12-19T00:00:00", + } + ) + print(f" ✓ Successfully created test record (ID: {test_record['id']})") + + # Clean up + datasets_table.delete(test_record["id"]) + print(" ✓ Successfully deleted test record") + +except Exception as e: + print(f" ✗ Failed to write to Datasets table: {e}") + print("\n Check your API token has 'data.records:write' scope") + exit(1) + +print("\n" + "=" * 70) +print("✅ SUCCESS: Airtable is configured correctly!") +# %% diff --git a/pyproject.toml b/pyproject.toml index 4dd3c254e..2e43579e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,9 @@ optional-dependencies.metrics = [ "cellpose>=3.0.10", "imbalanced-learn>=0.12", "mahotas", + "mlflow", "ptflops>=0.7", + "pyairtable", "scikit-learn>=1.1.3", "torchmetrics[detection]>=1.6.3", "umap-learn", @@ -62,14 +64,14 @@ optional-dependencies.phate = [ optional-dependencies.visual = [ "cmap", "dash", + "gradio>=5.49.1", + "graphviz", "ipykernel", "nbformat", "plotly", "seaborn", "torchview", - "gradio>=5.49.1", - ] scripts.viscy = "viscy.cli:main" @@ -87,4 +89,4 @@ lint.per-file-ignores."applications/**/*.py" = [ "E402" ] lint.per-file-ignores."examples/**/*.ipynb" = [ "E402" ] lint.per-file-ignores."examples/**/*.py" = [ "E402", "F821" ] lint.per-file-ignores."tests/**/*.py" = [ "E402" ] -lint.isort.known-first-party = [ "viscy" ] \ No newline at end of file +lint.isort.known-first-party = [ "viscy" ] diff --git a/viscy/representation/airtable_callback.py b/viscy/representation/airtable_callback.py new file mode 100644 index 000000000..a39f4a45f --- /dev/null +++ b/viscy/representation/airtable_callback.py @@ -0,0 +1,120 @@ +"""Lightning callback to log training results to Airtable.""" + +import getpass +from typing import Any + +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import Callback + +from viscy.representation.airtable_registry import AirtableDatasetRegistry + + +class AirtableLoggingCallback(Callback): + """ + Log model training to Airtable after training completes. + + This callback automatically records: + - Best model checkpoint path + - Who trained the model + - When it was trained + - Link to the dataset used + + Parameters + ---------- + base_id : str + Airtable base ID + dataset_id : str + Airtable dataset record ID (from config) + model_name : str | None + Custom model name. If None, auto-generates from model class and timestamp. + log_metrics : bool + Whether to log metrics to Airtable (default: False). + If False, metrics should be viewed in TensorBoard. + + Examples + -------- + Add to config YAML: + + >>> trainer: + >>> callbacks: + >>> - class_path: viscy.representation.airtable_callback.AirtableLoggingCallback + >>> init_args: + >>> base_id: "appXXXXXXXXXXXXXX" + >>> dataset_id: "recYYYYYYYYYYYYYY" + + Or add programmatically: + + >>> callback = AirtableLoggingCallback( + >>> base_id="appXXXXXXXXXXXXXX", + >>> dataset_id="recYYYYYYYYYYYYYY" + >>> ) + >>> trainer = Trainer(callbacks=[callback]) + """ + + def __init__( + self, + base_id: str, + dataset_id: str, + model_name: str | None = None, + log_metrics: bool = False, + ): + super().__init__() + self.registry = AirtableDatasetRegistry(base_id=base_id) + self.dataset_id = dataset_id + self.model_name = model_name + self.log_metrics = log_metrics + + def on_fit_end(self, trainer: Trainer, pl_module: Any) -> None: + """Log model to Airtable after training completes.""" + # Get best checkpoint path + checkpoint_path = None + if trainer.checkpoint_callback: + checkpoint_path = trainer.checkpoint_callback.best_model_path + if not checkpoint_path: # Fallback to last checkpoint + checkpoint_path = trainer.checkpoint_callback.last_model_path + + # Generate model name + if self.model_name: + model_name = self.model_name + else: + model_class = pl_module.__class__.__name__ + logger_version = trainer.logger.version if trainer.logger else "unknown" + model_name = f"{model_class}_{logger_version}" + + # Optionally collect metrics + metrics = None + if self.log_metrics and trainer.callback_metrics: + metrics = {} + for key, value in trainer.callback_metrics.items(): + # Only log test metrics or validation metrics + if "test" in key or "val" in key: + try: + metrics[key] = float(value) + except (TypeError, ValueError): + pass # Skip non-numeric metrics + + # Get logger run ID (works with TensorBoard or MLflow) + run_id = None + if trainer.logger: + if hasattr(trainer.logger, "run_id"): + run_id = trainer.logger.run_id # MLflow + elif hasattr(trainer.logger, "version"): + run_id = str(trainer.logger.version) # TensorBoard + + # Log to Airtable + try: + model_id = self.registry.log_model_training( + dataset_id=self.dataset_id, + mlflow_run_id=run_id or "unknown", + model_name=model_name, + checkpoint_path=str(checkpoint_path) if checkpoint_path else None, + trained_by=getpass.getuser(), + metrics=metrics, + ) + print(f"\n✓ Model logged to Airtable (record ID: {model_id})") + print(f" Model name: {model_name}") + print(f" Checkpoint: {checkpoint_path}") + print(f" Dataset ID: {self.dataset_id}") + except Exception as e: + print(f"\n✗ Failed to log to Airtable: {e}") + # Don't fail training if Airtable logging fails diff --git a/viscy/representation/airtable_fov_registry.py b/viscy/representation/airtable_fov_registry.py new file mode 100644 index 000000000..f4bdbf59d --- /dev/null +++ b/viscy/representation/airtable_fov_registry.py @@ -0,0 +1,494 @@ +"""FOV-level dataset registry with Airtable.""" + +import getpass +import json +import os +from datetime import datetime +from typing import Any + +import pandas as pd +from pyairtable import Api + + +class AirtableFOVRegistry: + """ + Interface to Airtable for FOV-level dataset management. + + Use this to: + - Register individual FOVs from HCS plates + - Create dataset "tags" (collections of FOVs) + - Query which FOVs are in each dataset + - Generate training configs from dataset tags + + Parameters + ---------- + base_id : str + Airtable base ID + api_key : str | None + Airtable API key. If None, reads from AIRTABLE_API_KEY env var. + + Examples + -------- + >>> registry = AirtableFOVRegistry(base_id="appXXXXXXXXXXXXXX") + >>> + >>> # Create dataset from FOV selection + >>> registry.create_dataset_from_fovs( + ... dataset_name="RPE1_infection_v2", + ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], + ... version="v2", + ... purpose="training" + ... ) + >>> + >>> # Get all FOV paths for a dataset + >>> fov_paths = registry.get_dataset_fov_paths("RPE1_infection_v2") + >>> print(fov_paths) + >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] + """ + + def __init__( + self, + base_id: str, + api_key: str | None = None, + ): + api_key = api_key or os.getenv("AIRTABLE_API_KEY") + if not api_key: + raise ValueError("Airtable API key required (set AIRTABLE_API_KEY)") + + self.api = Api(api_key) + self.base_id = base_id + self.fovs_table = self.api.table(base_id, "FOVs") + self.datasets_table = self.api.table(base_id, "Datasets") + self.models_table = self.api.table(base_id, "Models") + + def register_fov( + self, + fov_id: str, + plate_name: str, + well_id: str, + row: str, + column: str, + fov_name: str, + fov_path: str, + quality: str = "Good", + metadata: dict[str, Any] | None = None, + ) -> str: + """ + Register a single FOV in Airtable. + + Parameters + ---------- + fov_id : str + Human-readable identifier (e.g., "RPE1_plate1_B_3_0") + plate_name : str + Name of the plate this FOV belongs to + well_id : str + Well identifier as row_column (e.g., "B_3") + row : str + Well row (e.g., "B") + column : str + Well column (e.g., "3") + fov_name : str + FOV index within well (e.g., "0", "1", "2") + fov_path : str + Full path to FOV (e.g., "/hpc/data/plate.zarr/B/3/0") + quality : str + Quality assessment ("Good", "Poor", "Contaminated", etc.) + metadata : dict | None + Additional metadata (cell_count, timestamp, etc.) + + Returns + ------- + str + Airtable record ID + """ + record = { + "fov_id": fov_id, + "plate_name": plate_name, + "well_id": well_id, + "row": row, + "column": column, + "fov_name": fov_name, + "fov_path": fov_path, + "quality": quality, + } + + if metadata: + # Store as JSON string in notes field + record["notes"] = json.dumps(metadata) + + created = self.fovs_table.create(record) + return created["id"] + + def create_dataset_from_fovs( + self, + dataset_name: str, + fov_ids: list[str], + version: str, + purpose: str = "training", + description: str | None = None, + ) -> str: + """ + Create a dataset (tag) from a list of FOV IDs. + + Parameters + ---------- + dataset_name : str + Name for this dataset collection + fov_ids : list[str] + List of FOV IDs to include (e.g., ["FOV_001", "FOV_002"]) + version : str + Semantic version (e.g., "0.0.1", "0.1.0", "1.0.0") + REQUIRED - forces conscious versioning + purpose : str + Purpose of this dataset ("training", "validation", "test") + description : str | None + Human-readable description + + Returns + ------- + str + Airtable dataset record ID + + Examples + -------- + >>> registry.create_dataset_from_fovs( + ... dataset_name="RPE1_clean_wells", + ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], + ... version="0.0.1", + ... description="High-quality FOVs from wells B3-B4" + ... ) + """ + # Validate semantic version format + import re + + if not re.match(r"^\d+\.\d+\.\d+$", version): + raise ValueError( + f"Version must be semantic version format (e.g., '0.0.1', '1.0.0'), got: '{version}'" + ) + + # Check if dataset with same name + version exists (use DataFrame) + df_datasets = self.list_datasets() + + if len(df_datasets) > 0: + existing = df_datasets[ + (df_datasets["name"] == dataset_name) + & (df_datasets["version"] == version) + ] + + if len(existing) > 0: + raise ValueError( + f"Dataset '{dataset_name}' version '{version}' already exists. " + f"To create a new version, increment the version number (e.g., '0.0.2')." + ) + + # Show existing versions (helpful feedback) + existing_versions = df_datasets[df_datasets["name"] == dataset_name] + if len(existing_versions) > 0: + versions = sorted(existing_versions["version"].tolist()) + print(f"ℹ Dataset '{dataset_name}' existing versions: {versions}") + print(f" Creating new version: '{version}'") + + # Get Airtable record IDs for these FOV IDs (ensure unique) + fov_record_ids = [] + seen_fov_ids = set() + + for fov_id in fov_ids: + if fov_id in seen_fov_ids: + continue # Skip duplicates + + formula = f"{{fov_id}}='{fov_id}'" + records = self.fovs_table.all(formula=formula) + if records: + fov_record_ids.append(records[0]["id"]) + seen_fov_ids.add(fov_id) + else: + raise ValueError(f"FOV '{fov_id}' not found in FOVs table") + + # Remove any duplicate record IDs (shouldn't happen, but extra safety) + fov_record_ids = list(dict.fromkeys(fov_record_ids)) + + # Create dataset record + dataset_record = { + "name": dataset_name, + "fovs": fov_record_ids, # Linked records (unique) + "version": version, # Semantic version (required) + "purpose": purpose, + "created_date": datetime.now().isoformat(), + "created_by": getpass.getuser(), + "num_fovs": len(fov_record_ids), + } + + if description: + dataset_record["description"] = description + + created = self.datasets_table.create(dataset_record) + return created["id"] + + def create_dataset_from_query( + self, + dataset_name: str, + version: str, + plate_name: str | None = None, + well_ids: list[str] | None = None, + quality: str | None = None, + exclude_fov_ids: list[str] | None = None, + **kwargs, + ) -> str: + """ + Create a dataset by filtering FOVs with pandas. + + Parameters + ---------- + dataset_name : str + Name for this dataset + version : str + Semantic version (e.g., "0.0.1") - REQUIRED + plate_name : str | None + Filter by plate name + well_ids : list[str] | None + Filter by well identifiers (e.g., ["B_3", "B_4"]) + quality : str | None + Filter by quality ("Good", "Poor", etc.) + exclude_fov_ids : list[str] | None + FOV IDs to exclude + **kwargs + Additional arguments for create_dataset_from_fovs + + Returns + ------- + str + Airtable dataset record ID + + Examples + -------- + >>> # Create dataset from all good-quality FOVs in specific wells + >>> registry.create_dataset_from_query( + ... dataset_name="RPE1_infection_training", + ... version="0.0.1", + ... plate_name="RPE1_plate1", + ... well_ids=["B_3", "B_4"], + ... quality="Good", + ... exclude_fov_ids=["RPE1_plate1_B_3_2"] + ... ) + """ + # Get all FOVs as DataFrame + df = self.list_fovs() + + # Apply filters with pandas + if plate_name: + df = df[df["plate_name"] == plate_name] + + if quality: + df = df[df["quality"] == quality] + + if well_ids: + df = df[df["well_id"].isin(well_ids)] + + # Exclude specified FOVs + if exclude_fov_ids: + df = df[~df["fov_id"].isin(exclude_fov_ids)] + + fov_ids = df["fov_id"].tolist() + + print(f"Found {len(fov_ids)} FOVs matching criteria") + + # Create dataset + return self.create_dataset_from_fovs( + dataset_name=dataset_name, version=version, fov_ids=fov_ids, **kwargs + ) + + def get_dataset_fov_paths( + self, dataset_name: str, version: str | None = None + ) -> list[str]: + """ + Get list of FOV paths for a dataset. + + Parameters + ---------- + dataset_name : str + Dataset name + version : str | None + Specific version (if None, returns latest) + + Returns + ------- + list[str] + List of FOV paths + + Examples + -------- + >>> paths = registry.get_dataset_fov_paths("RPE1_infection_v2") + >>> print(paths) + >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] + """ + # Get all datasets as DataFrame + df_datasets = self.list_datasets() + + if len(df_datasets) == 0: + raise ValueError(f"Dataset '{dataset_name}' not found") + + # Filter by name + filtered = df_datasets[df_datasets["name"] == dataset_name] + + if len(filtered) == 0: + raise ValueError(f"Dataset '{dataset_name}' not found") + + # Filter by version if specified, otherwise get latest + if version: + filtered = filtered[filtered["version"] == version] + if len(filtered) == 0: + raise ValueError( + f"Dataset '{dataset_name}' version '{version}' not found" + ) + else: + # Get latest version (sort by created_date) + filtered = filtered.sort_values("created_date", ascending=False) + + # Get the first (or only) matching dataset + dataset_row = filtered.iloc[0] + + # Get linked FOV record IDs + fov_record_ids = dataset_row.get("fovs", []) + if not fov_record_ids or len(fov_record_ids) == 0: + return [] + + # Fetch FOV paths + fov_paths = [] + for fov_id in fov_record_ids: + fov_record = self.fovs_table.get(fov_id) + fov_paths.append(fov_record["fields"]["fov_path"]) + + return fov_paths + + def get_dataset( + self, dataset_name: str, version: str | None = None + ) -> dict[str, Any]: + """ + Get full dataset information including FOV details. + + Parameters + ---------- + dataset_name : str + Dataset name + version : str | None + Specific version + + Returns + ------- + dict + Dataset info with FOV paths and metadata + """ + # Get all datasets as DataFrame + df_datasets = self.list_datasets() + + if len(df_datasets) == 0: + raise ValueError(f"Dataset '{dataset_name}' not found") + + # Filter by name + filtered = df_datasets[df_datasets["name"] == dataset_name] + + if len(filtered) == 0: + raise ValueError(f"Dataset '{dataset_name}' not found") + + # Filter by version if specified, otherwise get latest + if version: + filtered = filtered[filtered["version"] == version] + if len(filtered) == 0: + raise ValueError( + f"Dataset '{dataset_name}' version '{version}' not found" + ) + else: + # Get latest version (sort by created_date) + filtered = filtered.sort_values("created_date", ascending=False) + + # Get the first (or only) matching dataset + dataset_row = filtered.iloc[0] + dataset = dataset_row.to_dict() + + # Add FOV paths + dataset["fov_paths"] = self.get_dataset_fov_paths(dataset_name, version) + + return dataset + + def list_datasets( + self, purpose: str | None = None, as_dataframe: bool = True + ) -> pd.DataFrame | list[dict]: + """ + List all datasets. + + Parameters + ---------- + purpose : str | None + Filter by purpose ("training", "validation", "test") + as_dataframe : bool + If True, return pandas DataFrame. If False, return list of dicts. + + Returns + ------- + pd.DataFrame | list[dict] + Dataset records as DataFrame or list of dicts + + Examples + -------- + >>> registry.list_datasets(purpose="training") + >>> # Returns DataFrame with columns: id, name, version, purpose, ... + """ + # Fetch all datasets (sorted by most recent first) + records = self.datasets_table.all(sort=["-created_date"]) + data = [{"id": r["id"], **r["fields"]} for r in records] + + # Convert to DataFrame or list + if as_dataframe: + df = pd.DataFrame(data) + # Filter by purpose if specified + if purpose and len(df) > 0: + df = df[df["purpose"] == purpose] + return df + else: + # Filter list if purpose specified + if purpose: + data = [d for d in data if d.get("purpose") == purpose] + return data + + def list_fovs(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: + """ + Get all FOVs as a DataFrame (or list of dicts). + + Use pandas for filtering - much simpler and more powerful than + building Airtable formulas. + + Parameters + ---------- + as_dataframe : bool + If True, return pandas DataFrame. If False, return list of dicts. + + Returns + ------- + pd.DataFrame | list[dict] + All FOV records + + Examples + -------- + >>> # Get all FOVs + >>> df = registry.list_fovs() + >>> + >>> # Filter with pandas (simple and powerful!) + >>> filtered = df[df['plate_name'] == 'RPE1_plate1'] + >>> filtered = df[df['quality'] == 'Good'] + >>> filtered = df[df['row'] == 'B'] + >>> filtered = df[df['row'].isin(['B', 'C'])] + >>> filtered = df[(df['row'] == 'B') & (df['column'] == '3')] + >>> + >>> # Exclude FOVs + >>> filtered = df[~df['fov_id'].isin(['RPE1_plate1_B_3_2'])] + >>> + >>> # Group and analyze + >>> df.groupby('plate_name').size() + >>> df.groupby(['row', 'column']).size() + """ + records = self.fovs_table.all() + data = [{"id": r["id"], **r["fields"]} for r in records] + + if as_dataframe: + return pd.DataFrame(data) + return data From 241a2cdc38a2dadf7c57d886d5b51118dbbf6f0e Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 6 Jan 2026 11:47:16 -0800 Subject: [PATCH 02/18] move the files for the airtable to higher hierarchy --- .../airtable/filter_n_create_dataset_tag.py | 2 +- viscy/representation/airtable_callback.py | 120 ----- viscy/representation/airtable_fov_registry.py | 494 ------------------ 3 files changed, 1 insertion(+), 615 deletions(-) delete mode 100644 viscy/representation/airtable_callback.py delete mode 100644 viscy/representation/airtable_fov_registry.py diff --git a/examples/airtable/filter_n_create_dataset_tag.py b/examples/airtable/filter_n_create_dataset_tag.py index d44ec7e01..31b94ebc3 100644 --- a/examples/airtable/filter_n_create_dataset_tag.py +++ b/examples/airtable/filter_n_create_dataset_tag.py @@ -3,7 +3,7 @@ # %% import os -from viscy.representation.airtable_fov_registry import AirtableFOVRegistry +from viscy.airtable.airtable_fov_registry import AirtableFOVRegistry BASE_ID = os.getenv("AIRTABLE_BASE_ID") registry = AirtableFOVRegistry(base_id=BASE_ID) diff --git a/viscy/representation/airtable_callback.py b/viscy/representation/airtable_callback.py deleted file mode 100644 index a39f4a45f..000000000 --- a/viscy/representation/airtable_callback.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Lightning callback to log training results to Airtable.""" - -import getpass -from typing import Any - -from lightning.pytorch import Trainer -from lightning.pytorch.callbacks import Callback - -from viscy.representation.airtable_registry import AirtableDatasetRegistry - - -class AirtableLoggingCallback(Callback): - """ - Log model training to Airtable after training completes. - - This callback automatically records: - - Best model checkpoint path - - Who trained the model - - When it was trained - - Link to the dataset used - - Parameters - ---------- - base_id : str - Airtable base ID - dataset_id : str - Airtable dataset record ID (from config) - model_name : str | None - Custom model name. If None, auto-generates from model class and timestamp. - log_metrics : bool - Whether to log metrics to Airtable (default: False). - If False, metrics should be viewed in TensorBoard. - - Examples - -------- - Add to config YAML: - - >>> trainer: - >>> callbacks: - >>> - class_path: viscy.representation.airtable_callback.AirtableLoggingCallback - >>> init_args: - >>> base_id: "appXXXXXXXXXXXXXX" - >>> dataset_id: "recYYYYYYYYYYYYYY" - - Or add programmatically: - - >>> callback = AirtableLoggingCallback( - >>> base_id="appXXXXXXXXXXXXXX", - >>> dataset_id="recYYYYYYYYYYYYYY" - >>> ) - >>> trainer = Trainer(callbacks=[callback]) - """ - - def __init__( - self, - base_id: str, - dataset_id: str, - model_name: str | None = None, - log_metrics: bool = False, - ): - super().__init__() - self.registry = AirtableDatasetRegistry(base_id=base_id) - self.dataset_id = dataset_id - self.model_name = model_name - self.log_metrics = log_metrics - - def on_fit_end(self, trainer: Trainer, pl_module: Any) -> None: - """Log model to Airtable after training completes.""" - # Get best checkpoint path - checkpoint_path = None - if trainer.checkpoint_callback: - checkpoint_path = trainer.checkpoint_callback.best_model_path - if not checkpoint_path: # Fallback to last checkpoint - checkpoint_path = trainer.checkpoint_callback.last_model_path - - # Generate model name - if self.model_name: - model_name = self.model_name - else: - model_class = pl_module.__class__.__name__ - logger_version = trainer.logger.version if trainer.logger else "unknown" - model_name = f"{model_class}_{logger_version}" - - # Optionally collect metrics - metrics = None - if self.log_metrics and trainer.callback_metrics: - metrics = {} - for key, value in trainer.callback_metrics.items(): - # Only log test metrics or validation metrics - if "test" in key or "val" in key: - try: - metrics[key] = float(value) - except (TypeError, ValueError): - pass # Skip non-numeric metrics - - # Get logger run ID (works with TensorBoard or MLflow) - run_id = None - if trainer.logger: - if hasattr(trainer.logger, "run_id"): - run_id = trainer.logger.run_id # MLflow - elif hasattr(trainer.logger, "version"): - run_id = str(trainer.logger.version) # TensorBoard - - # Log to Airtable - try: - model_id = self.registry.log_model_training( - dataset_id=self.dataset_id, - mlflow_run_id=run_id or "unknown", - model_name=model_name, - checkpoint_path=str(checkpoint_path) if checkpoint_path else None, - trained_by=getpass.getuser(), - metrics=metrics, - ) - print(f"\n✓ Model logged to Airtable (record ID: {model_id})") - print(f" Model name: {model_name}") - print(f" Checkpoint: {checkpoint_path}") - print(f" Dataset ID: {self.dataset_id}") - except Exception as e: - print(f"\n✗ Failed to log to Airtable: {e}") - # Don't fail training if Airtable logging fails diff --git a/viscy/representation/airtable_fov_registry.py b/viscy/representation/airtable_fov_registry.py deleted file mode 100644 index f4bdbf59d..000000000 --- a/viscy/representation/airtable_fov_registry.py +++ /dev/null @@ -1,494 +0,0 @@ -"""FOV-level dataset registry with Airtable.""" - -import getpass -import json -import os -from datetime import datetime -from typing import Any - -import pandas as pd -from pyairtable import Api - - -class AirtableFOVRegistry: - """ - Interface to Airtable for FOV-level dataset management. - - Use this to: - - Register individual FOVs from HCS plates - - Create dataset "tags" (collections of FOVs) - - Query which FOVs are in each dataset - - Generate training configs from dataset tags - - Parameters - ---------- - base_id : str - Airtable base ID - api_key : str | None - Airtable API key. If None, reads from AIRTABLE_API_KEY env var. - - Examples - -------- - >>> registry = AirtableFOVRegistry(base_id="appXXXXXXXXXXXXXX") - >>> - >>> # Create dataset from FOV selection - >>> registry.create_dataset_from_fovs( - ... dataset_name="RPE1_infection_v2", - ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], - ... version="v2", - ... purpose="training" - ... ) - >>> - >>> # Get all FOV paths for a dataset - >>> fov_paths = registry.get_dataset_fov_paths("RPE1_infection_v2") - >>> print(fov_paths) - >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] - """ - - def __init__( - self, - base_id: str, - api_key: str | None = None, - ): - api_key = api_key or os.getenv("AIRTABLE_API_KEY") - if not api_key: - raise ValueError("Airtable API key required (set AIRTABLE_API_KEY)") - - self.api = Api(api_key) - self.base_id = base_id - self.fovs_table = self.api.table(base_id, "FOVs") - self.datasets_table = self.api.table(base_id, "Datasets") - self.models_table = self.api.table(base_id, "Models") - - def register_fov( - self, - fov_id: str, - plate_name: str, - well_id: str, - row: str, - column: str, - fov_name: str, - fov_path: str, - quality: str = "Good", - metadata: dict[str, Any] | None = None, - ) -> str: - """ - Register a single FOV in Airtable. - - Parameters - ---------- - fov_id : str - Human-readable identifier (e.g., "RPE1_plate1_B_3_0") - plate_name : str - Name of the plate this FOV belongs to - well_id : str - Well identifier as row_column (e.g., "B_3") - row : str - Well row (e.g., "B") - column : str - Well column (e.g., "3") - fov_name : str - FOV index within well (e.g., "0", "1", "2") - fov_path : str - Full path to FOV (e.g., "/hpc/data/plate.zarr/B/3/0") - quality : str - Quality assessment ("Good", "Poor", "Contaminated", etc.) - metadata : dict | None - Additional metadata (cell_count, timestamp, etc.) - - Returns - ------- - str - Airtable record ID - """ - record = { - "fov_id": fov_id, - "plate_name": plate_name, - "well_id": well_id, - "row": row, - "column": column, - "fov_name": fov_name, - "fov_path": fov_path, - "quality": quality, - } - - if metadata: - # Store as JSON string in notes field - record["notes"] = json.dumps(metadata) - - created = self.fovs_table.create(record) - return created["id"] - - def create_dataset_from_fovs( - self, - dataset_name: str, - fov_ids: list[str], - version: str, - purpose: str = "training", - description: str | None = None, - ) -> str: - """ - Create a dataset (tag) from a list of FOV IDs. - - Parameters - ---------- - dataset_name : str - Name for this dataset collection - fov_ids : list[str] - List of FOV IDs to include (e.g., ["FOV_001", "FOV_002"]) - version : str - Semantic version (e.g., "0.0.1", "0.1.0", "1.0.0") - REQUIRED - forces conscious versioning - purpose : str - Purpose of this dataset ("training", "validation", "test") - description : str | None - Human-readable description - - Returns - ------- - str - Airtable dataset record ID - - Examples - -------- - >>> registry.create_dataset_from_fovs( - ... dataset_name="RPE1_clean_wells", - ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], - ... version="0.0.1", - ... description="High-quality FOVs from wells B3-B4" - ... ) - """ - # Validate semantic version format - import re - - if not re.match(r"^\d+\.\d+\.\d+$", version): - raise ValueError( - f"Version must be semantic version format (e.g., '0.0.1', '1.0.0'), got: '{version}'" - ) - - # Check if dataset with same name + version exists (use DataFrame) - df_datasets = self.list_datasets() - - if len(df_datasets) > 0: - existing = df_datasets[ - (df_datasets["name"] == dataset_name) - & (df_datasets["version"] == version) - ] - - if len(existing) > 0: - raise ValueError( - f"Dataset '{dataset_name}' version '{version}' already exists. " - f"To create a new version, increment the version number (e.g., '0.0.2')." - ) - - # Show existing versions (helpful feedback) - existing_versions = df_datasets[df_datasets["name"] == dataset_name] - if len(existing_versions) > 0: - versions = sorted(existing_versions["version"].tolist()) - print(f"ℹ Dataset '{dataset_name}' existing versions: {versions}") - print(f" Creating new version: '{version}'") - - # Get Airtable record IDs for these FOV IDs (ensure unique) - fov_record_ids = [] - seen_fov_ids = set() - - for fov_id in fov_ids: - if fov_id in seen_fov_ids: - continue # Skip duplicates - - formula = f"{{fov_id}}='{fov_id}'" - records = self.fovs_table.all(formula=formula) - if records: - fov_record_ids.append(records[0]["id"]) - seen_fov_ids.add(fov_id) - else: - raise ValueError(f"FOV '{fov_id}' not found in FOVs table") - - # Remove any duplicate record IDs (shouldn't happen, but extra safety) - fov_record_ids = list(dict.fromkeys(fov_record_ids)) - - # Create dataset record - dataset_record = { - "name": dataset_name, - "fovs": fov_record_ids, # Linked records (unique) - "version": version, # Semantic version (required) - "purpose": purpose, - "created_date": datetime.now().isoformat(), - "created_by": getpass.getuser(), - "num_fovs": len(fov_record_ids), - } - - if description: - dataset_record["description"] = description - - created = self.datasets_table.create(dataset_record) - return created["id"] - - def create_dataset_from_query( - self, - dataset_name: str, - version: str, - plate_name: str | None = None, - well_ids: list[str] | None = None, - quality: str | None = None, - exclude_fov_ids: list[str] | None = None, - **kwargs, - ) -> str: - """ - Create a dataset by filtering FOVs with pandas. - - Parameters - ---------- - dataset_name : str - Name for this dataset - version : str - Semantic version (e.g., "0.0.1") - REQUIRED - plate_name : str | None - Filter by plate name - well_ids : list[str] | None - Filter by well identifiers (e.g., ["B_3", "B_4"]) - quality : str | None - Filter by quality ("Good", "Poor", etc.) - exclude_fov_ids : list[str] | None - FOV IDs to exclude - **kwargs - Additional arguments for create_dataset_from_fovs - - Returns - ------- - str - Airtable dataset record ID - - Examples - -------- - >>> # Create dataset from all good-quality FOVs in specific wells - >>> registry.create_dataset_from_query( - ... dataset_name="RPE1_infection_training", - ... version="0.0.1", - ... plate_name="RPE1_plate1", - ... well_ids=["B_3", "B_4"], - ... quality="Good", - ... exclude_fov_ids=["RPE1_plate1_B_3_2"] - ... ) - """ - # Get all FOVs as DataFrame - df = self.list_fovs() - - # Apply filters with pandas - if plate_name: - df = df[df["plate_name"] == plate_name] - - if quality: - df = df[df["quality"] == quality] - - if well_ids: - df = df[df["well_id"].isin(well_ids)] - - # Exclude specified FOVs - if exclude_fov_ids: - df = df[~df["fov_id"].isin(exclude_fov_ids)] - - fov_ids = df["fov_id"].tolist() - - print(f"Found {len(fov_ids)} FOVs matching criteria") - - # Create dataset - return self.create_dataset_from_fovs( - dataset_name=dataset_name, version=version, fov_ids=fov_ids, **kwargs - ) - - def get_dataset_fov_paths( - self, dataset_name: str, version: str | None = None - ) -> list[str]: - """ - Get list of FOV paths for a dataset. - - Parameters - ---------- - dataset_name : str - Dataset name - version : str | None - Specific version (if None, returns latest) - - Returns - ------- - list[str] - List of FOV paths - - Examples - -------- - >>> paths = registry.get_dataset_fov_paths("RPE1_infection_v2") - >>> print(paths) - >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] - """ - # Get all datasets as DataFrame - df_datasets = self.list_datasets() - - if len(df_datasets) == 0: - raise ValueError(f"Dataset '{dataset_name}' not found") - - # Filter by name - filtered = df_datasets[df_datasets["name"] == dataset_name] - - if len(filtered) == 0: - raise ValueError(f"Dataset '{dataset_name}' not found") - - # Filter by version if specified, otherwise get latest - if version: - filtered = filtered[filtered["version"] == version] - if len(filtered) == 0: - raise ValueError( - f"Dataset '{dataset_name}' version '{version}' not found" - ) - else: - # Get latest version (sort by created_date) - filtered = filtered.sort_values("created_date", ascending=False) - - # Get the first (or only) matching dataset - dataset_row = filtered.iloc[0] - - # Get linked FOV record IDs - fov_record_ids = dataset_row.get("fovs", []) - if not fov_record_ids or len(fov_record_ids) == 0: - return [] - - # Fetch FOV paths - fov_paths = [] - for fov_id in fov_record_ids: - fov_record = self.fovs_table.get(fov_id) - fov_paths.append(fov_record["fields"]["fov_path"]) - - return fov_paths - - def get_dataset( - self, dataset_name: str, version: str | None = None - ) -> dict[str, Any]: - """ - Get full dataset information including FOV details. - - Parameters - ---------- - dataset_name : str - Dataset name - version : str | None - Specific version - - Returns - ------- - dict - Dataset info with FOV paths and metadata - """ - # Get all datasets as DataFrame - df_datasets = self.list_datasets() - - if len(df_datasets) == 0: - raise ValueError(f"Dataset '{dataset_name}' not found") - - # Filter by name - filtered = df_datasets[df_datasets["name"] == dataset_name] - - if len(filtered) == 0: - raise ValueError(f"Dataset '{dataset_name}' not found") - - # Filter by version if specified, otherwise get latest - if version: - filtered = filtered[filtered["version"] == version] - if len(filtered) == 0: - raise ValueError( - f"Dataset '{dataset_name}' version '{version}' not found" - ) - else: - # Get latest version (sort by created_date) - filtered = filtered.sort_values("created_date", ascending=False) - - # Get the first (or only) matching dataset - dataset_row = filtered.iloc[0] - dataset = dataset_row.to_dict() - - # Add FOV paths - dataset["fov_paths"] = self.get_dataset_fov_paths(dataset_name, version) - - return dataset - - def list_datasets( - self, purpose: str | None = None, as_dataframe: bool = True - ) -> pd.DataFrame | list[dict]: - """ - List all datasets. - - Parameters - ---------- - purpose : str | None - Filter by purpose ("training", "validation", "test") - as_dataframe : bool - If True, return pandas DataFrame. If False, return list of dicts. - - Returns - ------- - pd.DataFrame | list[dict] - Dataset records as DataFrame or list of dicts - - Examples - -------- - >>> registry.list_datasets(purpose="training") - >>> # Returns DataFrame with columns: id, name, version, purpose, ... - """ - # Fetch all datasets (sorted by most recent first) - records = self.datasets_table.all(sort=["-created_date"]) - data = [{"id": r["id"], **r["fields"]} for r in records] - - # Convert to DataFrame or list - if as_dataframe: - df = pd.DataFrame(data) - # Filter by purpose if specified - if purpose and len(df) > 0: - df = df[df["purpose"] == purpose] - return df - else: - # Filter list if purpose specified - if purpose: - data = [d for d in data if d.get("purpose") == purpose] - return data - - def list_fovs(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: - """ - Get all FOVs as a DataFrame (or list of dicts). - - Use pandas for filtering - much simpler and more powerful than - building Airtable formulas. - - Parameters - ---------- - as_dataframe : bool - If True, return pandas DataFrame. If False, return list of dicts. - - Returns - ------- - pd.DataFrame | list[dict] - All FOV records - - Examples - -------- - >>> # Get all FOVs - >>> df = registry.list_fovs() - >>> - >>> # Filter with pandas (simple and powerful!) - >>> filtered = df[df['plate_name'] == 'RPE1_plate1'] - >>> filtered = df[df['quality'] == 'Good'] - >>> filtered = df[df['row'] == 'B'] - >>> filtered = df[df['row'].isin(['B', 'C'])] - >>> filtered = df[(df['row'] == 'B') & (df['column'] == '3')] - >>> - >>> # Exclude FOVs - >>> filtered = df[~df['fov_id'].isin(['RPE1_plate1_B_3_2'])] - >>> - >>> # Group and analyze - >>> df.groupby('plate_name').size() - >>> df.groupby(['row', 'column']).size() - """ - records = self.fovs_table.all() - data = [{"id": r["id"], **r["fields"]} for r in records] - - if as_dataframe: - return pd.DataFrame(data) - return data From bd477086e5cc08d30723165998c00da0af43f67b Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 6 Jan 2026 11:51:00 -0800 Subject: [PATCH 03/18] add the registry and callback --- viscy/airtable/airtable_callback.py | 120 ++++++ viscy/airtable/airtable_fov_registry.py | 494 ++++++++++++++++++++++++ 2 files changed, 614 insertions(+) create mode 100644 viscy/airtable/airtable_callback.py create mode 100644 viscy/airtable/airtable_fov_registry.py diff --git a/viscy/airtable/airtable_callback.py b/viscy/airtable/airtable_callback.py new file mode 100644 index 000000000..a39f4a45f --- /dev/null +++ b/viscy/airtable/airtable_callback.py @@ -0,0 +1,120 @@ +"""Lightning callback to log training results to Airtable.""" + +import getpass +from typing import Any + +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import Callback + +from viscy.representation.airtable_registry import AirtableDatasetRegistry + + +class AirtableLoggingCallback(Callback): + """ + Log model training to Airtable after training completes. + + This callback automatically records: + - Best model checkpoint path + - Who trained the model + - When it was trained + - Link to the dataset used + + Parameters + ---------- + base_id : str + Airtable base ID + dataset_id : str + Airtable dataset record ID (from config) + model_name : str | None + Custom model name. If None, auto-generates from model class and timestamp. + log_metrics : bool + Whether to log metrics to Airtable (default: False). + If False, metrics should be viewed in TensorBoard. + + Examples + -------- + Add to config YAML: + + >>> trainer: + >>> callbacks: + >>> - class_path: viscy.representation.airtable_callback.AirtableLoggingCallback + >>> init_args: + >>> base_id: "appXXXXXXXXXXXXXX" + >>> dataset_id: "recYYYYYYYYYYYYYY" + + Or add programmatically: + + >>> callback = AirtableLoggingCallback( + >>> base_id="appXXXXXXXXXXXXXX", + >>> dataset_id="recYYYYYYYYYYYYYY" + >>> ) + >>> trainer = Trainer(callbacks=[callback]) + """ + + def __init__( + self, + base_id: str, + dataset_id: str, + model_name: str | None = None, + log_metrics: bool = False, + ): + super().__init__() + self.registry = AirtableDatasetRegistry(base_id=base_id) + self.dataset_id = dataset_id + self.model_name = model_name + self.log_metrics = log_metrics + + def on_fit_end(self, trainer: Trainer, pl_module: Any) -> None: + """Log model to Airtable after training completes.""" + # Get best checkpoint path + checkpoint_path = None + if trainer.checkpoint_callback: + checkpoint_path = trainer.checkpoint_callback.best_model_path + if not checkpoint_path: # Fallback to last checkpoint + checkpoint_path = trainer.checkpoint_callback.last_model_path + + # Generate model name + if self.model_name: + model_name = self.model_name + else: + model_class = pl_module.__class__.__name__ + logger_version = trainer.logger.version if trainer.logger else "unknown" + model_name = f"{model_class}_{logger_version}" + + # Optionally collect metrics + metrics = None + if self.log_metrics and trainer.callback_metrics: + metrics = {} + for key, value in trainer.callback_metrics.items(): + # Only log test metrics or validation metrics + if "test" in key or "val" in key: + try: + metrics[key] = float(value) + except (TypeError, ValueError): + pass # Skip non-numeric metrics + + # Get logger run ID (works with TensorBoard or MLflow) + run_id = None + if trainer.logger: + if hasattr(trainer.logger, "run_id"): + run_id = trainer.logger.run_id # MLflow + elif hasattr(trainer.logger, "version"): + run_id = str(trainer.logger.version) # TensorBoard + + # Log to Airtable + try: + model_id = self.registry.log_model_training( + dataset_id=self.dataset_id, + mlflow_run_id=run_id or "unknown", + model_name=model_name, + checkpoint_path=str(checkpoint_path) if checkpoint_path else None, + trained_by=getpass.getuser(), + metrics=metrics, + ) + print(f"\n✓ Model logged to Airtable (record ID: {model_id})") + print(f" Model name: {model_name}") + print(f" Checkpoint: {checkpoint_path}") + print(f" Dataset ID: {self.dataset_id}") + except Exception as e: + print(f"\n✗ Failed to log to Airtable: {e}") + # Don't fail training if Airtable logging fails diff --git a/viscy/airtable/airtable_fov_registry.py b/viscy/airtable/airtable_fov_registry.py new file mode 100644 index 000000000..f4bdbf59d --- /dev/null +++ b/viscy/airtable/airtable_fov_registry.py @@ -0,0 +1,494 @@ +"""FOV-level dataset registry with Airtable.""" + +import getpass +import json +import os +from datetime import datetime +from typing import Any + +import pandas as pd +from pyairtable import Api + + +class AirtableFOVRegistry: + """ + Interface to Airtable for FOV-level dataset management. + + Use this to: + - Register individual FOVs from HCS plates + - Create dataset "tags" (collections of FOVs) + - Query which FOVs are in each dataset + - Generate training configs from dataset tags + + Parameters + ---------- + base_id : str + Airtable base ID + api_key : str | None + Airtable API key. If None, reads from AIRTABLE_API_KEY env var. + + Examples + -------- + >>> registry = AirtableFOVRegistry(base_id="appXXXXXXXXXXXXXX") + >>> + >>> # Create dataset from FOV selection + >>> registry.create_dataset_from_fovs( + ... dataset_name="RPE1_infection_v2", + ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], + ... version="v2", + ... purpose="training" + ... ) + >>> + >>> # Get all FOV paths for a dataset + >>> fov_paths = registry.get_dataset_fov_paths("RPE1_infection_v2") + >>> print(fov_paths) + >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] + """ + + def __init__( + self, + base_id: str, + api_key: str | None = None, + ): + api_key = api_key or os.getenv("AIRTABLE_API_KEY") + if not api_key: + raise ValueError("Airtable API key required (set AIRTABLE_API_KEY)") + + self.api = Api(api_key) + self.base_id = base_id + self.fovs_table = self.api.table(base_id, "FOVs") + self.datasets_table = self.api.table(base_id, "Datasets") + self.models_table = self.api.table(base_id, "Models") + + def register_fov( + self, + fov_id: str, + plate_name: str, + well_id: str, + row: str, + column: str, + fov_name: str, + fov_path: str, + quality: str = "Good", + metadata: dict[str, Any] | None = None, + ) -> str: + """ + Register a single FOV in Airtable. + + Parameters + ---------- + fov_id : str + Human-readable identifier (e.g., "RPE1_plate1_B_3_0") + plate_name : str + Name of the plate this FOV belongs to + well_id : str + Well identifier as row_column (e.g., "B_3") + row : str + Well row (e.g., "B") + column : str + Well column (e.g., "3") + fov_name : str + FOV index within well (e.g., "0", "1", "2") + fov_path : str + Full path to FOV (e.g., "/hpc/data/plate.zarr/B/3/0") + quality : str + Quality assessment ("Good", "Poor", "Contaminated", etc.) + metadata : dict | None + Additional metadata (cell_count, timestamp, etc.) + + Returns + ------- + str + Airtable record ID + """ + record = { + "fov_id": fov_id, + "plate_name": plate_name, + "well_id": well_id, + "row": row, + "column": column, + "fov_name": fov_name, + "fov_path": fov_path, + "quality": quality, + } + + if metadata: + # Store as JSON string in notes field + record["notes"] = json.dumps(metadata) + + created = self.fovs_table.create(record) + return created["id"] + + def create_dataset_from_fovs( + self, + dataset_name: str, + fov_ids: list[str], + version: str, + purpose: str = "training", + description: str | None = None, + ) -> str: + """ + Create a dataset (tag) from a list of FOV IDs. + + Parameters + ---------- + dataset_name : str + Name for this dataset collection + fov_ids : list[str] + List of FOV IDs to include (e.g., ["FOV_001", "FOV_002"]) + version : str + Semantic version (e.g., "0.0.1", "0.1.0", "1.0.0") + REQUIRED - forces conscious versioning + purpose : str + Purpose of this dataset ("training", "validation", "test") + description : str | None + Human-readable description + + Returns + ------- + str + Airtable dataset record ID + + Examples + -------- + >>> registry.create_dataset_from_fovs( + ... dataset_name="RPE1_clean_wells", + ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], + ... version="0.0.1", + ... description="High-quality FOVs from wells B3-B4" + ... ) + """ + # Validate semantic version format + import re + + if not re.match(r"^\d+\.\d+\.\d+$", version): + raise ValueError( + f"Version must be semantic version format (e.g., '0.0.1', '1.0.0'), got: '{version}'" + ) + + # Check if dataset with same name + version exists (use DataFrame) + df_datasets = self.list_datasets() + + if len(df_datasets) > 0: + existing = df_datasets[ + (df_datasets["name"] == dataset_name) + & (df_datasets["version"] == version) + ] + + if len(existing) > 0: + raise ValueError( + f"Dataset '{dataset_name}' version '{version}' already exists. " + f"To create a new version, increment the version number (e.g., '0.0.2')." + ) + + # Show existing versions (helpful feedback) + existing_versions = df_datasets[df_datasets["name"] == dataset_name] + if len(existing_versions) > 0: + versions = sorted(existing_versions["version"].tolist()) + print(f"ℹ Dataset '{dataset_name}' existing versions: {versions}") + print(f" Creating new version: '{version}'") + + # Get Airtable record IDs for these FOV IDs (ensure unique) + fov_record_ids = [] + seen_fov_ids = set() + + for fov_id in fov_ids: + if fov_id in seen_fov_ids: + continue # Skip duplicates + + formula = f"{{fov_id}}='{fov_id}'" + records = self.fovs_table.all(formula=formula) + if records: + fov_record_ids.append(records[0]["id"]) + seen_fov_ids.add(fov_id) + else: + raise ValueError(f"FOV '{fov_id}' not found in FOVs table") + + # Remove any duplicate record IDs (shouldn't happen, but extra safety) + fov_record_ids = list(dict.fromkeys(fov_record_ids)) + + # Create dataset record + dataset_record = { + "name": dataset_name, + "fovs": fov_record_ids, # Linked records (unique) + "version": version, # Semantic version (required) + "purpose": purpose, + "created_date": datetime.now().isoformat(), + "created_by": getpass.getuser(), + "num_fovs": len(fov_record_ids), + } + + if description: + dataset_record["description"] = description + + created = self.datasets_table.create(dataset_record) + return created["id"] + + def create_dataset_from_query( + self, + dataset_name: str, + version: str, + plate_name: str | None = None, + well_ids: list[str] | None = None, + quality: str | None = None, + exclude_fov_ids: list[str] | None = None, + **kwargs, + ) -> str: + """ + Create a dataset by filtering FOVs with pandas. + + Parameters + ---------- + dataset_name : str + Name for this dataset + version : str + Semantic version (e.g., "0.0.1") - REQUIRED + plate_name : str | None + Filter by plate name + well_ids : list[str] | None + Filter by well identifiers (e.g., ["B_3", "B_4"]) + quality : str | None + Filter by quality ("Good", "Poor", etc.) + exclude_fov_ids : list[str] | None + FOV IDs to exclude + **kwargs + Additional arguments for create_dataset_from_fovs + + Returns + ------- + str + Airtable dataset record ID + + Examples + -------- + >>> # Create dataset from all good-quality FOVs in specific wells + >>> registry.create_dataset_from_query( + ... dataset_name="RPE1_infection_training", + ... version="0.0.1", + ... plate_name="RPE1_plate1", + ... well_ids=["B_3", "B_4"], + ... quality="Good", + ... exclude_fov_ids=["RPE1_plate1_B_3_2"] + ... ) + """ + # Get all FOVs as DataFrame + df = self.list_fovs() + + # Apply filters with pandas + if plate_name: + df = df[df["plate_name"] == plate_name] + + if quality: + df = df[df["quality"] == quality] + + if well_ids: + df = df[df["well_id"].isin(well_ids)] + + # Exclude specified FOVs + if exclude_fov_ids: + df = df[~df["fov_id"].isin(exclude_fov_ids)] + + fov_ids = df["fov_id"].tolist() + + print(f"Found {len(fov_ids)} FOVs matching criteria") + + # Create dataset + return self.create_dataset_from_fovs( + dataset_name=dataset_name, version=version, fov_ids=fov_ids, **kwargs + ) + + def get_dataset_fov_paths( + self, dataset_name: str, version: str | None = None + ) -> list[str]: + """ + Get list of FOV paths for a dataset. + + Parameters + ---------- + dataset_name : str + Dataset name + version : str | None + Specific version (if None, returns latest) + + Returns + ------- + list[str] + List of FOV paths + + Examples + -------- + >>> paths = registry.get_dataset_fov_paths("RPE1_infection_v2") + >>> print(paths) + >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] + """ + # Get all datasets as DataFrame + df_datasets = self.list_datasets() + + if len(df_datasets) == 0: + raise ValueError(f"Dataset '{dataset_name}' not found") + + # Filter by name + filtered = df_datasets[df_datasets["name"] == dataset_name] + + if len(filtered) == 0: + raise ValueError(f"Dataset '{dataset_name}' not found") + + # Filter by version if specified, otherwise get latest + if version: + filtered = filtered[filtered["version"] == version] + if len(filtered) == 0: + raise ValueError( + f"Dataset '{dataset_name}' version '{version}' not found" + ) + else: + # Get latest version (sort by created_date) + filtered = filtered.sort_values("created_date", ascending=False) + + # Get the first (or only) matching dataset + dataset_row = filtered.iloc[0] + + # Get linked FOV record IDs + fov_record_ids = dataset_row.get("fovs", []) + if not fov_record_ids or len(fov_record_ids) == 0: + return [] + + # Fetch FOV paths + fov_paths = [] + for fov_id in fov_record_ids: + fov_record = self.fovs_table.get(fov_id) + fov_paths.append(fov_record["fields"]["fov_path"]) + + return fov_paths + + def get_dataset( + self, dataset_name: str, version: str | None = None + ) -> dict[str, Any]: + """ + Get full dataset information including FOV details. + + Parameters + ---------- + dataset_name : str + Dataset name + version : str | None + Specific version + + Returns + ------- + dict + Dataset info with FOV paths and metadata + """ + # Get all datasets as DataFrame + df_datasets = self.list_datasets() + + if len(df_datasets) == 0: + raise ValueError(f"Dataset '{dataset_name}' not found") + + # Filter by name + filtered = df_datasets[df_datasets["name"] == dataset_name] + + if len(filtered) == 0: + raise ValueError(f"Dataset '{dataset_name}' not found") + + # Filter by version if specified, otherwise get latest + if version: + filtered = filtered[filtered["version"] == version] + if len(filtered) == 0: + raise ValueError( + f"Dataset '{dataset_name}' version '{version}' not found" + ) + else: + # Get latest version (sort by created_date) + filtered = filtered.sort_values("created_date", ascending=False) + + # Get the first (or only) matching dataset + dataset_row = filtered.iloc[0] + dataset = dataset_row.to_dict() + + # Add FOV paths + dataset["fov_paths"] = self.get_dataset_fov_paths(dataset_name, version) + + return dataset + + def list_datasets( + self, purpose: str | None = None, as_dataframe: bool = True + ) -> pd.DataFrame | list[dict]: + """ + List all datasets. + + Parameters + ---------- + purpose : str | None + Filter by purpose ("training", "validation", "test") + as_dataframe : bool + If True, return pandas DataFrame. If False, return list of dicts. + + Returns + ------- + pd.DataFrame | list[dict] + Dataset records as DataFrame or list of dicts + + Examples + -------- + >>> registry.list_datasets(purpose="training") + >>> # Returns DataFrame with columns: id, name, version, purpose, ... + """ + # Fetch all datasets (sorted by most recent first) + records = self.datasets_table.all(sort=["-created_date"]) + data = [{"id": r["id"], **r["fields"]} for r in records] + + # Convert to DataFrame or list + if as_dataframe: + df = pd.DataFrame(data) + # Filter by purpose if specified + if purpose and len(df) > 0: + df = df[df["purpose"] == purpose] + return df + else: + # Filter list if purpose specified + if purpose: + data = [d for d in data if d.get("purpose") == purpose] + return data + + def list_fovs(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: + """ + Get all FOVs as a DataFrame (or list of dicts). + + Use pandas for filtering - much simpler and more powerful than + building Airtable formulas. + + Parameters + ---------- + as_dataframe : bool + If True, return pandas DataFrame. If False, return list of dicts. + + Returns + ------- + pd.DataFrame | list[dict] + All FOV records + + Examples + -------- + >>> # Get all FOVs + >>> df = registry.list_fovs() + >>> + >>> # Filter with pandas (simple and powerful!) + >>> filtered = df[df['plate_name'] == 'RPE1_plate1'] + >>> filtered = df[df['quality'] == 'Good'] + >>> filtered = df[df['row'] == 'B'] + >>> filtered = df[df['row'].isin(['B', 'C'])] + >>> filtered = df[(df['row'] == 'B') & (df['column'] == '3')] + >>> + >>> # Exclude FOVs + >>> filtered = df[~df['fov_id'].isin(['RPE1_plate1_B_3_2'])] + >>> + >>> # Group and analyze + >>> df.groupby('plate_name').size() + >>> df.groupby(['row', 'column']).size() + """ + records = self.fovs_table.all() + data = [{"id": r["id"], **r["fields"]} for r in records] + + if as_dataframe: + return pd.DataFrame(data) + return data From 4d1f91266b9651a1306050af5af064ce46aa91d9 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 6 Jan 2026 11:53:28 -0800 Subject: [PATCH 04/18] add the dataset callback --- viscy/airtable/airtable_callback.py | 2 +- viscy/airtable/airtable_dataset_registry.py | 197 ++++++++++++++++++++ 2 files changed, 198 insertions(+), 1 deletion(-) create mode 100644 viscy/airtable/airtable_dataset_registry.py diff --git a/viscy/airtable/airtable_callback.py b/viscy/airtable/airtable_callback.py index a39f4a45f..4dbf93ab9 100644 --- a/viscy/airtable/airtable_callback.py +++ b/viscy/airtable/airtable_callback.py @@ -6,7 +6,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.callbacks import Callback -from viscy.representation.airtable_registry import AirtableDatasetRegistry +from viscy.airtable.airtable_dataset_registry import AirtableDatasetRegistry class AirtableLoggingCallback(Callback): diff --git a/viscy/airtable/airtable_dataset_registry.py b/viscy/airtable/airtable_dataset_registry.py new file mode 100644 index 000000000..535ab13e4 --- /dev/null +++ b/viscy/airtable/airtable_dataset_registry.py @@ -0,0 +1,197 @@ +"""Dataset registry integration with Airtable for experiment tracking.""" + +import os +from datetime import datetime +from typing import Any + +from pyairtable import Api + + +class AirtableDatasetRegistry: + """ + Interface to Airtable for dataset registry. + + Airtable acts as source of truth for: + - Dataset paths on HPC + - Dataset versions and metadata + - Links between datasets and trained models + + Parameters + ---------- + base_id : str + Airtable base ID + api_key : str | None + Airtable API key. If None, reads from AIRTABLE_API_KEY env var. + + Examples + -------- + >>> registry = AirtableDatasetRegistry(base_id="appXXXXXXXXXXXXXX") + >>> + >>> # Get dataset info + >>> dataset = registry.get_dataset("rpe1_fucci_embeddings", version="v2") + >>> print(dataset['hpc_path']) + >>> + >>> # Record that a model was trained with this dataset + >>> registry.log_model_training( + ... dataset_id=dataset['id'], + ... mlflow_run_id="run_123", + ... metrics={"accuracy": 0.89} + ... ) + """ + + def __init__( + self, + base_id: str, + api_key: str | None = None, + ): + api_key = api_key or os.getenv("AIRTABLE_API_KEY") + if not api_key: + raise ValueError("Airtable API key required (set AIRTABLE_API_KEY)") + + self.api = Api(api_key) + self.base_id = base_id + self.datasets_table = self.api.table(base_id, "Datasets") + self.models_table = self.api.table(base_id, "Models") + + def get_dataset(self, name: str, version: str | None = None) -> dict[str, Any]: + """ + Retrieve dataset record from Airtable. + + Parameters + ---------- + name : str + Dataset name + version : str | None + Specific version (e.g., "v2"). If None, returns latest. + + Returns + ------- + dict + Airtable record with fields: + - id: Airtable record ID + - hpc_path: Path to dataset on HPC + - version: Dataset version + - sha256: Dataset hash + - created_date: Creation timestamp + """ + if version: + formula = f"AND({{name}}='{name}', {{version}}='{version}')" + else: + formula = f"{{name}}='{name}'" + + records = self.datasets_table.all(formula=formula, sort=["-created_date"]) + + if not records: + raise ValueError( + f"Dataset '{name}' (version={version}) not found in Airtable" + ) + + record = records[0] + return {"id": record["id"], **record["fields"]} + + def log_model_training( + self, + dataset_id: str, + mlflow_run_id: str, + model_name: str | None = None, + metrics: dict[str, float] | None = None, + checkpoint_path: str | None = None, + trained_by: str | None = None, + ) -> str: + """ + Log that a model was trained using a dataset. + + Creates entry in Models table and updates Datasets table. + + Parameters + ---------- + dataset_id : str + Airtable record ID of dataset used + mlflow_run_id : str + MLflow run ID for experiment tracking + model_name : str | None + Human-readable model name + metrics : dict | None + Training metrics (accuracy, f1_score, etc.) + checkpoint_path : str | None + Path to saved model checkpoint + trained_by : str | None + Username of person who trained the model + + Returns + ------- + str + Airtable record ID of created model entry + """ + # Create model record + model_record = { + "model_name": model_name or f"model_{datetime.now():%Y%m%d_%H%M%S}", + "dataset": [dataset_id], # Link to dataset + "mlflow_run_id": mlflow_run_id, + "trained_date": datetime.now().isoformat(), + } + + if metrics: + model_record.update(metrics) + + if checkpoint_path: + model_record["checkpoint_path"] = checkpoint_path + + if trained_by: + model_record["trained_by"] = trained_by + + created = self.models_table.create(model_record) + + # Update dataset record to track usage + dataset = self.datasets_table.get(dataset_id) + models_trained_str = dataset["fields"].get("models_trained", "") + + # Handle models_trained as comma-separated string + if models_trained_str: + models_list = [m.strip() for m in models_trained_str.split(",")] + models_list.append(mlflow_run_id) + new_models_str = ", ".join(models_list) + else: + new_models_str = mlflow_run_id + + self.datasets_table.update( + dataset_id, + {"models_trained": new_models_str, "last_used": datetime.now().isoformat()}, + ) + + return created["id"] + + def list_datasets(self, formula: str | None = None) -> list[dict]: + """ + List all datasets in registry. + + Parameters + ---------- + formula : str | None + Optional Airtable formula for filtering + + Returns + ------- + list[dict] + List of dataset records + """ + records = self.datasets_table.all(formula=formula, sort=["-created_date"]) + return [{"id": r["id"], **r["fields"]} for r in records] + + def get_models_for_dataset(self, dataset_id: str) -> list[dict]: + """ + Get all models trained on a specific dataset. + + Parameters + ---------- + dataset_id : str + Airtable record ID of dataset + + Returns + ------- + list[dict] + List of model records + """ + formula = f"FIND('{dataset_id}', ARRAYJOIN({{dataset}}))" + records = self.models_table.all(formula=formula, sort=["-trained_date"]) + return [{"id": r["id"], **r["fields"]} for r in records] From 93de3d5776dbad2393c87183536062248fff38b8 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 6 Jan 2026 13:32:03 -0800 Subject: [PATCH 05/18] edit test connection to match new airtable base --- examples/airtable/test_airtable_connection.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/airtable/test_airtable_connection.py b/examples/airtable/test_airtable_connection.py index c3ee19761..f5797cc65 100644 --- a/examples/airtable/test_airtable_connection.py +++ b/examples/airtable/test_airtable_connection.py @@ -43,8 +43,8 @@ # Test Datasets table print("\n3. Testing Datasets Table...") try: - datasets_table = api.table(base_id, "Datasets") - records = datasets_table.all() + models_table = api.table(base_id, "Models") + records = models_table.all() print(" ✓ Connected to Datasets table") print(f" ✓ Found {len(records)} record(s)") @@ -89,19 +89,16 @@ # Test creating a dummy record print("\n5. Testing Write Permissions...") try: - test_record = datasets_table.create( + test_record = models_table.create( { - "name": "connection_test", - "version": "v0", - "hpc_path": "/tmp/test", - "sha256": "test_hash", - "created_date": "2024-12-19T00:00:00", + "model_name": "connection_test", + "model_family": "DynaCLR", } ) print(f" ✓ Successfully created test record (ID: {test_record['id']})") # Clean up - datasets_table.delete(test_record["id"]) + models_table.delete(test_record["id"]) print(" ✓ Successfully deleted test record") except Exception as e: From 6a75bb3c4350bfe48dc29b7fb4d4785369255bd6 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 6 Jan 2026 13:49:14 -0800 Subject: [PATCH 06/18] rename the files more intuitive --- .../{airtable_callback.py => callbacks.py} | 0 viscy/airtable/datasets.py | 496 ++++++++++++++++++ ...table_dataset_registry.py => manifests.py} | 12 +- 3 files changed, 499 insertions(+), 9 deletions(-) rename viscy/airtable/{airtable_callback.py => callbacks.py} (100%) create mode 100644 viscy/airtable/datasets.py rename viscy/airtable/{airtable_dataset_registry.py => manifests.py} (92%) diff --git a/viscy/airtable/airtable_callback.py b/viscy/airtable/callbacks.py similarity index 100% rename from viscy/airtable/airtable_callback.py rename to viscy/airtable/callbacks.py diff --git a/viscy/airtable/datasets.py b/viscy/airtable/datasets.py new file mode 100644 index 000000000..b7b294009 --- /dev/null +++ b/viscy/airtable/datasets.py @@ -0,0 +1,496 @@ +"""FOV-level dataset registry with Airtable.""" + +import getpass +import json +import os +from datetime import datetime +from typing import Any + +import pandas as pd +from pyairtable import Api + +# TODO: update the usage examples in the docstrings + + +class AirtableDatasets: + """ + Interface to Airtable for FOV-level dataset management. + + Use this to: + - Register individual FOVs from HCS plates + - Create dataset "tags" (collections of FOVs) + - Query which FOVs are in each dataset + - Generate training configs from dataset tags + + Parameters + ---------- + base_id : str + Airtable base ID + api_key : str | None + Airtable API key. If None, reads from AIRTABLE_API_KEY env var. + + Examples + -------- + >>> registry = AirtableDatasets(base_id="appXXXXXXXXXXXXXX") + >>> + >>> # Create dataset from FOV selection + >>> registry.create_manifest_from_datasets( + ... dataset_name="RPE1_infection_v2", + ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], + ... version="v2", + ... purpose="training" + ... ) + >>> + >>> # Get all FOV paths for a dataset + >>> fov_paths = registry.get_dataset_fov_paths("RPE1_infection_v2") + >>> print(fov_paths) + >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] + """ + + def __init__( + self, + base_id: str, + api_key: str | None = None, + ): + api_key = api_key or os.getenv("AIRTABLE_API_KEY") + if not api_key: + raise ValueError("Airtable API key required (set AIRTABLE_API_KEY)") + + self.api = Api(api_key) + self.base_id = base_id + self.fovs_table = self.api.table(base_id, "FOVs") + self.datasets_table = self.api.table(base_id, "Datasets") + self.models_table = self.api.table(base_id, "Models") + + def register_fov( + self, + fov_id: str, + plate_name: str, + well_id: str, + row: str, + column: str, + fov_name: str, + fov_path: str, + quality: str = "Good", + metadata: dict[str, Any] | None = None, + ) -> str: + """ + Register a single FOV in Airtable. + + Parameters + ---------- + fov_id : str + Human-readable identifier (e.g., "RPE1_plate1_B_3_0") + plate_name : str + Name of the plate this FOV belongs to + well_id : str + Well identifier as row_column (e.g., "B_3") + row : str + Well row (e.g., "B") + column : str + Well column (e.g., "3") + fov_name : str + FOV index within well (e.g., "0", "1", "2") + fov_path : str + Full path to FOV (e.g., "/hpc/data/plate.zarr/B/3/0") + quality : str + Quality assessment ("Good", "Poor", "Contaminated", etc.) + metadata : dict | None + Additional metadata (cell_count, timestamp, etc.) + + Returns + ------- + str + Airtable record ID + """ + record = { + "fov_id": fov_id, + "plate_name": plate_name, + "well_id": well_id, + "row": row, + "column": column, + "fov_name": fov_name, + "fov_path": fov_path, + "quality": quality, + } + + if metadata: + # Store as JSON string in notes field + record["notes"] = json.dumps(metadata) + + created = self.fovs_table.create(record) + return created["id"] + + def create_manifest_from_datasets( + self, + dataset_name: str, + fov_ids: list[str], + version: str, + purpose: str = "training", + description: str | None = None, + ) -> str: + """ + Create a dataset (tag) from a list of FOV IDs. + + Parameters + ---------- + dataset_name : str + Name for this dataset collection + fov_ids : list[str] + List of FOV IDs to include (e.g., ["FOV_001", "FOV_002"]) + version : str + Semantic version (e.g., "0.0.1", "0.1.0", "1.0.0") + REQUIRED - forces conscious versioning + purpose : str + Purpose of this dataset ("training", "validation", "test") + description : str | None + Human-readable description + + Returns + ------- + str + Airtable dataset record ID + + Examples + -------- + >>> registry.create_manifest_from_datasets( + ... dataset_name="RPE1_clean_wells", + ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], + ... version="0.0.1", + ... description="High-quality FOVs from wells B3-B4" + ... ) + """ + # Validate semantic version format + import re + + if not re.match(r"^\d+\.\d+\.\d+$", version): + raise ValueError( + f"Version must be semantic version format (e.g., '0.0.1', '1.0.0'), got: '{version}'" + ) + + # Check if dataset with same name + version exists (use DataFrame) + df_datasets = self.list_datasets() + + if len(df_datasets) > 0: + existing = df_datasets[ + (df_datasets["name"] == dataset_name) + & (df_datasets["version"] == version) + ] + + if len(existing) > 0: + raise ValueError( + f"Dataset '{dataset_name}' version '{version}' already exists. " + f"To create a new version, increment the version number (e.g., '0.0.2')." + ) + + # Show existing versions (helpful feedback) + existing_versions = df_datasets[df_datasets["name"] == dataset_name] + if len(existing_versions) > 0: + versions = sorted(existing_versions["version"].tolist()) + print(f"ℹ Dataset '{dataset_name}' existing versions: {versions}") + print(f" Creating new version: '{version}'") + + # Get Airtable record IDs for these FOV IDs (ensure unique) + fov_record_ids = [] + seen_fov_ids = set() + + for fov_id in fov_ids: + if fov_id in seen_fov_ids: + continue # Skip duplicates + + formula = f"{{fov_id}}='{fov_id}'" + records = self.fovs_table.all(formula=formula) + if records: + fov_record_ids.append(records[0]["id"]) + seen_fov_ids.add(fov_id) + else: + raise ValueError(f"FOV '{fov_id}' not found in FOVs table") + + # Remove any duplicate record IDs (shouldn't happen, but extra safety) + fov_record_ids = list(dict.fromkeys(fov_record_ids)) + + # Create dataset record + dataset_record = { + "name": dataset_name, + "fovs": fov_record_ids, # Linked records (unique) + "version": version, # Semantic version (required) + "purpose": purpose, + "created_date": datetime.now().isoformat(), + "created_by": getpass.getuser(), + "num_fovs": len(fov_record_ids), + } + + if description: + dataset_record["description"] = description + + created = self.datasets_table.create(dataset_record) + return created["id"] + + def create_dataset_from_query( + self, + dataset_name: str, + version: str, + plate_name: str | None = None, + well_ids: list[str] | None = None, + quality: str | None = None, + exclude_fov_ids: list[str] | None = None, + **kwargs, + ) -> str: + """ + Create a dataset by filtering FOVs with pandas. + + Parameters + ---------- + dataset_name : str + Name for this dataset + version : str + Semantic version (e.g., "0.0.1") - REQUIRED + plate_name : str | None + Filter by plate name + well_ids : list[str] | None + Filter by well identifiers (e.g., ["B_3", "B_4"]) + quality : str | None + Filter by quality ("Good", "Poor", etc.) + exclude_fov_ids : list[str] | None + FOV IDs to exclude + **kwargs + Additional arguments for create_manifest_from_datasets + + Returns + ------- + str + Airtable dataset record ID + + Examples + -------- + >>> # Create dataset from all good-quality FOVs in specific wells + >>> registry.create_dataset_from_query( + ... dataset_name="RPE1_infection_training", + ... version="0.0.1", + ... plate_name="RPE1_plate1", + ... well_ids=["B_3", "B_4"], + ... quality="Good", + ... exclude_fov_ids=["RPE1_plate1_B_3_2"] + ... ) + """ + # Get all FOVs as DataFrame + df = self.list_fovs() + + # Apply filters with pandas + if plate_name: + df = df[df["plate_name"] == plate_name] + + if quality: + df = df[df["quality"] == quality] + + if well_ids: + df = df[df["well_id"].isin(well_ids)] + + # Exclude specified FOVs + if exclude_fov_ids: + df = df[~df["fov_id"].isin(exclude_fov_ids)] + + fov_ids = df["fov_id"].tolist() + + print(f"Found {len(fov_ids)} FOVs matching criteria") + + # Create dataset + return self.create_manifest_from_datasets( + dataset_name=dataset_name, version=version, fov_ids=fov_ids, **kwargs + ) + + def get_dataset_fov_paths( + self, dataset_name: str, version: str | None = None + ) -> list[str]: + """ + Get list of FOV paths for a dataset. + + Parameters + ---------- + dataset_name : str + Dataset name + version : str | None + Specific version (if None, returns latest) + + Returns + ------- + list[str] + List of FOV paths + + Examples + -------- + >>> paths = registry.get_dataset_fov_paths("RPE1_infection_v2") + >>> print(paths) + >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] + """ + # Get all datasets as DataFrame + df_datasets = self.list_datasets() + + if len(df_datasets) == 0: + raise ValueError(f"Dataset '{dataset_name}' not found") + + # Filter by name + filtered = df_datasets[df_datasets["name"] == dataset_name] + + if len(filtered) == 0: + raise ValueError(f"Dataset '{dataset_name}' not found") + + # Filter by version if specified, otherwise get latest + if version: + filtered = filtered[filtered["version"] == version] + if len(filtered) == 0: + raise ValueError( + f"Dataset '{dataset_name}' version '{version}' not found" + ) + else: + # Get latest version (sort by created_date) + filtered = filtered.sort_values("created_date", ascending=False) + + # Get the first (or only) matching dataset + dataset_row = filtered.iloc[0] + + # Get linked FOV record IDs + fov_record_ids = dataset_row.get("fovs", []) + if not fov_record_ids or len(fov_record_ids) == 0: + return [] + + # Fetch FOV paths + fov_paths = [] + for fov_id in fov_record_ids: + fov_record = self.fovs_table.get(fov_id) + fov_paths.append(fov_record["fields"]["fov_path"]) + + return fov_paths + + def get_dataset( + self, dataset_name: str, version: str | None = None + ) -> dict[str, Any]: + """ + Get full dataset information including FOV details. + + Parameters + ---------- + dataset_name : str + Dataset name + version : str | None + Specific version + + Returns + ------- + dict + Dataset info with FOV paths and metadata + """ + # Get all datasets as DataFrame + df_datasets = self.list_datasets() + + if len(df_datasets) == 0: + raise ValueError(f"Dataset '{dataset_name}' not found") + + # Filter by name + filtered = df_datasets[df_datasets["name"] == dataset_name] + + if len(filtered) == 0: + raise ValueError(f"Dataset '{dataset_name}' not found") + + # Filter by version if specified, otherwise get latest + if version: + filtered = filtered[filtered["version"] == version] + if len(filtered) == 0: + raise ValueError( + f"Dataset '{dataset_name}' version '{version}' not found" + ) + else: + # Get latest version (sort by created_date) + filtered = filtered.sort_values("created_date", ascending=False) + + # Get the first (or only) matching dataset + dataset_row = filtered.iloc[0] + dataset = dataset_row.to_dict() + + # Add FOV paths + dataset["fov_paths"] = self.get_dataset_fov_paths(dataset_name, version) + + return dataset + + def list_datasets( + self, purpose: str | None = None, as_dataframe: bool = True + ) -> pd.DataFrame | list[dict]: + """ + List all datasets. + + Parameters + ---------- + purpose : str | None + Filter by purpose ("training", "validation", "test") + as_dataframe : bool + If True, return pandas DataFrame. If False, return list of dicts. + + Returns + ------- + pd.DataFrame | list[dict] + Dataset records as DataFrame or list of dicts + + Examples + -------- + >>> registry.list_datasets(purpose="training") + >>> # Returns DataFrame with columns: id, name, version, purpose, ... + """ + # Fetch all datasets (sorted by most recent first) + records = self.datasets_table.all(sort=["-created_date"]) + data = [{"id": r["id"], **r["fields"]} for r in records] + + # Convert to DataFrame or list + if as_dataframe: + df = pd.DataFrame(data) + # Filter by purpose if specified + if purpose and len(df) > 0: + df = df[df["purpose"] == purpose] + return df + else: + # Filter list if purpose specified + if purpose: + data = [d for d in data if d.get("purpose") == purpose] + return data + + def list_fovs(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: + """ + Get all FOVs as a DataFrame (or list of dicts). + + Use pandas for filtering - much simpler and more powerful than + building Airtable formulas. + + Parameters + ---------- + as_dataframe : bool + If True, return pandas DataFrame. If False, return list of dicts. + + Returns + ------- + pd.DataFrame | list[dict] + All FOV records + + Examples + -------- + >>> # Get all FOVs + >>> df = registry.list_fovs() + >>> + >>> # Filter with pandas (simple and powerful!) + >>> filtered = df[df['plate_name'] == 'RPE1_plate1'] + >>> filtered = df[df['quality'] == 'Good'] + >>> filtered = df[df['row'] == 'B'] + >>> filtered = df[df['row'].isin(['B', 'C'])] + >>> filtered = df[(df['row'] == 'B') & (df['column'] == '3')] + >>> + >>> # Exclude FOVs + >>> filtered = df[~df['fov_id'].isin(['RPE1_plate1_B_3_2'])] + >>> + >>> # Group and analyze + >>> df.groupby('plate_name').size() + >>> df.groupby(['row', 'column']).size() + """ + records = self.fovs_table.all() + data = [{"id": r["id"], **r["fields"]} for r in records] + + if as_dataframe: + return pd.DataFrame(data) + return data diff --git a/viscy/airtable/airtable_dataset_registry.py b/viscy/airtable/manifests.py similarity index 92% rename from viscy/airtable/airtable_dataset_registry.py rename to viscy/airtable/manifests.py index 535ab13e4..141863e57 100644 --- a/viscy/airtable/airtable_dataset_registry.py +++ b/viscy/airtable/manifests.py @@ -28,7 +28,7 @@ class AirtableDatasetRegistry: >>> registry = AirtableDatasetRegistry(base_id="appXXXXXXXXXXXXXX") >>> >>> # Get dataset info - >>> dataset = registry.get_dataset("rpe1_fucci_embeddings", version="v2") + >>> dataset = registry.get_manifest("rpe1_fucci_embeddings", version="v2") >>> print(dataset['hpc_path']) >>> >>> # Record that a model was trained with this dataset @@ -53,7 +53,7 @@ def __init__( self.datasets_table = self.api.table(base_id, "Datasets") self.models_table = self.api.table(base_id, "Models") - def get_dataset(self, name: str, version: str | None = None) -> dict[str, Any]: + def get_manifest(self, name: str, version: str | None = None) -> dict[str, Any]: """ Retrieve dataset record from Airtable. @@ -66,13 +66,7 @@ def get_dataset(self, name: str, version: str | None = None) -> dict[str, Any]: Returns ------- - dict - Airtable record with fields: - - id: Airtable record ID - - hpc_path: Path to dataset on HPC - - version: Dataset version - - sha256: Dataset hash - - created_date: Creation timestamp + TODO: typing for the headers in manifest """ if version: formula = f"AND({{name}}='{name}', {{version}}='{version}')" From 886352356620819dd6bb4ca350e6f447e92b0f90 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 6 Jan 2026 13:49:31 -0800 Subject: [PATCH 07/18] remove the fovregistry --- viscy/airtable/airtable_fov_registry.py | 494 ------------------------ 1 file changed, 494 deletions(-) delete mode 100644 viscy/airtable/airtable_fov_registry.py diff --git a/viscy/airtable/airtable_fov_registry.py b/viscy/airtable/airtable_fov_registry.py deleted file mode 100644 index f4bdbf59d..000000000 --- a/viscy/airtable/airtable_fov_registry.py +++ /dev/null @@ -1,494 +0,0 @@ -"""FOV-level dataset registry with Airtable.""" - -import getpass -import json -import os -from datetime import datetime -from typing import Any - -import pandas as pd -from pyairtable import Api - - -class AirtableFOVRegistry: - """ - Interface to Airtable for FOV-level dataset management. - - Use this to: - - Register individual FOVs from HCS plates - - Create dataset "tags" (collections of FOVs) - - Query which FOVs are in each dataset - - Generate training configs from dataset tags - - Parameters - ---------- - base_id : str - Airtable base ID - api_key : str | None - Airtable API key. If None, reads from AIRTABLE_API_KEY env var. - - Examples - -------- - >>> registry = AirtableFOVRegistry(base_id="appXXXXXXXXXXXXXX") - >>> - >>> # Create dataset from FOV selection - >>> registry.create_dataset_from_fovs( - ... dataset_name="RPE1_infection_v2", - ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], - ... version="v2", - ... purpose="training" - ... ) - >>> - >>> # Get all FOV paths for a dataset - >>> fov_paths = registry.get_dataset_fov_paths("RPE1_infection_v2") - >>> print(fov_paths) - >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] - """ - - def __init__( - self, - base_id: str, - api_key: str | None = None, - ): - api_key = api_key or os.getenv("AIRTABLE_API_KEY") - if not api_key: - raise ValueError("Airtable API key required (set AIRTABLE_API_KEY)") - - self.api = Api(api_key) - self.base_id = base_id - self.fovs_table = self.api.table(base_id, "FOVs") - self.datasets_table = self.api.table(base_id, "Datasets") - self.models_table = self.api.table(base_id, "Models") - - def register_fov( - self, - fov_id: str, - plate_name: str, - well_id: str, - row: str, - column: str, - fov_name: str, - fov_path: str, - quality: str = "Good", - metadata: dict[str, Any] | None = None, - ) -> str: - """ - Register a single FOV in Airtable. - - Parameters - ---------- - fov_id : str - Human-readable identifier (e.g., "RPE1_plate1_B_3_0") - plate_name : str - Name of the plate this FOV belongs to - well_id : str - Well identifier as row_column (e.g., "B_3") - row : str - Well row (e.g., "B") - column : str - Well column (e.g., "3") - fov_name : str - FOV index within well (e.g., "0", "1", "2") - fov_path : str - Full path to FOV (e.g., "/hpc/data/plate.zarr/B/3/0") - quality : str - Quality assessment ("Good", "Poor", "Contaminated", etc.) - metadata : dict | None - Additional metadata (cell_count, timestamp, etc.) - - Returns - ------- - str - Airtable record ID - """ - record = { - "fov_id": fov_id, - "plate_name": plate_name, - "well_id": well_id, - "row": row, - "column": column, - "fov_name": fov_name, - "fov_path": fov_path, - "quality": quality, - } - - if metadata: - # Store as JSON string in notes field - record["notes"] = json.dumps(metadata) - - created = self.fovs_table.create(record) - return created["id"] - - def create_dataset_from_fovs( - self, - dataset_name: str, - fov_ids: list[str], - version: str, - purpose: str = "training", - description: str | None = None, - ) -> str: - """ - Create a dataset (tag) from a list of FOV IDs. - - Parameters - ---------- - dataset_name : str - Name for this dataset collection - fov_ids : list[str] - List of FOV IDs to include (e.g., ["FOV_001", "FOV_002"]) - version : str - Semantic version (e.g., "0.0.1", "0.1.0", "1.0.0") - REQUIRED - forces conscious versioning - purpose : str - Purpose of this dataset ("training", "validation", "test") - description : str | None - Human-readable description - - Returns - ------- - str - Airtable dataset record ID - - Examples - -------- - >>> registry.create_dataset_from_fovs( - ... dataset_name="RPE1_clean_wells", - ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], - ... version="0.0.1", - ... description="High-quality FOVs from wells B3-B4" - ... ) - """ - # Validate semantic version format - import re - - if not re.match(r"^\d+\.\d+\.\d+$", version): - raise ValueError( - f"Version must be semantic version format (e.g., '0.0.1', '1.0.0'), got: '{version}'" - ) - - # Check if dataset with same name + version exists (use DataFrame) - df_datasets = self.list_datasets() - - if len(df_datasets) > 0: - existing = df_datasets[ - (df_datasets["name"] == dataset_name) - & (df_datasets["version"] == version) - ] - - if len(existing) > 0: - raise ValueError( - f"Dataset '{dataset_name}' version '{version}' already exists. " - f"To create a new version, increment the version number (e.g., '0.0.2')." - ) - - # Show existing versions (helpful feedback) - existing_versions = df_datasets[df_datasets["name"] == dataset_name] - if len(existing_versions) > 0: - versions = sorted(existing_versions["version"].tolist()) - print(f"ℹ Dataset '{dataset_name}' existing versions: {versions}") - print(f" Creating new version: '{version}'") - - # Get Airtable record IDs for these FOV IDs (ensure unique) - fov_record_ids = [] - seen_fov_ids = set() - - for fov_id in fov_ids: - if fov_id in seen_fov_ids: - continue # Skip duplicates - - formula = f"{{fov_id}}='{fov_id}'" - records = self.fovs_table.all(formula=formula) - if records: - fov_record_ids.append(records[0]["id"]) - seen_fov_ids.add(fov_id) - else: - raise ValueError(f"FOV '{fov_id}' not found in FOVs table") - - # Remove any duplicate record IDs (shouldn't happen, but extra safety) - fov_record_ids = list(dict.fromkeys(fov_record_ids)) - - # Create dataset record - dataset_record = { - "name": dataset_name, - "fovs": fov_record_ids, # Linked records (unique) - "version": version, # Semantic version (required) - "purpose": purpose, - "created_date": datetime.now().isoformat(), - "created_by": getpass.getuser(), - "num_fovs": len(fov_record_ids), - } - - if description: - dataset_record["description"] = description - - created = self.datasets_table.create(dataset_record) - return created["id"] - - def create_dataset_from_query( - self, - dataset_name: str, - version: str, - plate_name: str | None = None, - well_ids: list[str] | None = None, - quality: str | None = None, - exclude_fov_ids: list[str] | None = None, - **kwargs, - ) -> str: - """ - Create a dataset by filtering FOVs with pandas. - - Parameters - ---------- - dataset_name : str - Name for this dataset - version : str - Semantic version (e.g., "0.0.1") - REQUIRED - plate_name : str | None - Filter by plate name - well_ids : list[str] | None - Filter by well identifiers (e.g., ["B_3", "B_4"]) - quality : str | None - Filter by quality ("Good", "Poor", etc.) - exclude_fov_ids : list[str] | None - FOV IDs to exclude - **kwargs - Additional arguments for create_dataset_from_fovs - - Returns - ------- - str - Airtable dataset record ID - - Examples - -------- - >>> # Create dataset from all good-quality FOVs in specific wells - >>> registry.create_dataset_from_query( - ... dataset_name="RPE1_infection_training", - ... version="0.0.1", - ... plate_name="RPE1_plate1", - ... well_ids=["B_3", "B_4"], - ... quality="Good", - ... exclude_fov_ids=["RPE1_plate1_B_3_2"] - ... ) - """ - # Get all FOVs as DataFrame - df = self.list_fovs() - - # Apply filters with pandas - if plate_name: - df = df[df["plate_name"] == plate_name] - - if quality: - df = df[df["quality"] == quality] - - if well_ids: - df = df[df["well_id"].isin(well_ids)] - - # Exclude specified FOVs - if exclude_fov_ids: - df = df[~df["fov_id"].isin(exclude_fov_ids)] - - fov_ids = df["fov_id"].tolist() - - print(f"Found {len(fov_ids)} FOVs matching criteria") - - # Create dataset - return self.create_dataset_from_fovs( - dataset_name=dataset_name, version=version, fov_ids=fov_ids, **kwargs - ) - - def get_dataset_fov_paths( - self, dataset_name: str, version: str | None = None - ) -> list[str]: - """ - Get list of FOV paths for a dataset. - - Parameters - ---------- - dataset_name : str - Dataset name - version : str | None - Specific version (if None, returns latest) - - Returns - ------- - list[str] - List of FOV paths - - Examples - -------- - >>> paths = registry.get_dataset_fov_paths("RPE1_infection_v2") - >>> print(paths) - >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] - """ - # Get all datasets as DataFrame - df_datasets = self.list_datasets() - - if len(df_datasets) == 0: - raise ValueError(f"Dataset '{dataset_name}' not found") - - # Filter by name - filtered = df_datasets[df_datasets["name"] == dataset_name] - - if len(filtered) == 0: - raise ValueError(f"Dataset '{dataset_name}' not found") - - # Filter by version if specified, otherwise get latest - if version: - filtered = filtered[filtered["version"] == version] - if len(filtered) == 0: - raise ValueError( - f"Dataset '{dataset_name}' version '{version}' not found" - ) - else: - # Get latest version (sort by created_date) - filtered = filtered.sort_values("created_date", ascending=False) - - # Get the first (or only) matching dataset - dataset_row = filtered.iloc[0] - - # Get linked FOV record IDs - fov_record_ids = dataset_row.get("fovs", []) - if not fov_record_ids or len(fov_record_ids) == 0: - return [] - - # Fetch FOV paths - fov_paths = [] - for fov_id in fov_record_ids: - fov_record = self.fovs_table.get(fov_id) - fov_paths.append(fov_record["fields"]["fov_path"]) - - return fov_paths - - def get_dataset( - self, dataset_name: str, version: str | None = None - ) -> dict[str, Any]: - """ - Get full dataset information including FOV details. - - Parameters - ---------- - dataset_name : str - Dataset name - version : str | None - Specific version - - Returns - ------- - dict - Dataset info with FOV paths and metadata - """ - # Get all datasets as DataFrame - df_datasets = self.list_datasets() - - if len(df_datasets) == 0: - raise ValueError(f"Dataset '{dataset_name}' not found") - - # Filter by name - filtered = df_datasets[df_datasets["name"] == dataset_name] - - if len(filtered) == 0: - raise ValueError(f"Dataset '{dataset_name}' not found") - - # Filter by version if specified, otherwise get latest - if version: - filtered = filtered[filtered["version"] == version] - if len(filtered) == 0: - raise ValueError( - f"Dataset '{dataset_name}' version '{version}' not found" - ) - else: - # Get latest version (sort by created_date) - filtered = filtered.sort_values("created_date", ascending=False) - - # Get the first (or only) matching dataset - dataset_row = filtered.iloc[0] - dataset = dataset_row.to_dict() - - # Add FOV paths - dataset["fov_paths"] = self.get_dataset_fov_paths(dataset_name, version) - - return dataset - - def list_datasets( - self, purpose: str | None = None, as_dataframe: bool = True - ) -> pd.DataFrame | list[dict]: - """ - List all datasets. - - Parameters - ---------- - purpose : str | None - Filter by purpose ("training", "validation", "test") - as_dataframe : bool - If True, return pandas DataFrame. If False, return list of dicts. - - Returns - ------- - pd.DataFrame | list[dict] - Dataset records as DataFrame or list of dicts - - Examples - -------- - >>> registry.list_datasets(purpose="training") - >>> # Returns DataFrame with columns: id, name, version, purpose, ... - """ - # Fetch all datasets (sorted by most recent first) - records = self.datasets_table.all(sort=["-created_date"]) - data = [{"id": r["id"], **r["fields"]} for r in records] - - # Convert to DataFrame or list - if as_dataframe: - df = pd.DataFrame(data) - # Filter by purpose if specified - if purpose and len(df) > 0: - df = df[df["purpose"] == purpose] - return df - else: - # Filter list if purpose specified - if purpose: - data = [d for d in data if d.get("purpose") == purpose] - return data - - def list_fovs(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: - """ - Get all FOVs as a DataFrame (or list of dicts). - - Use pandas for filtering - much simpler and more powerful than - building Airtable formulas. - - Parameters - ---------- - as_dataframe : bool - If True, return pandas DataFrame. If False, return list of dicts. - - Returns - ------- - pd.DataFrame | list[dict] - All FOV records - - Examples - -------- - >>> # Get all FOVs - >>> df = registry.list_fovs() - >>> - >>> # Filter with pandas (simple and powerful!) - >>> filtered = df[df['plate_name'] == 'RPE1_plate1'] - >>> filtered = df[df['quality'] == 'Good'] - >>> filtered = df[df['row'] == 'B'] - >>> filtered = df[df['row'].isin(['B', 'C'])] - >>> filtered = df[(df['row'] == 'B') & (df['column'] == '3')] - >>> - >>> # Exclude FOVs - >>> filtered = df[~df['fov_id'].isin(['RPE1_plate1_B_3_2'])] - >>> - >>> # Group and analyze - >>> df.groupby('plate_name').size() - >>> df.groupby(['row', 'column']).size() - """ - records = self.fovs_table.all() - data = [{"id": r["id"], **r["fields"]} for r in records] - - if as_dataframe: - return pd.DataFrame(data) - return data From 37eb7798595c32d368cec06d045c8cf27da9b7df Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 6 Jan 2026 14:52:24 -0800 Subject: [PATCH 08/18] update the manifest and datasets to wrangle our new fov database --- .../airtable/filter_n_create_dataset_tag.py | 140 +++--- viscy/airtable/callbacks.py | 24 +- viscy/airtable/datasets.py | 398 ++++++++++-------- viscy/airtable/manifests.py | 63 ++- 4 files changed, 327 insertions(+), 298 deletions(-) diff --git a/examples/airtable/filter_n_create_dataset_tag.py b/examples/airtable/filter_n_create_dataset_tag.py index 31b94ebc3..840d33e64 100644 --- a/examples/airtable/filter_n_create_dataset_tag.py +++ b/examples/airtable/filter_n_create_dataset_tag.py @@ -1,143 +1,121 @@ -"""Filter FOVs using pandas and create dataset tags.""" +"""Filter datasets using pandas and create manifest tags.""" # %% -import os -from viscy.airtable.airtable_fov_registry import AirtableFOVRegistry +from viscy.airtable.datasets import AirtableDatasets -BASE_ID = os.getenv("AIRTABLE_BASE_ID") -registry = AirtableFOVRegistry(base_id=BASE_ID) +# BASE_ID = os.getenv("AIRTABLE_BASE_ID") +BASE_ID = "app8vqaoWyOwa0sB5" +registry = AirtableDatasets(base_id=BASE_ID) # %% -# EXAMPLE 1: Get all FOVs as DataFrame and explore +# EXAMPLE 1: Get all dataset records as DataFrame and explore print("=" * 70) -print("Getting all FOVs as DataFrame") +print("Getting all dataset records as DataFrame") print("=" * 70) -df_fovs = registry.list_fovs() -print(f"\nTotal FOVs: {len(df_fovs)}") +df_datasets = registry.list_datasets() +print(f"\nTotal dataset records: {len(df_datasets)}") print("\nDataFrame columns:") -print(df_fovs.columns.tolist()) +print(df_datasets.columns.tolist()) print("\nFirst few rows:") -print(df_fovs.head()) +print(df_datasets.head()) # %% -# EXAMPLE 2: Filter by plate and rows B and C using pandas +# EXAMPLE 2: Filter by dataset and specific wells using pandas print("\n" + "=" * 70) -print("Filter: Plate RPE1_plate1, Rows B and C, Good quality") +print("Filter: Dataset, Wells B_3 and B_4") print("=" * 70) -# Get all FOVs as DataFrame -df = registry.list_fovs() +# Get all dataset records as DataFrame +df = registry.list_datasets() # Filter with pandas - simple and powerful! filtered = df[ - (df["plate_name"] == "RPE1_plate1") - & (df["quality"] == "Good") - & (df["row"].isin(["B", "C"])) + (df["Dataset"] == "2024_11_07_A549_SEC61_DENV") + & (df["Well ID"].isin(["B/1", "B/2"])) ] -print(f"\nTotal FOVs after filtering: {len(filtered)}") +print(f"\nTotal dataset records after filtering: {len(filtered)}") print("\nBreakdown by well:") -print(filtered.groupby("well_id").size()) +print(filtered.groupby("Well ID").size()) -# Create dataset from filtered FOVs -fov_ids = filtered["fov_id"].tolist() +# Create manifest from filtered dataset records +fov_ids = filtered["FOV_ID"].tolist() try: - dataset_id = registry.create_dataset_from_fovs( - dataset_name="RPE1_rows_BC_good", + manifest_id = registry.create_manifest_from_datasets( + manifest_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", fov_ids=fov_ids, version="0.0.1", # Semantic versioning purpose="training", - description="Good quality FOVs from rows B and C", + description="Dataset records from wells B_3 and B_4", ) - print(f"\n✓ Created dataset: {dataset_id}") - print(f" Contains {len(fov_ids)} FOVs") + print(f"\n✓ Created manifest: {manifest_id}") + print(f" Contains {len(fov_ids)} dataset records") except ValueError as e: print(f"\n⚠ {e}") # %% -# EXAMPLE 3: Group by plate and show summary +# Delete the manifest entry demo +registry.delete_manifest(manifest_id) +print(f"Deleted manifest: {manifest_id}") + +# %% +# EXAMPLE 3: Group by dataset and show summary print("\n" + "=" * 70) -print("Group by plate and show summary") +print("Group by dataset and show summary") print("=" * 70) -df_all = registry.list_fovs() +df_all = registry.list_datasets() -# Filter for good quality only -df_all = df_all[df_all["quality"] == "Good"] +grouped = df_all.groupby("Dataset") -grouped = df_all.groupby("plate_name") - -for plate_name, group in grouped: - print(f"\n{plate_name}:") - print(f" Total FOVs: {len(group)}") - print(f" Wells: {group['well_id'].unique()}") - print(f" Rows: {group['row'].unique()}") +for dataset_name, group in grouped: + print(f"\n{dataset_name}:") + print(f" Total records: {len(group)}") + print(f" Wells: {sorted(group['Well ID'].unique())}") # %% -# EXAMPLE 4: Complex filtering - specific rows and columns +# EXAMPLE 4: Filter by multiple wells print("\n" + "=" * 70) -print("Complex Filter: Rows B/C AND Columns 3/4") +print("Filter: Multiple specific wells") print("=" * 70) -df = registry.list_fovs() +df = registry.list_datasets() -# Complex pandas filter: plate, quality, rows B or C, AND columns 3 or 4 +# Filter for specific wells from a dataset filtered = df[ - (df["plate_name"] == "RPE1_plate1") - & (df["quality"] == "Good") - & (df["row"].isin(["B", "C"])) - & (df["column"].isin(["3", "4"])) + (df["Dataset"] == "2024_11_07_A549_SEC61_DENV") + & (df["Well ID"].isin(["B/3", "B/4", "C/3", "C/4"])) ] -print(f"\nFOVs matching criteria: {len(filtered)}") +print(f"\nDataset records matching criteria: {len(filtered)}") print("\nBy well:") -print(filtered.groupby(["row", "column"]).size()) +print(filtered.groupby("Well ID").size()) print("\nFOV IDs:") -for fov_id in filtered["fov_id"]: +for fov_id in filtered["FOV_ID"]: print(f" {fov_id}") # %% -# EXAMPLE 5: Exclude specific FOVs +# EXAMPLE 5: Summary statistics print("\n" + "=" * 70) -print("Exclude specific FOVs from dataset") +print("Summary Statistics") print("=" * 70) -df = registry.list_fovs() +df = registry.list_datasets() -# Start with good quality FOVs from specific plate -filtered = df[(df["plate_name"] == "RPE1_plate1") & (df["quality"] == "Good")] +print("\nDataset records per source dataset:") +print(df.groupby("Dataset").size()) -print(f"\nBefore exclusion: {len(filtered)} FOVs") +print("\nWells with most dataset records:") +print(df.groupby("Well ID").size().sort_values(ascending=False).head(10)) -# List of FOVs to exclude (e.g., known contamination) -exclude_list = ["RPE1_plate1_B_3_2", "RPE1_plate1_C_4_1"] +print("\nTotal unique wells:") +print(f"{df['Well ID'].nunique()} wells") -# Filter out excluded FOVs -filtered = filtered[~filtered["fov_id"].isin(exclude_list)] - -print(f"Excluded: {len(exclude_list)} FOVs") -print(f"After exclusion: {len(filtered)} FOVs") +print("\nTotal unique FOV IDs:") +print(f"{df['FOV_ID'].nunique()} FOV IDs") # %% -# EXAMPLE 6: Summary statistics -print("\n" + "=" * 70) -print("Summary Statistics") -print("=" * 70) - -df = registry.list_fovs() - -print("\nFOVs per plate:") -print(df.groupby("plate_name").size()) - -print("\nFOVs per quality:") -print(df.groupby("quality").size()) - -print("\nFOVs per row (across all plates):") -print(df.groupby("row").size().sort_index()) - -print("\nWells with most FOVs:") -print(df.groupby("well_id").size().sort_values(ascending=False).head(10)) diff --git a/viscy/airtable/callbacks.py b/viscy/airtable/callbacks.py index 4dbf93ab9..1eeb411e9 100644 --- a/viscy/airtable/callbacks.py +++ b/viscy/airtable/callbacks.py @@ -6,7 +6,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.callbacks import Callback -from viscy.airtable.airtable_dataset_registry import AirtableDatasetRegistry +from viscy.airtable.manifests import AirtableManifests class AirtableLoggingCallback(Callback): @@ -17,14 +17,14 @@ class AirtableLoggingCallback(Callback): - Best model checkpoint path - Who trained the model - When it was trained - - Link to the dataset used + - Link to the manifest used Parameters ---------- base_id : str Airtable base ID - dataset_id : str - Airtable dataset record ID (from config) + manifest_id : str + Airtable manifest record ID (from config) model_name : str | None Custom model name. If None, auto-generates from model class and timestamp. log_metrics : bool @@ -37,16 +37,16 @@ class AirtableLoggingCallback(Callback): >>> trainer: >>> callbacks: - >>> - class_path: viscy.representation.airtable_callback.AirtableLoggingCallback + >>> - class_path: viscy.airtable.callbacks.AirtableLoggingCallback >>> init_args: >>> base_id: "appXXXXXXXXXXXXXX" - >>> dataset_id: "recYYYYYYYYYYYYYY" + >>> manifest_id: "recYYYYYYYYYYYYYY" Or add programmatically: >>> callback = AirtableLoggingCallback( >>> base_id="appXXXXXXXXXXXXXX", - >>> dataset_id="recYYYYYYYYYYYYYY" + >>> manifest_id="recYYYYYYYYYYYYYY" >>> ) >>> trainer = Trainer(callbacks=[callback]) """ @@ -54,13 +54,13 @@ class AirtableLoggingCallback(Callback): def __init__( self, base_id: str, - dataset_id: str, + manifest_id: str, model_name: str | None = None, log_metrics: bool = False, ): super().__init__() - self.registry = AirtableDatasetRegistry(base_id=base_id) - self.dataset_id = dataset_id + self.registry = AirtableManifests(base_id=base_id) + self.manifest_id = manifest_id self.model_name = model_name self.log_metrics = log_metrics @@ -104,7 +104,7 @@ def on_fit_end(self, trainer: Trainer, pl_module: Any) -> None: # Log to Airtable try: model_id = self.registry.log_model_training( - dataset_id=self.dataset_id, + manifest_id=self.manifest_id, mlflow_run_id=run_id or "unknown", model_name=model_name, checkpoint_path=str(checkpoint_path) if checkpoint_path else None, @@ -114,7 +114,7 @@ def on_fit_end(self, trainer: Trainer, pl_module: Any) -> None: print(f"\n✓ Model logged to Airtable (record ID: {model_id})") print(f" Model name: {model_name}") print(f" Checkpoint: {checkpoint_path}") - print(f" Dataset ID: {self.dataset_id}") + print(f" Manifest ID: {self.manifest_id}") except Exception as e: print(f"\n✗ Failed to log to Airtable: {e}") # Don't fail training if Airtable logging fails diff --git a/viscy/airtable/datasets.py b/viscy/airtable/datasets.py index b7b294009..a76254256 100644 --- a/viscy/airtable/datasets.py +++ b/viscy/airtable/datasets.py @@ -1,15 +1,51 @@ """FOV-level dataset registry with Airtable.""" import getpass -import json import os -from datetime import datetime from typing import Any import pandas as pd from pyairtable import Api # TODO: update the usage examples in the docstrings +# TODO: update the headers to match the Airtable columns and potentially move to separate file +# TODO: Convert these to Pydantic models, so we can easily dump and load from/to Airtable + +DATASETS_INDEX = [ + "Dataset", + "Well ID", + "FOV", + "Cell type", + "Cell state", + "Cell line", + "Organelle", + "Channel-0", + "Channel-1", + "Channel-2", + "Data path", + "Fluorescence modality", + "OrganelleBox Infectomics", + "FOV_ID", +] + +MODELS_INDEX = [ + "model_name", + "model_family", + "trained_date", + "trained_by", + "tensorboard_log", + "mlflow_run_id", +] + +MANIFESTS_INDEX = [ + "name", + "version", + "datasets", + "project", + "purpose", + "created_by", + "created_time", +] class AirtableDatasets: @@ -58,45 +94,33 @@ def __init__( self.api = Api(api_key) self.base_id = base_id - self.fovs_table = self.api.table(base_id, "FOVs") self.datasets_table = self.api.table(base_id, "Datasets") + self.manifests_table = self.api.table(base_id, "Manifest") self.models_table = self.api.table(base_id, "Models") - def register_fov( + def register_dataset( self, fov_id: str, - plate_name: str, + dataset_name: str, well_id: str, - row: str, - column: str, fov_name: str, - fov_path: str, - quality: str = "Good", - metadata: dict[str, Any] | None = None, + data_path: str, ) -> str: """ - Register a single FOV in Airtable. + Register a single dataset record (FOV) in Airtable. Parameters ---------- fov_id : str Human-readable identifier (e.g., "RPE1_plate1_B_3_0") - plate_name : str - Name of the plate this FOV belongs to + dataset_name : str + Name of the dataset/plate this FOV belongs to well_id : str Well identifier as row_column (e.g., "B_3") - row : str - Well row (e.g., "B") - column : str - Well column (e.g., "3") fov_name : str FOV index within well (e.g., "0", "1", "2") - fov_path : str + data_path : str Full path to FOV (e.g., "/hpc/data/plate.zarr/B/3/0") - quality : str - Quality assessment ("Good", "Poor", "Contaminated", etc.) - metadata : dict | None - Additional metadata (cell_count, timestamp, etc.) Returns ------- @@ -104,60 +128,59 @@ def register_fov( Airtable record ID """ record = { - "fov_id": fov_id, - "plate_name": plate_name, - "well_id": well_id, - "row": row, - "column": column, - "fov_name": fov_name, - "fov_path": fov_path, - "quality": quality, + "FOV_ID": fov_id, + "Dataset": dataset_name, + "Well ID": well_id, + "FOV": fov_name, + "Data path": data_path, } - if metadata: - # Store as JSON string in notes field - record["notes"] = json.dumps(metadata) - - created = self.fovs_table.create(record) + created = self.datasets_table.create(record) return created["id"] def create_manifest_from_datasets( self, - dataset_name: str, + manifest_name: str, fov_ids: list[str], version: str, purpose: str = "training", + project_name: str | None = None, description: str | None = None, ) -> str: """ - Create a dataset (tag) from a list of FOV IDs. + Create a manifest (collection) from a list of FOV IDs. Parameters ---------- - dataset_name : str - Name for this dataset collection + manifest_name : str + Name for this manifest + fov_ids : list[str] - List of FOV IDs to include (e.g., ["FOV_001", "FOV_002"]) + List of FOV_ID values from Datasets table (e.g., ["plate1_B_3_0", "plate1_B_3_1"]) version : str Semantic version (e.g., "0.0.1", "0.1.0", "1.0.0") - REQUIRED - forces conscious versioning purpose : str - Purpose of this dataset ("training", "validation", "test") + Purpose of this manifest ("training", "validation", "test") + project_name : str | None + Project Name (e.g OrganelleBox, DynaCLR, etc.) description : str | None Human-readable description Returns ------- str - Airtable dataset record ID + Airtable manifest record ID Examples -------- >>> registry.create_manifest_from_datasets( - ... dataset_name="RPE1_clean_wells", - ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], + ... manifest_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", + ... project_name="OrganelleBox", + ... fov_ids=["2024_11_07_A549_SEC61_DENV_B1_0", "2024_11_07_A549_SEC61_DENV_B1_1"], ... version="0.0.1", - ... description="High-quality FOVs from wells B3-B4" + ... purpose="training", + ... project_name="OrganelleBox", + ... description="High-quality dataset records from wells B3-B4" ... ) """ # Validate semantic version format @@ -168,255 +191,258 @@ def create_manifest_from_datasets( f"Version must be semantic version format (e.g., '0.0.1', '1.0.0'), got: '{version}'" ) - # Check if dataset with same name + version exists (use DataFrame) - df_datasets = self.list_datasets() - - if len(df_datasets) > 0: - existing = df_datasets[ - (df_datasets["name"] == dataset_name) - & (df_datasets["version"] == version) + # Check if manifest with same name + version exists (use DataFrame) + df_manifests = self.list_manifests() + + # Only check for duplicates if table is not empty and has required columns + if ( + len(df_manifests) > 0 + and "name" in df_manifests.columns + and "version" in df_manifests.columns + ): + existing = df_manifests[ + (df_manifests["name"] == manifest_name) + & (df_manifests["version"] == version) ] if len(existing) > 0: raise ValueError( - f"Dataset '{dataset_name}' version '{version}' already exists. " + f"Manifest '{manifest_name}' version '{version}' already exists. " f"To create a new version, increment the version number (e.g., '0.0.2')." ) # Show existing versions (helpful feedback) - existing_versions = df_datasets[df_datasets["name"] == dataset_name] + existing_versions = df_manifests[df_manifests["name"] == manifest_name] if len(existing_versions) > 0: versions = sorted(existing_versions["version"].tolist()) - print(f"ℹ Dataset '{dataset_name}' existing versions: {versions}") + print(f"ℹ Manifest '{manifest_name}' existing versions: {versions}") print(f" Creating new version: '{version}'") # Get Airtable record IDs for these FOV IDs (ensure unique) - fov_record_ids = [] + dataset_record_ids = [] seen_fov_ids = set() for fov_id in fov_ids: if fov_id in seen_fov_ids: continue # Skip duplicates - formula = f"{{fov_id}}='{fov_id}'" - records = self.fovs_table.all(formula=formula) + formula = f"{{FOV_ID}}='{fov_id}'" + records = self.datasets_table.all(formula=formula) if records: - fov_record_ids.append(records[0]["id"]) + dataset_record_ids.append(records[0]["id"]) seen_fov_ids.add(fov_id) else: - raise ValueError(f"FOV '{fov_id}' not found in FOVs table") + raise ValueError(f"FOV ID '{fov_id}' not found in Datasets table") # Remove any duplicate record IDs (shouldn't happen, but extra safety) - fov_record_ids = list(dict.fromkeys(fov_record_ids)) + dataset_record_ids = list(dict.fromkeys(dataset_record_ids)) - # Create dataset record - dataset_record = { - "name": dataset_name, - "fovs": fov_record_ids, # Linked records (unique) + # Create manifest record + manifest_record = { + "name": manifest_name, + "datasets": dataset_record_ids, # Linked records (unique) "version": version, # Semantic version (required) "purpose": purpose, - "created_date": datetime.now().isoformat(), "created_by": getpass.getuser(), - "num_fovs": len(fov_record_ids), } - + if project_name: + manifest_record["project"] = project_name if description: - dataset_record["description"] = description + manifest_record["description"] = description - created = self.datasets_table.create(dataset_record) + created = self.manifests_table.create(manifest_record) return created["id"] - def create_dataset_from_query( + def create_manifest_from_query( self, - dataset_name: str, + manifest_name: str, version: str, - plate_name: str | None = None, + source_dataset: str | None = None, well_ids: list[str] | None = None, - quality: str | None = None, exclude_fov_ids: list[str] | None = None, **kwargs, ) -> str: """ - Create a dataset by filtering FOVs with pandas. + Create a manifest by filtering dataset records with pandas. Parameters ---------- - dataset_name : str - Name for this dataset + manifest_name : str + Name for this manifest version : str Semantic version (e.g., "0.0.1") - REQUIRED - plate_name : str | None - Filter by plate name + source_dataset : str | None + Filter by source dataset name (from 'Dataset' field) well_ids : list[str] | None Filter by well identifiers (e.g., ["B_3", "B_4"]) - quality : str | None - Filter by quality ("Good", "Poor", etc.) exclude_fov_ids : list[str] | None - FOV IDs to exclude + FOV_ID values to exclude **kwargs Additional arguments for create_manifest_from_datasets Returns ------- str - Airtable dataset record ID + Airtable manifest record ID Examples -------- - >>> # Create dataset from all good-quality FOVs in specific wells - >>> registry.create_dataset_from_query( - ... dataset_name="RPE1_infection_training", + >>> # Create manifest from specific wells in a dataset + >>> registry.create_manifest_from_query( + ... manifest_name="RPE1_infection_training", ... version="0.0.1", - ... plate_name="RPE1_plate1", + ... source_dataset="RPE1_plate1", ... well_ids=["B_3", "B_4"], - ... quality="Good", ... exclude_fov_ids=["RPE1_plate1_B_3_2"] ... ) """ - # Get all FOVs as DataFrame - df = self.list_fovs() + # Get all dataset records as DataFrame + df = self.list_datasets() # Apply filters with pandas - if plate_name: - df = df[df["plate_name"] == plate_name] - - if quality: - df = df[df["quality"] == quality] + if source_dataset: + df = df[df["Dataset"] == source_dataset] if well_ids: - df = df[df["well_id"].isin(well_ids)] + df = df[df["Well ID"].isin(well_ids)] # Exclude specified FOVs if exclude_fov_ids: - df = df[~df["fov_id"].isin(exclude_fov_ids)] + df = df[~df["FOV_ID"].isin(exclude_fov_ids)] - fov_ids = df["fov_id"].tolist() + fov_ids = df["FOV_ID"].tolist() - print(f"Found {len(fov_ids)} FOVs matching criteria") + print(f"Found {len(fov_ids)} dataset records matching criteria") - # Create dataset + # Create manifest return self.create_manifest_from_datasets( - dataset_name=dataset_name, version=version, fov_ids=fov_ids, **kwargs + manifest_name=manifest_name, version=version, fov_ids=fov_ids, **kwargs ) - def get_dataset_fov_paths( - self, dataset_name: str, version: str | None = None + def get_manifest_data_paths( + self, manifest_name: str, version: str | None = None ) -> list[str]: """ - Get list of FOV paths for a dataset. + Get list of data paths for a manifest. Parameters ---------- - dataset_name : str - Dataset name + manifest_name : str + Manifest name version : str | None Specific version (if None, returns latest) Returns ------- list[str] - List of FOV paths + List of data paths Examples -------- - >>> paths = registry.get_dataset_fov_paths("RPE1_infection_v2") + >>> paths = registry.get_manifest_data_paths("RPE1_infection_v2") >>> print(paths) >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] """ - # Get all datasets as DataFrame - df_datasets = self.list_datasets() + # Get all manifests as DataFrame + df_manifests = self.list_manifests() - if len(df_datasets) == 0: - raise ValueError(f"Dataset '{dataset_name}' not found") + if len(df_manifests) == 0 or "name" not in df_manifests.columns: + raise ValueError(f"Manifest '{manifest_name}' not found (table is empty)") # Filter by name - filtered = df_datasets[df_datasets["name"] == dataset_name] + filtered = df_manifests[df_manifests["name"] == manifest_name] if len(filtered) == 0: - raise ValueError(f"Dataset '{dataset_name}' not found") + raise ValueError(f"Manifest '{manifest_name}' not found") # Filter by version if specified, otherwise get latest if version: + if "version" not in df_manifests.columns: + raise ValueError("Version field not found in Manifest table") filtered = filtered[filtered["version"] == version] if len(filtered) == 0: raise ValueError( - f"Dataset '{dataset_name}' version '{version}' not found" + f"Manifest '{manifest_name}' version '{version}' not found" ) else: - # Get latest version (sort by created_date) - filtered = filtered.sort_values("created_date", ascending=False) + # Get latest version (sort by created_time if column exists) + if "created_time" in filtered.columns: + filtered = filtered.sort_values("created_time", ascending=False) - # Get the first (or only) matching dataset - dataset_row = filtered.iloc[0] + # Get the first (or only) matching manifest + manifest_row = filtered.iloc[0] - # Get linked FOV record IDs - fov_record_ids = dataset_row.get("fovs", []) - if not fov_record_ids or len(fov_record_ids) == 0: + # Get linked dataset record IDs + dataset_record_ids = manifest_row.get("datasets", []) + if not dataset_record_ids or len(dataset_record_ids) == 0: return [] - # Fetch FOV paths - fov_paths = [] - for fov_id in fov_record_ids: - fov_record = self.fovs_table.get(fov_id) - fov_paths.append(fov_record["fields"]["fov_path"]) + # Fetch data paths + data_paths = [] + for dataset_id in dataset_record_ids: + dataset_record = self.datasets_table.get(dataset_id) + data_paths.append(dataset_record["fields"]["Data path"]) - return fov_paths + return data_paths - def get_dataset( - self, dataset_name: str, version: str | None = None + def get_manifest( + self, manifest_name: str, version: str | None = None ) -> dict[str, Any]: """ - Get full dataset information including FOV details. + Get full manifest information including data paths. Parameters ---------- - dataset_name : str - Dataset name + manifest_name : str + Manifest name version : str | None Specific version Returns ------- dict - Dataset info with FOV paths and metadata + Manifest info with data paths and metadata """ - # Get all datasets as DataFrame - df_datasets = self.list_datasets() + # Get all manifests as DataFrame + df_manifests = self.list_manifests() - if len(df_datasets) == 0: - raise ValueError(f"Dataset '{dataset_name}' not found") + if len(df_manifests) == 0 or "name" not in df_manifests.columns: + raise ValueError(f"Manifest '{manifest_name}' not found (table is empty)") # Filter by name - filtered = df_datasets[df_datasets["name"] == dataset_name] + filtered = df_manifests[df_manifests["name"] == manifest_name] if len(filtered) == 0: - raise ValueError(f"Dataset '{dataset_name}' not found") + raise ValueError(f"Manifest '{manifest_name}' not found") # Filter by version if specified, otherwise get latest if version: + if "version" not in df_manifests.columns: + raise ValueError("Version field not found in Manifest table") filtered = filtered[filtered["version"] == version] if len(filtered) == 0: raise ValueError( - f"Dataset '{dataset_name}' version '{version}' not found" + f"Manifest '{manifest_name}' version '{version}' not found" ) else: - # Get latest version (sort by created_date) - filtered = filtered.sort_values("created_date", ascending=False) + # Get latest version (sort by created_time if column exists) + if "created_time" in filtered.columns: + filtered = filtered.sort_values("created_time", ascending=False) - # Get the first (or only) matching dataset - dataset_row = filtered.iloc[0] - dataset = dataset_row.to_dict() + # Get the first (or only) matching manifest + manifest_row = filtered.iloc[0] + manifest = manifest_row.to_dict() - # Add FOV paths - dataset["fov_paths"] = self.get_dataset_fov_paths(dataset_name, version) + # Add data paths + manifest["data_paths"] = self.get_manifest_data_paths(manifest_name, version) - return dataset + return manifest - def list_datasets( + def list_manifests( self, purpose: str | None = None, as_dataframe: bool = True ) -> pd.DataFrame | list[dict]: """ - List all datasets. + List all manifests. Parameters ---------- @@ -428,22 +454,30 @@ def list_datasets( Returns ------- pd.DataFrame | list[dict] - Dataset records as DataFrame or list of dicts + Manifest records as DataFrame or list of dicts Examples -------- - >>> registry.list_datasets(purpose="training") + >>> registry.list_manifests(purpose="training") >>> # Returns DataFrame with columns: id, name, version, purpose, ... """ - # Fetch all datasets (sorted by most recent first) - records = self.datasets_table.all(sort=["-created_date"]) + # Fetch all manifests (try sorting, but don't fail if field doesn't exist) + try: + records = self.manifests_table.all(sort=["-created_time"]) + except Exception: + # If sort fails (field might not exist), fetch without sorting + records = self.manifests_table.all() + data = [{"id": r["id"], **r["fields"]} for r in records] # Convert to DataFrame or list if as_dataframe: df = pd.DataFrame(data) + # Sort by created_time if column exists + if len(df) > 0 and "created_time" in df.columns: + df = df.sort_values("created_time", ascending=False) # Filter by purpose if specified - if purpose and len(df) > 0: + if purpose and len(df) > 0 and "purpose" in df.columns: df = df[df["purpose"] == purpose] return df else: @@ -452,9 +486,9 @@ def list_datasets( data = [d for d in data if d.get("purpose") == purpose] return data - def list_fovs(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: + def list_datasets(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: """ - Get all FOVs as a DataFrame (or list of dicts). + Get all dataset records (FOVs) as a DataFrame (or list of dicts). Use pandas for filtering - much simpler and more powerful than building Airtable formulas. @@ -467,30 +501,48 @@ def list_fovs(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: Returns ------- pd.DataFrame | list[dict] - All FOV records + All dataset records Examples -------- - >>> # Get all FOVs - >>> df = registry.list_fovs() + >>> # Get all datasets + >>> df = registry.list_datasets() >>> >>> # Filter with pandas (simple and powerful!) - >>> filtered = df[df['plate_name'] == 'RPE1_plate1'] - >>> filtered = df[df['quality'] == 'Good'] - >>> filtered = df[df['row'] == 'B'] - >>> filtered = df[df['row'].isin(['B', 'C'])] - >>> filtered = df[(df['row'] == 'B') & (df['column'] == '3')] - >>> - >>> # Exclude FOVs - >>> filtered = df[~df['fov_id'].isin(['RPE1_plate1_B_3_2'])] + >>> filtered = df[df['Dataset'] == 'RPE1_plate1'] + >>> filtered = df[df['Well ID'].isin(['B_3', 'B_4'])] + >>> filtered = df[~df['FOV_ID'].isin(['RPE1_plate1_B_3_2'])] >>> >>> # Group and analyze - >>> df.groupby('plate_name').size() - >>> df.groupby(['row', 'column']).size() + >>> df.groupby('Dataset').size() + >>> df.groupby('Well ID').size() """ - records = self.fovs_table.all() + records = self.datasets_table.all() data = [{"id": r["id"], **r["fields"]} for r in records] if as_dataframe: return pd.DataFrame(data) return data + + def delete_manifest(self, manifest_id: str) -> bool: + """ + Delete a manifest record from Airtable. + + Parameters + ---------- + manifest_id : str + Airtable record ID of the manifest to delete + + Returns + ------- + bool + True if deletion was successful + + Examples + -------- + >>> manifest_id = registry.create_manifest_from_datasets(...) + >>> registry.delete_manifest(manifest_id) + >>> print(f"Deleted manifest: {manifest_id}") + """ + self.manifests_table.delete(manifest_id) + return True diff --git a/viscy/airtable/manifests.py b/viscy/airtable/manifests.py index 141863e57..835de3d10 100644 --- a/viscy/airtable/manifests.py +++ b/viscy/airtable/manifests.py @@ -7,14 +7,12 @@ from pyairtable import Api -class AirtableDatasetRegistry: +class AirtableManifests: """ - Interface to Airtable for dataset registry. + Interface to Airtable for manifests. Airtable acts as source of truth for: - - Dataset paths on HPC - - Dataset versions and metadata - - Links between datasets and trained models + - Dataset manifests Parameters ---------- @@ -50,34 +48,35 @@ def __init__( self.api = Api(api_key) self.base_id = base_id - self.datasets_table = self.api.table(base_id, "Datasets") + self.manifests_table = self.api.table(base_id, "Manifest") self.models_table = self.api.table(base_id, "Models") def get_manifest(self, name: str, version: str | None = None) -> dict[str, Any]: """ - Retrieve dataset record from Airtable. + Retrieve manifest record from Airtable. Parameters ---------- name : str - Dataset name + Manifest name version : str | None - Specific version (e.g., "v2"). If None, returns latest. + Specific version (e.g., "0.0.1"). If None, returns latest. Returns ------- - TODO: typing for the headers in manifest + dict + Manifest record with fields from MANIFESTS_INDEX """ if version: formula = f"AND({{name}}='{name}', {{version}}='{version}')" else: formula = f"{{name}}='{name}'" - records = self.datasets_table.all(formula=formula, sort=["-created_date"]) + records = self.manifests_table.all(formula=formula, sort=["-created_time"]) if not records: raise ValueError( - f"Dataset '{name}' (version={version}) not found in Airtable" + f"Manifest '{name}' (version={version}) not found in Airtable" ) record = records[0] @@ -85,7 +84,7 @@ def get_manifest(self, name: str, version: str | None = None) -> dict[str, Any]: def log_model_training( self, - dataset_id: str, + manifest_id: str, mlflow_run_id: str, model_name: str | None = None, metrics: dict[str, float] | None = None, @@ -93,14 +92,14 @@ def log_model_training( trained_by: str | None = None, ) -> str: """ - Log that a model was trained using a dataset. + Log that a model was trained using a manifest. - Creates entry in Models table and updates Datasets table. + Creates entry in Models table and updates Manifests table. Parameters ---------- - dataset_id : str - Airtable record ID of dataset used + manifest_id : str + Airtable record ID of manifest used mlflow_run_id : str MLflow run ID for experiment tracking model_name : str | None @@ -120,7 +119,7 @@ def log_model_training( # Create model record model_record = { "model_name": model_name or f"model_{datetime.now():%Y%m%d_%H%M%S}", - "dataset": [dataset_id], # Link to dataset + "manifest": [manifest_id], # Link to manifest "mlflow_run_id": mlflow_run_id, "trained_date": datetime.now().isoformat(), } @@ -136,9 +135,9 @@ def log_model_training( created = self.models_table.create(model_record) - # Update dataset record to track usage - dataset = self.datasets_table.get(dataset_id) - models_trained_str = dataset["fields"].get("models_trained", "") + # Update manifest record to track usage + manifest = self.manifests_table.get(manifest_id) + models_trained_str = manifest["fields"].get("models_trained", "") # Handle models_trained as comma-separated string if models_trained_str: @@ -148,16 +147,16 @@ def log_model_training( else: new_models_str = mlflow_run_id - self.datasets_table.update( - dataset_id, + self.manifests_table.update( + manifest_id, {"models_trained": new_models_str, "last_used": datetime.now().isoformat()}, ) return created["id"] - def list_datasets(self, formula: str | None = None) -> list[dict]: + def list_manifests(self, formula: str | None = None) -> list[dict]: """ - List all datasets in registry. + List all manifests in registry. Parameters ---------- @@ -167,25 +166,25 @@ def list_datasets(self, formula: str | None = None) -> list[dict]: Returns ------- list[dict] - List of dataset records + List of manifest records """ - records = self.datasets_table.all(formula=formula, sort=["-created_date"]) + records = self.manifests_table.all(formula=formula, sort=["-created_time"]) return [{"id": r["id"], **r["fields"]} for r in records] - def get_models_for_dataset(self, dataset_id: str) -> list[dict]: + def get_models_for_manifest(self, manifest_id: str) -> list[dict]: """ - Get all models trained on a specific dataset. + Get all models trained on a specific manifest. Parameters ---------- - dataset_id : str - Airtable record ID of dataset + manifest_id : str + Airtable record ID of manifest Returns ------- list[dict] List of model records """ - formula = f"FIND('{dataset_id}', ARRAYJOIN({{dataset}}))" + formula = f"FIND('{manifest_id}', ARRAYJOIN({{manifest}}))" records = self.models_table.all(formula=formula, sort=["-trained_date"]) return [{"id": r["id"], **r["fields"]} for r in records] From afac50ba7ba63215de23308ece095021a775299b Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 6 Jan 2026 15:25:31 -0800 Subject: [PATCH 09/18] merging database and manifest into database.py --- .../airtable/filter_n_create_dataset_tag.py | 18 +- viscy/airtable/__init__.py | 0 viscy/airtable/callbacks.py | 6 +- viscy/airtable/{datasets.py => database.py} | 216 ++++++++++++++++-- viscy/airtable/manifests.py | 190 --------------- 5 files changed, 207 insertions(+), 223 deletions(-) create mode 100644 viscy/airtable/__init__.py rename viscy/airtable/{datasets.py => database.py} (71%) delete mode 100644 viscy/airtable/manifests.py diff --git a/examples/airtable/filter_n_create_dataset_tag.py b/examples/airtable/filter_n_create_dataset_tag.py index 840d33e64..9372a7bea 100644 --- a/examples/airtable/filter_n_create_dataset_tag.py +++ b/examples/airtable/filter_n_create_dataset_tag.py @@ -2,11 +2,11 @@ # %% -from viscy.airtable.datasets import AirtableDatasets +from viscy.airtable.database import AirtableManager # BASE_ID = os.getenv("AIRTABLE_BASE_ID") BASE_ID = "app8vqaoWyOwa0sB5" -registry = AirtableDatasets(base_id=BASE_ID) +airtable_db = AirtableManager(base_id=BASE_ID) # %% # EXAMPLE 1: Get all dataset records as DataFrame and explore @@ -14,7 +14,7 @@ print("Getting all dataset records as DataFrame") print("=" * 70) -df_datasets = registry.list_datasets() +df_datasets = airtable_db.list_datasets() print(f"\nTotal dataset records: {len(df_datasets)}") print("\nDataFrame columns:") print(df_datasets.columns.tolist()) @@ -28,7 +28,7 @@ print("=" * 70) # Get all dataset records as DataFrame -df = registry.list_datasets() +df = airtable_db.list_datasets() # Filter with pandas - simple and powerful! filtered = df[ @@ -44,7 +44,7 @@ fov_ids = filtered["FOV_ID"].tolist() try: - manifest_id = registry.create_manifest_from_datasets( + manifest_id = airtable_db.create_manifest_from_datasets( manifest_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", fov_ids=fov_ids, version="0.0.1", # Semantic versioning @@ -58,7 +58,7 @@ # %% # Delete the manifest entry demo -registry.delete_manifest(manifest_id) +airtable_db.delete_manifest(manifest_id) print(f"Deleted manifest: {manifest_id}") # %% @@ -67,7 +67,7 @@ print("Group by dataset and show summary") print("=" * 70) -df_all = registry.list_datasets() +df_all = airtable_db.list_datasets() grouped = df_all.groupby("Dataset") @@ -82,7 +82,7 @@ print("Filter: Multiple specific wells") print("=" * 70) -df = registry.list_datasets() +df = airtable_db.list_datasets() # Filter for specific wells from a dataset filtered = df[ @@ -104,7 +104,7 @@ print("Summary Statistics") print("=" * 70) -df = registry.list_datasets() +df = airtable_db.list_datasets() print("\nDataset records per source dataset:") print(df.groupby("Dataset").size()) diff --git a/viscy/airtable/__init__.py b/viscy/airtable/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/viscy/airtable/callbacks.py b/viscy/airtable/callbacks.py index 1eeb411e9..e6093b29f 100644 --- a/viscy/airtable/callbacks.py +++ b/viscy/airtable/callbacks.py @@ -6,7 +6,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.callbacks import Callback -from viscy.airtable.manifests import AirtableManifests +from viscy.airtable.datasets import AirtableManager class AirtableLoggingCallback(Callback): @@ -59,7 +59,7 @@ def __init__( log_metrics: bool = False, ): super().__init__() - self.registry = AirtableManifests(base_id=base_id) + self.airtable_db = AirtableManager(base_id=base_id) self.manifest_id = manifest_id self.model_name = model_name self.log_metrics = log_metrics @@ -103,7 +103,7 @@ def on_fit_end(self, trainer: Trainer, pl_module: Any) -> None: # Log to Airtable try: - model_id = self.registry.log_model_training( + model_id = self.airtable_db.log_model_training( manifest_id=self.manifest_id, mlflow_run_id=run_id or "unknown", model_name=model_name, diff --git a/viscy/airtable/datasets.py b/viscy/airtable/database.py similarity index 71% rename from viscy/airtable/datasets.py rename to viscy/airtable/database.py index a76254256..d6656be4a 100644 --- a/viscy/airtable/datasets.py +++ b/viscy/airtable/database.py @@ -1,7 +1,8 @@ -"""FOV-level dataset registry with Airtable.""" +"""FOV-level dataset airtable_db with Airtable.""" import getpass import os +from datetime import datetime from typing import Any import pandas as pd @@ -48,15 +49,15 @@ ] -class AirtableDatasets: +class AirtableManager: """ - Interface to Airtable for FOV-level dataset management. + Unified interface to Airtable for dataset, manifest, and model management. Use this to: - Register individual FOVs from HCS plates - - Create dataset "tags" (collections of FOVs) - - Query which FOVs are in each dataset - - Generate training configs from dataset tags + - Create and manage dataset manifests (collections of FOVs) + - Track model training on manifests + - Query datasets, manifests, and models Parameters ---------- @@ -67,18 +68,25 @@ class AirtableDatasets: Examples -------- - >>> registry = AirtableDatasets(base_id="appXXXXXXXXXXXXXX") + >>> airtable_db = AirtableManager(base_id="appXXXXXXXXXXXXXX") >>> - >>> # Create dataset from FOV selection - >>> registry.create_manifest_from_datasets( - ... dataset_name="RPE1_infection_v2", + >>> # Create manifest from FOV selection + >>> manifest_id = airtable_db.create_manifest_from_datasets( + ... manifest_name="RPE1_infection_v2", ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], - ... version="v2", + ... version="0.0.1", ... purpose="training" ... ) >>> - >>> # Get all FOV paths for a dataset - >>> fov_paths = registry.get_dataset_fov_paths("RPE1_infection_v2") + >>> # Track model training + >>> airtable_db.log_model_training( + ... manifest_id=manifest_id, + ... mlflow_run_id="run_123", + ... model_name="my_model", + ... ) + >>> + >>> # Get all FOV paths for a manifest + >>> fov_paths = airtable_db.get_manifest_data_paths("RPE1_infection_v2") >>> print(fov_paths) >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] """ @@ -173,7 +181,7 @@ def create_manifest_from_datasets( Examples -------- - >>> registry.create_manifest_from_datasets( + >>> airtable_db.create_manifest_from_datasets( ... manifest_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", ... project_name="OrganelleBox", ... fov_ids=["2024_11_07_A549_SEC61_DENV_B1_0", "2024_11_07_A549_SEC61_DENV_B1_1"], @@ -211,7 +219,6 @@ def create_manifest_from_datasets( f"To create a new version, increment the version number (e.g., '0.0.2')." ) - # Show existing versions (helpful feedback) existing_versions = df_manifests[df_manifests["name"] == manifest_name] if len(existing_versions) > 0: versions = sorted(existing_versions["version"].tolist()) @@ -288,7 +295,7 @@ def create_manifest_from_query( Examples -------- >>> # Create manifest from specific wells in a dataset - >>> registry.create_manifest_from_query( + >>> airtable_db.create_manifest_from_query( ... manifest_name="RPE1_infection_training", ... version="0.0.1", ... source_dataset="RPE1_plate1", @@ -339,7 +346,7 @@ def get_manifest_data_paths( Examples -------- - >>> paths = registry.get_manifest_data_paths("RPE1_infection_v2") + >>> paths = airtable_db.get_manifest_data_paths("RPE1_infection_v2") >>> print(paths) >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] """ @@ -458,7 +465,7 @@ def list_manifests( Examples -------- - >>> registry.list_manifests(purpose="training") + >>> airtable_db.list_manifests(purpose="training") >>> # Returns DataFrame with columns: id, name, version, purpose, ... """ # Fetch all manifests (try sorting, but don't fail if field doesn't exist) @@ -506,7 +513,7 @@ def list_datasets(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: Examples -------- >>> # Get all datasets - >>> df = registry.list_datasets() + >>> df = airtable_db.list_datasets() >>> >>> # Filter with pandas (simple and powerful!) >>> filtered = df[df['Dataset'] == 'RPE1_plate1'] @@ -540,9 +547,176 @@ def delete_manifest(self, manifest_id: str) -> bool: Examples -------- - >>> manifest_id = registry.create_manifest_from_datasets(...) - >>> registry.delete_manifest(manifest_id) + >>> manifest_id = airtable_db.create_manifest_from_datasets(...) + >>> airtable_db.delete_manifest(manifest_id) >>> print(f"Deleted manifest: {manifest_id}") """ self.manifests_table.delete(manifest_id) return True + + def log_model_training( + self, + manifest_id: str, + mlflow_run_id: str, + model_name: str | None = None, + metrics: dict[str, float] | None = None, + checkpoint_path: str | None = None, + trained_by: str | None = None, + ) -> str: + """ + Log that a model was trained using a manifest. + + Creates entry in Models table and updates Manifest table. + + Parameters + ---------- + manifest_id : str + Airtable record ID of manifest used + mlflow_run_id : str + MLflow run ID for experiment tracking + model_name : str | None + Human-readable model name + metrics : dict | None + Training metrics (e.g., {"accuracy": 0.89, "f1_score": 0.92}) + checkpoint_path : str | None + Path to saved model checkpoint + trained_by : str | None + Username of person who trained the model + + Returns + ------- + str + Airtable record ID of created model entry + + Examples + -------- + >>> manifest_id = airtable_db.create_manifest_from_datasets(...) + >>> model_id = airtable_db.log_model_training( + ... manifest_id=manifest_id, + ... mlflow_run_id="run_abc123", + ... model_name="sec61_model_v1", + ... metrics={"val_loss": 0.15}, + ... trained_by="researcher_name" + ... ) + """ + # Create model record + model_record = { + "model_name": model_name or f"model_{datetime.now():%Y%m%d_%H%M%S}", + "manifest": [manifest_id], # Link to manifest + "mlflow_run_id": mlflow_run_id, + "trained_date": datetime.now().isoformat(), + } + + if metrics: + model_record.update(metrics) + + if checkpoint_path: + model_record["checkpoint_path"] = checkpoint_path + + if trained_by: + model_record["trained_by"] = trained_by + + created = self.models_table.create(model_record) + + # Update manifest record to track usage + manifest = self.manifests_table.get(manifest_id) + models_trained_str = manifest["fields"].get("models_trained", "") + + # Handle models_trained as comma-separated string + if models_trained_str: + models_list = [m.strip() for m in models_trained_str.split(",")] + models_list.append(mlflow_run_id) + new_models_str = ", ".join(models_list) + else: + new_models_str = mlflow_run_id + + self.manifests_table.update( + manifest_id, + {"models_trained": new_models_str, "last_used": datetime.now().isoformat()}, + ) + + return created["id"] + + def get_models_for_manifest( + self, manifest_id: str, as_dataframe: bool = True + ) -> pd.DataFrame | list[dict]: + """ + Get all models trained on a specific manifest. + + Parameters + ---------- + manifest_id : str + Airtable record ID of manifest + as_dataframe : bool + If True, return pandas DataFrame. If False, return list of dicts. + + Returns + ------- + pd.DataFrame | list[dict] + Model records as DataFrame or list of dicts + + Examples + -------- + >>> models_df = airtable_db.get_models_for_manifest(manifest_id) + >>> print(models_df[["model_name", "mlflow_run_id", "trained_date"]]) + """ + # Get all models as DataFrame + records = self.models_table.all() + data = [{"id": r["id"], **r["fields"]} for r in records] + + if as_dataframe: + df = pd.DataFrame(data) + if len(df) == 0: + return df + + # Filter by manifest_id using pandas + # The 'manifest' field contains a list of linked record IDs + df_filtered = df[ + df["manifest"].apply( + lambda x: manifest_id in x if isinstance(x, list) else False + ) + ] + + # Sort by trained_date if column exists + if "trained_date" in df_filtered.columns: + df_filtered = df_filtered.sort_values("trained_date", ascending=False) + + return df_filtered + else: + # Filter list + filtered = [d for d in data if manifest_id in d.get("manifest", [])] + return filtered + + def list_models(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: + """ + List all models in the airtable_db. + + Parameters + ---------- + as_dataframe : bool + If True, return pandas DataFrame. If False, return list of dicts. + + Returns + ------- + pd.DataFrame | list[dict] + All model records + + Examples + -------- + >>> models_df = airtable_db.list_models() + >>> print(models_df.groupby("model_name").size()) + """ + records = self.models_table.all() + data = [{"id": r["id"], **r["fields"]} for r in records] + + if as_dataframe: + df = pd.DataFrame(data) + # Sort by trained_date if column exists + if len(df) > 0 and "trained_date" in df.columns: + df = df.sort_values("trained_date", ascending=False) + return df + return data + + +# Backward compatibility alias +AirtableDatasets = AirtableManager diff --git a/viscy/airtable/manifests.py b/viscy/airtable/manifests.py deleted file mode 100644 index 835de3d10..000000000 --- a/viscy/airtable/manifests.py +++ /dev/null @@ -1,190 +0,0 @@ -"""Dataset registry integration with Airtable for experiment tracking.""" - -import os -from datetime import datetime -from typing import Any - -from pyairtable import Api - - -class AirtableManifests: - """ - Interface to Airtable for manifests. - - Airtable acts as source of truth for: - - Dataset manifests - - Parameters - ---------- - base_id : str - Airtable base ID - api_key : str | None - Airtable API key. If None, reads from AIRTABLE_API_KEY env var. - - Examples - -------- - >>> registry = AirtableDatasetRegistry(base_id="appXXXXXXXXXXXXXX") - >>> - >>> # Get dataset info - >>> dataset = registry.get_manifest("rpe1_fucci_embeddings", version="v2") - >>> print(dataset['hpc_path']) - >>> - >>> # Record that a model was trained with this dataset - >>> registry.log_model_training( - ... dataset_id=dataset['id'], - ... mlflow_run_id="run_123", - ... metrics={"accuracy": 0.89} - ... ) - """ - - def __init__( - self, - base_id: str, - api_key: str | None = None, - ): - api_key = api_key or os.getenv("AIRTABLE_API_KEY") - if not api_key: - raise ValueError("Airtable API key required (set AIRTABLE_API_KEY)") - - self.api = Api(api_key) - self.base_id = base_id - self.manifests_table = self.api.table(base_id, "Manifest") - self.models_table = self.api.table(base_id, "Models") - - def get_manifest(self, name: str, version: str | None = None) -> dict[str, Any]: - """ - Retrieve manifest record from Airtable. - - Parameters - ---------- - name : str - Manifest name - version : str | None - Specific version (e.g., "0.0.1"). If None, returns latest. - - Returns - ------- - dict - Manifest record with fields from MANIFESTS_INDEX - """ - if version: - formula = f"AND({{name}}='{name}', {{version}}='{version}')" - else: - formula = f"{{name}}='{name}'" - - records = self.manifests_table.all(formula=formula, sort=["-created_time"]) - - if not records: - raise ValueError( - f"Manifest '{name}' (version={version}) not found in Airtable" - ) - - record = records[0] - return {"id": record["id"], **record["fields"]} - - def log_model_training( - self, - manifest_id: str, - mlflow_run_id: str, - model_name: str | None = None, - metrics: dict[str, float] | None = None, - checkpoint_path: str | None = None, - trained_by: str | None = None, - ) -> str: - """ - Log that a model was trained using a manifest. - - Creates entry in Models table and updates Manifests table. - - Parameters - ---------- - manifest_id : str - Airtable record ID of manifest used - mlflow_run_id : str - MLflow run ID for experiment tracking - model_name : str | None - Human-readable model name - metrics : dict | None - Training metrics (accuracy, f1_score, etc.) - checkpoint_path : str | None - Path to saved model checkpoint - trained_by : str | None - Username of person who trained the model - - Returns - ------- - str - Airtable record ID of created model entry - """ - # Create model record - model_record = { - "model_name": model_name or f"model_{datetime.now():%Y%m%d_%H%M%S}", - "manifest": [manifest_id], # Link to manifest - "mlflow_run_id": mlflow_run_id, - "trained_date": datetime.now().isoformat(), - } - - if metrics: - model_record.update(metrics) - - if checkpoint_path: - model_record["checkpoint_path"] = checkpoint_path - - if trained_by: - model_record["trained_by"] = trained_by - - created = self.models_table.create(model_record) - - # Update manifest record to track usage - manifest = self.manifests_table.get(manifest_id) - models_trained_str = manifest["fields"].get("models_trained", "") - - # Handle models_trained as comma-separated string - if models_trained_str: - models_list = [m.strip() for m in models_trained_str.split(",")] - models_list.append(mlflow_run_id) - new_models_str = ", ".join(models_list) - else: - new_models_str = mlflow_run_id - - self.manifests_table.update( - manifest_id, - {"models_trained": new_models_str, "last_used": datetime.now().isoformat()}, - ) - - return created["id"] - - def list_manifests(self, formula: str | None = None) -> list[dict]: - """ - List all manifests in registry. - - Parameters - ---------- - formula : str | None - Optional Airtable formula for filtering - - Returns - ------- - list[dict] - List of manifest records - """ - records = self.manifests_table.all(formula=formula, sort=["-created_time"]) - return [{"id": r["id"], **r["fields"]} for r in records] - - def get_models_for_manifest(self, manifest_id: str) -> list[dict]: - """ - Get all models trained on a specific manifest. - - Parameters - ---------- - manifest_id : str - Airtable record ID of manifest - - Returns - ------- - list[dict] - List of model records - """ - formula = f"FIND('{manifest_id}', ARRAYJOIN({{manifest}}))" - records = self.models_table.all(formula=formula, sort=["-trained_date"]) - return [{"id": r["id"], **r["fields"]} for r in records] From 1487540ff269ae0b3f9da4c44465ae992d0e446a Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 6 Jan 2026 15:27:42 -0800 Subject: [PATCH 10/18] remove backwards compatibility --- viscy/airtable/database.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/viscy/airtable/database.py b/viscy/airtable/database.py index d6656be4a..5545ea96c 100644 --- a/viscy/airtable/database.py +++ b/viscy/airtable/database.py @@ -716,7 +716,3 @@ def list_models(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: df = df.sort_values("trained_date", ascending=False) return df return data - - -# Backward compatibility alias -AirtableDatasets = AirtableManager From cc3203d9acf2e653c87cb1484956bef2a656d488 Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Tue, 6 Jan 2026 17:14:04 -0800 Subject: [PATCH 11/18] added get_dataset_paths --- .../airtable/get_dataset_paths_example.py | 47 +++++ viscy/airtable/database.py | 179 ++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 examples/airtable/get_dataset_paths_example.py diff --git a/examples/airtable/get_dataset_paths_example.py b/examples/airtable/get_dataset_paths_example.py new file mode 100644 index 000000000..1846a6a74 --- /dev/null +++ b/examples/airtable/get_dataset_paths_example.py @@ -0,0 +1,47 @@ +"""Example usage of get_dataset_paths with Manifest and ManifestDataset dataclasses.""" + +# %% +from viscy.airtable.database import AirtableManager + +BASE_ID = "app8vqaoWyOwa0sB5" +airtable_db = AirtableManager(base_id=BASE_ID) + +# %% +# Fetch manifest from Airtable +manifest = airtable_db.get_dataset_paths( + manifest_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", + version="0.0.1", +) + +# %% +# Manifest properties +print("=== Manifest ===") +print(f"manifest.name: {manifest.name}") +print(f"manifest.version: {manifest.version}") +print(f"len(manifest): {len(manifest)} HCS plate(s)") +print(f"manifest.total_fovs: {manifest.total_fovs} FOVs") + +# %% +# Iterate over ManifestDataset objects (one per HCS plate) +print("\n=== ManifestDataset ===") +for ds in manifest: + print(f"ds.data_path: {ds.data_path}") + print(f"ds.tracks_path: {ds.tracks_path}") + print(f"len(ds): {len(ds)} FOVs") + print(f"ds.fov_names: {ds.fov_names[:3]}...") + print(f"ds.fov_paths: {ds.fov_paths[:2]}...") + print(f"ds.exists(): {ds.exists()}") + +# %% +# Validate paths exist (raises FileNotFoundError if not) +manifest.validate() +print("\nAll paths validated successfully!") + + +# %% +# List available manifests +print("=== Available Manifests ===") +df = airtable_db.list_manifests() +print(df[["name", "version", "purpose"]].dropna(subset=["name"]).to_string()) + +# %% diff --git a/viscy/airtable/database.py b/viscy/airtable/database.py index 5545ea96c..2d2b09402 100644 --- a/viscy/airtable/database.py +++ b/viscy/airtable/database.py @@ -2,12 +2,75 @@ import getpass import os +from dataclasses import dataclass from datetime import datetime +from pathlib import Path from typing import Any import pandas as pd +from natsort import natsorted from pyairtable import Api + +@dataclass +class ManifestDataset: + """ + Dataset paths for one HCS plate/zarr store. + + A manifest may contain multiple stores, each returned as a separate ManifestDataset. + """ + + data_path: str + tracks_path: str + fov_names: list[str] + + def __len__(self) -> int: + return len(self.fov_names) + + @property + def fov_paths(self) -> list[str]: + """Full paths to each FOV: {data_path}/{fov_name}.""" + return [f"{self.data_path}/{fov}" for fov in self.fov_names] + + def exists(self) -> bool: + """Check if data_path and tracks_path exist.""" + return Path(self.data_path).exists() and Path(self.tracks_path).exists() + + def validate(self) -> None: + """Raise FileNotFoundError if paths don't exist.""" + if not Path(self.data_path).exists(): + raise FileNotFoundError(f"Data path not found: {self.data_path}") + if not Path(self.tracks_path).exists(): + raise FileNotFoundError(f"Tracks path not found: {self.tracks_path}") + + +@dataclass +class Manifest: + """All datasets for a manifest, potentially across multiple HCS plates.""" + + name: str + version: str + datasets: list[ManifestDataset] + + def __iter__(self): + """Iterate over datasets.""" + return iter(self.datasets) + + def __len__(self): + """Total number of HCS plates.""" + return len(self.datasets) + + @property + def total_fovs(self) -> int: + """Total FOVs across all plates.""" + return sum(len(ds) for ds in self.datasets) + + def validate(self) -> None: + """Validate all dataset paths exist.""" + for ds in self.datasets: + ds.validate() + + # TODO: update the usage examples in the docstrings # TODO: update the headers to match the Airtable columns and potentially move to separate file # TODO: Convert these to Pydantic models, so we can easily dump and load from/to Airtable @@ -716,3 +779,119 @@ def list_models(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: df = df.sort_values("trained_date", ascending=False) return df return data + + def get_dataset_paths( + self, + manifest_name: str, + version: str, + ) -> Manifest: + """ + Get zarr store paths and FOV names for a manifest. + + Parameters + ---------- + manifest_name : str + Name of the manifest + version : str + Semantic version of the manifest + + Returns + ------- + Manifest + Manifest object containing list of ManifestDataset (one per HCS plate) + + Examples + -------- + >>> manifest = airtable_db.get_dataset_paths("my_manifest", "0.0.1") + >>> print(f"{manifest.name} v{manifest.version}: {manifest.total_fovs} FOVs") + + >>> # Use with TripletDataModule + >>> for ds in manifest: + ... data_module = TripletDataModule( + ... data_path=ds.data_path, + ... tracks_path=ds.tracks_path, + ... include_fov_names=ds.fov_names, + ... ) + """ + # Get manifest record IDs + dataset_record_ids = self._get_manifest_dataset_ids(manifest_name, version) + if not dataset_record_ids: + return Manifest(name=manifest_name, version=version, datasets=[]) + + dataset_records = [ + self.datasets_table.get(dataset_id)["fields"] + for dataset_id in dataset_record_ids + ] + + stores: dict[str, list[str]] = {} + for fields in dataset_records: + data_path = fields["Data path"] + fov_name = f"{fields['Well ID']}/{fields['FOV']}" + + if data_path not in stores: + stores[data_path] = [] + stores[data_path].append(fov_name) + + datasets = [ + ManifestDataset( + data_path=data_path, + tracks_path=self._derive_tracks_path(data_path), + fov_names=natsorted(fov_names), + ) + for data_path, fov_names in stores.items() + ] + + return Manifest(name=manifest_name, version=version, datasets=datasets) + + @staticmethod + def _derive_tracks_path(data_path: str) -> str: + """ + Derive tracks path from data path. + + Pattern: + - Data: {base}/2-assemble/{name}.zarr + - Tracks: {base}/1-preprocess/label-free/3-track/{name}_cropped.zarr + """ + # Replace 2-assemble with 1-preprocess/label-free/3-track + tracks_path = data_path.replace( + "/2-assemble/", "/1-preprocess/label-free/3-track/" + ) + # Replace .zarr with _cropped.zarr + if tracks_path.endswith(".zarr"): + tracks_path = tracks_path[:-5] + "_cropped.zarr" + return tracks_path + + def _get_manifest_dataset_ids(self, manifest_name: str, version: str) -> list[str]: + """Get linked dataset record IDs for a manifest.""" + df_manifests = self.list_manifests() + + if len(df_manifests) == 0 or "name" not in df_manifests.columns: + raise ValueError(f"Manifest '{manifest_name}' not found (table is empty)") + + filtered = df_manifests[df_manifests["name"] == manifest_name] + if len(filtered) == 0: + raise ValueError(f"Manifest '{manifest_name}' not found") + + if "version" not in df_manifests.columns: + raise ValueError("Version field not found in Manifest table") + + filtered = filtered[filtered["version"] == version] + if len(filtered) == 0: + raise ValueError( + f"Manifest '{manifest_name}' version '{version}' not found" + ) + + manifest_row = filtered.iloc[0] + dataset_record_ids = manifest_row.get("datasets", []) + + if not dataset_record_ids or len(dataset_record_ids) == 0: + return [] + + return dataset_record_ids + + def update_record( + self, + ): ... # inputs are the table, the record, maybe manifest, use the unique keys instead of the first column + + # TODO: to update the tracks path column + raise NotImplementedError("Not implemented yet") From 7a677cca7ba5f90e5c3c9dad6e5a7b16ee618f40 Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Wed, 7 Jan 2026 10:01:47 -0800 Subject: [PATCH 12/18] indent --- viscy/airtable/database.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/viscy/airtable/database.py b/viscy/airtable/database.py index 2d2b09402..de5af4fdc 100644 --- a/viscy/airtable/database.py +++ b/viscy/airtable/database.py @@ -891,7 +891,6 @@ def _get_manifest_dataset_ids(self, manifest_name: str, version: str) -> list[st def update_record( self, - ): ... # inputs are the table, the record, maybe manifest, use the unique keys instead of the first column - - # TODO: to update the tracks path column - raise NotImplementedError("Not implemented yet") + ): + # TODO: to update the tracks path column + raise NotImplementedError("Not implemented yet") From 403e0d3de97d41fe5c311b6cde5a7c0f949f43be Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 7 Jan 2026 13:59:58 -0800 Subject: [PATCH 13/18] manifesttripletdatamodule --- .../airtable/get_dataset_paths_example.py | 88 ++++ viscy/airtable/__init__.py | 15 + viscy/airtable/factory.py | 457 ++++++++++++++++++ 3 files changed, 560 insertions(+) create mode 100644 viscy/airtable/factory.py diff --git a/examples/airtable/get_dataset_paths_example.py b/examples/airtable/get_dataset_paths_example.py index 1846a6a74..fc8bc97ff 100644 --- a/examples/airtable/get_dataset_paths_example.py +++ b/examples/airtable/get_dataset_paths_example.py @@ -45,3 +45,91 @@ print(df[["name", "version", "purpose"]].dropna(subset=["name"]).to_string()) # %% +# ============================================================================= +# Create TripletDataModule from manifest using factory function +# ============================================================================= +from viscy.airtable.factory import create_triplet_datamodule_from_manifest + +# Create data module from manifest +dm = create_triplet_datamodule_from_manifest( + manifest=manifest, + source_channel=["Phase3D"], + z_range=(20, 21), + batch_size=1, + num_workers=1, + initial_yx_patch_size=(160, 160), + final_yx_patch_size=(160, 160), + return_negative=False, + time_interval=1, +) + +# %% +# Setup and inspect the data module +dm.setup("fit") +print("\n=== TripletDataModule from Manifest ===") +print(f"Data module type: {type(dm).__name__}") +print(f"Train samples: {len(dm.train_dataset)}") +print(f"Val samples: {len(dm.val_dataset)}") + +# %% +# ============================================================================= +# Alternative: ManifestTripletDataModule (Lightning Config Compatible) +# ============================================================================= +from viscy.airtable.factory import ManifestTripletDataModule + +# This class is designed for Lightning CLI and config files +# but can also be used directly in Python +dm_class = ManifestTripletDataModule( + base_id=BASE_ID, + manifest_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", + manifest_version="0.0.1", + source_channel=["Phase3D"], + z_range=(20, 21), + batch_size=1, + num_workers=1, + initial_yx_patch_size=(160, 160), + final_yx_patch_size=(160, 160), + return_negative=False, + time_interval=1, +) + +dm_class.setup("fit") +print("\n=== ManifestTripletDataModule (Class) ===") +print(f"Data module type: {type(dm_class).__name__}") +print(f"Train samples: {len(dm_class.train_dataset)}") +print(f"Val samples: {len(dm_class.val_dataset)}") +print("Note: This class is designed for Lightning config files!") + +# %% Visualize some of the images +import matplotlib.pyplot as plt +import torch + +img_stack = [] +for idx, batch in enumerate(dm.train_dataloader()): + img_stack.append(batch["anchor"][0, 0, 0]) + if idx >= 9: + break +img_stack = torch.stack(img_stack) +# %% +# Make subplot with 10 images +fig, axs = plt.subplots(2, 5, figsize=(10, 4)) +for i in range(10): + axs[i // 5, i % 5].imshow(img_stack[i], cmap="gray") + axs[i // 5, i % 5].axis("off") +plt.show() +# %% +# ============================================================================= +# Summary: When to use which approach +# ============================================================================= +print("\n=== Summary ===") +print("Use create_triplet_datamodule_from_manifest() when:") +print(" - Working in Python scripts or notebooks") +print(" - Manifest has multiple HCS plates (auto-combines them)") +print(" - Already have manifest object loaded") +print("") +print("Use ManifestTripletDataModule when:") +print(" - Working with Lightning CLI and config files") +print(" - Training with single-plate manifests") +print(" - Want clean YAML configuration") +print("") +print("See examples/airtable/manifest_config_example.yml for config usage") diff --git a/viscy/airtable/__init__.py b/viscy/airtable/__init__.py index e69de29bb..3343168cb 100644 --- a/viscy/airtable/__init__.py +++ b/viscy/airtable/__init__.py @@ -0,0 +1,15 @@ +"""Airtable integration for dataset management and tracking.""" + +from viscy.airtable.database import AirtableManager, Manifest, ManifestDataset +from viscy.airtable.factory import ( + ManifestTripletDataModule, + create_triplet_datamodule_from_manifest, +) + +__all__ = [ + "AirtableManager", + "Manifest", + "ManifestDataset", + "ManifestTripletDataModule", + "create_triplet_datamodule_from_manifest", +] diff --git a/viscy/airtable/factory.py b/viscy/airtable/factory.py new file mode 100644 index 000000000..df7634c2f --- /dev/null +++ b/viscy/airtable/factory.py @@ -0,0 +1,457 @@ +"""Factory functions for creating data modules from Airtable manifests.""" + +import os +from typing import Literal, Sequence + +from lightning.pytorch import LightningDataModule +from monai.transforms import MapTransform + +from viscy.airtable.database import AirtableManager, Manifest, ManifestDataset +from viscy.data.combined import BatchedConcatDataModule, CachedConcatDataModule +from viscy.data.triplet import TripletDataModule + + +def _extract_wells_from_fov_names(fov_names: list[str]) -> list[str]: + """ + Extract unique well IDs from FOV names. + + Parameters + ---------- + fov_names : list[str] + List of FOV names in format "Row/Column/FOV_idx" (e.g., "B/3/0") + + Returns + ------- + list[str] + Unique well IDs in format "Row/Column" (e.g., ["B/3", "C/4"]) + + Examples + -------- + >>> _extract_wells_from_fov_names(["B/3/0", "B/3/1", "C/4/2"]) + ['B/3', 'C/4'] + """ + wells = set() + for fov_name in fov_names: + # Split "B/3/0" -> ["B", "3", "0"] + parts = fov_name.split("/") + if len(parts) >= 2: + # Extract "B/3" + well_id = f"{parts[0]}/{parts[1]}" + wells.add(well_id) + else: + raise ValueError( + f"Invalid FOV name format: '{fov_name}'. " + f"Expected 'Row/Column/FOV_idx' (e.g., 'B/3/0')" + ) + + return sorted(list(wells)) + + +def create_triplet_datamodule_from_manifest( + manifest: Manifest | ManifestDataset, + source_channel: str | Sequence[str], + z_range: tuple[int, int], + *, + initial_yx_patch_size: tuple[int, int] = (512, 512), + final_yx_patch_size: tuple[int, int] = (224, 224), + split_ratio: float = 0.8, + batch_size: int = 16, + num_workers: int = 1, + normalizations: list[MapTransform] | None = None, + augmentations: list[MapTransform] | None = None, + augment_validation: bool = True, + caching: bool = False, + fit_include_wells: list[str] | None = None, + fit_exclude_fovs: list[str] | None = None, + predict_cells: bool = False, + include_fov_names: list[str] | None = None, + include_track_ids: list[int] | None = None, + time_interval: Literal["any"] | int = "any", + return_negative: bool = True, + persistent_workers: bool = False, + prefetch_factor: int | None = None, + pin_memory: bool = False, + z_window_size: int | None = None, + cache_pool_bytes: int = 0, + use_cached_concat: bool = False, +) -> LightningDataModule: + """ + Create TripletDataModule(s) from Airtable manifest. + + Automatically handles single or multiple HCS plates: + - Single plate: Returns TripletDataModule + - Multiple plates: Returns BatchedConcatDataModule or CachedConcatDataModule + + Parameters + ---------- + manifest : Manifest | ManifestDataset + Manifest from AirtableManager.get_dataset_paths() or single ManifestDataset + source_channel : str | Sequence[str] + Input channel name(s) - REQUIRED + z_range : tuple[int, int] + Range of valid z-slices - REQUIRED + initial_yx_patch_size : tuple[int, int] + YX size of initially sampled patch, default (512, 512) + final_yx_patch_size : tuple[int, int] + Output patch size after augmentation, default (224, 224) + split_ratio : float + Train/val split ratio, default 0.8 + batch_size : int + Batch size, default 16 + num_workers : int + Number of dataloader workers, default 1 + normalizations : list[MapTransform] | None + Normalization transforms + augmentations : list[MapTransform] | None + Augmentation transforms + augment_validation : bool + Apply augmentations to validation, default True + caching : bool + Cache dataset, default False + fit_include_wells : list[str] | None + Override manifest FOVs with specific wells (e.g., ["B/3", "C/4"]). + Takes precedence over manifest.fov_names + fit_exclude_fovs : list[str] | None + Exclude specific FOV paths from manifest + predict_cells : bool + Predict on specific cells only, default False + include_fov_names : list[str] | None + FOV names for prediction (when predict_cells=True) + include_track_ids : list[int] | None + Track IDs for prediction (when predict_cells=True) + time_interval : Literal["any"] | int + Time interval for positive sampling, default "any" + return_negative : bool + Return negative samples for triplet loss, default True + persistent_workers : bool + Keep workers alive between epochs, default False + prefetch_factor : int | None + Batches to prefetch per worker + pin_memory : bool + Pin memory for faster GPU transfer, default False + z_window_size : int | None + Final Z window size (inferred from z_range if None) + cache_pool_bytes : int + Tensorstore cache pool size in bytes, default 0 + use_cached_concat : bool + Use CachedConcatDataModule instead of BatchedConcatDataModule + for multi-plate manifests, default False + + Returns + ------- + LightningDataModule + - TripletDataModule for single plate + - BatchedConcatDataModule for multiple plates (default) + - CachedConcatDataModule for multiple plates (if use_cached_concat=True) + + Raises + ------ + ValueError + - If manifest has no datasets + - If paths don't exist (validation fails) + - If fit_include_wells and manifest both specify FOVs (ambiguous) + FileNotFoundError + If data_path or tracks_path don't exist + TypeError + If manifest is not Manifest or ManifestDataset + + Examples + -------- + Basic usage with single-plate manifest: + + >>> from viscy.airtable.database import AirtableManager + >>> from viscy.airtable.factory import create_triplet_datamodule_from_manifest + >>> + >>> airtable_db = AirtableManager(base_id="appXXXXXXXXXXXXXX") + >>> manifest = airtable_db.get_dataset_paths("my_manifest", "0.0.1") + >>> + >>> dm = create_triplet_datamodule_from_manifest( + ... manifest=manifest, + ... source_channel=["Phase3D"], + ... z_range=(0, 5), + ... batch_size=32, + ... num_workers=8, + ... ) + >>> + >>> # Use with PyTorch Lightning + >>> trainer.fit(model, dm) + + Multi-plate manifest with normalization: + + >>> from viscy.transforms import NormalizeSampled + >>> + >>> manifest = airtable_db.get_dataset_paths("multi_plate_manifest", "1.0.0") + >>> print(f"Manifest has {len(manifest)} plates") # e.g., 3 plates + >>> + >>> dm = create_triplet_datamodule_from_manifest( + ... manifest=manifest, + ... source_channel=["Phase3D", "RFP"], + ... z_range=(0, 10), + ... normalizations=[ + ... NormalizeSampled( + ... ["Phase3D"], + ... level="fov_statistics", + ... subtrahend="mean", + ... divisor="std" + ... ) + ... ], + ... batch_size=16, + ... use_cached_concat=False, # Use BatchedConcatDataModule + ... ) + >>> # Returns BatchedConcatDataModule wrapping 3 TripletDataModules + + Override manifest FOVs with specific wells: + + >>> dm = create_triplet_datamodule_from_manifest( + ... manifest=manifest, + ... source_channel=["Phase3D"], + ... z_range=(0, 5), + ... fit_include_wells=["B/3", "B/4"], # Override manifest FOVs + ... batch_size=16, + ... ) + + Using a single ManifestDataset directly: + + >>> ds = manifest.datasets[0] # Single plate + >>> dm = create_triplet_datamodule_from_manifest( + ... manifest=ds, # Pass ManifestDataset directly + ... source_channel=["Phase3D"], + ... z_range=(0, 5), + ... batch_size=16, + ... ) + + Notes + ----- + - FOV filtering priority: fit_include_wells > manifest.fov_names + - If both specified, raises ValueError to avoid ambiguity + - The factory validates paths before creating data modules + - Multi-plate handling uses BatchedConcatDataModule for efficient batching + - All TripletDataModule parameters can be passed through kwargs + """ + # STEP 1: Normalize input - handle both Manifest and ManifestDataset + if isinstance(manifest, ManifestDataset): + datasets = [manifest] + manifest_name = "single_dataset" + elif isinstance(manifest, Manifest): + if len(manifest.datasets) == 0: + raise ValueError(f"Manifest '{manifest.name}' has no datasets") + datasets = manifest.datasets + manifest_name = manifest.name + else: + raise TypeError( + f"Expected Manifest or ManifestDataset, got {type(manifest).__name__}" + ) + + # STEP 2: Validate all paths exist (fail early) + for i, ds in enumerate(datasets): + try: + ds.validate() + except FileNotFoundError as e: + raise FileNotFoundError(f"Manifest '{manifest_name}' dataset {i}: {e}") + + # STEP 3: Handle FOV filtering logic + # Check for ambiguous FOV specification + has_manifest_fovs = any(len(ds.fov_names) > 0 for ds in datasets) + + if fit_include_wells is not None and has_manifest_fovs: + # Ambiguous: both manifest and user specified FOVs + raise ValueError( + "Cannot specify both 'fit_include_wells' and use manifest FOV filtering. " + "The manifest already specifies FOVs to include. " + "Either:\n" + " 1. Use fit_include_wells=None to respect manifest FOVs, OR\n" + " 2. Create a new manifest without FOV filtering if you want custom wells" + ) + + # STEP 4: Ensure normalizations and augmentations are lists + normalizations = normalizations or [] + augmentations = augmentations or [] + + # STEP 5: Create TripletDataModule for each dataset + data_modules = [] + + for ds in datasets: + # Determine well filtering strategy + if fit_include_wells is not None: + # User override: use explicit wells + include_wells = fit_include_wells + elif len(ds.fov_names) > 0: + # Convert manifest FOV names to well IDs + include_wells = _extract_wells_from_fov_names(ds.fov_names) + else: + # No filtering: use all wells + include_wells = None + + # Create TripletDataModule + dm = TripletDataModule( + data_path=ds.data_path, + tracks_path=ds.tracks_path, + source_channel=source_channel, + z_range=z_range, + initial_yx_patch_size=initial_yx_patch_size, + final_yx_patch_size=final_yx_patch_size, + split_ratio=split_ratio, + batch_size=batch_size, + num_workers=num_workers, + normalizations=normalizations, + augmentations=augmentations, + augment_validation=augment_validation, + caching=caching, + fit_include_wells=include_wells, + fit_exclude_fovs=fit_exclude_fovs, + predict_cells=predict_cells, + include_fov_names=include_fov_names, + include_track_ids=include_track_ids, + time_interval=time_interval, + return_negative=return_negative, + persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + pin_memory=pin_memory, + z_window_size=z_window_size, + cache_pool_bytes=cache_pool_bytes, + ) + data_modules.append(dm) + + # STEP 6: Return appropriate data module type + if len(data_modules) == 1: + # Single plate: return TripletDataModule directly + return data_modules[0] + else: + # Multiple plates: wrap in ConcatDataModule + if use_cached_concat: + return CachedConcatDataModule(data_modules=data_modules) + else: + return BatchedConcatDataModule(data_modules=data_modules) + + +class ManifestTripletDataModule(TripletDataModule): + """ + TripletDataModule that fetches paths from Airtable manifests. + + This class is designed to work with PyTorch Lightning CLI and config files. + It extends TripletDataModule to accept Airtable manifest parameters instead + of explicit data_path and tracks_path. + + Parameters + ---------- + base_id : str + Airtable base ID + manifest_name : str + Name of the manifest in Airtable + manifest_version : str + Semantic version of the manifest (e.g., "0.0.1") + source_channel : str | Sequence[str] + Input channel name(s) + z_range : tuple[int, int] + Range of valid z-slices + api_key : str | None + Airtable API key (if None, reads from AIRTABLE_API_KEY env var) + **kwargs + All other parameters passed to TripletDataModule.__init__ + + Raises + ------ + ValueError + If manifest has multiple datasets (only single-plate manifests supported) + + Examples + -------- + In a Lightning config file (config.yml): + + ```yaml + data: + class_path: viscy.airtable.factory.ManifestTripletDataModule + init_args: + base_id: "appXXXXXXXXXXXXXX" + manifest_name: "my_manifest" + manifest_version: "0.0.1" + source_channel: [Phase] + z_range: [0, 5] + batch_size: 16 + num_workers: 8 + normalizations: + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [Phase] + level: fov_statistics + subtrahend: mean + divisor: std + ``` + + Command line usage: + ```bash + viscy fit -c config.yml + ``` + + Direct usage in Python: + ```python + dm = ManifestTripletDataModule( + base_id="appXXXXXXXXXXXXXX", + manifest_name="my_manifest", + manifest_version="0.0.1", + source_channel=["Phase"], + z_range=(0, 5), + batch_size=16, + ) + trainer.fit(model, dm) + ``` + + Notes + ----- + - Only supports single-plate manifests (use create_triplet_datamodule_from_manifest + for multi-plate support with BatchedConcatDataModule) + - Fetches manifest from Airtable during __init__ + - All TripletDataModule parameters are available + - FOV filtering from manifest is automatically applied via fit_include_wells + """ + + def __init__( + self, + base_id: str, + manifest_name: str, + manifest_version: str, + source_channel: str | Sequence[str], + z_range: tuple[int, int], + api_key: str | None = None, + fit_include_wells: list[str] | None = None, + **kwargs, + ): + # Fetch manifest from Airtable + airtable_db = AirtableManager( + base_id=base_id, api_key=api_key or os.getenv("AIRTABLE_API_KEY") + ) + manifest = airtable_db.get_dataset_paths( + manifest_name=manifest_name, + version=manifest_version, + ) + + # Validate single plate + if len(manifest.datasets) != 1: + raise ValueError( + f"ManifestTripletDataModule only supports single-plate manifests. " + f"Manifest '{manifest_name}' has {len(manifest.datasets)} plates. " + f"Use create_triplet_datamodule_from_manifest() for multi-plate support." + ) + + dataset = manifest.datasets[0] + + # Handle FOV filtering + if fit_include_wells is not None: + # User override: use explicit wells + include_wells = fit_include_wells + elif len(dataset.fov_names) > 0: + # Convert manifest FOV names to well IDs + include_wells = _extract_wells_from_fov_names(dataset.fov_names) + else: + # No filtering: use all wells + include_wells = None + + # Initialize parent TripletDataModule with extracted paths + super().__init__( + data_path=dataset.data_path, + tracks_path=dataset.tracks_path, + source_channel=source_channel, + z_range=z_range, + fit_include_wells=include_wells, + **kwargs, + ) From 7c4e01ec45b7b7ec6ff25547dfd892c65eec5fc0 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 7 Jan 2026 16:40:39 -0800 Subject: [PATCH 14/18] closing the loop to use airtable and start a dynaclr model --- examples/airtable/train_with_wandb.yml | 75 ++++++++++++++++++++++++++ viscy/airtable/__init__.py | 2 + viscy/airtable/callbacks.py | 74 ++++++++++++++++++++++++- viscy/airtable/factory.py | 7 +++ viscy/representation/engine.py | 19 +++++-- 5 files changed, 173 insertions(+), 4 deletions(-) create mode 100644 examples/airtable/train_with_wandb.yml diff --git a/examples/airtable/train_with_wandb.yml b/examples/airtable/train_with_wandb.yml new file mode 100644 index 000000000..fb9332a00 --- /dev/null +++ b/examples/airtable/train_with_wandb.yml @@ -0,0 +1,75 @@ +# Example Lightning config using ManifestTripletDataModule with W&B tracking +# This config fetches dataset paths from Airtable manifests and logs to W&B +# Usage: viscy fit -c examples/airtable/train_with_wandb.yml + +seed_everything: 42 + +trainer: + accelerator: gpu + devices: 1 + max_epochs: 100 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + default_root_dir: logs/wandb # W&B will create versioned subdirs + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: viscy-experiments + entity: null # Set to your W&B team name or leave null for personal + log_model: false # Don't upload checkpoints to W&B (too large) + save_dir: ./logs # Explicit save directory for config files + name: null # Auto-generate unique run name + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: loss/val + save_top_k: 3 + save_last: true + - class_path: viscy.airtable.callbacks.ManifestWandbCallback + # No init_args needed - automatically logs manifest metadata + +model: + class_path: viscy.representation.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy.representation.contrastive.ContrastiveEncoder + init_args: + backbone: convnext_tiny + in_channels: 1 + in_stack_depth: 5 + stem_kernel_size: [5, 4, 4] + stem_stride: [5, 4, 4] + embedding_dim: 768 + projection_dim: 32 + loss_function: + class_path: pytorch_metric_learning.losses.NTXentLoss + init_args: + temperature: 0.05 + lr: 0.0002 + log_batches_per_epoch: 3 + log_samples_per_batch: 3 + example_input_array_shape: [1, 1, 5, 160, 160] + + +data: + # NEW: Use ManifestTripletDataModule for Airtable manifest integration + class_path: viscy.airtable.factory.ManifestTripletDataModule + init_args: + # Airtable manifest parameters + base_id: "app8vqaoWyOwa0sB5" # Replace with your base ID + manifest_name: "2024_11_07_A549_SEC61_DENV_wells_B1_B2" + manifest_version: "0.0.1" + + # TripletDataModule parameters + source_channel: [Phase3D] + z_range: [10, 15] + z_window_size: 5 + initial_yx_patch_size: [160, 160] + final_yx_patch_size: [160, 160] + batch_size: 16 + num_workers: 1 + split_ratio: 0.8 + time_interval: any + + # Optional: Override manifest FOVs with specific wells + fit_include_wells: ["B/1", "B/2"] diff --git a/viscy/airtable/__init__.py b/viscy/airtable/__init__.py index 3343168cb..e4ed3b9ca 100644 --- a/viscy/airtable/__init__.py +++ b/viscy/airtable/__init__.py @@ -1,5 +1,6 @@ """Airtable integration for dataset management and tracking.""" +from viscy.airtable.callbacks import ManifestWandbCallback from viscy.airtable.database import AirtableManager, Manifest, ManifestDataset from viscy.airtable.factory import ( ManifestTripletDataModule, @@ -11,5 +12,6 @@ "Manifest", "ManifestDataset", "ManifestTripletDataModule", + "ManifestWandbCallback", "create_triplet_datamodule_from_manifest", ] diff --git a/viscy/airtable/callbacks.py b/viscy/airtable/callbacks.py index e6093b29f..6f11a34ef 100644 --- a/viscy/airtable/callbacks.py +++ b/viscy/airtable/callbacks.py @@ -6,7 +6,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.callbacks import Callback -from viscy.airtable.datasets import AirtableManager +from viscy.airtable.database import AirtableManager class AirtableLoggingCallback(Callback): @@ -118,3 +118,75 @@ def on_fit_end(self, trainer: Trainer, pl_module: Any) -> None: except Exception as e: print(f"\n✗ Failed to log to Airtable: {e}") # Don't fail training if Airtable logging fails + + +class ManifestWandbCallback(Callback): + """ + Log manifest metadata to Weights & Biases automatically. + + This callback extracts manifest information from ManifestTripletDataModule + and logs it to W&B config for searchability and lineage tracking. + + Examples + -------- + Add to config YAML: + + >>> trainer: + >>> logger: + >>> class_path: lightning.pytorch.loggers.WandbLogger + >>> init_args: + >>> project: viscy-experiments + >>> log_model: false + >>> callbacks: + >>> - class_path: viscy.airtable.callbacks.ManifestWandbCallback + + Or add programmatically: + + >>> from lightning.pytorch.loggers import WandbLogger + >>> logger = WandbLogger(project="viscy-experiments") + >>> callback = ManifestWandbCallback() + >>> trainer = Trainer(logger=logger, callbacks=[callback]) + """ + + def on_train_start(self, trainer: Trainer, pl_module: Any) -> None: + """Log manifest metadata to W&B config at training start.""" + # Import here to avoid requiring wandb as a dependency + try: + from lightning.pytorch.loggers import WandbLogger + except ImportError: + return # Skip if wandb not installed + + # Check if using WandbLogger + if not isinstance(trainer.logger, WandbLogger): + return + + # Check if using ManifestTripletDataModule + from viscy.airtable.factory import ManifestTripletDataModule + + dm = trainer.datamodule + + # Log manifest metadata if using ManifestTripletDataModule + if isinstance(dm, ManifestTripletDataModule): + manifest_config = { + "manifest/name": dm.manifest_name, + "manifest/version": dm.manifest_version, + "manifest/base_id": dm.base_id, + "manifest/data_path": str(dm.data_path), + "manifest/tracks_path": str(dm.tracks_path), + } + trainer.logger.experiment.config.update(manifest_config) + + print("\n✓ Manifest metadata logged to W&B:") + print(f" Manifest: {dm.manifest_name} v{dm.manifest_version}") + print(f" Data path: {dm.data_path}") + print(f" Tracks path: {dm.tracks_path}") + + # Also log data module hyperparameters explicitly + if dm is not None and hasattr(dm, "hparams"): + data_config = {f"data/{k}": v for k, v in dm.hparams.items()} + trainer.logger.experiment.config.update(data_config, allow_val_change=True) + + # Log model hyperparameters explicitly + if hasattr(pl_module, "hparams"): + model_config = {f"model/{k}": v for k, v in pl_module.hparams.items()} + trainer.logger.experiment.config.update(model_config, allow_val_change=True) diff --git a/viscy/airtable/factory.py b/viscy/airtable/factory.py index df7634c2f..5327467d2 100644 --- a/viscy/airtable/factory.py +++ b/viscy/airtable/factory.py @@ -435,6 +435,13 @@ def __init__( dataset = manifest.datasets[0] + # Store manifest metadata as instance attributes for callbacks/logging + self.base_id = base_id + self.manifest_name = manifest_name + self.manifest_version = manifest_version + self.data_path = dataset.data_path + self.tracks_path = dataset.tracks_path + # Handle FOV filtering if fit_include_wells is not None: # User override: use explicit wells diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 6c2d243d9..5a87a9abf 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -168,9 +168,22 @@ def _log_metrics( def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): grid = render_images(imgs, cmaps=["gray"] * 3) - self.logger.experiment.add_image( - key, grid, self.current_epoch, dataformats="HWC" - ) + + # Handle different logger types + if hasattr(self.logger, "experiment"): + # Check if TensorBoard logger + if hasattr(self.logger.experiment, "add_image"): + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + # Check if WandB logger + # FIXME this is just temporary fix to get the samples logged to WandB + elif hasattr(self.logger.experiment, "log"): + import wandb + + self.logger.experiment.log( + {key: wandb.Image(grid), "epoch": self.current_epoch} + ) def _log_step_samples(self, batch_idx, samples, stage: Literal["train", "val"]): """Common method for logging step samples""" From 8ee1899ee7011f7d4cd372157d6a3069effd24df Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 9 Jan 2026 10:16:39 -0800 Subject: [PATCH 15/18] add pydantic for the datasets and demo script to add programmatically new datasets --- examples/airtable/test_pydantic_airtable.py | 84 ++++++++++ viscy/airtable/__init__.py | 10 ++ viscy/airtable/database.py | 149 ++++++++++++----- viscy/airtable/schemas.py | 170 ++++++++++++++++++++ 4 files changed, 370 insertions(+), 43 deletions(-) create mode 100644 examples/airtable/test_pydantic_airtable.py create mode 100644 viscy/airtable/schemas.py diff --git a/examples/airtable/test_pydantic_airtable.py b/examples/airtable/test_pydantic_airtable.py new file mode 100644 index 000000000..5f83f128b --- /dev/null +++ b/examples/airtable/test_pydantic_airtable.py @@ -0,0 +1,84 @@ +"""Test Pydantic dataset registration with Airtable.""" + +from viscy.airtable import AirtableManager, DatasetRecord + +BASE_ID = "app8vqaoWyOwa0sB5" + +print("=" * 70) +print("Testing Pydantic + Airtable Integration") +print("=" * 70) + +airtable_db = AirtableManager(base_id=BASE_ID) + +# %% +# Test 1: Single dataset registration +print("\n[Test 1] Register single dataset") +print("-" * 70) + +dataset = DatasetRecord( + dataset_name="pydantic_test_plate", + well_id="A_1", + fov_name="0", + data_path="/hpc/data/pydantic_test.zarr/A/1/0", + cell_type="A549", + organelle="SEC61B", +) + +try: + record_id = airtable_db.register_dataset(dataset) + print(f"✓ Registered: {record_id}") + print(f" Dataset: {dataset.dataset_name}") + print(f" Well: {dataset.well_id}, FOV: {dataset.fov_name}") + print(f" Path: {dataset.data_path}") +except Exception as e: + print(f"✗ Failed: {e}") + +# %% +# Test 2: Bulk dataset registration +print("\n[Test 2] Register multiple datasets") +print("-" * 70) + +datasets = [ + DatasetRecord( + dataset_name="pydantic_test_plate", + well_id=f"B_{well}", + fov_name=str(fov), + data_path=f"/hpc/data/pydantic_test.zarr/B/{well}/{fov}", + cell_type="A549", + ) + for well in range(1, 3) + for fov in range(2) +] + +try: + record_ids = airtable_db.register_datasets(datasets) + print(f"✓ Registered {len(record_ids)} datasets") + for ds, rec_id in zip(datasets, record_ids): + print(f" {ds.dataset_name}_{ds.well_id}_{ds.fov_name} -> {rec_id}") +except Exception as e: + print(f"✗ Failed: {e}") + +# %% +# Test 3: List datasets as Pydantic models +print("\n[Test 3] List datasets as Pydantic models") +print("-" * 70) + +try: + all_datasets = airtable_db.list_datasets(as_pydantic=True) + print(f"Total datasets in Airtable: {len(all_datasets)}") + + # Filter for our test datasets + test_datasets = [ds for ds in all_datasets if ds.fov_id.startswith("pydantic_test")] + print(f"\nTest datasets found: {len(test_datasets)}") + for ds in test_datasets: + print(f" - {ds.fov_id}") + print(f" Dataset: {ds.dataset_name}") + print(f" Cell type: {ds.cell_type}") + print(f" Record ID: {ds.record_id}") + +except Exception as e: + print(f"✗ Failed: {e}") + +print("\n" + "=" * 70) +print("Airtable integration test complete!") +print("=" * 70) diff --git a/viscy/airtable/__init__.py b/viscy/airtable/__init__.py index e4ed3b9ca..04de533d1 100644 --- a/viscy/airtable/__init__.py +++ b/viscy/airtable/__init__.py @@ -6,12 +6,22 @@ ManifestTripletDataModule, create_triplet_datamodule_from_manifest, ) +from viscy.airtable.register_model import ( + list_registered_models, + load_model_from_registry, + register_model, +) +from viscy.airtable.schemas import DatasetRecord __all__ = [ "AirtableManager", + "DatasetRecord", "Manifest", "ManifestDataset", "ManifestTripletDataModule", "ManifestWandbCallback", "create_triplet_datamodule_from_manifest", + "register_model", + "load_model_from_registry", + "list_registered_models", ] diff --git a/viscy/airtable/database.py b/viscy/airtable/database.py index de5af4fdc..be580a471 100644 --- a/viscy/airtable/database.py +++ b/viscy/airtable/database.py @@ -11,6 +11,8 @@ from natsort import natsorted from pyairtable import Api +from viscy.airtable.schemas import DatasetRecord + @dataclass class ManifestDataset: @@ -169,46 +171,78 @@ def __init__( self.manifests_table = self.api.table(base_id, "Manifest") self.models_table = self.api.table(base_id, "Models") - def register_dataset( - self, - fov_id: str, - dataset_name: str, - well_id: str, - fov_name: str, - data_path: str, - ) -> str: + def register_dataset(self, dataset: DatasetRecord) -> str: """ Register a single dataset record (FOV) in Airtable. Parameters ---------- - fov_id : str - Human-readable identifier (e.g., "RPE1_plate1_B_3_0") - dataset_name : str - Name of the dataset/plate this FOV belongs to - well_id : str - Well identifier as row_column (e.g., "B_3") - fov_name : str - FOV index within well (e.g., "0", "1", "2") - data_path : str - Full path to FOV (e.g., "/hpc/data/plate.zarr/B/3/0") + dataset : DatasetRecord + Dataset record with FOV metadata Returns ------- str Airtable record ID - """ - record = { - "FOV_ID": fov_id, - "Dataset": dataset_name, - "Well ID": well_id, - "FOV": fov_name, - "Data path": data_path, - } - created = self.datasets_table.create(record) + Examples + -------- + >>> from viscy.airtable.schemas import DatasetRecord + >>> dataset = DatasetRecord( + ... fov_id="plate_B_3_0", + ... dataset_name="plate", + ... well_id="B_3", + ... fov_name="0", + ... data_path="/hpc/data/plate.zarr/B/3/0" + ... ) + >>> record_id = airtable_db.register_dataset(dataset) + """ + record_dict = dataset.to_airtable_dict() + created = self.datasets_table.create(record_dict) return created["id"] + def register_datasets(self, datasets: list[DatasetRecord]) -> list[str]: + """ + Register multiple dataset records (FOVs) in Airtable. + + Parameters + ---------- + datasets : list[DatasetRecord] + List of dataset records to register + + Returns + ------- + list[str] + List of Airtable record IDs + + Examples + -------- + >>> from viscy.airtable.schemas import DatasetRecord + >>> datasets = [ + ... DatasetRecord( + ... fov_id="plate_B_3_0", + ... dataset_name="plate", + ... well_id="B_3", + ... fov_name="0", + ... data_path="/hpc/data/plate.zarr/B/3/0" + ... ), + ... DatasetRecord( + ... fov_id="plate_B_3_1", + ... dataset_name="plate", + ... well_id="B_3", + ... fov_name="1", + ... data_path="/hpc/data/plate.zarr/B/3/1" + ... ), + ... ] + >>> record_ids = airtable_db.register_datasets(datasets) + """ + record_ids = [] + for dataset in datasets: + record_dict = dataset.to_airtable_dict() + created = self.datasets_table.create(record_dict) + record_ids.append(created["id"]) + return record_ids + def create_manifest_from_datasets( self, manifest_name: str, @@ -556,9 +590,14 @@ def list_manifests( data = [d for d in data if d.get("purpose") == purpose] return data - def list_datasets(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: + def list_datasets( + self, + as_dataframe: bool = True, + as_pydantic: bool = False, + skip_invalid: bool = True, + ) -> pd.DataFrame | list[dict] | list[DatasetRecord]: """ - Get all dataset records (FOVs) as a DataFrame (or list of dicts). + Get all dataset records (FOVs) as a DataFrame, list of dicts, or Pydantic models. Use pandas for filtering - much simpler and more powerful than building Airtable formulas. @@ -566,16 +605,21 @@ def list_datasets(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: Parameters ---------- as_dataframe : bool - If True, return pandas DataFrame. If False, return list of dicts. + If True, return pandas DataFrame. Ignored if as_pydantic is True. + as_pydantic : bool + If True, return list of DatasetRecord objects. Takes precedence over as_dataframe. + skip_invalid : bool + If True and as_pydantic=True, skip records that fail validation instead of raising error. + Default is True to handle legacy/incomplete records gracefully. Returns ------- - pd.DataFrame | list[dict] + pd.DataFrame | list[dict] | list[DatasetRecord] All dataset records Examples -------- - >>> # Get all datasets + >>> # Get all datasets as DataFrame >>> df = airtable_db.list_datasets() >>> >>> # Filter with pandas (simple and powerful!) @@ -583,11 +627,30 @@ def list_datasets(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: >>> filtered = df[df['Well ID'].isin(['B_3', 'B_4'])] >>> filtered = df[~df['FOV_ID'].isin(['RPE1_plate1_B_3_2'])] >>> + >>> # Get as Pydantic models for type safety + >>> datasets = airtable_db.list_datasets(as_pydantic=True) + >>> for dataset in datasets: + ... print(dataset.fov_id, dataset.data_path) + >>> >>> # Group and analyze >>> df.groupby('Dataset').size() >>> df.groupby('Well ID').size() """ records = self.datasets_table.all() + + if as_pydantic: + parsed_records = [] + for r in records: + try: + parsed_records.append(DatasetRecord.from_airtable_record(r)) + except Exception: + if skip_invalid: + # Skip invalid records silently + continue + else: + raise + return parsed_records + data = [{"id": r["id"], **r["fields"]} for r in records] if as_dataframe: @@ -620,7 +683,7 @@ def delete_manifest(self, manifest_id: str) -> bool: def log_model_training( self, manifest_id: str, - mlflow_run_id: str, + wandb_run_id: str, model_name: str | None = None, metrics: dict[str, float] | None = None, checkpoint_path: str | None = None, @@ -635,12 +698,12 @@ def log_model_training( ---------- manifest_id : str Airtable record ID of manifest used - mlflow_run_id : str - MLflow run ID for experiment tracking + wandb_run_id : str + W&B run ID for experiment tracking model_name : str | None Human-readable model name metrics : dict | None - Training metrics (e.g., {"accuracy": 0.89, "f1_score": 0.92}) + Training metrics (e.g., {"val_loss": 0.15, "dice": 0.92}) checkpoint_path : str | None Path to saved model checkpoint trained_by : str | None @@ -656,17 +719,17 @@ def log_model_training( >>> manifest_id = airtable_db.create_manifest_from_datasets(...) >>> model_id = airtable_db.log_model_training( ... manifest_id=manifest_id, - ... mlflow_run_id="run_abc123", - ... model_name="sec61_model_v1", + ... wandb_run_id="20260107-152420", + ... model_name="contrastive-a549:v1", ... metrics={"val_loss": 0.15}, - ... trained_by="researcher_name" + ... trained_by="eduardo.hirata" ... ) """ # Create model record model_record = { "model_name": model_name or f"model_{datetime.now():%Y%m%d_%H%M%S}", "manifest": [manifest_id], # Link to manifest - "mlflow_run_id": mlflow_run_id, + "wandb_run_id": wandb_run_id, "trained_date": datetime.now().isoformat(), } @@ -688,10 +751,10 @@ def log_model_training( # Handle models_trained as comma-separated string if models_trained_str: models_list = [m.strip() for m in models_trained_str.split(",")] - models_list.append(mlflow_run_id) + models_list.append(wandb_run_id) new_models_str = ", ".join(models_list) else: - new_models_str = mlflow_run_id + new_models_str = wandb_run_id self.manifests_table.update( manifest_id, @@ -891,6 +954,6 @@ def _get_manifest_dataset_ids(self, manifest_name: str, version: str) -> list[st def update_record( self, - ): + ): # TODO: to update the tracks path column raise NotImplementedError("Not implemented yet") diff --git a/viscy/airtable/schemas.py b/viscy/airtable/schemas.py new file mode 100644 index 000000000..b5340b971 --- /dev/null +++ b/viscy/airtable/schemas.py @@ -0,0 +1,170 @@ +"""Pydantic schemas for Airtable records.""" + +from typing import Any + +from pydantic import BaseModel, Field, field_validator + + +class DatasetRecord(BaseModel): + """ + Pydantic model for a dataset record (FOV) in Airtable. + + This represents a single field of view (FOV) from an HCS plate, + with metadata about the biological sample and data location. + + Note + ---- + FOV_ID is a computed field in Airtable (formula: Dataset + "_" + Well ID + "_" + FOV) + and is automatically generated. You can read it but cannot set it during creation. + """ + + # Required fields for creation (but optional when reading from Airtable due to legacy data) + dataset_name: str = Field( + ..., + description="Name of the dataset/plate this FOV belongs to", + alias="Dataset", + ) + well_id: str = Field( + ..., + description="Well identifier (e.g., 'B_3' or 'B/3')", + alias="Well ID", + ) + fov_name: str = Field( + ..., description="FOV index within well (e.g., '0', '1', '2')", alias="FOV" + ) + data_path: str | None = Field( + None, + description="Full path to FOV data (e.g., '/hpc/data/plate.zarr/B/3/0'). Required for new records.", + alias="Data path", + ) + + # Computed field (read-only, auto-generated by Airtable) + fov_id: str | None = Field( + None, + description="Computed FOV identifier (Dataset_WellID_FOV) - auto-generated by Airtable", + alias="FOV_ID", + exclude=True, # Don't include in to_airtable_dict() by default + ) + + # Optional metadata fields + cell_type: str | None = Field( + None, description="Cell type (e.g., 'RPE1', 'A549')", alias="Cell type" + ) + cell_state: str | None = Field( + None, + description="Cell state (e.g., 'healthy', 'infected')", + alias="Cell state", + ) + cell_line: list[str] | None = Field( + None, + description="Cell line identifiers (Airtable array field)", + alias="Cell line", + ) + organelle: str | None = Field( + None, description="Target organelle (e.g., 'SEC61B')", alias="Organelle" + ) + channel_0: str | None = Field( + None, description="Channel 0 name (e.g., 'Phase')", alias="Channel-0" + ) + channel_1: str | None = Field( + None, description="Channel 1 name (e.g., 'GFP')", alias="Channel-1" + ) + channel_2: str | None = Field( + None, description="Channel 2 name (e.g., 'RFP')", alias="Channel-2" + ) + fluorescence_modality: str | None = Field( + None, + description="Fluorescence modality (e.g., 'confocal')", + alias="Fluorescence modality", + ) + organellebox_infectomics: bool | None = Field( + None, + description="Part of OrganelleBox Infectomics dataset", + alias="OrganelleBox Infectomics", + ) + + # Airtable record ID (not used for creation, but for retrieval) + record_id: str | None = Field(None, description="Airtable record ID", alias="id") + + model_config = { + "populate_by_name": True, # Allow both field name and alias + "str_strip_whitespace": True, # Strip whitespace from strings + } + + @field_validator("dataset_name", "well_id", "fov_name") + @classmethod + def no_empty_strings(cls, v: str) -> str: + """Ensure required string fields are not empty.""" + if not v or not v.strip(): + raise ValueError("Field cannot be empty") + return v.strip() + + @field_validator("fov_id") + @classmethod + def validate_fov_id(cls, v: str | None) -> str | None: + """Validate FOV ID if provided (it's computed, so usually None on creation).""" + if v is not None and (not v or not v.strip()): + raise ValueError("FOV ID cannot be empty string if provided") + return v.strip() if v else None + + @field_validator("data_path") + @classmethod + def validate_data_path(cls, v: str | None) -> str | None: + """Validate data path format if provided.""" + if v is None: + return None + if not v.strip(): + raise ValueError("Data path cannot be empty string") + # Don't check if path exists - it might be on a different machine + # Just ensure it's a reasonable path format + if not v.startswith("/"): + raise ValueError("Data path must be an absolute path (start with /)") + return v.strip() + + def to_airtable_dict(self) -> dict[str, Any]: + """ + Convert to Airtable-compatible dictionary. + + Returns + ------- + dict + Dictionary with Airtable field names (aliases), excluding None values + and the record_id field. + """ + # Use by_alias=True to get Airtable field names + # exclude_none=True to skip optional fields that weren't set + return self.model_dump( + by_alias=True, exclude_none=True, exclude={"id", "record_id"} + ) + + @classmethod + def from_airtable_record(cls, record: dict[str, Any]) -> "DatasetRecord": + """ + Create DatasetRecord from Airtable API response. + + Parameters + ---------- + record : dict + Airtable record with 'id' and 'fields' keys + + Returns + ------- + DatasetRecord + Parsed dataset record + """ + return cls(id=record["id"], **record["fields"]) + + def __str__(self) -> str: + """Human-readable string representation.""" + identifier = ( + self.fov_id or f"{self.dataset_name}_{self.well_id}_{self.fov_name}" + ) + path_str = self.data_path or "no path" + return f"Dataset({identifier} @ {path_str})" + + def __repr__(self) -> str: + """Detailed string representation.""" + return ( + f"DatasetRecord(dataset_name={self.dataset_name!r}, well_id={self.well_id!r}, " + f"fov_name={self.fov_name!r}, fov_id={self.fov_id!r})" + ) From d18e294f028d5ffd1e5ff5de32cd27363b4fb885 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 12 Jan 2026 16:24:46 -0800 Subject: [PATCH 16/18] renaming manifest to collection --- .../airtable/filter_n_create_dataset_tag.py | 18 +- .../airtable/get_dataset_paths_example.py | 77 ++-- examples/airtable/train_with_wandb.yml | 20 +- viscy/airtable/__init__.py | 21 +- viscy/airtable/callbacks.py | 56 +-- viscy/airtable/database.py | 334 +++++++++--------- viscy/airtable/factory.py | 168 ++++----- viscy/airtable/schemas.py | 117 +++++- 8 files changed, 461 insertions(+), 350 deletions(-) diff --git a/examples/airtable/filter_n_create_dataset_tag.py b/examples/airtable/filter_n_create_dataset_tag.py index 9372a7bea..14bbc000e 100644 --- a/examples/airtable/filter_n_create_dataset_tag.py +++ b/examples/airtable/filter_n_create_dataset_tag.py @@ -1,4 +1,4 @@ -"""Filter datasets using pandas and create manifest tags.""" +"""Filter datasets using pandas and create collection tags.""" # %% @@ -40,26 +40,26 @@ print("\nBreakdown by well:") print(filtered.groupby("Well ID").size()) -# Create manifest from filtered dataset records +# Create collection from filtered dataset records fov_ids = filtered["FOV_ID"].tolist() try: - manifest_id = airtable_db.create_manifest_from_datasets( - manifest_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", + collection_id = airtable_db.create_collection_from_datasets( + collection_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", fov_ids=fov_ids, - version="0.0.1", # Semantic versioning + version="0.0.2", # Semantic versioning purpose="training", description="Dataset records from wells B_3 and B_4", ) - print(f"\n✓ Created manifest: {manifest_id}") + print(f"\n✓ Created collection: {collection_id}") print(f" Contains {len(fov_ids)} dataset records") except ValueError as e: print(f"\n⚠ {e}") # %% -# Delete the manifest entry demo -airtable_db.delete_manifest(manifest_id) -print(f"Deleted manifest: {manifest_id}") +# Delete the collection entry demo +airtable_db.delete_collection(collection_id) +print(f"Deleted collection: {collection_id}") # %% # EXAMPLE 3: Group by dataset and show summary diff --git a/examples/airtable/get_dataset_paths_example.py b/examples/airtable/get_dataset_paths_example.py index fc8bc97ff..067469a32 100644 --- a/examples/airtable/get_dataset_paths_example.py +++ b/examples/airtable/get_dataset_paths_example.py @@ -1,4 +1,4 @@ -"""Example usage of get_dataset_paths with Manifest and ManifestDataset dataclasses.""" +"""Example usage of get_dataset_paths with Collections and CollectionDataset dataclasses.""" # %% from viscy.airtable.database import AirtableManager @@ -7,24 +7,24 @@ airtable_db = AirtableManager(base_id=BASE_ID) # %% -# Fetch manifest from Airtable -manifest = airtable_db.get_dataset_paths( - manifest_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", - version="0.0.1", +# Fetch collection from Airtable +collection = airtable_db.get_dataset_paths( + collection_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", + version="v1", ) # %% -# Manifest properties -print("=== Manifest ===") -print(f"manifest.name: {manifest.name}") -print(f"manifest.version: {manifest.version}") -print(f"len(manifest): {len(manifest)} HCS plate(s)") -print(f"manifest.total_fovs: {manifest.total_fovs} FOVs") +# Collections properties +print("=== Collections ===") +print(f"collection.name: {collection.name}") +print(f"collection.version: {collection.version}") +print(f"len(collection): {len(collection)} HCS plate(s)") +print(f"collection.total_fovs: {collection.total_fovs} FOVs") # %% -# Iterate over ManifestDataset objects (one per HCS plate) -print("\n=== ManifestDataset ===") -for ds in manifest: +# Iterate over CollectionDataset objects (one per HCS plate) +print("\n=== CollectionDataset ===") +for ds in collection: print(f"ds.data_path: {ds.data_path}") print(f"ds.tracks_path: {ds.tracks_path}") print(f"len(ds): {len(ds)} FOVs") @@ -34,25 +34,25 @@ # %% # Validate paths exist (raises FileNotFoundError if not) -manifest.validate() +collection.validate() print("\nAll paths validated successfully!") # %% -# List available manifests -print("=== Available Manifests ===") -df = airtable_db.list_manifests() +# List available collections +print("=== Available Collections ===") +df = airtable_db.list_collections() print(df[["name", "version", "purpose"]].dropna(subset=["name"]).to_string()) # %% # ============================================================================= -# Create TripletDataModule from manifest using factory function +# Create TripletDataModule from collection using factory function # ============================================================================= -from viscy.airtable.factory import create_triplet_datamodule_from_manifest +from viscy.airtable.factory import create_triplet_datamodule_from_collection -# Create data module from manifest -dm = create_triplet_datamodule_from_manifest( - manifest=manifest, +# Create data module from collection +dm = create_triplet_datamodule_from_collection( + collection=collection, source_channel=["Phase3D"], z_range=(20, 21), batch_size=1, @@ -66,23 +66,23 @@ # %% # Setup and inspect the data module dm.setup("fit") -print("\n=== TripletDataModule from Manifest ===") +print("\n=== TripletDataModule from Collections ===") print(f"Data module type: {type(dm).__name__}") print(f"Train samples: {len(dm.train_dataset)}") print(f"Val samples: {len(dm.val_dataset)}") # %% # ============================================================================= -# Alternative: ManifestTripletDataModule (Lightning Config Compatible) +# Alternative: CollectionTripletDataModule (Lightning Config Compatible) # ============================================================================= -from viscy.airtable.factory import ManifestTripletDataModule +from viscy.airtable.factory import CollectionTripletDataModule # This class is designed for Lightning CLI and config files # but can also be used directly in Python -dm_class = ManifestTripletDataModule( +dm_class = CollectionTripletDataModule( base_id=BASE_ID, - manifest_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", - manifest_version="0.0.1", + collection_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", + collection_version="v1", source_channel=["Phase3D"], z_range=(20, 21), batch_size=1, @@ -94,11 +94,10 @@ ) dm_class.setup("fit") -print("\n=== ManifestTripletDataModule (Class) ===") +print("\n=== CollectionTripletDataModule (Class) ===") print(f"Data module type: {type(dm_class).__name__}") print(f"Train samples: {len(dm_class.train_dataset)}") print(f"Val samples: {len(dm_class.val_dataset)}") -print("Note: This class is designed for Lightning config files!") # %% Visualize some of the images import matplotlib.pyplot as plt @@ -117,19 +116,3 @@ axs[i // 5, i % 5].imshow(img_stack[i], cmap="gray") axs[i // 5, i % 5].axis("off") plt.show() -# %% -# ============================================================================= -# Summary: When to use which approach -# ============================================================================= -print("\n=== Summary ===") -print("Use create_triplet_datamodule_from_manifest() when:") -print(" - Working in Python scripts or notebooks") -print(" - Manifest has multiple HCS plates (auto-combines them)") -print(" - Already have manifest object loaded") -print("") -print("Use ManifestTripletDataModule when:") -print(" - Working with Lightning CLI and config files") -print(" - Training with single-plate manifests") -print(" - Want clean YAML configuration") -print("") -print("See examples/airtable/manifest_config_example.yml for config usage") diff --git a/examples/airtable/train_with_wandb.yml b/examples/airtable/train_with_wandb.yml index fb9332a00..b69730d3b 100644 --- a/examples/airtable/train_with_wandb.yml +++ b/examples/airtable/train_with_wandb.yml @@ -1,5 +1,5 @@ -# Example Lightning config using ManifestTripletDataModule with W&B tracking -# This config fetches dataset paths from Airtable manifests and logs to W&B +# Example Lightning config using CollectionTripletDataModule with W&B tracking +# This config fetches dataset paths from Airtable collections and logs to W&B # Usage: viscy fit -c examples/airtable/train_with_wandb.yml seed_everything: 42 @@ -25,8 +25,8 @@ trainer: monitor: loss/val save_top_k: 3 save_last: true - - class_path: viscy.airtable.callbacks.ManifestWandbCallback - # No init_args needed - automatically logs manifest metadata + - class_path: viscy.airtable.callbacks.CollectionWandbCallback + # No init_args needed - automatically logs collection metadata model: class_path: viscy.representation.engine.ContrastiveModule @@ -52,13 +52,13 @@ model: data: - # NEW: Use ManifestTripletDataModule for Airtable manifest integration - class_path: viscy.airtable.factory.ManifestTripletDataModule + # NEW: Use CollectionTripletDataModule for Airtable collection integration + class_path: viscy.airtable.factory.CollectionTripletDataModule init_args: - # Airtable manifest parameters + # Airtable collection parameters base_id: "app8vqaoWyOwa0sB5" # Replace with your base ID - manifest_name: "2024_11_07_A549_SEC61_DENV_wells_B1_B2" - manifest_version: "0.0.1" + collection_name: "2024_11_07_A549_SEC61_DENV_wells_B1_B2" + collection_version: "0.0.1" # TripletDataModule parameters source_channel: [Phase3D] @@ -71,5 +71,5 @@ data: split_ratio: 0.8 time_interval: any - # Optional: Override manifest FOVs with specific wells + # Optional: Override collection FOVs with specific wells fit_include_wells: ["B/1", "B/2"] diff --git a/viscy/airtable/__init__.py b/viscy/airtable/__init__.py index 04de533d1..04bc4cf96 100644 --- a/viscy/airtable/__init__.py +++ b/viscy/airtable/__init__.py @@ -1,26 +1,27 @@ """Airtable integration for dataset management and tracking.""" -from viscy.airtable.callbacks import ManifestWandbCallback -from viscy.airtable.database import AirtableManager, Manifest, ManifestDataset +from viscy.airtable.callbacks import CollectionWandbCallback +from viscy.airtable.database import AirtableManager, CollectionDataset, Collections from viscy.airtable.factory import ( - ManifestTripletDataModule, - create_triplet_datamodule_from_manifest, + CollectionTripletDataModule, + create_triplet_datamodule_from_collection, ) from viscy.airtable.register_model import ( list_registered_models, load_model_from_registry, register_model, ) -from viscy.airtable.schemas import DatasetRecord +from viscy.airtable.schemas import DatasetRecord, ModelRecord __all__ = [ "AirtableManager", "DatasetRecord", - "Manifest", - "ManifestDataset", - "ManifestTripletDataModule", - "ManifestWandbCallback", - "create_triplet_datamodule_from_manifest", + "ModelRecord", + "Collections", + "CollectionDataset", + "CollectionTripletDataModule", + "CollectionWandbCallback", + "create_triplet_datamodule_from_collection", "register_model", "load_model_from_registry", "list_registered_models", diff --git a/viscy/airtable/callbacks.py b/viscy/airtable/callbacks.py index 6f11a34ef..c9afd0d82 100644 --- a/viscy/airtable/callbacks.py +++ b/viscy/airtable/callbacks.py @@ -17,14 +17,14 @@ class AirtableLoggingCallback(Callback): - Best model checkpoint path - Who trained the model - When it was trained - - Link to the manifest used + - Link to the collection used Parameters ---------- base_id : str Airtable base ID - manifest_id : str - Airtable manifest record ID (from config) + collection_id : str + Airtable collection record ID (from config) model_name : str | None Custom model name. If None, auto-generates from model class and timestamp. log_metrics : bool @@ -40,13 +40,13 @@ class AirtableLoggingCallback(Callback): >>> - class_path: viscy.airtable.callbacks.AirtableLoggingCallback >>> init_args: >>> base_id: "appXXXXXXXXXXXXXX" - >>> manifest_id: "recYYYYYYYYYYYYYY" + >>> collection_id: "recYYYYYYYYYYYYYY" Or add programmatically: >>> callback = AirtableLoggingCallback( >>> base_id="appXXXXXXXXXXXXXX", - >>> manifest_id="recYYYYYYYYYYYYYY" + >>> collection_id="recYYYYYYYYYYYYYY" >>> ) >>> trainer = Trainer(callbacks=[callback]) """ @@ -54,13 +54,13 @@ class AirtableLoggingCallback(Callback): def __init__( self, base_id: str, - manifest_id: str, + collection_id: str, model_name: str | None = None, log_metrics: bool = False, ): super().__init__() self.airtable_db = AirtableManager(base_id=base_id) - self.manifest_id = manifest_id + self.collection_id = collection_id self.model_name = model_name self.log_metrics = log_metrics @@ -104,7 +104,7 @@ def on_fit_end(self, trainer: Trainer, pl_module: Any) -> None: # Log to Airtable try: model_id = self.airtable_db.log_model_training( - manifest_id=self.manifest_id, + collection_id=self.collection_id, mlflow_run_id=run_id or "unknown", model_name=model_name, checkpoint_path=str(checkpoint_path) if checkpoint_path else None, @@ -114,17 +114,17 @@ def on_fit_end(self, trainer: Trainer, pl_module: Any) -> None: print(f"\n✓ Model logged to Airtable (record ID: {model_id})") print(f" Model name: {model_name}") print(f" Checkpoint: {checkpoint_path}") - print(f" Manifest ID: {self.manifest_id}") + print(f" Collections ID: {self.collection_id}") except Exception as e: print(f"\n✗ Failed to log to Airtable: {e}") # Don't fail training if Airtable logging fails -class ManifestWandbCallback(Callback): +class CollectionWandbCallback(Callback): """ - Log manifest metadata to Weights & Biases automatically. + Log collection metadata to Weights & Biases automatically. - This callback extracts manifest information from ManifestTripletDataModule + This callback extracts collection information from CollectionTripletDataModule and logs it to W&B config for searchability and lineage tracking. Examples @@ -138,18 +138,18 @@ class ManifestWandbCallback(Callback): >>> project: viscy-experiments >>> log_model: false >>> callbacks: - >>> - class_path: viscy.airtable.callbacks.ManifestWandbCallback + >>> - class_path: viscy.airtable.callbacks.CollectionWandbCallback Or add programmatically: >>> from lightning.pytorch.loggers import WandbLogger >>> logger = WandbLogger(project="viscy-experiments") - >>> callback = ManifestWandbCallback() + >>> callback = CollectionWandbCallback() >>> trainer = Trainer(logger=logger, callbacks=[callback]) """ def on_train_start(self, trainer: Trainer, pl_module: Any) -> None: - """Log manifest metadata to W&B config at training start.""" + """Log collection metadata to W&B config at training start.""" # Import here to avoid requiring wandb as a dependency try: from lightning.pytorch.loggers import WandbLogger @@ -160,24 +160,24 @@ def on_train_start(self, trainer: Trainer, pl_module: Any) -> None: if not isinstance(trainer.logger, WandbLogger): return - # Check if using ManifestTripletDataModule - from viscy.airtable.factory import ManifestTripletDataModule + # Check if using CollectionTripletDataModule + from viscy.airtable.factory import CollectionTripletDataModule dm = trainer.datamodule - # Log manifest metadata if using ManifestTripletDataModule - if isinstance(dm, ManifestTripletDataModule): - manifest_config = { - "manifest/name": dm.manifest_name, - "manifest/version": dm.manifest_version, - "manifest/base_id": dm.base_id, - "manifest/data_path": str(dm.data_path), - "manifest/tracks_path": str(dm.tracks_path), + # Log collection metadata if using CollectionTripletDataModule + if isinstance(dm, CollectionTripletDataModule): + collection_config = { + "collection/name": dm.collection_name, + "collection/version": dm.collection_version, + "collection/base_id": dm.base_id, + "collection/data_path": str(dm.data_path), + "collection/tracks_path": str(dm.tracks_path), } - trainer.logger.experiment.config.update(manifest_config) + trainer.logger.experiment.config.update(collection_config) - print("\n✓ Manifest metadata logged to W&B:") - print(f" Manifest: {dm.manifest_name} v{dm.manifest_version}") + print("\n✓ Collections metadata logged to W&B:") + print(f" Collections: {dm.collection_name} v{dm.collection_version}") print(f" Data path: {dm.data_path}") print(f" Tracks path: {dm.tracks_path}") diff --git a/viscy/airtable/database.py b/viscy/airtable/database.py index be580a471..105091e0d 100644 --- a/viscy/airtable/database.py +++ b/viscy/airtable/database.py @@ -15,11 +15,11 @@ @dataclass -class ManifestDataset: +class CollectionDataset: """ Dataset paths for one HCS plate/zarr store. - A manifest may contain multiple stores, each returned as a separate ManifestDataset. + A collection may contain multiple stores, each returned as a separate CollectionDataset. """ data_path: str @@ -47,12 +47,12 @@ def validate(self) -> None: @dataclass -class Manifest: - """All datasets for a manifest, potentially across multiple HCS plates.""" +class Collections: + """All datasets for a collection, potentially across multiple HCS plates.""" name: str version: str - datasets: list[ManifestDataset] + datasets: list[CollectionDataset] def __iter__(self): """Iterate over datasets.""" @@ -116,13 +116,13 @@ def validate(self) -> None: class AirtableManager: """ - Unified interface to Airtable for dataset, manifest, and model management. + Unified interface to Airtable for dataset, collection, and model management. Use this to: - Register individual FOVs from HCS plates - - Create and manage dataset manifests (collections of FOVs) - - Track model training on manifests - - Query datasets, manifests, and models + - Create and manage dataset collections (collections of FOVs) + - Track model training on collections + - Query datasets, collections, and models Parameters ---------- @@ -135,9 +135,9 @@ class AirtableManager: -------- >>> airtable_db = AirtableManager(base_id="appXXXXXXXXXXXXXX") >>> - >>> # Create manifest from FOV selection - >>> manifest_id = airtable_db.create_manifest_from_datasets( - ... manifest_name="RPE1_infection_v2", + >>> # Create collection from FOV selection + >>> collection_id = airtable_db.create_collection_from_datasets( + ... collection_name="RPE1_infection_v2", ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], ... version="0.0.1", ... purpose="training" @@ -145,13 +145,13 @@ class AirtableManager: >>> >>> # Track model training >>> airtable_db.log_model_training( - ... manifest_id=manifest_id, + ... collection_id=collection_id, ... mlflow_run_id="run_123", ... model_name="my_model", ... ) >>> - >>> # Get all FOV paths for a manifest - >>> fov_paths = airtable_db.get_manifest_data_paths("RPE1_infection_v2") + >>> # Get all FOV paths for a collection + >>> fov_paths = airtable_db.get_collection_data_paths("RPE1_infection_v2") >>> print(fov_paths) >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] """ @@ -168,7 +168,7 @@ def __init__( self.api = Api(api_key) self.base_id = base_id self.datasets_table = self.api.table(base_id, "Datasets") - self.manifests_table = self.api.table(base_id, "Manifest") + self.collections_table = self.api.table(base_id, "Collections") self.models_table = self.api.table(base_id, "Models") def register_dataset(self, dataset: DatasetRecord) -> str: @@ -243,9 +243,9 @@ def register_datasets(self, datasets: list[DatasetRecord]) -> list[str]: record_ids.append(created["id"]) return record_ids - def create_manifest_from_datasets( + def create_collection_from_datasets( self, - manifest_name: str, + collection_name: str, fov_ids: list[str], version: str, purpose: str = "training", @@ -253,19 +253,19 @@ def create_manifest_from_datasets( description: str | None = None, ) -> str: """ - Create a manifest (collection) from a list of FOV IDs. + Create a collection (collection) from a list of FOV IDs. Parameters ---------- - manifest_name : str - Name for this manifest + collection_name : str + Name for this collection fov_ids : list[str] List of FOV_ID values from Datasets table (e.g., ["plate1_B_3_0", "plate1_B_3_1"]) version : str Semantic version (e.g., "0.0.1", "0.1.0", "1.0.0") purpose : str - Purpose of this manifest ("training", "validation", "test") + Purpose of this collection ("training", "validation", "test") project_name : str | None Project Name (e.g OrganelleBox, DynaCLR, etc.) description : str | None @@ -274,12 +274,12 @@ def create_manifest_from_datasets( Returns ------- str - Airtable manifest record ID + Airtable collection record ID Examples -------- - >>> airtable_db.create_manifest_from_datasets( - ... manifest_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", + >>> airtable_db.create_collection_from_datasets( + ... collection_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", ... project_name="OrganelleBox", ... fov_ids=["2024_11_07_A549_SEC61_DENV_B1_0", "2024_11_07_A549_SEC61_DENV_B1_1"], ... version="0.0.1", @@ -296,30 +296,34 @@ def create_manifest_from_datasets( f"Version must be semantic version format (e.g., '0.0.1', '1.0.0'), got: '{version}'" ) - # Check if manifest with same name + version exists (use DataFrame) - df_manifests = self.list_manifests() + # Check if collection with same name + version exists (use DataFrame) + df_collections = self.list_collections() # Only check for duplicates if table is not empty and has required columns if ( - len(df_manifests) > 0 - and "name" in df_manifests.columns - and "version" in df_manifests.columns + len(df_collections) > 0 + and "name" in df_collections.columns + and "version" in df_collections.columns ): - existing = df_manifests[ - (df_manifests["name"] == manifest_name) - & (df_manifests["version"] == version) + existing = df_collections[ + (df_collections["name"] == collection_name) + & (df_collections["version"] == version) ] if len(existing) > 0: raise ValueError( - f"Manifest '{manifest_name}' version '{version}' already exists. " + f"Collections '{collection_name}' version '{version}' already exists. " f"To create a new version, increment the version number (e.g., '0.0.2')." ) - existing_versions = df_manifests[df_manifests["name"] == manifest_name] + existing_versions = df_collections[ + df_collections["name"] == collection_name + ] if len(existing_versions) > 0: versions = sorted(existing_versions["version"].tolist()) - print(f"ℹ Manifest '{manifest_name}' existing versions: {versions}") + print( + f"ℹ Collections '{collection_name}' existing versions: {versions}" + ) print(f" Creating new version: '{version}'") # Get Airtable record IDs for these FOV IDs (ensure unique) @@ -341,25 +345,25 @@ def create_manifest_from_datasets( # Remove any duplicate record IDs (shouldn't happen, but extra safety) dataset_record_ids = list(dict.fromkeys(dataset_record_ids)) - # Create manifest record - manifest_record = { - "name": manifest_name, + # Create collection record + collection_record = { + "name": collection_name, "datasets": dataset_record_ids, # Linked records (unique) "version": version, # Semantic version (required) "purpose": purpose, "created_by": getpass.getuser(), } if project_name: - manifest_record["project"] = project_name + collection_record["project"] = project_name if description: - manifest_record["description"] = description + collection_record["description"] = description - created = self.manifests_table.create(manifest_record) + created = self.collections_table.create(collection_record) return created["id"] - def create_manifest_from_query( + def create_collection_from_query( self, - manifest_name: str, + collection_name: str, version: str, source_dataset: str | None = None, well_ids: list[str] | None = None, @@ -367,12 +371,12 @@ def create_manifest_from_query( **kwargs, ) -> str: """ - Create a manifest by filtering dataset records with pandas. + Create a collection by filtering dataset records with pandas. Parameters ---------- - manifest_name : str - Name for this manifest + collection_name : str + Name for this collection version : str Semantic version (e.g., "0.0.1") - REQUIRED source_dataset : str | None @@ -382,18 +386,18 @@ def create_manifest_from_query( exclude_fov_ids : list[str] | None FOV_ID values to exclude **kwargs - Additional arguments for create_manifest_from_datasets + Additional arguments for create_collection_from_datasets Returns ------- str - Airtable manifest record ID + Airtable collection record ID Examples -------- - >>> # Create manifest from specific wells in a dataset - >>> airtable_db.create_manifest_from_query( - ... manifest_name="RPE1_infection_training", + >>> # Create collection from specific wells in a dataset + >>> airtable_db.create_collection_from_query( + ... collection_name="RPE1_infection_training", ... version="0.0.1", ... source_dataset="RPE1_plate1", ... well_ids=["B_3", "B_4"], @@ -418,21 +422,21 @@ def create_manifest_from_query( print(f"Found {len(fov_ids)} dataset records matching criteria") - # Create manifest - return self.create_manifest_from_datasets( - manifest_name=manifest_name, version=version, fov_ids=fov_ids, **kwargs + # Create collection + return self.create_collection_from_datasets( + collection_name=collection_name, version=version, fov_ids=fov_ids, **kwargs ) - def get_manifest_data_paths( - self, manifest_name: str, version: str | None = None + def get_collection_data_paths( + self, collection_name: str, version: str | None = None ) -> list[str]: """ - Get list of data paths for a manifest. + Get list of data paths for a collection. Parameters ---------- - manifest_name : str - Manifest name + collection_name : str + Collections name version : str | None Specific version (if None, returns latest) @@ -443,41 +447,43 @@ def get_manifest_data_paths( Examples -------- - >>> paths = airtable_db.get_manifest_data_paths("RPE1_infection_v2") + >>> paths = airtable_db.get_collection_data_paths("RPE1_infection_v2") >>> print(paths) >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] """ - # Get all manifests as DataFrame - df_manifests = self.list_manifests() + # Get all collections as DataFrame + df_collections = self.list_collections() - if len(df_manifests) == 0 or "name" not in df_manifests.columns: - raise ValueError(f"Manifest '{manifest_name}' not found (table is empty)") + if len(df_collections) == 0 or "name" not in df_collections.columns: + raise ValueError( + f"Collections '{collection_name}' not found (table is empty)" + ) # Filter by name - filtered = df_manifests[df_manifests["name"] == manifest_name] + filtered = df_collections[df_collections["name"] == collection_name] if len(filtered) == 0: - raise ValueError(f"Manifest '{manifest_name}' not found") + raise ValueError(f"Collections '{collection_name}' not found") # Filter by version if specified, otherwise get latest if version: - if "version" not in df_manifests.columns: - raise ValueError("Version field not found in Manifest table") + if "version" not in df_collections.columns: + raise ValueError("Version field not found in Collections table") filtered = filtered[filtered["version"] == version] if len(filtered) == 0: raise ValueError( - f"Manifest '{manifest_name}' version '{version}' not found" + f"Collections '{collection_name}' version '{version}' not found" ) else: # Get latest version (sort by created_time if column exists) if "created_time" in filtered.columns: filtered = filtered.sort_values("created_time", ascending=False) - # Get the first (or only) matching manifest - manifest_row = filtered.iloc[0] + # Get the first (or only) matching collection + collection_row = filtered.iloc[0] # Get linked dataset record IDs - dataset_record_ids = manifest_row.get("datasets", []) + dataset_record_ids = collection_row.get("datasets", []) if not dataset_record_ids or len(dataset_record_ids) == 0: return [] @@ -489,64 +495,68 @@ def get_manifest_data_paths( return data_paths - def get_manifest( - self, manifest_name: str, version: str | None = None + def get_collection( + self, collection_name: str, version: str | None = None ) -> dict[str, Any]: """ - Get full manifest information including data paths. + Get full collection information including data paths. Parameters ---------- - manifest_name : str - Manifest name + collection_name : str + Collections name version : str | None Specific version Returns ------- dict - Manifest info with data paths and metadata + Collections info with data paths and metadata """ - # Get all manifests as DataFrame - df_manifests = self.list_manifests() + # Get all collections as DataFrame + df_collections = self.list_collections() - if len(df_manifests) == 0 or "name" not in df_manifests.columns: - raise ValueError(f"Manifest '{manifest_name}' not found (table is empty)") + if len(df_collections) == 0 or "name" not in df_collections.columns: + raise ValueError( + f"Collections '{collection_name}' not found (table is empty)" + ) # Filter by name - filtered = df_manifests[df_manifests["name"] == manifest_name] + filtered = df_collections[df_collections["name"] == collection_name] if len(filtered) == 0: - raise ValueError(f"Manifest '{manifest_name}' not found") + raise ValueError(f"Collections '{collection_name}' not found") # Filter by version if specified, otherwise get latest if version: - if "version" not in df_manifests.columns: - raise ValueError("Version field not found in Manifest table") + if "version" not in df_collections.columns: + raise ValueError("Version field not found in Collections table") filtered = filtered[filtered["version"] == version] if len(filtered) == 0: raise ValueError( - f"Manifest '{manifest_name}' version '{version}' not found" + f"Collections '{collection_name}' version '{version}' not found" ) else: # Get latest version (sort by created_time if column exists) if "created_time" in filtered.columns: filtered = filtered.sort_values("created_time", ascending=False) - # Get the first (or only) matching manifest - manifest_row = filtered.iloc[0] - manifest = manifest_row.to_dict() + # Get the first (or only) matching collection + collection_row = filtered.iloc[0] + collection = collection_row.to_dict() # Add data paths - manifest["data_paths"] = self.get_manifest_data_paths(manifest_name, version) + collection["data_paths"] = self.get_collection_data_paths( + collection_name, version + ) - return manifest + return collection - def list_manifests( + def list_collections( self, purpose: str | None = None, as_dataframe: bool = True ) -> pd.DataFrame | list[dict]: """ - List all manifests. + List all collections. Parameters ---------- @@ -558,19 +568,19 @@ def list_manifests( Returns ------- pd.DataFrame | list[dict] - Manifest records as DataFrame or list of dicts + Collections records as DataFrame or list of dicts Examples -------- - >>> airtable_db.list_manifests(purpose="training") + >>> airtable_db.list_collections(purpose="training") >>> # Returns DataFrame with columns: id, name, version, purpose, ... """ - # Fetch all manifests (try sorting, but don't fail if field doesn't exist) + # Fetch all collections (try sorting, but don't fail if field doesn't exist) try: - records = self.manifests_table.all(sort=["-created_time"]) + records = self.collections_table.all(sort=["-created_time"]) except Exception: # If sort fails (field might not exist), fetch without sorting - records = self.manifests_table.all() + records = self.collections_table.all() data = [{"id": r["id"], **r["fields"]} for r in records] @@ -657,14 +667,14 @@ def list_datasets( return pd.DataFrame(data) return data - def delete_manifest(self, manifest_id: str) -> bool: + def delete_collection(self, collection_id: str) -> bool: """ - Delete a manifest record from Airtable. + Delete a collection record from Airtable. Parameters ---------- - manifest_id : str - Airtable record ID of the manifest to delete + collection_id : str + Airtable record ID of the collection to delete Returns ------- @@ -673,16 +683,16 @@ def delete_manifest(self, manifest_id: str) -> bool: Examples -------- - >>> manifest_id = airtable_db.create_manifest_from_datasets(...) - >>> airtable_db.delete_manifest(manifest_id) - >>> print(f"Deleted manifest: {manifest_id}") + >>> collection_id = airtable_db.create_collection_from_datasets(...) + >>> airtable_db.delete_collection(collection_id) + >>> print(f"Deleted collection: {collection_id}") """ - self.manifests_table.delete(manifest_id) + self.collections_table.delete(collection_id) return True def log_model_training( self, - manifest_id: str, + collection_id: str, wandb_run_id: str, model_name: str | None = None, metrics: dict[str, float] | None = None, @@ -690,14 +700,14 @@ def log_model_training( trained_by: str | None = None, ) -> str: """ - Log that a model was trained using a manifest. + Log that a model was trained using a collection. - Creates entry in Models table and updates Manifest table. + Creates entry in Models table and updates Collections table. Parameters ---------- - manifest_id : str - Airtable record ID of manifest used + collection_id : str + Airtable record ID of collection used wandb_run_id : str W&B run ID for experiment tracking model_name : str | None @@ -716,9 +726,9 @@ def log_model_training( Examples -------- - >>> manifest_id = airtable_db.create_manifest_from_datasets(...) + >>> collection_id = airtable_db.create_collection_from_datasets(...) >>> model_id = airtable_db.log_model_training( - ... manifest_id=manifest_id, + ... collection_id=collection_id, ... wandb_run_id="20260107-152420", ... model_name="contrastive-a549:v1", ... metrics={"val_loss": 0.15}, @@ -728,7 +738,7 @@ def log_model_training( # Create model record model_record = { "model_name": model_name or f"model_{datetime.now():%Y%m%d_%H%M%S}", - "manifest": [manifest_id], # Link to manifest + "collection": [collection_id], # Link to collection "wandb_run_id": wandb_run_id, "trained_date": datetime.now().isoformat(), } @@ -744,9 +754,9 @@ def log_model_training( created = self.models_table.create(model_record) - # Update manifest record to track usage - manifest = self.manifests_table.get(manifest_id) - models_trained_str = manifest["fields"].get("models_trained", "") + # Update collection record to track usage + collection = self.collections_table.get(collection_id) + models_trained_str = collection["fields"].get("models_trained", "") # Handle models_trained as comma-separated string if models_trained_str: @@ -756,23 +766,23 @@ def log_model_training( else: new_models_str = wandb_run_id - self.manifests_table.update( - manifest_id, + self.collections_table.update( + collection_id, {"models_trained": new_models_str, "last_used": datetime.now().isoformat()}, ) return created["id"] - def get_models_for_manifest( - self, manifest_id: str, as_dataframe: bool = True + def get_models_for_collection( + self, collection_id: str, as_dataframe: bool = True ) -> pd.DataFrame | list[dict]: """ - Get all models trained on a specific manifest. + Get all models trained on a specific collection. Parameters ---------- - manifest_id : str - Airtable record ID of manifest + collection_id : str + Airtable record ID of collection as_dataframe : bool If True, return pandas DataFrame. If False, return list of dicts. @@ -783,7 +793,7 @@ def get_models_for_manifest( Examples -------- - >>> models_df = airtable_db.get_models_for_manifest(manifest_id) + >>> models_df = airtable_db.get_models_for_collection(collection_id) >>> print(models_df[["model_name", "mlflow_run_id", "trained_date"]]) """ # Get all models as DataFrame @@ -795,11 +805,11 @@ def get_models_for_manifest( if len(df) == 0: return df - # Filter by manifest_id using pandas - # The 'manifest' field contains a list of linked record IDs + # Filter by collection_id using pandas + # The 'collection' field contains a list of linked record IDs df_filtered = df[ - df["manifest"].apply( - lambda x: manifest_id in x if isinstance(x, list) else False + df["collection"].apply( + lambda x: collection_id in x if isinstance(x, list) else False ) ] @@ -810,7 +820,7 @@ def get_models_for_manifest( return df_filtered else: # Filter list - filtered = [d for d in data if manifest_id in d.get("manifest", [])] + filtered = [d for d in data if collection_id in d.get("collection", [])] return filtered def list_models(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: @@ -845,41 +855,41 @@ def list_models(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: def get_dataset_paths( self, - manifest_name: str, + collection_name: str, version: str, - ) -> Manifest: + ) -> Collections: """ - Get zarr store paths and FOV names for a manifest. + Get zarr store paths and FOV names for a collection. Parameters ---------- - manifest_name : str - Name of the manifest + collection_name : str + Name of the collection version : str - Semantic version of the manifest + Semantic version of the collection Returns ------- - Manifest - Manifest object containing list of ManifestDataset (one per HCS plate) + Collections + Collections object containing list of CollectionDataset (one per HCS plate) Examples -------- - >>> manifest = airtable_db.get_dataset_paths("my_manifest", "0.0.1") - >>> print(f"{manifest.name} v{manifest.version}: {manifest.total_fovs} FOVs") + >>> collection = airtable_db.get_dataset_paths("my_collection", "0.0.1") + >>> print(f"{collection.name} v{collection.version}: {collection.total_fovs} FOVs") >>> # Use with TripletDataModule - >>> for ds in manifest: + >>> for ds in collection: ... data_module = TripletDataModule( ... data_path=ds.data_path, ... tracks_path=ds.tracks_path, ... include_fov_names=ds.fov_names, ... ) """ - # Get manifest record IDs - dataset_record_ids = self._get_manifest_dataset_ids(manifest_name, version) + # Get collection record IDs + dataset_record_ids = self._get_collection_dataset_ids(collection_name, version) if not dataset_record_ids: - return Manifest(name=manifest_name, version=version, datasets=[]) + return Collections(name=collection_name, version=version, datasets=[]) dataset_records = [ self.datasets_table.get(dataset_id)["fields"] @@ -896,7 +906,7 @@ def get_dataset_paths( stores[data_path].append(fov_name) datasets = [ - ManifestDataset( + CollectionDataset( data_path=data_path, tracks_path=self._derive_tracks_path(data_path), fov_names=natsorted(fov_names), @@ -904,7 +914,7 @@ def get_dataset_paths( for data_path, fov_names in stores.items() ] - return Manifest(name=manifest_name, version=version, datasets=datasets) + return Collections(name=collection_name, version=version, datasets=datasets) @staticmethod def _derive_tracks_path(data_path: str) -> str: @@ -924,28 +934,32 @@ def _derive_tracks_path(data_path: str) -> str: tracks_path = tracks_path[:-5] + "_cropped.zarr" return tracks_path - def _get_manifest_dataset_ids(self, manifest_name: str, version: str) -> list[str]: - """Get linked dataset record IDs for a manifest.""" - df_manifests = self.list_manifests() + def _get_collection_dataset_ids( + self, collection_name: str, version: str + ) -> list[str]: + """Get linked dataset record IDs for a collection.""" + df_collections = self.list_collections() - if len(df_manifests) == 0 or "name" not in df_manifests.columns: - raise ValueError(f"Manifest '{manifest_name}' not found (table is empty)") + if len(df_collections) == 0 or "name" not in df_collections.columns: + raise ValueError( + f"Collections '{collection_name}' not found (table is empty)" + ) - filtered = df_manifests[df_manifests["name"] == manifest_name] + filtered = df_collections[df_collections["name"] == collection_name] if len(filtered) == 0: - raise ValueError(f"Manifest '{manifest_name}' not found") + raise ValueError(f"Collections '{collection_name}' not found") - if "version" not in df_manifests.columns: - raise ValueError("Version field not found in Manifest table") + if "version" not in df_collections.columns: + raise ValueError("Version field not found in Collections table") filtered = filtered[filtered["version"] == version] if len(filtered) == 0: raise ValueError( - f"Manifest '{manifest_name}' version '{version}' not found" + f"Collections '{collection_name}' version '{version}' not found" ) - manifest_row = filtered.iloc[0] - dataset_record_ids = manifest_row.get("datasets", []) + collection_row = filtered.iloc[0] + dataset_record_ids = collection_row.get("datasets", []) if not dataset_record_ids or len(dataset_record_ids) == 0: return [] diff --git a/viscy/airtable/factory.py b/viscy/airtable/factory.py index 5327467d2..db9bf6a07 100644 --- a/viscy/airtable/factory.py +++ b/viscy/airtable/factory.py @@ -1,4 +1,4 @@ -"""Factory functions for creating data modules from Airtable manifests.""" +"""Factory functions for creating data modules from Airtable collections.""" import os from typing import Literal, Sequence @@ -6,7 +6,7 @@ from lightning.pytorch import LightningDataModule from monai.transforms import MapTransform -from viscy.airtable.database import AirtableManager, Manifest, ManifestDataset +from viscy.airtable.database import AirtableManager, CollectionDataset, Collections from viscy.data.combined import BatchedConcatDataModule, CachedConcatDataModule from viscy.data.triplet import TripletDataModule @@ -47,8 +47,8 @@ def _extract_wells_from_fov_names(fov_names: list[str]) -> list[str]: return sorted(list(wells)) -def create_triplet_datamodule_from_manifest( - manifest: Manifest | ManifestDataset, +def create_triplet_datamodule_from_collection( + collection: Collections | CollectionDataset, source_channel: str | Sequence[str], z_range: tuple[int, int], *, @@ -76,7 +76,7 @@ def create_triplet_datamodule_from_manifest( use_cached_concat: bool = False, ) -> LightningDataModule: """ - Create TripletDataModule(s) from Airtable manifest. + Create TripletDataModule(s) from Airtable collection. Automatically handles single or multiple HCS plates: - Single plate: Returns TripletDataModule @@ -84,8 +84,8 @@ def create_triplet_datamodule_from_manifest( Parameters ---------- - manifest : Manifest | ManifestDataset - Manifest from AirtableManager.get_dataset_paths() or single ManifestDataset + collection : Collections | CollectionDataset + Collections from AirtableManager.get_dataset_paths() or single CollectionDataset source_channel : str | Sequence[str] Input channel name(s) - REQUIRED z_range : tuple[int, int] @@ -109,10 +109,10 @@ def create_triplet_datamodule_from_manifest( caching : bool Cache dataset, default False fit_include_wells : list[str] | None - Override manifest FOVs with specific wells (e.g., ["B/3", "C/4"]). - Takes precedence over manifest.fov_names + Override collection FOVs with specific wells (e.g., ["B/3", "C/4"]). + Takes precedence over collection.fov_names fit_exclude_fovs : list[str] | None - Exclude specific FOV paths from manifest + Exclude specific FOV paths from collection predict_cells : bool Predict on specific cells only, default False include_fov_names : list[str] | None @@ -135,7 +135,7 @@ def create_triplet_datamodule_from_manifest( Tensorstore cache pool size in bytes, default 0 use_cached_concat : bool Use CachedConcatDataModule instead of BatchedConcatDataModule - for multi-plate manifests, default False + for multi-plate collections, default False Returns ------- @@ -147,26 +147,26 @@ def create_triplet_datamodule_from_manifest( Raises ------ ValueError - - If manifest has no datasets + - If collection has no datasets - If paths don't exist (validation fails) - - If fit_include_wells and manifest both specify FOVs (ambiguous) + - If fit_include_wells and collection both specify FOVs (ambiguous) FileNotFoundError If data_path or tracks_path don't exist TypeError - If manifest is not Manifest or ManifestDataset + If collection is not Collections or CollectionDataset Examples -------- - Basic usage with single-plate manifest: + Basic usage with single-plate collection: >>> from viscy.airtable.database import AirtableManager - >>> from viscy.airtable.factory import create_triplet_datamodule_from_manifest + >>> from viscy.airtable.factory import create_triplet_datamodule_from_collection >>> >>> airtable_db = AirtableManager(base_id="appXXXXXXXXXXXXXX") - >>> manifest = airtable_db.get_dataset_paths("my_manifest", "0.0.1") + >>> collection = airtable_db.get_dataset_paths("my_collection", "0.0.1") >>> - >>> dm = create_triplet_datamodule_from_manifest( - ... manifest=manifest, + >>> dm = create_triplet_datamodule_from_collection( + ... collection=collection, ... source_channel=["Phase3D"], ... z_range=(0, 5), ... batch_size=32, @@ -176,15 +176,15 @@ def create_triplet_datamodule_from_manifest( >>> # Use with PyTorch Lightning >>> trainer.fit(model, dm) - Multi-plate manifest with normalization: + Multi-plate collection with normalization: >>> from viscy.transforms import NormalizeSampled >>> - >>> manifest = airtable_db.get_dataset_paths("multi_plate_manifest", "1.0.0") - >>> print(f"Manifest has {len(manifest)} plates") # e.g., 3 plates + >>> collection = airtable_db.get_dataset_paths("multi_plate_collection", "1.0.0") + >>> print(f"Collections has {len(collection)} plates") # e.g., 3 plates >>> - >>> dm = create_triplet_datamodule_from_manifest( - ... manifest=manifest, + >>> dm = create_triplet_datamodule_from_collection( + ... collection=collection, ... source_channel=["Phase3D", "RFP"], ... z_range=(0, 10), ... normalizations=[ @@ -200,21 +200,21 @@ def create_triplet_datamodule_from_manifest( ... ) >>> # Returns BatchedConcatDataModule wrapping 3 TripletDataModules - Override manifest FOVs with specific wells: + Override collection FOVs with specific wells: - >>> dm = create_triplet_datamodule_from_manifest( - ... manifest=manifest, + >>> dm = create_triplet_datamodule_from_collection( + ... collection=collection, ... source_channel=["Phase3D"], ... z_range=(0, 5), - ... fit_include_wells=["B/3", "B/4"], # Override manifest FOVs + ... fit_include_wells=["B/3", "B/4"], # Override collection FOVs ... batch_size=16, ... ) - Using a single ManifestDataset directly: + Using a single CollectionDataset directly: - >>> ds = manifest.datasets[0] # Single plate - >>> dm = create_triplet_datamodule_from_manifest( - ... manifest=ds, # Pass ManifestDataset directly + >>> ds = collection.datasets[0] # Single plate + >>> dm = create_triplet_datamodule_from_collection( + ... collection=ds, # Pass CollectionDataset directly ... source_channel=["Phase3D"], ... z_range=(0, 5), ... batch_size=16, @@ -222,24 +222,24 @@ def create_triplet_datamodule_from_manifest( Notes ----- - - FOV filtering priority: fit_include_wells > manifest.fov_names + - FOV filtering priority: fit_include_wells > collection.fov_names - If both specified, raises ValueError to avoid ambiguity - The factory validates paths before creating data modules - Multi-plate handling uses BatchedConcatDataModule for efficient batching - All TripletDataModule parameters can be passed through kwargs """ - # STEP 1: Normalize input - handle both Manifest and ManifestDataset - if isinstance(manifest, ManifestDataset): - datasets = [manifest] - manifest_name = "single_dataset" - elif isinstance(manifest, Manifest): - if len(manifest.datasets) == 0: - raise ValueError(f"Manifest '{manifest.name}' has no datasets") - datasets = manifest.datasets - manifest_name = manifest.name + # STEP 1: Normalize input - handle both Collections and CollectionDataset + if isinstance(collection, CollectionDataset): + datasets = [collection] + collection_name = "single_dataset" + elif isinstance(collection, Collections): + if len(collection.datasets) == 0: + raise ValueError(f"Collections '{collection.name}' has no datasets") + datasets = collection.datasets + collection_name = collection.name else: raise TypeError( - f"Expected Manifest or ManifestDataset, got {type(manifest).__name__}" + f"Expected Collections or CollectionDataset, got {type(collection).__name__}" ) # STEP 2: Validate all paths exist (fail early) @@ -247,20 +247,20 @@ def create_triplet_datamodule_from_manifest( try: ds.validate() except FileNotFoundError as e: - raise FileNotFoundError(f"Manifest '{manifest_name}' dataset {i}: {e}") + raise FileNotFoundError(f"Collections '{collection_name}' dataset {i}: {e}") # STEP 3: Handle FOV filtering logic # Check for ambiguous FOV specification - has_manifest_fovs = any(len(ds.fov_names) > 0 for ds in datasets) + has_collection_fovs = any(len(ds.fov_names) > 0 for ds in datasets) - if fit_include_wells is not None and has_manifest_fovs: - # Ambiguous: both manifest and user specified FOVs + if fit_include_wells is not None and has_collection_fovs: + # Ambiguous: both collection and user specified FOVs raise ValueError( - "Cannot specify both 'fit_include_wells' and use manifest FOV filtering. " - "The manifest already specifies FOVs to include. " + "Cannot specify both 'fit_include_wells' and use collection FOV filtering. " + "The collection already specifies FOVs to include. " "Either:\n" - " 1. Use fit_include_wells=None to respect manifest FOVs, OR\n" - " 2. Create a new manifest without FOV filtering if you want custom wells" + " 1. Use fit_include_wells=None to respect collection FOVs, OR\n" + " 2. Create a new collection without FOV filtering if you want custom wells" ) # STEP 4: Ensure normalizations and augmentations are lists @@ -276,7 +276,7 @@ def create_triplet_datamodule_from_manifest( # User override: use explicit wells include_wells = fit_include_wells elif len(ds.fov_names) > 0: - # Convert manifest FOV names to well IDs + # Convert collection FOV names to well IDs include_wells = _extract_wells_from_fov_names(ds.fov_names) else: # No filtering: use all wells @@ -324,22 +324,22 @@ def create_triplet_datamodule_from_manifest( return BatchedConcatDataModule(data_modules=data_modules) -class ManifestTripletDataModule(TripletDataModule): +class CollectionTripletDataModule(TripletDataModule): """ - TripletDataModule that fetches paths from Airtable manifests. + TripletDataModule that fetches paths from Airtable collections. This class is designed to work with PyTorch Lightning CLI and config files. - It extends TripletDataModule to accept Airtable manifest parameters instead + It extends TripletDataModule to accept Airtable collection parameters instead of explicit data_path and tracks_path. Parameters ---------- base_id : str Airtable base ID - manifest_name : str - Name of the manifest in Airtable - manifest_version : str - Semantic version of the manifest (e.g., "0.0.1") + collection_name : str + Name of the collection in Airtable + collection_version : str + Semantic version of the collection (e.g., "0.0.1") source_channel : str | Sequence[str] Input channel name(s) z_range : tuple[int, int] @@ -352,7 +352,7 @@ class ManifestTripletDataModule(TripletDataModule): Raises ------ ValueError - If manifest has multiple datasets (only single-plate manifests supported) + If collection has multiple datasets (only single-plate collections supported) Examples -------- @@ -360,11 +360,11 @@ class ManifestTripletDataModule(TripletDataModule): ```yaml data: - class_path: viscy.airtable.factory.ManifestTripletDataModule + class_path: viscy.airtable.factory.CollectionTripletDataModule init_args: base_id: "appXXXXXXXXXXXXXX" - manifest_name: "my_manifest" - manifest_version: "0.0.1" + collection_name: "my_collection" + collection_version: "0.0.1" source_channel: [Phase] z_range: [0, 5] batch_size: 16 @@ -385,10 +385,10 @@ class ManifestTripletDataModule(TripletDataModule): Direct usage in Python: ```python - dm = ManifestTripletDataModule( + dm = CollectionTripletDataModule( base_id="appXXXXXXXXXXXXXX", - manifest_name="my_manifest", - manifest_version="0.0.1", + collection_name="my_collection", + collection_version="0.0.1", source_channel=["Phase"], z_range=(0, 5), batch_size=16, @@ -398,47 +398,47 @@ class ManifestTripletDataModule(TripletDataModule): Notes ----- - - Only supports single-plate manifests (use create_triplet_datamodule_from_manifest + - Only supports single-plate collections (use create_triplet_datamodule_from_collection for multi-plate support with BatchedConcatDataModule) - - Fetches manifest from Airtable during __init__ + - Fetches collection from Airtable during __init__ - All TripletDataModule parameters are available - - FOV filtering from manifest is automatically applied via fit_include_wells + - FOV filtering from collection is automatically applied via fit_include_wells """ def __init__( self, base_id: str, - manifest_name: str, - manifest_version: str, + collection_name: str, + collection_version: str, source_channel: str | Sequence[str], z_range: tuple[int, int], api_key: str | None = None, fit_include_wells: list[str] | None = None, **kwargs, ): - # Fetch manifest from Airtable + # Fetch collection from Airtable airtable_db = AirtableManager( base_id=base_id, api_key=api_key or os.getenv("AIRTABLE_API_KEY") ) - manifest = airtable_db.get_dataset_paths( - manifest_name=manifest_name, - version=manifest_version, + collection = airtable_db.get_dataset_paths( + collection_name=collection_name, + version=collection_version, ) # Validate single plate - if len(manifest.datasets) != 1: + if len(collection.datasets) != 1: raise ValueError( - f"ManifestTripletDataModule only supports single-plate manifests. " - f"Manifest '{manifest_name}' has {len(manifest.datasets)} plates. " - f"Use create_triplet_datamodule_from_manifest() for multi-plate support." + f"CollectionTripletDataModule only supports single-plate collections. " + f"Collections '{collection_name}' has {len(collection.datasets)} plates. " + f"Use create_triplet_datamodule_from_collection() for multi-plate support." ) - dataset = manifest.datasets[0] + dataset = collection.datasets[0] - # Store manifest metadata as instance attributes for callbacks/logging + # Store collection metadata as instance attributes for callbacks/logging self.base_id = base_id - self.manifest_name = manifest_name - self.manifest_version = manifest_version + self.collection_name = collection_name + self.collection_version = collection_version self.data_path = dataset.data_path self.tracks_path = dataset.tracks_path @@ -447,7 +447,7 @@ def __init__( # User override: use explicit wells include_wells = fit_include_wells elif len(dataset.fov_names) > 0: - # Convert manifest FOV names to well IDs + # Convert collection FOV names to well IDs include_wells = _extract_wells_from_fov_names(dataset.fov_names) else: # No filtering: use all wells diff --git a/viscy/airtable/schemas.py b/viscy/airtable/schemas.py index b5340b971..57bfbe165 100644 --- a/viscy/airtable/schemas.py +++ b/viscy/airtable/schemas.py @@ -1,6 +1,6 @@ -"""Pydantic schemas for Airtable records.""" +"""Pydantic schemas for Airtable records and model registry.""" -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, Field, field_validator @@ -168,3 +168,116 @@ def __repr__(self) -> str: f"DatasetRecord(dataset_name={self.dataset_name!r}, well_id={self.well_id!r}, " f"fov_name={self.fov_name!r}, fov_id={self.fov_id!r})" ) + + +class ModelRecord(BaseModel): + """ + Pydantic model for a registered model artifact. + + This represents a registered model in WandB with metadata about + training, architecture, and file locations. + """ + + # Core identification + model_name: str = Field(..., description="Model identifier") + model_type: Literal["contrastive", "segmentation", "vae", "translation"] = Field( + ..., description="Model category" + ) + version: str = Field(..., description="Version (e.g., 'v1', 'v2')") + + # File locations + checkpoint_path: str = Field(..., description="Absolute path to .ckpt file") + config_artifact_path: str | None = Field( + None, description="Path to downloaded config_fit.yml from artifact" + ) + + # Architecture (for quick reference) + architecture: str | None = Field( + None, description="Model architecture (e.g., '2.5D')" + ) + backbone: str | None = Field(None, description="Backbone (e.g., 'convnext_tiny')") + + # Training lineage + collection_name: str | None = Field(None, description="Training data collection") + wandb_run_id: str | None = Field(None, description="Training run ID") + wandb_artifact_url: str | None = Field(None, description="WandB artifact URL") + + # Metadata + description: str | None = Field(None, description="Model description") + trained_date: str | None = Field(None, description="Training date") + + model_config = { + "populate_by_name": True, + "str_strip_whitespace": True, + } + + @field_validator("checkpoint_path") + @classmethod + def validate_checkpoint_path(cls, v: str) -> str: + """Validate checkpoint path.""" + if not v.startswith("/"): + raise ValueError("Checkpoint path must be absolute") + if not v.endswith(".ckpt"): + raise ValueError("Checkpoint must be .ckpt file") + return v + + @field_validator("version") + @classmethod + def validate_version(cls, v: str) -> str: + """Validate version format.""" + if not v or not v.strip(): + raise ValueError("Version cannot be empty") + return v.strip() + + def to_wandb_metadata(self) -> dict[str, Any]: + """ + Convert to WandB artifact metadata dictionary. + + Returns + ------- + dict + Dictionary suitable for wandb.Artifact(metadata=...) + """ + return self.model_dump( + exclude_none=True, + exclude={"config_artifact_path"}, # Don't include temp paths + ) + + @classmethod + def from_wandb_artifact(cls, artifact: Any) -> "ModelRecord": + """ + Create ModelRecord from WandB artifact. + + Parameters + ---------- + artifact : wandb.Artifact + WandB artifact object + + Returns + ------- + ModelRecord + Parsed model record + """ + metadata = artifact.metadata.copy() + + # Add artifact-level fields + metadata["model_name"] = artifact.name + if hasattr(artifact, "version"): + metadata["version"] = artifact.version + if hasattr(artifact, "description") and artifact.description: + metadata["description"] = artifact.description + if hasattr(artifact, "url"): + metadata["wandb_artifact_url"] = artifact.url + + return cls(**metadata) + + def __str__(self) -> str: + """Human-readable string representation.""" + return f"Model({self.model_name}:{self.version} @ {self.checkpoint_path})" + + def __repr__(self) -> str: + """Detailed string representation.""" + return ( + f"ModelRecord(model_name={self.model_name!r}, model_type={self.model_type!r}, " + f"version={self.version!r})" + ) From 1b53df40cf95839edbaf1a607e9c236190071492 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 12 Jan 2026 16:39:04 -0800 Subject: [PATCH 17/18] test airtable database --- tests/airtable/__init__.py | 1 + tests/airtable/test_database.py | 497 ++++++++++++++++++++++++++++++++ 2 files changed, 498 insertions(+) create mode 100644 tests/airtable/__init__.py create mode 100644 tests/airtable/test_database.py diff --git a/tests/airtable/__init__.py b/tests/airtable/__init__.py new file mode 100644 index 000000000..bb1c48247 --- /dev/null +++ b/tests/airtable/__init__.py @@ -0,0 +1 @@ +"""Tests for Airtable integration module.""" diff --git a/tests/airtable/test_database.py b/tests/airtable/test_database.py new file mode 100644 index 000000000..e277fbe0b --- /dev/null +++ b/tests/airtable/test_database.py @@ -0,0 +1,497 @@ +"""Tests for Airtable database module using mocks.""" + +from unittest.mock import Mock, patch + +import pandas as pd +import pytest + +from viscy.airtable.database import AirtableManager, CollectionDataset, Collections +from viscy.airtable.schemas import DatasetRecord + +# ============================================================================ +# Tests for CollectionDataset dataclass +# ============================================================================ + + +def test_collection_dataset_creation(): + """Test creating a CollectionDataset.""" + dataset = CollectionDataset( + data_path="/hpc/data/plate.zarr", + tracks_path="/hpc/tracks/plate.zarr", + fov_names=["B/3/0", "B/3/1", "B/4/0"], + ) + + assert dataset.data_path == "/hpc/data/plate.zarr" + assert dataset.tracks_path == "/hpc/tracks/plate.zarr" + assert len(dataset) == 3 + assert len(dataset.fov_names) == 3 + + +def test_collection_dataset_fov_paths(): + """Test generating FOV paths.""" + dataset = CollectionDataset( + data_path="/hpc/data/plate.zarr", + tracks_path="/hpc/tracks/plate.zarr", + fov_names=["B/3/0", "B/3/1"], + ) + + fov_paths = dataset.fov_paths + assert fov_paths == [ + "/hpc/data/plate.zarr/B/3/0", + "/hpc/data/plate.zarr/B/3/1", + ] + + +def test_collection_dataset_exists(tmp_path): + """Test checking if paths exist.""" + # Create actual directories + data_path = tmp_path / "data.zarr" + tracks_path = tmp_path / "tracks.zarr" + data_path.mkdir() + tracks_path.mkdir() + + dataset = CollectionDataset( + data_path=str(data_path), + tracks_path=str(tracks_path), + fov_names=["0", "1"], + ) + + assert dataset.exists() is True + + # Test with non-existent paths + dataset_bad = CollectionDataset( + data_path="/nonexistent/data.zarr", + tracks_path="/nonexistent/tracks.zarr", + fov_names=["0"], + ) + + assert dataset_bad.exists() is False + + +def test_collection_dataset_validate(tmp_path): + """Test validation raises error for missing paths.""" + data_path = tmp_path / "data.zarr" + data_path.mkdir() + + # Missing tracks path + dataset = CollectionDataset( + data_path=str(data_path), + tracks_path="/nonexistent/tracks.zarr", + fov_names=["0"], + ) + + with pytest.raises(FileNotFoundError, match="Tracks path not found"): + dataset.validate() + + +# ============================================================================ +# Tests for Collections dataclass +# ============================================================================ + + +def test_collections_creation(): + """Test creating a Collections object.""" + datasets = [ + CollectionDataset("/data1.zarr", "/tracks1.zarr", ["0", "1"]), + CollectionDataset("/data2.zarr", "/tracks2.zarr", ["0", "1", "2"]), + ] + + collection = Collections( + name="test_collection", + version="0.0.1", + datasets=datasets, + ) + + assert collection.name == "test_collection" + assert collection.version == "0.0.1" + assert len(collection) == 2 + assert collection.total_fovs == 5 # 2 + 3 + + +def test_collections_iteration(): + """Test iterating over datasets in a collection.""" + datasets = [ + CollectionDataset("/data1.zarr", "/tracks1.zarr", ["0"]), + CollectionDataset("/data2.zarr", "/tracks2.zarr", ["0", "1"]), + ] + + collection = Collections("test", "0.0.1", datasets) + + dataset_list = list(collection) + assert len(dataset_list) == 2 + assert dataset_list[0].data_path == "/data1.zarr" + assert dataset_list[1].data_path == "/data2.zarr" + + +def test_collections_validate(tmp_path): + """Test validation checks all dataset paths.""" + # Create valid paths for first dataset + data1 = tmp_path / "data1.zarr" + tracks1 = tmp_path / "tracks1.zarr" + data1.mkdir() + tracks1.mkdir() + + datasets = [ + CollectionDataset(str(data1), str(tracks1), ["0"]), + CollectionDataset("/bad/data.zarr", "/bad/tracks.zarr", ["0"]), # Invalid + ] + + collection = Collections("test", "0.0.1", datasets) + + with pytest.raises(FileNotFoundError, match="Data path not found"): + collection.validate() + + +# ============================================================================ +# Tests for AirtableManager +# ============================================================================ + + +@pytest.fixture +def mock_airtable_api(): + """Create a mock Airtable API.""" + api = Mock() + api.table.return_value = Mock() + return api + + +@pytest.fixture +def airtable_manager(mock_airtable_api): + """Create an AirtableManager with mocked API.""" + with patch("viscy.airtable.database.Api", return_value=mock_airtable_api): + with patch.dict("os.environ", {"AIRTABLE_API_KEY": "test_key"}): + manager = AirtableManager(base_id="test_base_id") + return manager + + +def test_airtable_manager_init_requires_api_key(): + """Test that AirtableManager requires an API key.""" + with patch.dict("os.environ", {}, clear=True): + # Remove AIRTABLE_API_KEY from env + with pytest.raises(ValueError, match="Airtable API key required"): + AirtableManager(base_id="test_base") + + +def test_airtable_manager_init_with_explicit_key(): + """Test initializing with explicit API key.""" + with patch("viscy.airtable.database.Api") as mock_api_class: + AirtableManager(base_id="test_base", api_key="explicit_key") + + # Verify Api was called with the explicit key + mock_api_class.assert_called_once_with("explicit_key") + + +def test_airtable_manager_init_with_env_key(): + """Test initializing with API key from environment.""" + with patch("viscy.airtable.database.Api") as mock_api_class: + with patch.dict("os.environ", {"AIRTABLE_API_KEY": "env_key"}): + AirtableManager(base_id="test_base") + + mock_api_class.assert_called_once_with("env_key") + + +def test_register_dataset(airtable_manager): + """Test registering a single dataset.""" + dataset = DatasetRecord( + dataset_name="test_plate", + well_id="B_3", + fov_name="0", + data_path="/hpc/data/test.zarr/B/3/0", + ) + + # Mock the create response + airtable_manager.datasets_table.create.return_value = {"id": "rec123"} + + record_id = airtable_manager.register_dataset(dataset) + + assert record_id == "rec123" + airtable_manager.datasets_table.create.assert_called_once() + + # Verify the data passed to Airtable + call_args = airtable_manager.datasets_table.create.call_args[0][0] + assert call_args["Dataset"] == "test_plate" + assert call_args["Well ID"] == "B_3" + assert call_args["FOV"] == "0" + + +def test_register_datasets_multiple(airtable_manager): + """Test registering multiple datasets.""" + datasets = [ + DatasetRecord( + dataset_name="plate", + well_id="B_3", + fov_name="0", + data_path="/hpc/data/plate.zarr/B/3/0", + ), + DatasetRecord( + dataset_name="plate", + well_id="B_3", + fov_name="1", + data_path="/hpc/data/plate.zarr/B/3/1", + ), + ] + + # Mock create responses + airtable_manager.datasets_table.create.side_effect = [ + {"id": "rec123"}, + {"id": "rec456"}, + ] + + record_ids = airtable_manager.register_datasets(datasets) + + assert record_ids == ["rec123", "rec456"] + assert airtable_manager.datasets_table.create.call_count == 2 + + +def test_create_collection_from_datasets(airtable_manager): + """Test creating a collection from dataset FOV IDs.""" + # Mock list_collections to return empty (no duplicates) + with patch.object( + airtable_manager, "list_collections", return_value=pd.DataFrame() + ): + # Mock the datasets_table.all() response for FOV lookup + def mock_all(formula=None): + if formula == "{FOV_ID}='plate_B_3_0'": + return [{"id": "rec1", "fields": {"FOV_ID": "plate_B_3_0"}}] + elif formula == "{FOV_ID}='plate_B_3_1'": + return [{"id": "rec2", "fields": {"FOV_ID": "plate_B_3_1"}}] + return [] + + airtable_manager.datasets_table.all.side_effect = mock_all + + # Mock the collections_table.create response + airtable_manager.collections_table.create.return_value = {"id": "col123"} + + collection_id = airtable_manager.create_collection_from_datasets( + collection_name="test_collection", + fov_ids=["plate_B_3_0", "plate_B_3_1"], + version="0.0.1", + purpose="training", + ) + + assert collection_id == "col123" + + # Verify the collection was created with correct data + airtable_manager.collections_table.create.assert_called_once() + call_args = airtable_manager.collections_table.create.call_args[0][0] + + assert call_args["name"] == "test_collection" + assert call_args["version"] == "0.0.1" + assert call_args["purpose"] == "training" + assert call_args["datasets"] == ["rec1", "rec2"] # Linked record IDs + + +def test_create_collection_validates_version_format(airtable_manager): + """Test that collection creation validates version format.""" + with pytest.raises(ValueError, match="semantic version format"): + airtable_manager.create_collection_from_datasets( + collection_name="test", + fov_ids=["fov1"], + version="invalid_version", # Should be like "0.0.1" + ) + + +def test_list_datasets_returns_pydantic_models(airtable_manager): + """Test listing datasets returns Pydantic models.""" + # Mock the all() response + airtable_manager.datasets_table.all.return_value = [ + { + "id": "rec123", + "fields": { + "Dataset": "test_plate", + "Well ID": "B_3", + "FOV": "0", + "Data path": "/hpc/data/test.zarr/B/3/0", + }, + } + ] + + datasets = airtable_manager.list_datasets(as_pydantic=True) + + assert len(datasets) == 1 + assert isinstance(datasets[0], DatasetRecord) + assert datasets[0].dataset_name == "test_plate" + assert datasets[0].well_id == "B_3" + + +def test_list_datasets_returns_dicts(airtable_manager): + """Test listing datasets can return raw dictionaries.""" + airtable_manager.datasets_table.all.return_value = [ + { + "id": "rec123", + "fields": {"Dataset": "test_plate"}, + } + ] + + datasets = airtable_manager.list_datasets(as_dataframe=False, as_pydantic=False) + + assert len(datasets) == 1 + assert isinstance(datasets[0], dict) + assert datasets[0]["Dataset"] == "test_plate" + + +def test_get_collection_data_paths(airtable_manager): + """Test getting data paths from a collection.""" + # Mock list_collections to return collection info + collections_df = pd.DataFrame( + [ + { + "id": "col123", + "name": "test_collection", + "version": "0.0.1", + "datasets": ["rec1", "rec2"], + } + ] + ) + + with patch.object( + airtable_manager, "list_collections", return_value=collections_df + ): + # Mock datasets_table.get() for individual record fetches + def mock_get(record_id): + if record_id == "rec1": + return { + "id": "rec1", + "fields": { + "Dataset": "plate", + "Data path": "/hpc/data/plate.zarr", + "Well ID": "B_3", + "FOV": "0", + }, + } + elif record_id == "rec2": + return { + "id": "rec2", + "fields": { + "Dataset": "plate", + "Data path": "/hpc/data/plate.zarr", + "Well ID": "B_3", + "FOV": "1", + }, + } + return None + + airtable_manager.datasets_table.get.side_effect = mock_get + + collection = airtable_manager.get_dataset_paths( + collection_name="test_collection", version="0.0.1" + ) + + assert isinstance(collection, Collections) + assert collection.name == "test_collection" + assert collection.version == "0.0.1" + assert collection.total_fovs == 2 + + +def test_log_model_training(airtable_manager): + """Test logging model training to Models table.""" + airtable_manager.models_table.create.return_value = {"id": "model123"} + + # Mock collections_table.get() for update operation + airtable_manager.collections_table.get.return_value = { + "id": "col123", + "fields": {"models_trained": ""}, + } + + model_id = airtable_manager.log_model_training( + collection_id="col123", + wandb_run_id="run456", + model_name="contrastive-v1", + checkpoint_path="/hpc/models/model.ckpt", + trained_by="test_user", + metrics={"val_loss": 0.15}, + ) + + assert model_id == "model123" + + # Verify the model record was created correctly + airtable_manager.models_table.create.assert_called_once() + call_args = airtable_manager.models_table.create.call_args[0][0] + + assert call_args["collection"] == ["col123"] + assert call_args["wandb_run_id"] == "run456" # Fixed: was mlflow_run_id + assert call_args["model_name"] == "contrastive-v1" + assert call_args["checkpoint_path"] == "/hpc/models/model.ckpt" + assert call_args["val_loss"] == 0.15 + + +def test_list_collections(airtable_manager): + """Test listing all collections.""" + airtable_manager.collections_table.all.return_value = [ + { + "id": "col1", + "fields": { + "name": "collection_1", + "version": "0.0.1", + "purpose": "training", + }, + }, + { + "id": "col2", + "fields": { + "name": "collection_2", + "version": "0.0.2", + "purpose": "validation", + }, + }, + ] + + # Test returning DataFrame (default) + collections_df = airtable_manager.list_collections() + + assert isinstance(collections_df, pd.DataFrame) + assert len(collections_df) == 2 + assert collections_df.iloc[0]["name"] == "collection_1" + assert collections_df.iloc[1]["purpose"] == "validation" + + # Test returning list of dicts + collections_list = airtable_manager.list_collections(as_dataframe=False) + + assert isinstance(collections_list, list) + assert len(collections_list) == 2 + assert collections_list[0]["name"] == "collection_1" + assert collections_list[1]["purpose"] == "validation" + + +def test_get_models_for_collection(airtable_manager): + """Test getting all models trained on a specific collection.""" + airtable_manager.models_table.all.return_value = [ + { + "id": "model1", + "fields": { + "model_name": "model-v1", + "collection": ["col123"], + "trained_date": "2026-01-12", + }, + }, + { + "id": "model2", + "fields": { + "model_name": "model-v2", + "collection": ["col123"], + "trained_date": "2026-01-13", + }, + }, + ] + + # Test returning DataFrame (default, sorted by trained_date descending) + models_df = airtable_manager.get_models_for_collection(collection_id="col123") + + assert isinstance(models_df, pd.DataFrame) + assert len(models_df) == 2 + # Latest model comes first due to descending sort + assert models_df.iloc[0]["model_name"] == "model-v2" + assert models_df.iloc[0]["trained_date"] == "2026-01-13" + assert models_df.iloc[1]["model_name"] == "model-v1" + assert models_df.iloc[1]["trained_date"] == "2026-01-12" + + # Test returning list of dicts + models_list = airtable_manager.get_models_for_collection( + collection_id="col123", as_dataframe=False + ) + + assert isinstance(models_list, list) + assert len(models_list) == 2 + # List doesn't get sorted automatically, so order is preserved from input + assert models_list[0]["model_name"] == "model-v1" + assert models_list[1]["model_name"] == "model-v2" From 7ebdb74ffca28aee515bbedd14237ecfd0d904d8 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 22 Jan 2026 14:56:02 -0800 Subject: [PATCH 18/18] model registry example --- examples/airtable/model_registry_example.py | 232 +++++++++ .../register_single_dataset_example.py | 37 ++ viscy/airtable/register_model.py | 439 ++++++++++++++++++ viscy/cli/wandb_utils.py | 60 +++ 4 files changed, 768 insertions(+) create mode 100644 examples/airtable/model_registry_example.py create mode 100644 examples/airtable/register_single_dataset_example.py create mode 100644 viscy/airtable/register_model.py create mode 100644 viscy/cli/wandb_utils.py diff --git a/examples/airtable/model_registry_example.py b/examples/airtable/model_registry_example.py new file mode 100644 index 000000000..f30334e18 --- /dev/null +++ b/examples/airtable/model_registry_example.py @@ -0,0 +1,232 @@ +"""Example usage of W&B Model Registry for curating and loading models. + +This demonstrates Part 2 of the W&B integration: +- Registering trained models as artifacts +- Loading models from the registry +- Querying available models +- Full lineage: Collections → Training Run → Model Artifact +""" + +# %% +# ============================================================================= +# Part 1: After Training - Register a "Blessed" Model +# ============================================================================= + +from viscy.airtable.register_model import register_model + +# After training completes, find the best checkpoint +checkpoint_path = "logs/wandb/run-20260107-154117/checkpoints/epoch=50-step=2550.ckpt" +wandb_run_id = "20260107-154117" # Get from W&B UI or training output + +# Register the model +artifact_url = register_model( + checkpoint_path=checkpoint_path, + model_name="contrastive-a549-sec61", + model_type="contrastive", + version="v1", + aliases=["production", "best"], + wandb_run_id=wandb_run_id, + wandb_project="viscy-model-registry", + shared_dir="/hpc/models/shared", + description="Contrastive model trained on A549 SEC61 DENV wells B1-B2, val_loss=0.152", + metadata={ + "val_loss": 0.152, + "collection_name": "2024_11_07_A549_SEC61_DENV_wells_B1_B2", + "collection_version": "0.0.1", + "backbone": "convnext_tiny", + "embedding_dim": 768, + }, +) + +print("\n✓ Model registered successfully!") +print(f" View in W&B: {artifact_url}") +print(" Checkpoint: /hpc/models/shared/contrastive/contrastive-a549-sec61-v1.ckpt") + +# %% +# ============================================================================= +# Part 2: Discovery - List Available Models +# ============================================================================= + +from viscy.airtable.register_model import list_registered_models + +# List all registered models +all_models = list_registered_models(wandb_project="viscy-model-registry") + +print("\n=== All Registered Models ===") +for model in all_models: + print(f"{model['name']}:{model['version']}") + print(f" Type: {model['model_type']}") + print(f" Aliases: {model['aliases']}") + print(f" Description: {model['description']}") + print(f" Checkpoint: {model['checkpoint_path']}") + print() + +# %% +# Filter by model type +contrastive_models = list_registered_models( + wandb_project="viscy-model-registry", model_type="contrastive" +) + +print("\n=== Contrastive Models ===") +for model in contrastive_models: + metadata = model["metadata"] + print(f"{model['name']}:{model['version']}") + print(f" Val Loss: {metadata.get('val_loss', 'N/A')}") + print(f" Collections: {metadata.get('collection_name', 'N/A')}") + print() + +# %% +# Find production models +production_models = [m for m in all_models if "production" in m["aliases"]] + +print("\n=== Production Models ===") +for model in production_models: + print(f"- {model['name']} ({model['model_type']})") + +# %% +# ============================================================================= +# Part 3: Loading Models - Use Registered Models in Analysis/Inference +# ============================================================================= + +from viscy.airtable.register_model import load_model_from_registry +from viscy.representation.engine import ContrastiveModule + +# Load production model by alias +model = load_model_from_registry( + model_name="contrastive-a549-sec61", + version="production", # Can also use "v1", "latest", "best" + wandb_project="viscy-model-registry", + model_class=ContrastiveModule, +) + +print("\n=== Model Loaded ===") +print(f"Model type: {type(model).__name__}") +print(f"Model in eval mode: {not model.training}") + +# %% +# Use model for inference +import torch + +# Create dummy input (replace with real data) +dummy_batch = torch.randn(2, 1, 5, 160, 160) # [batch, channels, z, y, x] + +with torch.no_grad(): + embeddings = model(dummy_batch) + +print("\n✓ Inference successful!") +print(f" Input shape: {dummy_batch.shape}") +print(f" Embedding shape: {embeddings.shape}") + +# %% +# ============================================================================= +# Part 4: Full Lineage - From Collections to Model +# ============================================================================= + +import wandb + +# Get model artifact +api = wandb.Api() +artifact = api.artifact("viscy-model-registry/contrastive-a549-sec61:production") + +print("\n=== Model Lineage ===") +print(f"Model: {artifact.name}:{artifact.version}") +print(f"Description: {artifact.description}") +print() + +# Get training run (lineage) +training_run_id = artifact.metadata.get("training_run_id") +if training_run_id: + # NOTE: You may need to adjust project name + try: + training_run = api.run(f"eduardo-hirata/viscy-experiments/{training_run_id}") + print(f"Training Run: {training_run.name}") + print(f" URL: {training_run.url}") + print(f" Metrics: val_loss={training_run.summary.get('loss/val', 'N/A')}") + print() + except Exception as e: + print(f"Could not fetch training run: {e}") + +# Get collection (from training run config or artifact metadata) +collection_name = artifact.metadata.get("collection_name") +collection_version = artifact.metadata.get("collection_version") + +if collection_name and collection_version: + print(f"Collections: {collection_name} v{collection_version}") + + # You can now fetch the full collection from Airtable + from viscy.airtable.database import AirtableManager + + airtable_db = AirtableManager(base_id="app8vqaoWyOwa0sB5") + collection = airtable_db.get_dataset_paths( + collection_name=collection_name, + version=collection_version, + ) + + print(f" Total FOVs: {collection.total_fovs}") + print(f" Data paths: {[str(ds.data_path) for ds in collection.datasets]}") + +print("\n✓ Full lineage chain:") +print( + " Airtable Collections → W&B Training Run → W&B Model Artifact → Checkpoint File" +) + +# %% +# ============================================================================= +# Part 5: Command Line Usage (Alternative to Python API) +# ============================================================================= + +print("\n=== CLI Usage Examples ===") + +print( + """ +# Register a model after training: +python -m viscy.airtable.register_model \\ + logs/wandb/run-20260107/checkpoints/epoch=50.ckpt \\ + --name contrastive-rpe1 \\ + --type contrastive \\ + --version v2 \\ + --aliases production best \\ + --run-id 20260107-152420 \\ + --description "Best RPE1 model, val_loss=0.145" + +# Checkpoints are copied to: +# /hpc/models/shared/contrastive/contrastive-rpe1-v2.ckpt + +# View in W&B: +# https://wandb.ai/YOUR_ENTITY/viscy-model-registry/artifacts/model +""" +) + +# %% +# ============================================================================= +# Summary: When to use what +# ============================================================================= + +print("\n=== Summary ===") +print( + """ +Part 1 (Automatic - Already Implemented): +- CollectionWandbCallback logs collection metadata to every training run +- All experiments tracked automatically in W&B +- No manual work required + +Part 2 (Manual - This Example): +- After training, manually register "blessed" models using register_model() +- Creates W&B artifact with lineage to training run +- Copies checkpoint to shared HPC directory +- Team can discover and load models via W&B UI or Python API + +Workflow: +1. Train model with train_with_wandb.yml (automatic tracking) +2. Review metrics in W&B, identify best model +3. Register best checkpoint as artifact (manual) +4. Team can now load model by name/alias from registry +5. Full lineage: Collections → Training → Model Artifact + +Benefits: +- Discoverability: Find models by collection, performance, date +- Versioning: v1, v2, v3 with aliases (production, latest) +- Lineage: Track which data/config produced which model +- Team collaboration: Shared registry + shared checkpoints +""" +) diff --git a/examples/airtable/register_single_dataset_example.py b/examples/airtable/register_single_dataset_example.py new file mode 100644 index 000000000..dea4c9a4a --- /dev/null +++ b/examples/airtable/register_single_dataset_example.py @@ -0,0 +1,37 @@ +"""Simple example: Register a single dataset to Airtable.""" + +from viscy.airtable import AirtableManager, DatasetRecord + +BASE_ID = "app8vqaoWyOwa0sB5" + +# Create dataset with validation +dataset = DatasetRecord( + dataset_name="2024_11_07_A549_SEC61_DENV", + well_id="B/1", + fov_name="0", + data_path="/hpc/data/2024_11_07_A549_SEC61_DENV.zarr/B/1/0", + cell_type="A549", + organelle="SEC61B", + channel_0="brightfield", + channel_1="nucleus", + channel_2="protein", +) + +print("Dataset to register:") +print(f" {dataset}") +print(f" Dataset: {dataset.dataset_name}") +print(f" Well: {dataset.well_id}, FOV: {dataset.fov_name}") +print(f" Cell type: {dataset.cell_type}") +print(f" Organelle: {dataset.organelle}") +print(f" Path: {dataset.data_path}") + +# Register to Airtable +print("\nRegistering to Airtable...") +airtable_db = AirtableManager(base_id=BASE_ID) +record_id = airtable_db.register_dataset(dataset) + +print("\n✓ Successfully registered!") +print(f" Airtable Record ID: {record_id}") +print( + f" FOV_ID will be auto-generated as: {dataset.dataset_name}_{dataset.well_id}_{dataset.fov_name}" +) diff --git a/viscy/airtable/register_model.py b/viscy/airtable/register_model.py new file mode 100644 index 000000000..d69bc4dd3 --- /dev/null +++ b/viscy/airtable/register_model.py @@ -0,0 +1,439 @@ +"""Register trained models to W&B artifact registry and shared directory.""" + +import shutil +from pathlib import Path +from typing import Any + +import wandb + + +def register_model( + checkpoint_path: str, + model_name: str, + model_type: str, + version: str, + config_path: str | None = None, + aliases: list[str] | None = None, + wandb_run_id: str | None = None, + wandb_project: str = "viscy-model-registry", + shared_dir: str = "/hpc/models/shared", + description: str | None = None, + metadata: dict[str, Any] | None = None, + airtable_base_id: str | None = None, + airtable_collection_id: str | None = None, +) -> str: + """ + Register a trained model to W&B artifacts and copy to shared directory. + + This creates a W&B artifact with references (not uploads) to track model lineage. + The checkpoint file is copied to a shared HPC directory for team access. + + Parameters + ---------- + checkpoint_path : str + Path to Lightning checkpoint (.ckpt file) + model_name : str + Human-readable name (e.g., "contrastive-rpe1") + model_type : str + Model category: contrastive, segmentation, vae, translation + version : str + Semantic version (e.g., "v1", "v2", "v3") + config_path : str | None + Path to training config YAML file (config_fit.yml). This will be stored + in the WandB artifact for later use in prediction. + aliases : list[str] | None + Tags like "production", "best", "latest" + wandb_run_id : str | None + W&B run ID that trained this model (for lineage tracking) + wandb_project : str + W&B project for model registry (default: "viscy-model-registry") + shared_dir : str + Shared checkpoint directory path (default: "/hpc/models/shared") + description : str | None + Model description (e.g., "Trained on RPE1 collection v0.0.1, val_loss=0.15") + metadata : dict[str, Any] | None + Additional metadata to store (metrics, config, etc.) + airtable_base_id : str | None + Airtable base ID to log model to Models table (optional) + airtable_collection_id : str | None + Airtable collection record ID to link this model to (optional) + + Returns + ------- + str + W&B artifact URL + + Examples + -------- + Register a model after training: + + >>> from viscy.airtable.register_model import register_model + >>> artifact_url = register_model( + ... checkpoint_path="logs/wandb/run-20260107/checkpoints/epoch=50.ckpt", + ... config_path="examples/configs/fit_example.yml", # Training config + ... model_name="contrastive-rpe1", + ... model_type="contrastive", + ... version="v2", + ... aliases=["production", "best"], + ... wandb_run_id="20260107-152420", + ... description="Trained on RPE1 collection v0.0.1, best val_loss=0.15", + ... metadata={"val_loss": 0.15, "collection_name": "RPE1_infection", "collection_version": "v0.0.1"}, + ... airtable_base_id="app8vqaoWyOwa0sB5", # Optional: log to Airtable + ... airtable_collection_id="recXXXXXXXXXXXXXX", # Optional: link to collection + ... ) + >>> print(f"Model registered: {artifact_url}") + + CLI usage: + + >>> # From command line + >>> python -m viscy.airtable.register_model \\ + ... logs/wandb/run-20260107/checkpoints/epoch=50.ckpt \\ + ... --name contrastive-rpe1 \\ + ... --type contrastive \\ + ... --version v2 \\ + ... --aliases production best \\ + ... --run-id 20260107-152420 \\ + ... --description "Best RPE1 model" + """ + checkpoint_path = Path(checkpoint_path) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + # 1. Copy to shared directory + shared_path = Path(shared_dir) / model_type / f"{model_name}-{version}.ckpt" + shared_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(checkpoint_path, shared_path) + print(f"✓ Copied to shared directory: {shared_path}") + + # 2. Initialize W&B (use API for non-training context) + run = wandb.init(project=wandb_project, job_type="register-model") + + # 3. Create artifact with metadata (references only, not uploads) + artifact_metadata = { + "model_type": model_type, + "version": version, + "checkpoint_path": str(shared_path), + "original_checkpoint": str(checkpoint_path), + } + + # Add custom metadata + if metadata: + artifact_metadata.update(metadata) + + artifact = wandb.Artifact( + name=model_name, + type="model", + description=description or f"{model_type} model {version}", + metadata=artifact_metadata, + ) + + # 4. Add training config to artifact (if provided) + if config_path: + config_path = Path(config_path) + if config_path.exists(): + artifact.add_file(str(config_path), name="config_fit.yml") + print(f"✓ Added training config to artifact: {config_path.name}") + else: + print(f"⚠ Config file not found: {config_path}") + + # 5. Store checkpoint path in metadata only (no upload/reference) + # The checkpoint stays on HPC, W&B only tracks the metadata + # This is "Option 2" from the plan - references only, not uploads + + # 6. Link to training run (lineage) + if wandb_run_id: + artifact.metadata["training_run_id"] = wandb_run_id + # Create lineage link + try: + training_run = wandb.Api().run( + f"{run.entity}/{wandb_project}/{wandb_run_id}" + ) + artifact.metadata["training_run_url"] = training_run.url + except Exception as e: + print(f"⚠ Could not link to training run: {e}") + + # 7. Log artifact with version and aliases + aliases = aliases or [] + wandb.log_artifact(artifact, aliases=aliases) + + artifact_url = ( + f"https://wandb.ai/{run.entity}/{wandb_project}/artifacts/model/{model_name}" + ) + + wandb.finish() + + print(f"✓ Registered in W&B: {model_name}:{version}") + print(f" Aliases: {aliases}") + print(f" View: {artifact_url}") + + # 8. Optionally log to Airtable Models table + if airtable_base_id and airtable_collection_id: + try: + import getpass + + from viscy.airtable.database import AirtableManager + + airtable_db = AirtableManager(base_id=airtable_base_id) + + # Prepare complete metadata for Airtable (mirrors W&B artifact metadata) + airtable_metadata = { + # Core W&B fields + "model_type": model_type, + "version": version, + "wandb_url": artifact_url, + "original_checkpoint": str(checkpoint_path), + } + + # Add all custom metadata (metrics, architecture, collection lineage, etc.) + if metadata: + airtable_metadata.update(metadata) + + model_id = airtable_db.log_model_training( + collection_id=airtable_collection_id, + wandb_run_id=wandb_run_id or "unknown", + model_name=f"{model_name}:{version}", + checkpoint_path=str(shared_path), + trained_by=getpass.getuser(), + metrics=airtable_metadata, # Pass ALL metadata to Airtable + ) + + print(f"✓ Logged to Airtable Models table (record ID: {model_id})") + print(f" Linked to collection: {airtable_collection_id}") + print(f" Metadata fields passed: {len(airtable_metadata)}") + + except Exception as e: + print(f"⚠ Could not log to Airtable: {e}") + print(" (Fields without matching Airtable columns are ignored)") + + return artifact_url + + +def load_model_from_registry( + model_name: str, + version: str = "latest", + wandb_project: str = "viscy-model-registry", + model_class=None, + **model_kwargs, +): + """ + Load a model from W&B artifact registry. + + This fetches the checkpoint path from W&B metadata and loads the model + from the shared HPC directory (does not download from W&B cloud). + + Parameters + ---------- + model_name : str + Model artifact name + version : str + Version or alias ("latest", "production", "v2") + wandb_project : str + W&B project name (default: "viscy-model-registry") + model_class : LightningModule class + Model class to instantiate (e.g., ContrastiveModule) + **model_kwargs + Additional arguments for model initialization + + Returns + ------- + model : LightningModule + Loaded model in eval mode + + Examples + -------- + Load a registered model: + + >>> from viscy.representation.engine import ContrastiveModule + >>> from viscy.airtable.register_model import load_model_from_registry + >>> + >>> model = load_model_from_registry( + ... model_name="contrastive-rpe1", + ... version="production", + ... wandb_project="viscy-model-registry", + ... model_class=ContrastiveModule, + ... ) + >>> model.eval() + >>> embeddings = model(images) + """ + # 1. Get artifact metadata from W&B + api = wandb.Api() + entity = api.default_entity + artifact = api.artifact(f"{entity}/{wandb_project}/{model_name}:{version}") + + # 2. Get checkpoint path from metadata (no download needed) + checkpoint_path = artifact.metadata.get("checkpoint_path") + if not checkpoint_path: + raise ValueError( + f"Artifact {model_name}:{version} missing checkpoint_path metadata" + ) + + checkpoint_path = Path(checkpoint_path) + if not checkpoint_path.exists(): + raise FileNotFoundError( + f"Checkpoint not found: {checkpoint_path}\n" + f"The model is registered but the file is missing from the shared directory." + ) + + print(f"✓ Loading model from: {checkpoint_path}") + print(f" Artifact: {model_name}:{version}") + print(f" Description: {artifact.description}") + + # 3. Load model + if model_class is None: + raise ValueError("Must provide model_class (e.g., ContrastiveModule)") + + model = model_class.load_from_checkpoint(checkpoint_path, **model_kwargs) + model.eval() + + return model + + +def list_registered_models( + wandb_project: str = "viscy-model-registry", + model_type: str | None = None, +) -> list[dict[str, Any]]: + """ + List all registered models in W&B artifact registry. + + Parameters + ---------- + wandb_project : str + W&B project name (default: "viscy-model-registry") + model_type : str | None + Filter by model type (contrastive, segmentation, vae, translation) + + Returns + ------- + list[dict[str, Any]] + List of model metadata dictionaries + + Examples + -------- + >>> from viscy.airtable.register_model import list_registered_models + >>> + >>> # List all models + >>> models = list_registered_models() + >>> for m in models: + ... print(f"{m['name']}:{m['version']} - {m['description']}") + >>> + >>> # List only contrastive models + >>> contrastive = list_registered_models(model_type="contrastive") + """ + api = wandb.Api() + + # Get all artifact collections in the project + try: + # Get entity from API + entity = api.default_entity + artifact_type = "model" + collection_name = f"{entity}/{wandb_project}/{artifact_type}" + + # Use the artifacts API (simpler) + artifacts = api.artifacts(artifact_type, collection_name) + + except Exception as e: + # Project might not exist yet or no artifacts + print(f"Warning: Could not list artifacts - {e}") + return [] + + models = [] + try: + for artifact in artifacts: + metadata = artifact.metadata + + # Filter by model type if specified + if model_type and metadata.get("model_type") != model_type: + continue + + models.append( + { + "name": artifact.name, + "version": artifact.version, + "aliases": artifact.aliases, + "description": artifact.description, + "model_type": metadata.get("model_type"), + "checkpoint_path": metadata.get("checkpoint_path"), + "created_at": artifact.created_at, + "metadata": metadata, + } + ) + except Exception as e: + print(f"Warning: Error iterating artifacts - {e}") + + return models + + +# CLI entry point +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Register a trained model to W&B artifact registry" + ) + + # Required arguments + parser.add_argument("checkpoint_path", help="Path to checkpoint file (.ckpt)") + parser.add_argument( + "--name", required=True, help="Model name (e.g., contrastive-rpe1)" + ) + parser.add_argument( + "--type", + required=True, + choices=["contrastive", "segmentation", "vae", "translation"], + help="Model type", + ) + parser.add_argument("--version", required=True, help="Version (e.g., v1, v2)") + + # Optional arguments + parser.add_argument( + "--config", + help="Path to training config YAML file (config_fit.yml)", + ) + parser.add_argument( + "--aliases", + nargs="+", + help="Aliases (e.g., production best latest)", + ) + parser.add_argument( + "--run-id", + help="W&B run ID that trained this model", + ) + parser.add_argument( + "--project", + default="viscy-model-registry", + help="W&B project name (default: viscy-model-registry)", + ) + parser.add_argument( + "--shared-dir", + default="/hpc/models/shared", + help="Shared checkpoint directory (default: /hpc/models/shared)", + ) + parser.add_argument( + "--description", + help="Model description", + ) + parser.add_argument( + "--airtable-base-id", + help="Airtable base ID (optional, for logging to Models table)", + ) + parser.add_argument( + "--airtable-collection-id", + help="Airtable collection record ID (optional, for linking)", + ) + + args = parser.parse_args() + + register_model( + checkpoint_path=args.checkpoint_path, + config_path=args.config, + model_name=args.name, + model_type=args.type, + version=args.version, + aliases=args.aliases, + wandb_run_id=args.run_id, + wandb_project=args.project, + shared_dir=args.shared_dir, + description=args.description, + airtable_base_id=args.airtable_base_id, + airtable_collection_id=args.airtable_collection_id, + ) diff --git a/viscy/cli/wandb_utils.py b/viscy/cli/wandb_utils.py new file mode 100644 index 000000000..24715b041 --- /dev/null +++ b/viscy/cli/wandb_utils.py @@ -0,0 +1,60 @@ +"""WandB utilities for model registry.""" + +import tempfile +from pathlib import Path + +import wandb + +from viscy.airtable.schemas import ModelRecord + + +def download_model_artifact( + model_name: str, + version: str = "latest", + wandb_project: str = "viscy-model-registry", + download_dir: str | None = None, +) -> tuple[ModelRecord, Path]: + """ + Download model artifact from WandB. + + Parameters + ---------- + model_name : str + Model name (e.g., 'contrastive-rpe1') + version : str + Version or alias (e.g., 'v1', 'latest', 'production') + wandb_project : str + WandB project name + download_dir : str | None + Where to download artifact files (temp dir if None) + + Returns + ------- + tuple[ModelRecord, Path] + (model_record, path_to_downloaded_config) + """ + api = wandb.Api() + entity = api.default_entity + + # Fetch artifact + artifact = api.artifact(f"{entity}/{wandb_project}/{model_name}:{version}") + + # Parse metadata with Pydantic validation + model_record = ModelRecord.from_wandb_artifact(artifact) + + # Download artifact files (config_fit.yml) + if download_dir is None: + download_dir = tempfile.mkdtemp(prefix=f"viscy_model_{model_name}_") + + artifact_dir = Path(artifact.download(root=download_dir)) + config_path = artifact_dir / "config_fit.yml" + + if not config_path.exists(): + raise FileNotFoundError( + "Training config not found in artifact. " + "Model may have been registered without config." + ) + + model_record.config_artifact_path = str(config_path) + + return model_record, config_path