diff --git a/examples/airtable/example_config_with_airtable.yml b/examples/airtable/example_config_with_airtable.yml new file mode 100644 index 000000000..60dbe558c --- /dev/null +++ b/examples/airtable/example_config_with_airtable.yml @@ -0,0 +1,66 @@ +# Example config showing Airtable integration with Lightning training +# Usage: viscy fit -c examples/airtable/example_config_with_airtable.yml + +seed_everything: true + +trainer: + accelerator: gpu + devices: 1 + max_epochs: 100 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + + # Add Airtable logging callback + callbacks: + - class_path: viscy.representation.airtable_callback.AirtableLoggingCallback + init_args: + base_id: "appXXXXXXXXXXXXXX" # Replace with your Airtable base ID + dataset_id: "recYYYYYYYYYYYYYY" # Replace with your dataset record ID + model_name: null # Auto-generate from model class and timestamp + log_metrics: false # Set to true to store metrics in Airtable (otherwise use TensorBoard) + +model: + class_path: viscy.representation.contrastive.ContrastiveModule + init_args: + # Your model config here + backbone: resnet50 + embedding_len: 256 + +data: + class_path: viscy.data.triplet.TripletDataModule + init_args: + data_path: /hpc/data/your_plate.zarr + tracks_path: /hpc/tracks/your_tracks/ + source_channel: [Phase] + z_range: [0, 5] + initial_yx_patch_size: [512, 512] + final_yx_patch_size: [224, 224] + split_ratio: 0.8 + batch_size: 16 + num_workers: 8 + + # FOV selection from Airtable dataset definition + fit_include_wells: ["B3", "B4", "C3"] + fit_exclude_fovs: [] + + # Data augmentation + augmentations: + - class_path: viscy.transforms.RandAffined + init_args: + keys: [Phase] + prob: 0.8 + rotate_range: [3.14, 0.0, 0.0] + scale_range: [0.1, 0.1, 0.1] + - class_path: viscy.transforms.RandGaussianNoised + init_args: + keys: [Phase] + prob: 0.5 + mean: 0.0 + std: 0.1 + +# Dataset metadata (for reference, not used by training) +dataset_metadata: + airtable_id: "recYYYYYYYYYYYYYY" + name: "RPE1_infection_v2" + version: "v2" + description: "RPE1 cells, infection experiment, wells B3-C3" diff --git a/examples/airtable/filter_n_create_dataset_tag.py b/examples/airtable/filter_n_create_dataset_tag.py new file mode 100644 index 000000000..14bbc000e --- /dev/null +++ b/examples/airtable/filter_n_create_dataset_tag.py @@ -0,0 +1,121 @@ +"""Filter datasets using pandas and create collection tags.""" + +# %% + +from viscy.airtable.database import AirtableManager + +# BASE_ID = os.getenv("AIRTABLE_BASE_ID") +BASE_ID = "app8vqaoWyOwa0sB5" +airtable_db = AirtableManager(base_id=BASE_ID) + +# %% +# EXAMPLE 1: Get all dataset records as DataFrame and explore +print("=" * 70) +print("Getting all dataset records as DataFrame") +print("=" * 70) + +df_datasets = airtable_db.list_datasets() +print(f"\nTotal dataset records: {len(df_datasets)}") +print("\nDataFrame columns:") +print(df_datasets.columns.tolist()) +print("\nFirst few rows:") +print(df_datasets.head()) + +# %% +# EXAMPLE 2: Filter by dataset and specific wells using pandas +print("\n" + "=" * 70) +print("Filter: Dataset, Wells B_3 and B_4") +print("=" * 70) + +# Get all dataset records as DataFrame +df = airtable_db.list_datasets() + +# Filter with pandas - simple and powerful! +filtered = df[ + (df["Dataset"] == "2024_11_07_A549_SEC61_DENV") + & (df["Well ID"].isin(["B/1", "B/2"])) +] + +print(f"\nTotal dataset records after filtering: {len(filtered)}") +print("\nBreakdown by well:") +print(filtered.groupby("Well ID").size()) + +# Create collection from filtered dataset records +fov_ids = filtered["FOV_ID"].tolist() + +try: + collection_id = airtable_db.create_collection_from_datasets( + collection_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", + fov_ids=fov_ids, + version="0.0.2", # Semantic versioning + purpose="training", + description="Dataset records from wells B_3 and B_4", + ) + print(f"\n✓ Created collection: {collection_id}") + print(f" Contains {len(fov_ids)} dataset records") +except ValueError as e: + print(f"\n⚠ {e}") + +# %% +# Delete the collection entry demo +airtable_db.delete_collection(collection_id) +print(f"Deleted collection: {collection_id}") + +# %% +# EXAMPLE 3: Group by dataset and show summary +print("\n" + "=" * 70) +print("Group by dataset and show summary") +print("=" * 70) + +df_all = airtable_db.list_datasets() + +grouped = df_all.groupby("Dataset") + +for dataset_name, group in grouped: + print(f"\n{dataset_name}:") + print(f" Total records: {len(group)}") + print(f" Wells: {sorted(group['Well ID'].unique())}") + +# %% +# EXAMPLE 4: Filter by multiple wells +print("\n" + "=" * 70) +print("Filter: Multiple specific wells") +print("=" * 70) + +df = airtable_db.list_datasets() + +# Filter for specific wells from a dataset +filtered = df[ + (df["Dataset"] == "2024_11_07_A549_SEC61_DENV") + & (df["Well ID"].isin(["B/3", "B/4", "C/3", "C/4"])) +] + +print(f"\nDataset records matching criteria: {len(filtered)}") +print("\nBy well:") +print(filtered.groupby("Well ID").size()) + +print("\nFOV IDs:") +for fov_id in filtered["FOV_ID"]: + print(f" {fov_id}") + +# %% +# EXAMPLE 5: Summary statistics +print("\n" + "=" * 70) +print("Summary Statistics") +print("=" * 70) + +df = airtable_db.list_datasets() + +print("\nDataset records per source dataset:") +print(df.groupby("Dataset").size()) + +print("\nWells with most dataset records:") +print(df.groupby("Well ID").size().sort_values(ascending=False).head(10)) + +print("\nTotal unique wells:") +print(f"{df['Well ID'].nunique()} wells") + +print("\nTotal unique FOV IDs:") +print(f"{df['FOV_ID'].nunique()} FOV IDs") + +# %% diff --git a/examples/airtable/get_dataset_paths_example.py b/examples/airtable/get_dataset_paths_example.py new file mode 100644 index 000000000..067469a32 --- /dev/null +++ b/examples/airtable/get_dataset_paths_example.py @@ -0,0 +1,118 @@ +"""Example usage of get_dataset_paths with Collections and CollectionDataset dataclasses.""" + +# %% +from viscy.airtable.database import AirtableManager + +BASE_ID = "app8vqaoWyOwa0sB5" +airtable_db = AirtableManager(base_id=BASE_ID) + +# %% +# Fetch collection from Airtable +collection = airtable_db.get_dataset_paths( + collection_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", + version="v1", +) + +# %% +# Collections properties +print("=== Collections ===") +print(f"collection.name: {collection.name}") +print(f"collection.version: {collection.version}") +print(f"len(collection): {len(collection)} HCS plate(s)") +print(f"collection.total_fovs: {collection.total_fovs} FOVs") + +# %% +# Iterate over CollectionDataset objects (one per HCS plate) +print("\n=== CollectionDataset ===") +for ds in collection: + print(f"ds.data_path: {ds.data_path}") + print(f"ds.tracks_path: {ds.tracks_path}") + print(f"len(ds): {len(ds)} FOVs") + print(f"ds.fov_names: {ds.fov_names[:3]}...") + print(f"ds.fov_paths: {ds.fov_paths[:2]}...") + print(f"ds.exists(): {ds.exists()}") + +# %% +# Validate paths exist (raises FileNotFoundError if not) +collection.validate() +print("\nAll paths validated successfully!") + + +# %% +# List available collections +print("=== Available Collections ===") +df = airtable_db.list_collections() +print(df[["name", "version", "purpose"]].dropna(subset=["name"]).to_string()) + +# %% +# ============================================================================= +# Create TripletDataModule from collection using factory function +# ============================================================================= +from viscy.airtable.factory import create_triplet_datamodule_from_collection + +# Create data module from collection +dm = create_triplet_datamodule_from_collection( + collection=collection, + source_channel=["Phase3D"], + z_range=(20, 21), + batch_size=1, + num_workers=1, + initial_yx_patch_size=(160, 160), + final_yx_patch_size=(160, 160), + return_negative=False, + time_interval=1, +) + +# %% +# Setup and inspect the data module +dm.setup("fit") +print("\n=== TripletDataModule from Collections ===") +print(f"Data module type: {type(dm).__name__}") +print(f"Train samples: {len(dm.train_dataset)}") +print(f"Val samples: {len(dm.val_dataset)}") + +# %% +# ============================================================================= +# Alternative: CollectionTripletDataModule (Lightning Config Compatible) +# ============================================================================= +from viscy.airtable.factory import CollectionTripletDataModule + +# This class is designed for Lightning CLI and config files +# but can also be used directly in Python +dm_class = CollectionTripletDataModule( + base_id=BASE_ID, + collection_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", + collection_version="v1", + source_channel=["Phase3D"], + z_range=(20, 21), + batch_size=1, + num_workers=1, + initial_yx_patch_size=(160, 160), + final_yx_patch_size=(160, 160), + return_negative=False, + time_interval=1, +) + +dm_class.setup("fit") +print("\n=== CollectionTripletDataModule (Class) ===") +print(f"Data module type: {type(dm_class).__name__}") +print(f"Train samples: {len(dm_class.train_dataset)}") +print(f"Val samples: {len(dm_class.val_dataset)}") + +# %% Visualize some of the images +import matplotlib.pyplot as plt +import torch + +img_stack = [] +for idx, batch in enumerate(dm.train_dataloader()): + img_stack.append(batch["anchor"][0, 0, 0]) + if idx >= 9: + break +img_stack = torch.stack(img_stack) +# %% +# Make subplot with 10 images +fig, axs = plt.subplots(2, 5, figsize=(10, 4)) +for i in range(10): + axs[i // 5, i % 5].imshow(img_stack[i], cmap="gray") + axs[i // 5, i % 5].axis("off") +plt.show() diff --git a/examples/airtable/model_registry_example.py b/examples/airtable/model_registry_example.py new file mode 100644 index 000000000..f30334e18 --- /dev/null +++ b/examples/airtable/model_registry_example.py @@ -0,0 +1,232 @@ +"""Example usage of W&B Model Registry for curating and loading models. + +This demonstrates Part 2 of the W&B integration: +- Registering trained models as artifacts +- Loading models from the registry +- Querying available models +- Full lineage: Collections → Training Run → Model Artifact +""" + +# %% +# ============================================================================= +# Part 1: After Training - Register a "Blessed" Model +# ============================================================================= + +from viscy.airtable.register_model import register_model + +# After training completes, find the best checkpoint +checkpoint_path = "logs/wandb/run-20260107-154117/checkpoints/epoch=50-step=2550.ckpt" +wandb_run_id = "20260107-154117" # Get from W&B UI or training output + +# Register the model +artifact_url = register_model( + checkpoint_path=checkpoint_path, + model_name="contrastive-a549-sec61", + model_type="contrastive", + version="v1", + aliases=["production", "best"], + wandb_run_id=wandb_run_id, + wandb_project="viscy-model-registry", + shared_dir="/hpc/models/shared", + description="Contrastive model trained on A549 SEC61 DENV wells B1-B2, val_loss=0.152", + metadata={ + "val_loss": 0.152, + "collection_name": "2024_11_07_A549_SEC61_DENV_wells_B1_B2", + "collection_version": "0.0.1", + "backbone": "convnext_tiny", + "embedding_dim": 768, + }, +) + +print("\n✓ Model registered successfully!") +print(f" View in W&B: {artifact_url}") +print(" Checkpoint: /hpc/models/shared/contrastive/contrastive-a549-sec61-v1.ckpt") + +# %% +# ============================================================================= +# Part 2: Discovery - List Available Models +# ============================================================================= + +from viscy.airtable.register_model import list_registered_models + +# List all registered models +all_models = list_registered_models(wandb_project="viscy-model-registry") + +print("\n=== All Registered Models ===") +for model in all_models: + print(f"{model['name']}:{model['version']}") + print(f" Type: {model['model_type']}") + print(f" Aliases: {model['aliases']}") + print(f" Description: {model['description']}") + print(f" Checkpoint: {model['checkpoint_path']}") + print() + +# %% +# Filter by model type +contrastive_models = list_registered_models( + wandb_project="viscy-model-registry", model_type="contrastive" +) + +print("\n=== Contrastive Models ===") +for model in contrastive_models: + metadata = model["metadata"] + print(f"{model['name']}:{model['version']}") + print(f" Val Loss: {metadata.get('val_loss', 'N/A')}") + print(f" Collections: {metadata.get('collection_name', 'N/A')}") + print() + +# %% +# Find production models +production_models = [m for m in all_models if "production" in m["aliases"]] + +print("\n=== Production Models ===") +for model in production_models: + print(f"- {model['name']} ({model['model_type']})") + +# %% +# ============================================================================= +# Part 3: Loading Models - Use Registered Models in Analysis/Inference +# ============================================================================= + +from viscy.airtable.register_model import load_model_from_registry +from viscy.representation.engine import ContrastiveModule + +# Load production model by alias +model = load_model_from_registry( + model_name="contrastive-a549-sec61", + version="production", # Can also use "v1", "latest", "best" + wandb_project="viscy-model-registry", + model_class=ContrastiveModule, +) + +print("\n=== Model Loaded ===") +print(f"Model type: {type(model).__name__}") +print(f"Model in eval mode: {not model.training}") + +# %% +# Use model for inference +import torch + +# Create dummy input (replace with real data) +dummy_batch = torch.randn(2, 1, 5, 160, 160) # [batch, channels, z, y, x] + +with torch.no_grad(): + embeddings = model(dummy_batch) + +print("\n✓ Inference successful!") +print(f" Input shape: {dummy_batch.shape}") +print(f" Embedding shape: {embeddings.shape}") + +# %% +# ============================================================================= +# Part 4: Full Lineage - From Collections to Model +# ============================================================================= + +import wandb + +# Get model artifact +api = wandb.Api() +artifact = api.artifact("viscy-model-registry/contrastive-a549-sec61:production") + +print("\n=== Model Lineage ===") +print(f"Model: {artifact.name}:{artifact.version}") +print(f"Description: {artifact.description}") +print() + +# Get training run (lineage) +training_run_id = artifact.metadata.get("training_run_id") +if training_run_id: + # NOTE: You may need to adjust project name + try: + training_run = api.run(f"eduardo-hirata/viscy-experiments/{training_run_id}") + print(f"Training Run: {training_run.name}") + print(f" URL: {training_run.url}") + print(f" Metrics: val_loss={training_run.summary.get('loss/val', 'N/A')}") + print() + except Exception as e: + print(f"Could not fetch training run: {e}") + +# Get collection (from training run config or artifact metadata) +collection_name = artifact.metadata.get("collection_name") +collection_version = artifact.metadata.get("collection_version") + +if collection_name and collection_version: + print(f"Collections: {collection_name} v{collection_version}") + + # You can now fetch the full collection from Airtable + from viscy.airtable.database import AirtableManager + + airtable_db = AirtableManager(base_id="app8vqaoWyOwa0sB5") + collection = airtable_db.get_dataset_paths( + collection_name=collection_name, + version=collection_version, + ) + + print(f" Total FOVs: {collection.total_fovs}") + print(f" Data paths: {[str(ds.data_path) for ds in collection.datasets]}") + +print("\n✓ Full lineage chain:") +print( + " Airtable Collections → W&B Training Run → W&B Model Artifact → Checkpoint File" +) + +# %% +# ============================================================================= +# Part 5: Command Line Usage (Alternative to Python API) +# ============================================================================= + +print("\n=== CLI Usage Examples ===") + +print( + """ +# Register a model after training: +python -m viscy.airtable.register_model \\ + logs/wandb/run-20260107/checkpoints/epoch=50.ckpt \\ + --name contrastive-rpe1 \\ + --type contrastive \\ + --version v2 \\ + --aliases production best \\ + --run-id 20260107-152420 \\ + --description "Best RPE1 model, val_loss=0.145" + +# Checkpoints are copied to: +# /hpc/models/shared/contrastive/contrastive-rpe1-v2.ckpt + +# View in W&B: +# https://wandb.ai/YOUR_ENTITY/viscy-model-registry/artifacts/model +""" +) + +# %% +# ============================================================================= +# Summary: When to use what +# ============================================================================= + +print("\n=== Summary ===") +print( + """ +Part 1 (Automatic - Already Implemented): +- CollectionWandbCallback logs collection metadata to every training run +- All experiments tracked automatically in W&B +- No manual work required + +Part 2 (Manual - This Example): +- After training, manually register "blessed" models using register_model() +- Creates W&B artifact with lineage to training run +- Copies checkpoint to shared HPC directory +- Team can discover and load models via W&B UI or Python API + +Workflow: +1. Train model with train_with_wandb.yml (automatic tracking) +2. Review metrics in W&B, identify best model +3. Register best checkpoint as artifact (manual) +4. Team can now load model by name/alias from registry +5. Full lineage: Collections → Training → Model Artifact + +Benefits: +- Discoverability: Find models by collection, performance, date +- Versioning: v1, v2, v3 with aliases (production, latest) +- Lineage: Track which data/config produced which model +- Team collaboration: Shared registry + shared checkpoints +""" +) diff --git a/examples/airtable/register_single_dataset_example.py b/examples/airtable/register_single_dataset_example.py new file mode 100644 index 000000000..dea4c9a4a --- /dev/null +++ b/examples/airtable/register_single_dataset_example.py @@ -0,0 +1,37 @@ +"""Simple example: Register a single dataset to Airtable.""" + +from viscy.airtable import AirtableManager, DatasetRecord + +BASE_ID = "app8vqaoWyOwa0sB5" + +# Create dataset with validation +dataset = DatasetRecord( + dataset_name="2024_11_07_A549_SEC61_DENV", + well_id="B/1", + fov_name="0", + data_path="/hpc/data/2024_11_07_A549_SEC61_DENV.zarr/B/1/0", + cell_type="A549", + organelle="SEC61B", + channel_0="brightfield", + channel_1="nucleus", + channel_2="protein", +) + +print("Dataset to register:") +print(f" {dataset}") +print(f" Dataset: {dataset.dataset_name}") +print(f" Well: {dataset.well_id}, FOV: {dataset.fov_name}") +print(f" Cell type: {dataset.cell_type}") +print(f" Organelle: {dataset.organelle}") +print(f" Path: {dataset.data_path}") + +# Register to Airtable +print("\nRegistering to Airtable...") +airtable_db = AirtableManager(base_id=BASE_ID) +record_id = airtable_db.register_dataset(dataset) + +print("\n✓ Successfully registered!") +print(f" Airtable Record ID: {record_id}") +print( + f" FOV_ID will be auto-generated as: {dataset.dataset_name}_{dataset.well_id}_{dataset.fov_name}" +) diff --git a/examples/airtable/sklearn_wrapper.py b/examples/airtable/sklearn_wrapper.py new file mode 100644 index 000000000..5e4ba4993 --- /dev/null +++ b/examples/airtable/sklearn_wrapper.py @@ -0,0 +1,174 @@ +import matplotlib.pyplot as plt +import torch +from lightning.pytorch import LightningModule +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import ( + ConfusionMatrixDisplay, + accuracy_score, + classification_report, + confusion_matrix, + f1_score, +) + + +class SklearnLogisticRegressionModule(LightningModule): + """ + Wrap sklearn LogisticRegression in Lightning for experiment tracking. + + This module collects features/labels during training_step, then trains + the sklearn model at the end of each epoch. This pattern allows us to + use Lightning's logging infrastructure while using sklearn's optimized + solvers. + + Parameters + ---------- + input_dim : int + Feature dimension + lr : float + Inverse regularization strength (C parameter for LogisticRegression) + solver : str + Solver algorithm + max_iter : int + Maximum iterations for solver + class_weight : str | None + Class weighting strategy + """ + + def __init__( + self, + input_dim: int = 768, + lr: float = 1.0, + solver: str = "lbfgs", + max_iter: int = 1000, + class_weight: str | None = "balanced", + ): + super().__init__() + self.save_hyperparameters() + + # Sklearn model (trained incrementally per epoch) + self.model = LogisticRegression( + C=lr, + solver=solver, + max_iter=max_iter, + class_weight=class_weight, + random_state=42, + ) + + # Storage for batch features/labels + self.train_features = [] + self.train_labels = [] + self.val_features = [] + self.val_labels = [] + + # Store example input for Lightning compatibility + self.example_input_array = torch.rand(2, input_dim) + + def forward(self, x): + """Sklearn doesn't have gradients, so forward just returns input.""" + return x + + def training_step(self, batch, batch_idx): + """Collect features and labels for end-of-epoch training.""" + features, labels = batch + self.train_features.append(features.cpu()) + self.train_labels.append(labels.cpu()) + return None # No loss to backprop + + def on_train_epoch_end(self): + """Train sklearn model on collected features.""" + X_train = torch.cat(self.train_features, dim=0).numpy() + y_train = torch.cat(self.train_labels, dim=0).numpy() + + # Train sklearn model + self.model.fit(X_train, y_train) + + # Compute training metrics + y_pred = self.model.predict(X_train) + train_acc = accuracy_score(y_train, y_pred) + train_f1 = f1_score(y_train, y_pred, average="binary") + + # Log metrics (viscy convention: metric/category/stage) + self.log_dict( + { + "metric/accuracy/train": train_acc, + "metric/f1_score/train": train_f1, + }, + on_step=False, + on_epoch=True, + ) + + print(f"\n Training - Accuracy: {train_acc:.3f}, F1: {train_f1:.3f}") + + # Clear storage + self.train_features.clear() + self.train_labels.clear() + + def validation_step(self, batch, batch_idx): + """Collect validation features and labels.""" + features, labels = batch + self.val_features.append(features.cpu()) + self.val_labels.append(labels.cpu()) + return None + + def on_validation_epoch_end(self): + """Evaluate sklearn model on validation set.""" + # Skip if model not fitted yet (happens during sanity check) + if not hasattr(self.model, "classes_"): + self.val_features.clear() + self.val_labels.clear() + return + + X_val = torch.cat(self.val_features, dim=0).numpy() + y_val = torch.cat(self.val_labels, dim=0).numpy() + + # Predict + y_pred = self.model.predict(X_val) + + # Compute metrics + val_acc = accuracy_score(y_val, y_pred) + val_f1 = f1_score(y_val, y_pred, average="binary") + + # Log metrics + self.log_dict( + { + "metric/accuracy/val": val_acc, + "metric/f1_score/val": val_f1, + }, + on_step=False, + on_epoch=True, + ) + + print(f" Validation - Accuracy: {val_acc:.3f}, F1: {val_f1:.3f}") + + # Log confusion matrix + cm = confusion_matrix(y_val, y_pred) + if hasattr(self.logger, "experiment"): + fig, ax = plt.subplots(figsize=(6, 6)) + ConfusionMatrixDisplay(cm, display_labels=["Uninfected", "Infected"]).plot( + ax=ax, cmap="Blues" + ) + ax.set_title(f"Confusion Matrix - Epoch {self.current_epoch}") + + # Save and log figure + fig_path = f"/tmp/confusion_matrix_epoch_{self.current_epoch}.png" + fig.savefig(fig_path, dpi=100, bbox_inches="tight") + self.logger.experiment.log_artifact( + self.logger.run_id, fig_path, artifact_path="plots" + ) + plt.close(fig) + + # Print classification report + print( + "\n" + + classification_report( + y_val, y_pred, target_names=["Uninfected", "Infected"] + ) + ) + + # Clear storage + self.val_features.clear() + self.val_labels.clear() + + def configure_optimizers(self): + """No optimizer needed for sklearn.""" + return None diff --git a/examples/airtable/test_airtable_connection.py b/examples/airtable/test_airtable_connection.py new file mode 100644 index 000000000..f5797cc65 --- /dev/null +++ b/examples/airtable/test_airtable_connection.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +"""Test Airtable connection and setup.""" + +# %% +import os + +from pyairtable import Api + +print("=" * 70) +print("Testing Airtable Connection") +print("=" * 70) + +# Check environment variables +# TODO: Add these ENVIRONMENT VARIABLES TO BASHRC or export them in the node +api_key = os.getenv("AIRTABLE_API_KEY") +base_id = os.getenv("AIRTABLE_BASE_ID") + +print("\n1. Environment Variables:") +print( + f" AIRTABLE_API_KEY: {'✓ Set' if api_key else '✗ Not set'} ({api_key[:10] if api_key else 'N/A'}...)" +) +print( + f" AIRTABLE_BASE_ID: {'✓ Set' if base_id else '✗ Not set'} ({base_id if base_id else 'N/A'})" +) + +if not api_key or not base_id: + print("\n❌ ERROR: Environment variables not set!") + print("\nRun these commands in your shell:") + print(' export AIRTABLE_API_KEY="patXXXXXXXXXXXXXX"') + print(' export AIRTABLE_BASE_ID="appXXXXXXXXXXXXXX"') + print("\nOr add them to your ~/.bashrc") + exit(1) + +# Test API connection +print("\n2. Testing API Connection...") +try: + api = Api(api_key) + print(" ✓ API initialized") +except Exception as e: + print(f" ✗ Failed to initialize API: {e}") + exit(1) + +# Test Datasets table +print("\n3. Testing Datasets Table...") +try: + models_table = api.table(base_id, "Models") + records = models_table.all() + print(" ✓ Connected to Datasets table") + print(f" ✓ Found {len(records)} record(s)") + + if records: + print("\n Existing datasets:") + for record in records: + fields = record["fields"] + name = fields.get("name", "N/A") + version = fields.get("version", "N/A") + print(f" - {name} (v{version})") + else: + print(" (Table is empty - this is OK for first run)") + +except Exception as e: + print(f" ✗ Failed to access Datasets table: {e}") + print("\n Make sure you created a table named 'Datasets' (case-sensitive)") + exit(1) + +# Test Models table +print("\n4. Testing Models Table...") +try: + models_table = api.table(base_id, "Models") + records = models_table.all() + print(" ✓ Connected to Models table") + print(f" ✓ Found {len(records)} record(s)") + + if records: + print("\n Existing models:") + for record in records: + fields = record["fields"] + name = fields.get("model_name", "N/A") + acc = fields.get("test_accuracy", "N/A") + print(f" - {name} (accuracy: {acc})") + else: + print(" (Table is empty - this is OK for first run)") + +except Exception as e: + print(f" ✗ Failed to access Models table: {e}") + print("\n Make sure you created a table named 'Models' (case-sensitive)") + exit(1) + +# Test creating a dummy record +print("\n5. Testing Write Permissions...") +try: + test_record = models_table.create( + { + "model_name": "connection_test", + "model_family": "DynaCLR", + } + ) + print(f" ✓ Successfully created test record (ID: {test_record['id']})") + + # Clean up + models_table.delete(test_record["id"]) + print(" ✓ Successfully deleted test record") + +except Exception as e: + print(f" ✗ Failed to write to Datasets table: {e}") + print("\n Check your API token has 'data.records:write' scope") + exit(1) + +print("\n" + "=" * 70) +print("✅ SUCCESS: Airtable is configured correctly!") +# %% diff --git a/examples/airtable/test_pydantic_airtable.py b/examples/airtable/test_pydantic_airtable.py new file mode 100644 index 000000000..5f83f128b --- /dev/null +++ b/examples/airtable/test_pydantic_airtable.py @@ -0,0 +1,84 @@ +"""Test Pydantic dataset registration with Airtable.""" + +from viscy.airtable import AirtableManager, DatasetRecord + +BASE_ID = "app8vqaoWyOwa0sB5" + +print("=" * 70) +print("Testing Pydantic + Airtable Integration") +print("=" * 70) + +airtable_db = AirtableManager(base_id=BASE_ID) + +# %% +# Test 1: Single dataset registration +print("\n[Test 1] Register single dataset") +print("-" * 70) + +dataset = DatasetRecord( + dataset_name="pydantic_test_plate", + well_id="A_1", + fov_name="0", + data_path="/hpc/data/pydantic_test.zarr/A/1/0", + cell_type="A549", + organelle="SEC61B", +) + +try: + record_id = airtable_db.register_dataset(dataset) + print(f"✓ Registered: {record_id}") + print(f" Dataset: {dataset.dataset_name}") + print(f" Well: {dataset.well_id}, FOV: {dataset.fov_name}") + print(f" Path: {dataset.data_path}") +except Exception as e: + print(f"✗ Failed: {e}") + +# %% +# Test 2: Bulk dataset registration +print("\n[Test 2] Register multiple datasets") +print("-" * 70) + +datasets = [ + DatasetRecord( + dataset_name="pydantic_test_plate", + well_id=f"B_{well}", + fov_name=str(fov), + data_path=f"/hpc/data/pydantic_test.zarr/B/{well}/{fov}", + cell_type="A549", + ) + for well in range(1, 3) + for fov in range(2) +] + +try: + record_ids = airtable_db.register_datasets(datasets) + print(f"✓ Registered {len(record_ids)} datasets") + for ds, rec_id in zip(datasets, record_ids): + print(f" {ds.dataset_name}_{ds.well_id}_{ds.fov_name} -> {rec_id}") +except Exception as e: + print(f"✗ Failed: {e}") + +# %% +# Test 3: List datasets as Pydantic models +print("\n[Test 3] List datasets as Pydantic models") +print("-" * 70) + +try: + all_datasets = airtable_db.list_datasets(as_pydantic=True) + print(f"Total datasets in Airtable: {len(all_datasets)}") + + # Filter for our test datasets + test_datasets = [ds for ds in all_datasets if ds.fov_id.startswith("pydantic_test")] + print(f"\nTest datasets found: {len(test_datasets)}") + for ds in test_datasets: + print(f" - {ds.fov_id}") + print(f" Dataset: {ds.dataset_name}") + print(f" Cell type: {ds.cell_type}") + print(f" Record ID: {ds.record_id}") + +except Exception as e: + print(f"✗ Failed: {e}") + +print("\n" + "=" * 70) +print("Airtable integration test complete!") +print("=" * 70) diff --git a/examples/airtable/train_with_wandb.yml b/examples/airtable/train_with_wandb.yml new file mode 100644 index 000000000..b69730d3b --- /dev/null +++ b/examples/airtable/train_with_wandb.yml @@ -0,0 +1,75 @@ +# Example Lightning config using CollectionTripletDataModule with W&B tracking +# This config fetches dataset paths from Airtable collections and logs to W&B +# Usage: viscy fit -c examples/airtable/train_with_wandb.yml + +seed_everything: 42 + +trainer: + accelerator: gpu + devices: 1 + max_epochs: 100 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + default_root_dir: logs/wandb # W&B will create versioned subdirs + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: viscy-experiments + entity: null # Set to your W&B team name or leave null for personal + log_model: false # Don't upload checkpoints to W&B (too large) + save_dir: ./logs # Explicit save directory for config files + name: null # Auto-generate unique run name + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: loss/val + save_top_k: 3 + save_last: true + - class_path: viscy.airtable.callbacks.CollectionWandbCallback + # No init_args needed - automatically logs collection metadata + +model: + class_path: viscy.representation.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy.representation.contrastive.ContrastiveEncoder + init_args: + backbone: convnext_tiny + in_channels: 1 + in_stack_depth: 5 + stem_kernel_size: [5, 4, 4] + stem_stride: [5, 4, 4] + embedding_dim: 768 + projection_dim: 32 + loss_function: + class_path: pytorch_metric_learning.losses.NTXentLoss + init_args: + temperature: 0.05 + lr: 0.0002 + log_batches_per_epoch: 3 + log_samples_per_batch: 3 + example_input_array_shape: [1, 1, 5, 160, 160] + + +data: + # NEW: Use CollectionTripletDataModule for Airtable collection integration + class_path: viscy.airtable.factory.CollectionTripletDataModule + init_args: + # Airtable collection parameters + base_id: "app8vqaoWyOwa0sB5" # Replace with your base ID + collection_name: "2024_11_07_A549_SEC61_DENV_wells_B1_B2" + collection_version: "0.0.1" + + # TripletDataModule parameters + source_channel: [Phase3D] + z_range: [10, 15] + z_window_size: 5 + initial_yx_patch_size: [160, 160] + final_yx_patch_size: [160, 160] + batch_size: 16 + num_workers: 1 + split_ratio: 0.8 + time_interval: any + + # Optional: Override collection FOVs with specific wells + fit_include_wells: ["B/1", "B/2"] diff --git a/pyproject.toml b/pyproject.toml index 4dd3c254e..2e43579e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,9 @@ optional-dependencies.metrics = [ "cellpose>=3.0.10", "imbalanced-learn>=0.12", "mahotas", + "mlflow", "ptflops>=0.7", + "pyairtable", "scikit-learn>=1.1.3", "torchmetrics[detection]>=1.6.3", "umap-learn", @@ -62,14 +64,14 @@ optional-dependencies.phate = [ optional-dependencies.visual = [ "cmap", "dash", + "gradio>=5.49.1", + "graphviz", "ipykernel", "nbformat", "plotly", "seaborn", "torchview", - "gradio>=5.49.1", - ] scripts.viscy = "viscy.cli:main" @@ -87,4 +89,4 @@ lint.per-file-ignores."applications/**/*.py" = [ "E402" ] lint.per-file-ignores."examples/**/*.ipynb" = [ "E402" ] lint.per-file-ignores."examples/**/*.py" = [ "E402", "F821" ] lint.per-file-ignores."tests/**/*.py" = [ "E402" ] -lint.isort.known-first-party = [ "viscy" ] \ No newline at end of file +lint.isort.known-first-party = [ "viscy" ] diff --git a/tests/airtable/__init__.py b/tests/airtable/__init__.py new file mode 100644 index 000000000..bb1c48247 --- /dev/null +++ b/tests/airtable/__init__.py @@ -0,0 +1 @@ +"""Tests for Airtable integration module.""" diff --git a/tests/airtable/test_database.py b/tests/airtable/test_database.py new file mode 100644 index 000000000..e277fbe0b --- /dev/null +++ b/tests/airtable/test_database.py @@ -0,0 +1,497 @@ +"""Tests for Airtable database module using mocks.""" + +from unittest.mock import Mock, patch + +import pandas as pd +import pytest + +from viscy.airtable.database import AirtableManager, CollectionDataset, Collections +from viscy.airtable.schemas import DatasetRecord + +# ============================================================================ +# Tests for CollectionDataset dataclass +# ============================================================================ + + +def test_collection_dataset_creation(): + """Test creating a CollectionDataset.""" + dataset = CollectionDataset( + data_path="/hpc/data/plate.zarr", + tracks_path="/hpc/tracks/plate.zarr", + fov_names=["B/3/0", "B/3/1", "B/4/0"], + ) + + assert dataset.data_path == "/hpc/data/plate.zarr" + assert dataset.tracks_path == "/hpc/tracks/plate.zarr" + assert len(dataset) == 3 + assert len(dataset.fov_names) == 3 + + +def test_collection_dataset_fov_paths(): + """Test generating FOV paths.""" + dataset = CollectionDataset( + data_path="/hpc/data/plate.zarr", + tracks_path="/hpc/tracks/plate.zarr", + fov_names=["B/3/0", "B/3/1"], + ) + + fov_paths = dataset.fov_paths + assert fov_paths == [ + "/hpc/data/plate.zarr/B/3/0", + "/hpc/data/plate.zarr/B/3/1", + ] + + +def test_collection_dataset_exists(tmp_path): + """Test checking if paths exist.""" + # Create actual directories + data_path = tmp_path / "data.zarr" + tracks_path = tmp_path / "tracks.zarr" + data_path.mkdir() + tracks_path.mkdir() + + dataset = CollectionDataset( + data_path=str(data_path), + tracks_path=str(tracks_path), + fov_names=["0", "1"], + ) + + assert dataset.exists() is True + + # Test with non-existent paths + dataset_bad = CollectionDataset( + data_path="/nonexistent/data.zarr", + tracks_path="/nonexistent/tracks.zarr", + fov_names=["0"], + ) + + assert dataset_bad.exists() is False + + +def test_collection_dataset_validate(tmp_path): + """Test validation raises error for missing paths.""" + data_path = tmp_path / "data.zarr" + data_path.mkdir() + + # Missing tracks path + dataset = CollectionDataset( + data_path=str(data_path), + tracks_path="/nonexistent/tracks.zarr", + fov_names=["0"], + ) + + with pytest.raises(FileNotFoundError, match="Tracks path not found"): + dataset.validate() + + +# ============================================================================ +# Tests for Collections dataclass +# ============================================================================ + + +def test_collections_creation(): + """Test creating a Collections object.""" + datasets = [ + CollectionDataset("/data1.zarr", "/tracks1.zarr", ["0", "1"]), + CollectionDataset("/data2.zarr", "/tracks2.zarr", ["0", "1", "2"]), + ] + + collection = Collections( + name="test_collection", + version="0.0.1", + datasets=datasets, + ) + + assert collection.name == "test_collection" + assert collection.version == "0.0.1" + assert len(collection) == 2 + assert collection.total_fovs == 5 # 2 + 3 + + +def test_collections_iteration(): + """Test iterating over datasets in a collection.""" + datasets = [ + CollectionDataset("/data1.zarr", "/tracks1.zarr", ["0"]), + CollectionDataset("/data2.zarr", "/tracks2.zarr", ["0", "1"]), + ] + + collection = Collections("test", "0.0.1", datasets) + + dataset_list = list(collection) + assert len(dataset_list) == 2 + assert dataset_list[0].data_path == "/data1.zarr" + assert dataset_list[1].data_path == "/data2.zarr" + + +def test_collections_validate(tmp_path): + """Test validation checks all dataset paths.""" + # Create valid paths for first dataset + data1 = tmp_path / "data1.zarr" + tracks1 = tmp_path / "tracks1.zarr" + data1.mkdir() + tracks1.mkdir() + + datasets = [ + CollectionDataset(str(data1), str(tracks1), ["0"]), + CollectionDataset("/bad/data.zarr", "/bad/tracks.zarr", ["0"]), # Invalid + ] + + collection = Collections("test", "0.0.1", datasets) + + with pytest.raises(FileNotFoundError, match="Data path not found"): + collection.validate() + + +# ============================================================================ +# Tests for AirtableManager +# ============================================================================ + + +@pytest.fixture +def mock_airtable_api(): + """Create a mock Airtable API.""" + api = Mock() + api.table.return_value = Mock() + return api + + +@pytest.fixture +def airtable_manager(mock_airtable_api): + """Create an AirtableManager with mocked API.""" + with patch("viscy.airtable.database.Api", return_value=mock_airtable_api): + with patch.dict("os.environ", {"AIRTABLE_API_KEY": "test_key"}): + manager = AirtableManager(base_id="test_base_id") + return manager + + +def test_airtable_manager_init_requires_api_key(): + """Test that AirtableManager requires an API key.""" + with patch.dict("os.environ", {}, clear=True): + # Remove AIRTABLE_API_KEY from env + with pytest.raises(ValueError, match="Airtable API key required"): + AirtableManager(base_id="test_base") + + +def test_airtable_manager_init_with_explicit_key(): + """Test initializing with explicit API key.""" + with patch("viscy.airtable.database.Api") as mock_api_class: + AirtableManager(base_id="test_base", api_key="explicit_key") + + # Verify Api was called with the explicit key + mock_api_class.assert_called_once_with("explicit_key") + + +def test_airtable_manager_init_with_env_key(): + """Test initializing with API key from environment.""" + with patch("viscy.airtable.database.Api") as mock_api_class: + with patch.dict("os.environ", {"AIRTABLE_API_KEY": "env_key"}): + AirtableManager(base_id="test_base") + + mock_api_class.assert_called_once_with("env_key") + + +def test_register_dataset(airtable_manager): + """Test registering a single dataset.""" + dataset = DatasetRecord( + dataset_name="test_plate", + well_id="B_3", + fov_name="0", + data_path="/hpc/data/test.zarr/B/3/0", + ) + + # Mock the create response + airtable_manager.datasets_table.create.return_value = {"id": "rec123"} + + record_id = airtable_manager.register_dataset(dataset) + + assert record_id == "rec123" + airtable_manager.datasets_table.create.assert_called_once() + + # Verify the data passed to Airtable + call_args = airtable_manager.datasets_table.create.call_args[0][0] + assert call_args["Dataset"] == "test_plate" + assert call_args["Well ID"] == "B_3" + assert call_args["FOV"] == "0" + + +def test_register_datasets_multiple(airtable_manager): + """Test registering multiple datasets.""" + datasets = [ + DatasetRecord( + dataset_name="plate", + well_id="B_3", + fov_name="0", + data_path="/hpc/data/plate.zarr/B/3/0", + ), + DatasetRecord( + dataset_name="plate", + well_id="B_3", + fov_name="1", + data_path="/hpc/data/plate.zarr/B/3/1", + ), + ] + + # Mock create responses + airtable_manager.datasets_table.create.side_effect = [ + {"id": "rec123"}, + {"id": "rec456"}, + ] + + record_ids = airtable_manager.register_datasets(datasets) + + assert record_ids == ["rec123", "rec456"] + assert airtable_manager.datasets_table.create.call_count == 2 + + +def test_create_collection_from_datasets(airtable_manager): + """Test creating a collection from dataset FOV IDs.""" + # Mock list_collections to return empty (no duplicates) + with patch.object( + airtable_manager, "list_collections", return_value=pd.DataFrame() + ): + # Mock the datasets_table.all() response for FOV lookup + def mock_all(formula=None): + if formula == "{FOV_ID}='plate_B_3_0'": + return [{"id": "rec1", "fields": {"FOV_ID": "plate_B_3_0"}}] + elif formula == "{FOV_ID}='plate_B_3_1'": + return [{"id": "rec2", "fields": {"FOV_ID": "plate_B_3_1"}}] + return [] + + airtable_manager.datasets_table.all.side_effect = mock_all + + # Mock the collections_table.create response + airtable_manager.collections_table.create.return_value = {"id": "col123"} + + collection_id = airtable_manager.create_collection_from_datasets( + collection_name="test_collection", + fov_ids=["plate_B_3_0", "plate_B_3_1"], + version="0.0.1", + purpose="training", + ) + + assert collection_id == "col123" + + # Verify the collection was created with correct data + airtable_manager.collections_table.create.assert_called_once() + call_args = airtable_manager.collections_table.create.call_args[0][0] + + assert call_args["name"] == "test_collection" + assert call_args["version"] == "0.0.1" + assert call_args["purpose"] == "training" + assert call_args["datasets"] == ["rec1", "rec2"] # Linked record IDs + + +def test_create_collection_validates_version_format(airtable_manager): + """Test that collection creation validates version format.""" + with pytest.raises(ValueError, match="semantic version format"): + airtable_manager.create_collection_from_datasets( + collection_name="test", + fov_ids=["fov1"], + version="invalid_version", # Should be like "0.0.1" + ) + + +def test_list_datasets_returns_pydantic_models(airtable_manager): + """Test listing datasets returns Pydantic models.""" + # Mock the all() response + airtable_manager.datasets_table.all.return_value = [ + { + "id": "rec123", + "fields": { + "Dataset": "test_plate", + "Well ID": "B_3", + "FOV": "0", + "Data path": "/hpc/data/test.zarr/B/3/0", + }, + } + ] + + datasets = airtable_manager.list_datasets(as_pydantic=True) + + assert len(datasets) == 1 + assert isinstance(datasets[0], DatasetRecord) + assert datasets[0].dataset_name == "test_plate" + assert datasets[0].well_id == "B_3" + + +def test_list_datasets_returns_dicts(airtable_manager): + """Test listing datasets can return raw dictionaries.""" + airtable_manager.datasets_table.all.return_value = [ + { + "id": "rec123", + "fields": {"Dataset": "test_plate"}, + } + ] + + datasets = airtable_manager.list_datasets(as_dataframe=False, as_pydantic=False) + + assert len(datasets) == 1 + assert isinstance(datasets[0], dict) + assert datasets[0]["Dataset"] == "test_plate" + + +def test_get_collection_data_paths(airtable_manager): + """Test getting data paths from a collection.""" + # Mock list_collections to return collection info + collections_df = pd.DataFrame( + [ + { + "id": "col123", + "name": "test_collection", + "version": "0.0.1", + "datasets": ["rec1", "rec2"], + } + ] + ) + + with patch.object( + airtable_manager, "list_collections", return_value=collections_df + ): + # Mock datasets_table.get() for individual record fetches + def mock_get(record_id): + if record_id == "rec1": + return { + "id": "rec1", + "fields": { + "Dataset": "plate", + "Data path": "/hpc/data/plate.zarr", + "Well ID": "B_3", + "FOV": "0", + }, + } + elif record_id == "rec2": + return { + "id": "rec2", + "fields": { + "Dataset": "plate", + "Data path": "/hpc/data/plate.zarr", + "Well ID": "B_3", + "FOV": "1", + }, + } + return None + + airtable_manager.datasets_table.get.side_effect = mock_get + + collection = airtable_manager.get_dataset_paths( + collection_name="test_collection", version="0.0.1" + ) + + assert isinstance(collection, Collections) + assert collection.name == "test_collection" + assert collection.version == "0.0.1" + assert collection.total_fovs == 2 + + +def test_log_model_training(airtable_manager): + """Test logging model training to Models table.""" + airtable_manager.models_table.create.return_value = {"id": "model123"} + + # Mock collections_table.get() for update operation + airtable_manager.collections_table.get.return_value = { + "id": "col123", + "fields": {"models_trained": ""}, + } + + model_id = airtable_manager.log_model_training( + collection_id="col123", + wandb_run_id="run456", + model_name="contrastive-v1", + checkpoint_path="/hpc/models/model.ckpt", + trained_by="test_user", + metrics={"val_loss": 0.15}, + ) + + assert model_id == "model123" + + # Verify the model record was created correctly + airtable_manager.models_table.create.assert_called_once() + call_args = airtable_manager.models_table.create.call_args[0][0] + + assert call_args["collection"] == ["col123"] + assert call_args["wandb_run_id"] == "run456" # Fixed: was mlflow_run_id + assert call_args["model_name"] == "contrastive-v1" + assert call_args["checkpoint_path"] == "/hpc/models/model.ckpt" + assert call_args["val_loss"] == 0.15 + + +def test_list_collections(airtable_manager): + """Test listing all collections.""" + airtable_manager.collections_table.all.return_value = [ + { + "id": "col1", + "fields": { + "name": "collection_1", + "version": "0.0.1", + "purpose": "training", + }, + }, + { + "id": "col2", + "fields": { + "name": "collection_2", + "version": "0.0.2", + "purpose": "validation", + }, + }, + ] + + # Test returning DataFrame (default) + collections_df = airtable_manager.list_collections() + + assert isinstance(collections_df, pd.DataFrame) + assert len(collections_df) == 2 + assert collections_df.iloc[0]["name"] == "collection_1" + assert collections_df.iloc[1]["purpose"] == "validation" + + # Test returning list of dicts + collections_list = airtable_manager.list_collections(as_dataframe=False) + + assert isinstance(collections_list, list) + assert len(collections_list) == 2 + assert collections_list[0]["name"] == "collection_1" + assert collections_list[1]["purpose"] == "validation" + + +def test_get_models_for_collection(airtable_manager): + """Test getting all models trained on a specific collection.""" + airtable_manager.models_table.all.return_value = [ + { + "id": "model1", + "fields": { + "model_name": "model-v1", + "collection": ["col123"], + "trained_date": "2026-01-12", + }, + }, + { + "id": "model2", + "fields": { + "model_name": "model-v2", + "collection": ["col123"], + "trained_date": "2026-01-13", + }, + }, + ] + + # Test returning DataFrame (default, sorted by trained_date descending) + models_df = airtable_manager.get_models_for_collection(collection_id="col123") + + assert isinstance(models_df, pd.DataFrame) + assert len(models_df) == 2 + # Latest model comes first due to descending sort + assert models_df.iloc[0]["model_name"] == "model-v2" + assert models_df.iloc[0]["trained_date"] == "2026-01-13" + assert models_df.iloc[1]["model_name"] == "model-v1" + assert models_df.iloc[1]["trained_date"] == "2026-01-12" + + # Test returning list of dicts + models_list = airtable_manager.get_models_for_collection( + collection_id="col123", as_dataframe=False + ) + + assert isinstance(models_list, list) + assert len(models_list) == 2 + # List doesn't get sorted automatically, so order is preserved from input + assert models_list[0]["model_name"] == "model-v1" + assert models_list[1]["model_name"] == "model-v2" diff --git a/viscy/airtable/__init__.py b/viscy/airtable/__init__.py new file mode 100644 index 000000000..04bc4cf96 --- /dev/null +++ b/viscy/airtable/__init__.py @@ -0,0 +1,28 @@ +"""Airtable integration for dataset management and tracking.""" + +from viscy.airtable.callbacks import CollectionWandbCallback +from viscy.airtable.database import AirtableManager, CollectionDataset, Collections +from viscy.airtable.factory import ( + CollectionTripletDataModule, + create_triplet_datamodule_from_collection, +) +from viscy.airtable.register_model import ( + list_registered_models, + load_model_from_registry, + register_model, +) +from viscy.airtable.schemas import DatasetRecord, ModelRecord + +__all__ = [ + "AirtableManager", + "DatasetRecord", + "ModelRecord", + "Collections", + "CollectionDataset", + "CollectionTripletDataModule", + "CollectionWandbCallback", + "create_triplet_datamodule_from_collection", + "register_model", + "load_model_from_registry", + "list_registered_models", +] diff --git a/viscy/airtable/callbacks.py b/viscy/airtable/callbacks.py new file mode 100644 index 000000000..c9afd0d82 --- /dev/null +++ b/viscy/airtable/callbacks.py @@ -0,0 +1,192 @@ +"""Lightning callback to log training results to Airtable.""" + +import getpass +from typing import Any + +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import Callback + +from viscy.airtable.database import AirtableManager + + +class AirtableLoggingCallback(Callback): + """ + Log model training to Airtable after training completes. + + This callback automatically records: + - Best model checkpoint path + - Who trained the model + - When it was trained + - Link to the collection used + + Parameters + ---------- + base_id : str + Airtable base ID + collection_id : str + Airtable collection record ID (from config) + model_name : str | None + Custom model name. If None, auto-generates from model class and timestamp. + log_metrics : bool + Whether to log metrics to Airtable (default: False). + If False, metrics should be viewed in TensorBoard. + + Examples + -------- + Add to config YAML: + + >>> trainer: + >>> callbacks: + >>> - class_path: viscy.airtable.callbacks.AirtableLoggingCallback + >>> init_args: + >>> base_id: "appXXXXXXXXXXXXXX" + >>> collection_id: "recYYYYYYYYYYYYYY" + + Or add programmatically: + + >>> callback = AirtableLoggingCallback( + >>> base_id="appXXXXXXXXXXXXXX", + >>> collection_id="recYYYYYYYYYYYYYY" + >>> ) + >>> trainer = Trainer(callbacks=[callback]) + """ + + def __init__( + self, + base_id: str, + collection_id: str, + model_name: str | None = None, + log_metrics: bool = False, + ): + super().__init__() + self.airtable_db = AirtableManager(base_id=base_id) + self.collection_id = collection_id + self.model_name = model_name + self.log_metrics = log_metrics + + def on_fit_end(self, trainer: Trainer, pl_module: Any) -> None: + """Log model to Airtable after training completes.""" + # Get best checkpoint path + checkpoint_path = None + if trainer.checkpoint_callback: + checkpoint_path = trainer.checkpoint_callback.best_model_path + if not checkpoint_path: # Fallback to last checkpoint + checkpoint_path = trainer.checkpoint_callback.last_model_path + + # Generate model name + if self.model_name: + model_name = self.model_name + else: + model_class = pl_module.__class__.__name__ + logger_version = trainer.logger.version if trainer.logger else "unknown" + model_name = f"{model_class}_{logger_version}" + + # Optionally collect metrics + metrics = None + if self.log_metrics and trainer.callback_metrics: + metrics = {} + for key, value in trainer.callback_metrics.items(): + # Only log test metrics or validation metrics + if "test" in key or "val" in key: + try: + metrics[key] = float(value) + except (TypeError, ValueError): + pass # Skip non-numeric metrics + + # Get logger run ID (works with TensorBoard or MLflow) + run_id = None + if trainer.logger: + if hasattr(trainer.logger, "run_id"): + run_id = trainer.logger.run_id # MLflow + elif hasattr(trainer.logger, "version"): + run_id = str(trainer.logger.version) # TensorBoard + + # Log to Airtable + try: + model_id = self.airtable_db.log_model_training( + collection_id=self.collection_id, + mlflow_run_id=run_id or "unknown", + model_name=model_name, + checkpoint_path=str(checkpoint_path) if checkpoint_path else None, + trained_by=getpass.getuser(), + metrics=metrics, + ) + print(f"\n✓ Model logged to Airtable (record ID: {model_id})") + print(f" Model name: {model_name}") + print(f" Checkpoint: {checkpoint_path}") + print(f" Collections ID: {self.collection_id}") + except Exception as e: + print(f"\n✗ Failed to log to Airtable: {e}") + # Don't fail training if Airtable logging fails + + +class CollectionWandbCallback(Callback): + """ + Log collection metadata to Weights & Biases automatically. + + This callback extracts collection information from CollectionTripletDataModule + and logs it to W&B config for searchability and lineage tracking. + + Examples + -------- + Add to config YAML: + + >>> trainer: + >>> logger: + >>> class_path: lightning.pytorch.loggers.WandbLogger + >>> init_args: + >>> project: viscy-experiments + >>> log_model: false + >>> callbacks: + >>> - class_path: viscy.airtable.callbacks.CollectionWandbCallback + + Or add programmatically: + + >>> from lightning.pytorch.loggers import WandbLogger + >>> logger = WandbLogger(project="viscy-experiments") + >>> callback = CollectionWandbCallback() + >>> trainer = Trainer(logger=logger, callbacks=[callback]) + """ + + def on_train_start(self, trainer: Trainer, pl_module: Any) -> None: + """Log collection metadata to W&B config at training start.""" + # Import here to avoid requiring wandb as a dependency + try: + from lightning.pytorch.loggers import WandbLogger + except ImportError: + return # Skip if wandb not installed + + # Check if using WandbLogger + if not isinstance(trainer.logger, WandbLogger): + return + + # Check if using CollectionTripletDataModule + from viscy.airtable.factory import CollectionTripletDataModule + + dm = trainer.datamodule + + # Log collection metadata if using CollectionTripletDataModule + if isinstance(dm, CollectionTripletDataModule): + collection_config = { + "collection/name": dm.collection_name, + "collection/version": dm.collection_version, + "collection/base_id": dm.base_id, + "collection/data_path": str(dm.data_path), + "collection/tracks_path": str(dm.tracks_path), + } + trainer.logger.experiment.config.update(collection_config) + + print("\n✓ Collections metadata logged to W&B:") + print(f" Collections: {dm.collection_name} v{dm.collection_version}") + print(f" Data path: {dm.data_path}") + print(f" Tracks path: {dm.tracks_path}") + + # Also log data module hyperparameters explicitly + if dm is not None and hasattr(dm, "hparams"): + data_config = {f"data/{k}": v for k, v in dm.hparams.items()} + trainer.logger.experiment.config.update(data_config, allow_val_change=True) + + # Log model hyperparameters explicitly + if hasattr(pl_module, "hparams"): + model_config = {f"model/{k}": v for k, v in pl_module.hparams.items()} + trainer.logger.experiment.config.update(model_config, allow_val_change=True) diff --git a/viscy/airtable/database.py b/viscy/airtable/database.py new file mode 100644 index 000000000..105091e0d --- /dev/null +++ b/viscy/airtable/database.py @@ -0,0 +1,973 @@ +"""FOV-level dataset airtable_db with Airtable.""" + +import getpass +import os +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any + +import pandas as pd +from natsort import natsorted +from pyairtable import Api + +from viscy.airtable.schemas import DatasetRecord + + +@dataclass +class CollectionDataset: + """ + Dataset paths for one HCS plate/zarr store. + + A collection may contain multiple stores, each returned as a separate CollectionDataset. + """ + + data_path: str + tracks_path: str + fov_names: list[str] + + def __len__(self) -> int: + return len(self.fov_names) + + @property + def fov_paths(self) -> list[str]: + """Full paths to each FOV: {data_path}/{fov_name}.""" + return [f"{self.data_path}/{fov}" for fov in self.fov_names] + + def exists(self) -> bool: + """Check if data_path and tracks_path exist.""" + return Path(self.data_path).exists() and Path(self.tracks_path).exists() + + def validate(self) -> None: + """Raise FileNotFoundError if paths don't exist.""" + if not Path(self.data_path).exists(): + raise FileNotFoundError(f"Data path not found: {self.data_path}") + if not Path(self.tracks_path).exists(): + raise FileNotFoundError(f"Tracks path not found: {self.tracks_path}") + + +@dataclass +class Collections: + """All datasets for a collection, potentially across multiple HCS plates.""" + + name: str + version: str + datasets: list[CollectionDataset] + + def __iter__(self): + """Iterate over datasets.""" + return iter(self.datasets) + + def __len__(self): + """Total number of HCS plates.""" + return len(self.datasets) + + @property + def total_fovs(self) -> int: + """Total FOVs across all plates.""" + return sum(len(ds) for ds in self.datasets) + + def validate(self) -> None: + """Validate all dataset paths exist.""" + for ds in self.datasets: + ds.validate() + + +# TODO: update the usage examples in the docstrings +# TODO: update the headers to match the Airtable columns and potentially move to separate file +# TODO: Convert these to Pydantic models, so we can easily dump and load from/to Airtable + +DATASETS_INDEX = [ + "Dataset", + "Well ID", + "FOV", + "Cell type", + "Cell state", + "Cell line", + "Organelle", + "Channel-0", + "Channel-1", + "Channel-2", + "Data path", + "Fluorescence modality", + "OrganelleBox Infectomics", + "FOV_ID", +] + +MODELS_INDEX = [ + "model_name", + "model_family", + "trained_date", + "trained_by", + "tensorboard_log", + "mlflow_run_id", +] + +MANIFESTS_INDEX = [ + "name", + "version", + "datasets", + "project", + "purpose", + "created_by", + "created_time", +] + + +class AirtableManager: + """ + Unified interface to Airtable for dataset, collection, and model management. + + Use this to: + - Register individual FOVs from HCS plates + - Create and manage dataset collections (collections of FOVs) + - Track model training on collections + - Query datasets, collections, and models + + Parameters + ---------- + base_id : str + Airtable base ID + api_key : str | None + Airtable API key. If None, reads from AIRTABLE_API_KEY env var. + + Examples + -------- + >>> airtable_db = AirtableManager(base_id="appXXXXXXXXXXXXXX") + >>> + >>> # Create collection from FOV selection + >>> collection_id = airtable_db.create_collection_from_datasets( + ... collection_name="RPE1_infection_v2", + ... fov_ids=["FOV_001", "FOV_002", "FOV_004"], + ... version="0.0.1", + ... purpose="training" + ... ) + >>> + >>> # Track model training + >>> airtable_db.log_model_training( + ... collection_id=collection_id, + ... mlflow_run_id="run_123", + ... model_name="my_model", + ... ) + >>> + >>> # Get all FOV paths for a collection + >>> fov_paths = airtable_db.get_collection_data_paths("RPE1_infection_v2") + >>> print(fov_paths) + >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] + """ + + def __init__( + self, + base_id: str, + api_key: str | None = None, + ): + api_key = api_key or os.getenv("AIRTABLE_API_KEY") + if not api_key: + raise ValueError("Airtable API key required (set AIRTABLE_API_KEY)") + + self.api = Api(api_key) + self.base_id = base_id + self.datasets_table = self.api.table(base_id, "Datasets") + self.collections_table = self.api.table(base_id, "Collections") + self.models_table = self.api.table(base_id, "Models") + + def register_dataset(self, dataset: DatasetRecord) -> str: + """ + Register a single dataset record (FOV) in Airtable. + + Parameters + ---------- + dataset : DatasetRecord + Dataset record with FOV metadata + + Returns + ------- + str + Airtable record ID + + Examples + -------- + >>> from viscy.airtable.schemas import DatasetRecord + >>> dataset = DatasetRecord( + ... fov_id="plate_B_3_0", + ... dataset_name="plate", + ... well_id="B_3", + ... fov_name="0", + ... data_path="/hpc/data/plate.zarr/B/3/0" + ... ) + >>> record_id = airtable_db.register_dataset(dataset) + """ + record_dict = dataset.to_airtable_dict() + created = self.datasets_table.create(record_dict) + return created["id"] + + def register_datasets(self, datasets: list[DatasetRecord]) -> list[str]: + """ + Register multiple dataset records (FOVs) in Airtable. + + Parameters + ---------- + datasets : list[DatasetRecord] + List of dataset records to register + + Returns + ------- + list[str] + List of Airtable record IDs + + Examples + -------- + >>> from viscy.airtable.schemas import DatasetRecord + >>> datasets = [ + ... DatasetRecord( + ... fov_id="plate_B_3_0", + ... dataset_name="plate", + ... well_id="B_3", + ... fov_name="0", + ... data_path="/hpc/data/plate.zarr/B/3/0" + ... ), + ... DatasetRecord( + ... fov_id="plate_B_3_1", + ... dataset_name="plate", + ... well_id="B_3", + ... fov_name="1", + ... data_path="/hpc/data/plate.zarr/B/3/1" + ... ), + ... ] + >>> record_ids = airtable_db.register_datasets(datasets) + """ + record_ids = [] + for dataset in datasets: + record_dict = dataset.to_airtable_dict() + created = self.datasets_table.create(record_dict) + record_ids.append(created["id"]) + return record_ids + + def create_collection_from_datasets( + self, + collection_name: str, + fov_ids: list[str], + version: str, + purpose: str = "training", + project_name: str | None = None, + description: str | None = None, + ) -> str: + """ + Create a collection (collection) from a list of FOV IDs. + + Parameters + ---------- + collection_name : str + Name for this collection + + fov_ids : list[str] + List of FOV_ID values from Datasets table (e.g., ["plate1_B_3_0", "plate1_B_3_1"]) + version : str + Semantic version (e.g., "0.0.1", "0.1.0", "1.0.0") + purpose : str + Purpose of this collection ("training", "validation", "test") + project_name : str | None + Project Name (e.g OrganelleBox, DynaCLR, etc.) + description : str | None + Human-readable description + + Returns + ------- + str + Airtable collection record ID + + Examples + -------- + >>> airtable_db.create_collection_from_datasets( + ... collection_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2", + ... project_name="OrganelleBox", + ... fov_ids=["2024_11_07_A549_SEC61_DENV_B1_0", "2024_11_07_A549_SEC61_DENV_B1_1"], + ... version="0.0.1", + ... purpose="training", + ... project_name="OrganelleBox", + ... description="High-quality dataset records from wells B3-B4" + ... ) + """ + # Validate semantic version format + import re + + if not re.match(r"^\d+\.\d+\.\d+$", version): + raise ValueError( + f"Version must be semantic version format (e.g., '0.0.1', '1.0.0'), got: '{version}'" + ) + + # Check if collection with same name + version exists (use DataFrame) + df_collections = self.list_collections() + + # Only check for duplicates if table is not empty and has required columns + if ( + len(df_collections) > 0 + and "name" in df_collections.columns + and "version" in df_collections.columns + ): + existing = df_collections[ + (df_collections["name"] == collection_name) + & (df_collections["version"] == version) + ] + + if len(existing) > 0: + raise ValueError( + f"Collections '{collection_name}' version '{version}' already exists. " + f"To create a new version, increment the version number (e.g., '0.0.2')." + ) + + existing_versions = df_collections[ + df_collections["name"] == collection_name + ] + if len(existing_versions) > 0: + versions = sorted(existing_versions["version"].tolist()) + print( + f"ℹ Collections '{collection_name}' existing versions: {versions}" + ) + print(f" Creating new version: '{version}'") + + # Get Airtable record IDs for these FOV IDs (ensure unique) + dataset_record_ids = [] + seen_fov_ids = set() + + for fov_id in fov_ids: + if fov_id in seen_fov_ids: + continue # Skip duplicates + + formula = f"{{FOV_ID}}='{fov_id}'" + records = self.datasets_table.all(formula=formula) + if records: + dataset_record_ids.append(records[0]["id"]) + seen_fov_ids.add(fov_id) + else: + raise ValueError(f"FOV ID '{fov_id}' not found in Datasets table") + + # Remove any duplicate record IDs (shouldn't happen, but extra safety) + dataset_record_ids = list(dict.fromkeys(dataset_record_ids)) + + # Create collection record + collection_record = { + "name": collection_name, + "datasets": dataset_record_ids, # Linked records (unique) + "version": version, # Semantic version (required) + "purpose": purpose, + "created_by": getpass.getuser(), + } + if project_name: + collection_record["project"] = project_name + if description: + collection_record["description"] = description + + created = self.collections_table.create(collection_record) + return created["id"] + + def create_collection_from_query( + self, + collection_name: str, + version: str, + source_dataset: str | None = None, + well_ids: list[str] | None = None, + exclude_fov_ids: list[str] | None = None, + **kwargs, + ) -> str: + """ + Create a collection by filtering dataset records with pandas. + + Parameters + ---------- + collection_name : str + Name for this collection + version : str + Semantic version (e.g., "0.0.1") - REQUIRED + source_dataset : str | None + Filter by source dataset name (from 'Dataset' field) + well_ids : list[str] | None + Filter by well identifiers (e.g., ["B_3", "B_4"]) + exclude_fov_ids : list[str] | None + FOV_ID values to exclude + **kwargs + Additional arguments for create_collection_from_datasets + + Returns + ------- + str + Airtable collection record ID + + Examples + -------- + >>> # Create collection from specific wells in a dataset + >>> airtable_db.create_collection_from_query( + ... collection_name="RPE1_infection_training", + ... version="0.0.1", + ... source_dataset="RPE1_plate1", + ... well_ids=["B_3", "B_4"], + ... exclude_fov_ids=["RPE1_plate1_B_3_2"] + ... ) + """ + # Get all dataset records as DataFrame + df = self.list_datasets() + + # Apply filters with pandas + if source_dataset: + df = df[df["Dataset"] == source_dataset] + + if well_ids: + df = df[df["Well ID"].isin(well_ids)] + + # Exclude specified FOVs + if exclude_fov_ids: + df = df[~df["FOV_ID"].isin(exclude_fov_ids)] + + fov_ids = df["FOV_ID"].tolist() + + print(f"Found {len(fov_ids)} dataset records matching criteria") + + # Create collection + return self.create_collection_from_datasets( + collection_name=collection_name, version=version, fov_ids=fov_ids, **kwargs + ) + + def get_collection_data_paths( + self, collection_name: str, version: str | None = None + ) -> list[str]: + """ + Get list of data paths for a collection. + + Parameters + ---------- + collection_name : str + Collections name + version : str | None + Specific version (if None, returns latest) + + Returns + ------- + list[str] + List of data paths + + Examples + -------- + >>> paths = airtable_db.get_collection_data_paths("RPE1_infection_v2") + >>> print(paths) + >>> # ['/hpc/data/rpe1.zarr/B/3/0', '/hpc/data/rpe1.zarr/B/3/1', ...] + """ + # Get all collections as DataFrame + df_collections = self.list_collections() + + if len(df_collections) == 0 or "name" not in df_collections.columns: + raise ValueError( + f"Collections '{collection_name}' not found (table is empty)" + ) + + # Filter by name + filtered = df_collections[df_collections["name"] == collection_name] + + if len(filtered) == 0: + raise ValueError(f"Collections '{collection_name}' not found") + + # Filter by version if specified, otherwise get latest + if version: + if "version" not in df_collections.columns: + raise ValueError("Version field not found in Collections table") + filtered = filtered[filtered["version"] == version] + if len(filtered) == 0: + raise ValueError( + f"Collections '{collection_name}' version '{version}' not found" + ) + else: + # Get latest version (sort by created_time if column exists) + if "created_time" in filtered.columns: + filtered = filtered.sort_values("created_time", ascending=False) + + # Get the first (or only) matching collection + collection_row = filtered.iloc[0] + + # Get linked dataset record IDs + dataset_record_ids = collection_row.get("datasets", []) + if not dataset_record_ids or len(dataset_record_ids) == 0: + return [] + + # Fetch data paths + data_paths = [] + for dataset_id in dataset_record_ids: + dataset_record = self.datasets_table.get(dataset_id) + data_paths.append(dataset_record["fields"]["Data path"]) + + return data_paths + + def get_collection( + self, collection_name: str, version: str | None = None + ) -> dict[str, Any]: + """ + Get full collection information including data paths. + + Parameters + ---------- + collection_name : str + Collections name + version : str | None + Specific version + + Returns + ------- + dict + Collections info with data paths and metadata + """ + # Get all collections as DataFrame + df_collections = self.list_collections() + + if len(df_collections) == 0 or "name" not in df_collections.columns: + raise ValueError( + f"Collections '{collection_name}' not found (table is empty)" + ) + + # Filter by name + filtered = df_collections[df_collections["name"] == collection_name] + + if len(filtered) == 0: + raise ValueError(f"Collections '{collection_name}' not found") + + # Filter by version if specified, otherwise get latest + if version: + if "version" not in df_collections.columns: + raise ValueError("Version field not found in Collections table") + filtered = filtered[filtered["version"] == version] + if len(filtered) == 0: + raise ValueError( + f"Collections '{collection_name}' version '{version}' not found" + ) + else: + # Get latest version (sort by created_time if column exists) + if "created_time" in filtered.columns: + filtered = filtered.sort_values("created_time", ascending=False) + + # Get the first (or only) matching collection + collection_row = filtered.iloc[0] + collection = collection_row.to_dict() + + # Add data paths + collection["data_paths"] = self.get_collection_data_paths( + collection_name, version + ) + + return collection + + def list_collections( + self, purpose: str | None = None, as_dataframe: bool = True + ) -> pd.DataFrame | list[dict]: + """ + List all collections. + + Parameters + ---------- + purpose : str | None + Filter by purpose ("training", "validation", "test") + as_dataframe : bool + If True, return pandas DataFrame. If False, return list of dicts. + + Returns + ------- + pd.DataFrame | list[dict] + Collections records as DataFrame or list of dicts + + Examples + -------- + >>> airtable_db.list_collections(purpose="training") + >>> # Returns DataFrame with columns: id, name, version, purpose, ... + """ + # Fetch all collections (try sorting, but don't fail if field doesn't exist) + try: + records = self.collections_table.all(sort=["-created_time"]) + except Exception: + # If sort fails (field might not exist), fetch without sorting + records = self.collections_table.all() + + data = [{"id": r["id"], **r["fields"]} for r in records] + + # Convert to DataFrame or list + if as_dataframe: + df = pd.DataFrame(data) + # Sort by created_time if column exists + if len(df) > 0 and "created_time" in df.columns: + df = df.sort_values("created_time", ascending=False) + # Filter by purpose if specified + if purpose and len(df) > 0 and "purpose" in df.columns: + df = df[df["purpose"] == purpose] + return df + else: + # Filter list if purpose specified + if purpose: + data = [d for d in data if d.get("purpose") == purpose] + return data + + def list_datasets( + self, + as_dataframe: bool = True, + as_pydantic: bool = False, + skip_invalid: bool = True, + ) -> pd.DataFrame | list[dict] | list[DatasetRecord]: + """ + Get all dataset records (FOVs) as a DataFrame, list of dicts, or Pydantic models. + + Use pandas for filtering - much simpler and more powerful than + building Airtable formulas. + + Parameters + ---------- + as_dataframe : bool + If True, return pandas DataFrame. Ignored if as_pydantic is True. + as_pydantic : bool + If True, return list of DatasetRecord objects. Takes precedence over as_dataframe. + skip_invalid : bool + If True and as_pydantic=True, skip records that fail validation instead of raising error. + Default is True to handle legacy/incomplete records gracefully. + + Returns + ------- + pd.DataFrame | list[dict] | list[DatasetRecord] + All dataset records + + Examples + -------- + >>> # Get all datasets as DataFrame + >>> df = airtable_db.list_datasets() + >>> + >>> # Filter with pandas (simple and powerful!) + >>> filtered = df[df['Dataset'] == 'RPE1_plate1'] + >>> filtered = df[df['Well ID'].isin(['B_3', 'B_4'])] + >>> filtered = df[~df['FOV_ID'].isin(['RPE1_plate1_B_3_2'])] + >>> + >>> # Get as Pydantic models for type safety + >>> datasets = airtable_db.list_datasets(as_pydantic=True) + >>> for dataset in datasets: + ... print(dataset.fov_id, dataset.data_path) + >>> + >>> # Group and analyze + >>> df.groupby('Dataset').size() + >>> df.groupby('Well ID').size() + """ + records = self.datasets_table.all() + + if as_pydantic: + parsed_records = [] + for r in records: + try: + parsed_records.append(DatasetRecord.from_airtable_record(r)) + except Exception: + if skip_invalid: + # Skip invalid records silently + continue + else: + raise + return parsed_records + + data = [{"id": r["id"], **r["fields"]} for r in records] + + if as_dataframe: + return pd.DataFrame(data) + return data + + def delete_collection(self, collection_id: str) -> bool: + """ + Delete a collection record from Airtable. + + Parameters + ---------- + collection_id : str + Airtable record ID of the collection to delete + + Returns + ------- + bool + True if deletion was successful + + Examples + -------- + >>> collection_id = airtable_db.create_collection_from_datasets(...) + >>> airtable_db.delete_collection(collection_id) + >>> print(f"Deleted collection: {collection_id}") + """ + self.collections_table.delete(collection_id) + return True + + def log_model_training( + self, + collection_id: str, + wandb_run_id: str, + model_name: str | None = None, + metrics: dict[str, float] | None = None, + checkpoint_path: str | None = None, + trained_by: str | None = None, + ) -> str: + """ + Log that a model was trained using a collection. + + Creates entry in Models table and updates Collections table. + + Parameters + ---------- + collection_id : str + Airtable record ID of collection used + wandb_run_id : str + W&B run ID for experiment tracking + model_name : str | None + Human-readable model name + metrics : dict | None + Training metrics (e.g., {"val_loss": 0.15, "dice": 0.92}) + checkpoint_path : str | None + Path to saved model checkpoint + trained_by : str | None + Username of person who trained the model + + Returns + ------- + str + Airtable record ID of created model entry + + Examples + -------- + >>> collection_id = airtable_db.create_collection_from_datasets(...) + >>> model_id = airtable_db.log_model_training( + ... collection_id=collection_id, + ... wandb_run_id="20260107-152420", + ... model_name="contrastive-a549:v1", + ... metrics={"val_loss": 0.15}, + ... trained_by="eduardo.hirata" + ... ) + """ + # Create model record + model_record = { + "model_name": model_name or f"model_{datetime.now():%Y%m%d_%H%M%S}", + "collection": [collection_id], # Link to collection + "wandb_run_id": wandb_run_id, + "trained_date": datetime.now().isoformat(), + } + + if metrics: + model_record.update(metrics) + + if checkpoint_path: + model_record["checkpoint_path"] = checkpoint_path + + if trained_by: + model_record["trained_by"] = trained_by + + created = self.models_table.create(model_record) + + # Update collection record to track usage + collection = self.collections_table.get(collection_id) + models_trained_str = collection["fields"].get("models_trained", "") + + # Handle models_trained as comma-separated string + if models_trained_str: + models_list = [m.strip() for m in models_trained_str.split(",")] + models_list.append(wandb_run_id) + new_models_str = ", ".join(models_list) + else: + new_models_str = wandb_run_id + + self.collections_table.update( + collection_id, + {"models_trained": new_models_str, "last_used": datetime.now().isoformat()}, + ) + + return created["id"] + + def get_models_for_collection( + self, collection_id: str, as_dataframe: bool = True + ) -> pd.DataFrame | list[dict]: + """ + Get all models trained on a specific collection. + + Parameters + ---------- + collection_id : str + Airtable record ID of collection + as_dataframe : bool + If True, return pandas DataFrame. If False, return list of dicts. + + Returns + ------- + pd.DataFrame | list[dict] + Model records as DataFrame or list of dicts + + Examples + -------- + >>> models_df = airtable_db.get_models_for_collection(collection_id) + >>> print(models_df[["model_name", "mlflow_run_id", "trained_date"]]) + """ + # Get all models as DataFrame + records = self.models_table.all() + data = [{"id": r["id"], **r["fields"]} for r in records] + + if as_dataframe: + df = pd.DataFrame(data) + if len(df) == 0: + return df + + # Filter by collection_id using pandas + # The 'collection' field contains a list of linked record IDs + df_filtered = df[ + df["collection"].apply( + lambda x: collection_id in x if isinstance(x, list) else False + ) + ] + + # Sort by trained_date if column exists + if "trained_date" in df_filtered.columns: + df_filtered = df_filtered.sort_values("trained_date", ascending=False) + + return df_filtered + else: + # Filter list + filtered = [d for d in data if collection_id in d.get("collection", [])] + return filtered + + def list_models(self, as_dataframe: bool = True) -> pd.DataFrame | list[dict]: + """ + List all models in the airtable_db. + + Parameters + ---------- + as_dataframe : bool + If True, return pandas DataFrame. If False, return list of dicts. + + Returns + ------- + pd.DataFrame | list[dict] + All model records + + Examples + -------- + >>> models_df = airtable_db.list_models() + >>> print(models_df.groupby("model_name").size()) + """ + records = self.models_table.all() + data = [{"id": r["id"], **r["fields"]} for r in records] + + if as_dataframe: + df = pd.DataFrame(data) + # Sort by trained_date if column exists + if len(df) > 0 and "trained_date" in df.columns: + df = df.sort_values("trained_date", ascending=False) + return df + return data + + def get_dataset_paths( + self, + collection_name: str, + version: str, + ) -> Collections: + """ + Get zarr store paths and FOV names for a collection. + + Parameters + ---------- + collection_name : str + Name of the collection + version : str + Semantic version of the collection + + Returns + ------- + Collections + Collections object containing list of CollectionDataset (one per HCS plate) + + Examples + -------- + >>> collection = airtable_db.get_dataset_paths("my_collection", "0.0.1") + >>> print(f"{collection.name} v{collection.version}: {collection.total_fovs} FOVs") + + >>> # Use with TripletDataModule + >>> for ds in collection: + ... data_module = TripletDataModule( + ... data_path=ds.data_path, + ... tracks_path=ds.tracks_path, + ... include_fov_names=ds.fov_names, + ... ) + """ + # Get collection record IDs + dataset_record_ids = self._get_collection_dataset_ids(collection_name, version) + if not dataset_record_ids: + return Collections(name=collection_name, version=version, datasets=[]) + + dataset_records = [ + self.datasets_table.get(dataset_id)["fields"] + for dataset_id in dataset_record_ids + ] + + stores: dict[str, list[str]] = {} + for fields in dataset_records: + data_path = fields["Data path"] + fov_name = f"{fields['Well ID']}/{fields['FOV']}" + + if data_path not in stores: + stores[data_path] = [] + stores[data_path].append(fov_name) + + datasets = [ + CollectionDataset( + data_path=data_path, + tracks_path=self._derive_tracks_path(data_path), + fov_names=natsorted(fov_names), + ) + for data_path, fov_names in stores.items() + ] + + return Collections(name=collection_name, version=version, datasets=datasets) + + @staticmethod + def _derive_tracks_path(data_path: str) -> str: + """ + Derive tracks path from data path. + + Pattern: + - Data: {base}/2-assemble/{name}.zarr + - Tracks: {base}/1-preprocess/label-free/3-track/{name}_cropped.zarr + """ + # Replace 2-assemble with 1-preprocess/label-free/3-track + tracks_path = data_path.replace( + "/2-assemble/", "/1-preprocess/label-free/3-track/" + ) + # Replace .zarr with _cropped.zarr + if tracks_path.endswith(".zarr"): + tracks_path = tracks_path[:-5] + "_cropped.zarr" + return tracks_path + + def _get_collection_dataset_ids( + self, collection_name: str, version: str + ) -> list[str]: + """Get linked dataset record IDs for a collection.""" + df_collections = self.list_collections() + + if len(df_collections) == 0 or "name" not in df_collections.columns: + raise ValueError( + f"Collections '{collection_name}' not found (table is empty)" + ) + + filtered = df_collections[df_collections["name"] == collection_name] + if len(filtered) == 0: + raise ValueError(f"Collections '{collection_name}' not found") + + if "version" not in df_collections.columns: + raise ValueError("Version field not found in Collections table") + + filtered = filtered[filtered["version"] == version] + if len(filtered) == 0: + raise ValueError( + f"Collections '{collection_name}' version '{version}' not found" + ) + + collection_row = filtered.iloc[0] + dataset_record_ids = collection_row.get("datasets", []) + + if not dataset_record_ids or len(dataset_record_ids) == 0: + return [] + + return dataset_record_ids + + def update_record( + self, + ): + # TODO: to update the tracks path column + raise NotImplementedError("Not implemented yet") diff --git a/viscy/airtable/factory.py b/viscy/airtable/factory.py new file mode 100644 index 000000000..db9bf6a07 --- /dev/null +++ b/viscy/airtable/factory.py @@ -0,0 +1,464 @@ +"""Factory functions for creating data modules from Airtable collections.""" + +import os +from typing import Literal, Sequence + +from lightning.pytorch import LightningDataModule +from monai.transforms import MapTransform + +from viscy.airtable.database import AirtableManager, CollectionDataset, Collections +from viscy.data.combined import BatchedConcatDataModule, CachedConcatDataModule +from viscy.data.triplet import TripletDataModule + + +def _extract_wells_from_fov_names(fov_names: list[str]) -> list[str]: + """ + Extract unique well IDs from FOV names. + + Parameters + ---------- + fov_names : list[str] + List of FOV names in format "Row/Column/FOV_idx" (e.g., "B/3/0") + + Returns + ------- + list[str] + Unique well IDs in format "Row/Column" (e.g., ["B/3", "C/4"]) + + Examples + -------- + >>> _extract_wells_from_fov_names(["B/3/0", "B/3/1", "C/4/2"]) + ['B/3', 'C/4'] + """ + wells = set() + for fov_name in fov_names: + # Split "B/3/0" -> ["B", "3", "0"] + parts = fov_name.split("/") + if len(parts) >= 2: + # Extract "B/3" + well_id = f"{parts[0]}/{parts[1]}" + wells.add(well_id) + else: + raise ValueError( + f"Invalid FOV name format: '{fov_name}'. " + f"Expected 'Row/Column/FOV_idx' (e.g., 'B/3/0')" + ) + + return sorted(list(wells)) + + +def create_triplet_datamodule_from_collection( + collection: Collections | CollectionDataset, + source_channel: str | Sequence[str], + z_range: tuple[int, int], + *, + initial_yx_patch_size: tuple[int, int] = (512, 512), + final_yx_patch_size: tuple[int, int] = (224, 224), + split_ratio: float = 0.8, + batch_size: int = 16, + num_workers: int = 1, + normalizations: list[MapTransform] | None = None, + augmentations: list[MapTransform] | None = None, + augment_validation: bool = True, + caching: bool = False, + fit_include_wells: list[str] | None = None, + fit_exclude_fovs: list[str] | None = None, + predict_cells: bool = False, + include_fov_names: list[str] | None = None, + include_track_ids: list[int] | None = None, + time_interval: Literal["any"] | int = "any", + return_negative: bool = True, + persistent_workers: bool = False, + prefetch_factor: int | None = None, + pin_memory: bool = False, + z_window_size: int | None = None, + cache_pool_bytes: int = 0, + use_cached_concat: bool = False, +) -> LightningDataModule: + """ + Create TripletDataModule(s) from Airtable collection. + + Automatically handles single or multiple HCS plates: + - Single plate: Returns TripletDataModule + - Multiple plates: Returns BatchedConcatDataModule or CachedConcatDataModule + + Parameters + ---------- + collection : Collections | CollectionDataset + Collections from AirtableManager.get_dataset_paths() or single CollectionDataset + source_channel : str | Sequence[str] + Input channel name(s) - REQUIRED + z_range : tuple[int, int] + Range of valid z-slices - REQUIRED + initial_yx_patch_size : tuple[int, int] + YX size of initially sampled patch, default (512, 512) + final_yx_patch_size : tuple[int, int] + Output patch size after augmentation, default (224, 224) + split_ratio : float + Train/val split ratio, default 0.8 + batch_size : int + Batch size, default 16 + num_workers : int + Number of dataloader workers, default 1 + normalizations : list[MapTransform] | None + Normalization transforms + augmentations : list[MapTransform] | None + Augmentation transforms + augment_validation : bool + Apply augmentations to validation, default True + caching : bool + Cache dataset, default False + fit_include_wells : list[str] | None + Override collection FOVs with specific wells (e.g., ["B/3", "C/4"]). + Takes precedence over collection.fov_names + fit_exclude_fovs : list[str] | None + Exclude specific FOV paths from collection + predict_cells : bool + Predict on specific cells only, default False + include_fov_names : list[str] | None + FOV names for prediction (when predict_cells=True) + include_track_ids : list[int] | None + Track IDs for prediction (when predict_cells=True) + time_interval : Literal["any"] | int + Time interval for positive sampling, default "any" + return_negative : bool + Return negative samples for triplet loss, default True + persistent_workers : bool + Keep workers alive between epochs, default False + prefetch_factor : int | None + Batches to prefetch per worker + pin_memory : bool + Pin memory for faster GPU transfer, default False + z_window_size : int | None + Final Z window size (inferred from z_range if None) + cache_pool_bytes : int + Tensorstore cache pool size in bytes, default 0 + use_cached_concat : bool + Use CachedConcatDataModule instead of BatchedConcatDataModule + for multi-plate collections, default False + + Returns + ------- + LightningDataModule + - TripletDataModule for single plate + - BatchedConcatDataModule for multiple plates (default) + - CachedConcatDataModule for multiple plates (if use_cached_concat=True) + + Raises + ------ + ValueError + - If collection has no datasets + - If paths don't exist (validation fails) + - If fit_include_wells and collection both specify FOVs (ambiguous) + FileNotFoundError + If data_path or tracks_path don't exist + TypeError + If collection is not Collections or CollectionDataset + + Examples + -------- + Basic usage with single-plate collection: + + >>> from viscy.airtable.database import AirtableManager + >>> from viscy.airtable.factory import create_triplet_datamodule_from_collection + >>> + >>> airtable_db = AirtableManager(base_id="appXXXXXXXXXXXXXX") + >>> collection = airtable_db.get_dataset_paths("my_collection", "0.0.1") + >>> + >>> dm = create_triplet_datamodule_from_collection( + ... collection=collection, + ... source_channel=["Phase3D"], + ... z_range=(0, 5), + ... batch_size=32, + ... num_workers=8, + ... ) + >>> + >>> # Use with PyTorch Lightning + >>> trainer.fit(model, dm) + + Multi-plate collection with normalization: + + >>> from viscy.transforms import NormalizeSampled + >>> + >>> collection = airtable_db.get_dataset_paths("multi_plate_collection", "1.0.0") + >>> print(f"Collections has {len(collection)} plates") # e.g., 3 plates + >>> + >>> dm = create_triplet_datamodule_from_collection( + ... collection=collection, + ... source_channel=["Phase3D", "RFP"], + ... z_range=(0, 10), + ... normalizations=[ + ... NormalizeSampled( + ... ["Phase3D"], + ... level="fov_statistics", + ... subtrahend="mean", + ... divisor="std" + ... ) + ... ], + ... batch_size=16, + ... use_cached_concat=False, # Use BatchedConcatDataModule + ... ) + >>> # Returns BatchedConcatDataModule wrapping 3 TripletDataModules + + Override collection FOVs with specific wells: + + >>> dm = create_triplet_datamodule_from_collection( + ... collection=collection, + ... source_channel=["Phase3D"], + ... z_range=(0, 5), + ... fit_include_wells=["B/3", "B/4"], # Override collection FOVs + ... batch_size=16, + ... ) + + Using a single CollectionDataset directly: + + >>> ds = collection.datasets[0] # Single plate + >>> dm = create_triplet_datamodule_from_collection( + ... collection=ds, # Pass CollectionDataset directly + ... source_channel=["Phase3D"], + ... z_range=(0, 5), + ... batch_size=16, + ... ) + + Notes + ----- + - FOV filtering priority: fit_include_wells > collection.fov_names + - If both specified, raises ValueError to avoid ambiguity + - The factory validates paths before creating data modules + - Multi-plate handling uses BatchedConcatDataModule for efficient batching + - All TripletDataModule parameters can be passed through kwargs + """ + # STEP 1: Normalize input - handle both Collections and CollectionDataset + if isinstance(collection, CollectionDataset): + datasets = [collection] + collection_name = "single_dataset" + elif isinstance(collection, Collections): + if len(collection.datasets) == 0: + raise ValueError(f"Collections '{collection.name}' has no datasets") + datasets = collection.datasets + collection_name = collection.name + else: + raise TypeError( + f"Expected Collections or CollectionDataset, got {type(collection).__name__}" + ) + + # STEP 2: Validate all paths exist (fail early) + for i, ds in enumerate(datasets): + try: + ds.validate() + except FileNotFoundError as e: + raise FileNotFoundError(f"Collections '{collection_name}' dataset {i}: {e}") + + # STEP 3: Handle FOV filtering logic + # Check for ambiguous FOV specification + has_collection_fovs = any(len(ds.fov_names) > 0 for ds in datasets) + + if fit_include_wells is not None and has_collection_fovs: + # Ambiguous: both collection and user specified FOVs + raise ValueError( + "Cannot specify both 'fit_include_wells' and use collection FOV filtering. " + "The collection already specifies FOVs to include. " + "Either:\n" + " 1. Use fit_include_wells=None to respect collection FOVs, OR\n" + " 2. Create a new collection without FOV filtering if you want custom wells" + ) + + # STEP 4: Ensure normalizations and augmentations are lists + normalizations = normalizations or [] + augmentations = augmentations or [] + + # STEP 5: Create TripletDataModule for each dataset + data_modules = [] + + for ds in datasets: + # Determine well filtering strategy + if fit_include_wells is not None: + # User override: use explicit wells + include_wells = fit_include_wells + elif len(ds.fov_names) > 0: + # Convert collection FOV names to well IDs + include_wells = _extract_wells_from_fov_names(ds.fov_names) + else: + # No filtering: use all wells + include_wells = None + + # Create TripletDataModule + dm = TripletDataModule( + data_path=ds.data_path, + tracks_path=ds.tracks_path, + source_channel=source_channel, + z_range=z_range, + initial_yx_patch_size=initial_yx_patch_size, + final_yx_patch_size=final_yx_patch_size, + split_ratio=split_ratio, + batch_size=batch_size, + num_workers=num_workers, + normalizations=normalizations, + augmentations=augmentations, + augment_validation=augment_validation, + caching=caching, + fit_include_wells=include_wells, + fit_exclude_fovs=fit_exclude_fovs, + predict_cells=predict_cells, + include_fov_names=include_fov_names, + include_track_ids=include_track_ids, + time_interval=time_interval, + return_negative=return_negative, + persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + pin_memory=pin_memory, + z_window_size=z_window_size, + cache_pool_bytes=cache_pool_bytes, + ) + data_modules.append(dm) + + # STEP 6: Return appropriate data module type + if len(data_modules) == 1: + # Single plate: return TripletDataModule directly + return data_modules[0] + else: + # Multiple plates: wrap in ConcatDataModule + if use_cached_concat: + return CachedConcatDataModule(data_modules=data_modules) + else: + return BatchedConcatDataModule(data_modules=data_modules) + + +class CollectionTripletDataModule(TripletDataModule): + """ + TripletDataModule that fetches paths from Airtable collections. + + This class is designed to work with PyTorch Lightning CLI and config files. + It extends TripletDataModule to accept Airtable collection parameters instead + of explicit data_path and tracks_path. + + Parameters + ---------- + base_id : str + Airtable base ID + collection_name : str + Name of the collection in Airtable + collection_version : str + Semantic version of the collection (e.g., "0.0.1") + source_channel : str | Sequence[str] + Input channel name(s) + z_range : tuple[int, int] + Range of valid z-slices + api_key : str | None + Airtable API key (if None, reads from AIRTABLE_API_KEY env var) + **kwargs + All other parameters passed to TripletDataModule.__init__ + + Raises + ------ + ValueError + If collection has multiple datasets (only single-plate collections supported) + + Examples + -------- + In a Lightning config file (config.yml): + + ```yaml + data: + class_path: viscy.airtable.factory.CollectionTripletDataModule + init_args: + base_id: "appXXXXXXXXXXXXXX" + collection_name: "my_collection" + collection_version: "0.0.1" + source_channel: [Phase] + z_range: [0, 5] + batch_size: 16 + num_workers: 8 + normalizations: + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [Phase] + level: fov_statistics + subtrahend: mean + divisor: std + ``` + + Command line usage: + ```bash + viscy fit -c config.yml + ``` + + Direct usage in Python: + ```python + dm = CollectionTripletDataModule( + base_id="appXXXXXXXXXXXXXX", + collection_name="my_collection", + collection_version="0.0.1", + source_channel=["Phase"], + z_range=(0, 5), + batch_size=16, + ) + trainer.fit(model, dm) + ``` + + Notes + ----- + - Only supports single-plate collections (use create_triplet_datamodule_from_collection + for multi-plate support with BatchedConcatDataModule) + - Fetches collection from Airtable during __init__ + - All TripletDataModule parameters are available + - FOV filtering from collection is automatically applied via fit_include_wells + """ + + def __init__( + self, + base_id: str, + collection_name: str, + collection_version: str, + source_channel: str | Sequence[str], + z_range: tuple[int, int], + api_key: str | None = None, + fit_include_wells: list[str] | None = None, + **kwargs, + ): + # Fetch collection from Airtable + airtable_db = AirtableManager( + base_id=base_id, api_key=api_key or os.getenv("AIRTABLE_API_KEY") + ) + collection = airtable_db.get_dataset_paths( + collection_name=collection_name, + version=collection_version, + ) + + # Validate single plate + if len(collection.datasets) != 1: + raise ValueError( + f"CollectionTripletDataModule only supports single-plate collections. " + f"Collections '{collection_name}' has {len(collection.datasets)} plates. " + f"Use create_triplet_datamodule_from_collection() for multi-plate support." + ) + + dataset = collection.datasets[0] + + # Store collection metadata as instance attributes for callbacks/logging + self.base_id = base_id + self.collection_name = collection_name + self.collection_version = collection_version + self.data_path = dataset.data_path + self.tracks_path = dataset.tracks_path + + # Handle FOV filtering + if fit_include_wells is not None: + # User override: use explicit wells + include_wells = fit_include_wells + elif len(dataset.fov_names) > 0: + # Convert collection FOV names to well IDs + include_wells = _extract_wells_from_fov_names(dataset.fov_names) + else: + # No filtering: use all wells + include_wells = None + + # Initialize parent TripletDataModule with extracted paths + super().__init__( + data_path=dataset.data_path, + tracks_path=dataset.tracks_path, + source_channel=source_channel, + z_range=z_range, + fit_include_wells=include_wells, + **kwargs, + ) diff --git a/viscy/airtable/register_model.py b/viscy/airtable/register_model.py new file mode 100644 index 000000000..d69bc4dd3 --- /dev/null +++ b/viscy/airtable/register_model.py @@ -0,0 +1,439 @@ +"""Register trained models to W&B artifact registry and shared directory.""" + +import shutil +from pathlib import Path +from typing import Any + +import wandb + + +def register_model( + checkpoint_path: str, + model_name: str, + model_type: str, + version: str, + config_path: str | None = None, + aliases: list[str] | None = None, + wandb_run_id: str | None = None, + wandb_project: str = "viscy-model-registry", + shared_dir: str = "/hpc/models/shared", + description: str | None = None, + metadata: dict[str, Any] | None = None, + airtable_base_id: str | None = None, + airtable_collection_id: str | None = None, +) -> str: + """ + Register a trained model to W&B artifacts and copy to shared directory. + + This creates a W&B artifact with references (not uploads) to track model lineage. + The checkpoint file is copied to a shared HPC directory for team access. + + Parameters + ---------- + checkpoint_path : str + Path to Lightning checkpoint (.ckpt file) + model_name : str + Human-readable name (e.g., "contrastive-rpe1") + model_type : str + Model category: contrastive, segmentation, vae, translation + version : str + Semantic version (e.g., "v1", "v2", "v3") + config_path : str | None + Path to training config YAML file (config_fit.yml). This will be stored + in the WandB artifact for later use in prediction. + aliases : list[str] | None + Tags like "production", "best", "latest" + wandb_run_id : str | None + W&B run ID that trained this model (for lineage tracking) + wandb_project : str + W&B project for model registry (default: "viscy-model-registry") + shared_dir : str + Shared checkpoint directory path (default: "/hpc/models/shared") + description : str | None + Model description (e.g., "Trained on RPE1 collection v0.0.1, val_loss=0.15") + metadata : dict[str, Any] | None + Additional metadata to store (metrics, config, etc.) + airtable_base_id : str | None + Airtable base ID to log model to Models table (optional) + airtable_collection_id : str | None + Airtable collection record ID to link this model to (optional) + + Returns + ------- + str + W&B artifact URL + + Examples + -------- + Register a model after training: + + >>> from viscy.airtable.register_model import register_model + >>> artifact_url = register_model( + ... checkpoint_path="logs/wandb/run-20260107/checkpoints/epoch=50.ckpt", + ... config_path="examples/configs/fit_example.yml", # Training config + ... model_name="contrastive-rpe1", + ... model_type="contrastive", + ... version="v2", + ... aliases=["production", "best"], + ... wandb_run_id="20260107-152420", + ... description="Trained on RPE1 collection v0.0.1, best val_loss=0.15", + ... metadata={"val_loss": 0.15, "collection_name": "RPE1_infection", "collection_version": "v0.0.1"}, + ... airtable_base_id="app8vqaoWyOwa0sB5", # Optional: log to Airtable + ... airtable_collection_id="recXXXXXXXXXXXXXX", # Optional: link to collection + ... ) + >>> print(f"Model registered: {artifact_url}") + + CLI usage: + + >>> # From command line + >>> python -m viscy.airtable.register_model \\ + ... logs/wandb/run-20260107/checkpoints/epoch=50.ckpt \\ + ... --name contrastive-rpe1 \\ + ... --type contrastive \\ + ... --version v2 \\ + ... --aliases production best \\ + ... --run-id 20260107-152420 \\ + ... --description "Best RPE1 model" + """ + checkpoint_path = Path(checkpoint_path) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + # 1. Copy to shared directory + shared_path = Path(shared_dir) / model_type / f"{model_name}-{version}.ckpt" + shared_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(checkpoint_path, shared_path) + print(f"✓ Copied to shared directory: {shared_path}") + + # 2. Initialize W&B (use API for non-training context) + run = wandb.init(project=wandb_project, job_type="register-model") + + # 3. Create artifact with metadata (references only, not uploads) + artifact_metadata = { + "model_type": model_type, + "version": version, + "checkpoint_path": str(shared_path), + "original_checkpoint": str(checkpoint_path), + } + + # Add custom metadata + if metadata: + artifact_metadata.update(metadata) + + artifact = wandb.Artifact( + name=model_name, + type="model", + description=description or f"{model_type} model {version}", + metadata=artifact_metadata, + ) + + # 4. Add training config to artifact (if provided) + if config_path: + config_path = Path(config_path) + if config_path.exists(): + artifact.add_file(str(config_path), name="config_fit.yml") + print(f"✓ Added training config to artifact: {config_path.name}") + else: + print(f"⚠ Config file not found: {config_path}") + + # 5. Store checkpoint path in metadata only (no upload/reference) + # The checkpoint stays on HPC, W&B only tracks the metadata + # This is "Option 2" from the plan - references only, not uploads + + # 6. Link to training run (lineage) + if wandb_run_id: + artifact.metadata["training_run_id"] = wandb_run_id + # Create lineage link + try: + training_run = wandb.Api().run( + f"{run.entity}/{wandb_project}/{wandb_run_id}" + ) + artifact.metadata["training_run_url"] = training_run.url + except Exception as e: + print(f"⚠ Could not link to training run: {e}") + + # 7. Log artifact with version and aliases + aliases = aliases or [] + wandb.log_artifact(artifact, aliases=aliases) + + artifact_url = ( + f"https://wandb.ai/{run.entity}/{wandb_project}/artifacts/model/{model_name}" + ) + + wandb.finish() + + print(f"✓ Registered in W&B: {model_name}:{version}") + print(f" Aliases: {aliases}") + print(f" View: {artifact_url}") + + # 8. Optionally log to Airtable Models table + if airtable_base_id and airtable_collection_id: + try: + import getpass + + from viscy.airtable.database import AirtableManager + + airtable_db = AirtableManager(base_id=airtable_base_id) + + # Prepare complete metadata for Airtable (mirrors W&B artifact metadata) + airtable_metadata = { + # Core W&B fields + "model_type": model_type, + "version": version, + "wandb_url": artifact_url, + "original_checkpoint": str(checkpoint_path), + } + + # Add all custom metadata (metrics, architecture, collection lineage, etc.) + if metadata: + airtable_metadata.update(metadata) + + model_id = airtable_db.log_model_training( + collection_id=airtable_collection_id, + wandb_run_id=wandb_run_id or "unknown", + model_name=f"{model_name}:{version}", + checkpoint_path=str(shared_path), + trained_by=getpass.getuser(), + metrics=airtable_metadata, # Pass ALL metadata to Airtable + ) + + print(f"✓ Logged to Airtable Models table (record ID: {model_id})") + print(f" Linked to collection: {airtable_collection_id}") + print(f" Metadata fields passed: {len(airtable_metadata)}") + + except Exception as e: + print(f"⚠ Could not log to Airtable: {e}") + print(" (Fields without matching Airtable columns are ignored)") + + return artifact_url + + +def load_model_from_registry( + model_name: str, + version: str = "latest", + wandb_project: str = "viscy-model-registry", + model_class=None, + **model_kwargs, +): + """ + Load a model from W&B artifact registry. + + This fetches the checkpoint path from W&B metadata and loads the model + from the shared HPC directory (does not download from W&B cloud). + + Parameters + ---------- + model_name : str + Model artifact name + version : str + Version or alias ("latest", "production", "v2") + wandb_project : str + W&B project name (default: "viscy-model-registry") + model_class : LightningModule class + Model class to instantiate (e.g., ContrastiveModule) + **model_kwargs + Additional arguments for model initialization + + Returns + ------- + model : LightningModule + Loaded model in eval mode + + Examples + -------- + Load a registered model: + + >>> from viscy.representation.engine import ContrastiveModule + >>> from viscy.airtable.register_model import load_model_from_registry + >>> + >>> model = load_model_from_registry( + ... model_name="contrastive-rpe1", + ... version="production", + ... wandb_project="viscy-model-registry", + ... model_class=ContrastiveModule, + ... ) + >>> model.eval() + >>> embeddings = model(images) + """ + # 1. Get artifact metadata from W&B + api = wandb.Api() + entity = api.default_entity + artifact = api.artifact(f"{entity}/{wandb_project}/{model_name}:{version}") + + # 2. Get checkpoint path from metadata (no download needed) + checkpoint_path = artifact.metadata.get("checkpoint_path") + if not checkpoint_path: + raise ValueError( + f"Artifact {model_name}:{version} missing checkpoint_path metadata" + ) + + checkpoint_path = Path(checkpoint_path) + if not checkpoint_path.exists(): + raise FileNotFoundError( + f"Checkpoint not found: {checkpoint_path}\n" + f"The model is registered but the file is missing from the shared directory." + ) + + print(f"✓ Loading model from: {checkpoint_path}") + print(f" Artifact: {model_name}:{version}") + print(f" Description: {artifact.description}") + + # 3. Load model + if model_class is None: + raise ValueError("Must provide model_class (e.g., ContrastiveModule)") + + model = model_class.load_from_checkpoint(checkpoint_path, **model_kwargs) + model.eval() + + return model + + +def list_registered_models( + wandb_project: str = "viscy-model-registry", + model_type: str | None = None, +) -> list[dict[str, Any]]: + """ + List all registered models in W&B artifact registry. + + Parameters + ---------- + wandb_project : str + W&B project name (default: "viscy-model-registry") + model_type : str | None + Filter by model type (contrastive, segmentation, vae, translation) + + Returns + ------- + list[dict[str, Any]] + List of model metadata dictionaries + + Examples + -------- + >>> from viscy.airtable.register_model import list_registered_models + >>> + >>> # List all models + >>> models = list_registered_models() + >>> for m in models: + ... print(f"{m['name']}:{m['version']} - {m['description']}") + >>> + >>> # List only contrastive models + >>> contrastive = list_registered_models(model_type="contrastive") + """ + api = wandb.Api() + + # Get all artifact collections in the project + try: + # Get entity from API + entity = api.default_entity + artifact_type = "model" + collection_name = f"{entity}/{wandb_project}/{artifact_type}" + + # Use the artifacts API (simpler) + artifacts = api.artifacts(artifact_type, collection_name) + + except Exception as e: + # Project might not exist yet or no artifacts + print(f"Warning: Could not list artifacts - {e}") + return [] + + models = [] + try: + for artifact in artifacts: + metadata = artifact.metadata + + # Filter by model type if specified + if model_type and metadata.get("model_type") != model_type: + continue + + models.append( + { + "name": artifact.name, + "version": artifact.version, + "aliases": artifact.aliases, + "description": artifact.description, + "model_type": metadata.get("model_type"), + "checkpoint_path": metadata.get("checkpoint_path"), + "created_at": artifact.created_at, + "metadata": metadata, + } + ) + except Exception as e: + print(f"Warning: Error iterating artifacts - {e}") + + return models + + +# CLI entry point +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Register a trained model to W&B artifact registry" + ) + + # Required arguments + parser.add_argument("checkpoint_path", help="Path to checkpoint file (.ckpt)") + parser.add_argument( + "--name", required=True, help="Model name (e.g., contrastive-rpe1)" + ) + parser.add_argument( + "--type", + required=True, + choices=["contrastive", "segmentation", "vae", "translation"], + help="Model type", + ) + parser.add_argument("--version", required=True, help="Version (e.g., v1, v2)") + + # Optional arguments + parser.add_argument( + "--config", + help="Path to training config YAML file (config_fit.yml)", + ) + parser.add_argument( + "--aliases", + nargs="+", + help="Aliases (e.g., production best latest)", + ) + parser.add_argument( + "--run-id", + help="W&B run ID that trained this model", + ) + parser.add_argument( + "--project", + default="viscy-model-registry", + help="W&B project name (default: viscy-model-registry)", + ) + parser.add_argument( + "--shared-dir", + default="/hpc/models/shared", + help="Shared checkpoint directory (default: /hpc/models/shared)", + ) + parser.add_argument( + "--description", + help="Model description", + ) + parser.add_argument( + "--airtable-base-id", + help="Airtable base ID (optional, for logging to Models table)", + ) + parser.add_argument( + "--airtable-collection-id", + help="Airtable collection record ID (optional, for linking)", + ) + + args = parser.parse_args() + + register_model( + checkpoint_path=args.checkpoint_path, + config_path=args.config, + model_name=args.name, + model_type=args.type, + version=args.version, + aliases=args.aliases, + wandb_run_id=args.run_id, + wandb_project=args.project, + shared_dir=args.shared_dir, + description=args.description, + airtable_base_id=args.airtable_base_id, + airtable_collection_id=args.airtable_collection_id, + ) diff --git a/viscy/airtable/schemas.py b/viscy/airtable/schemas.py new file mode 100644 index 000000000..57bfbe165 --- /dev/null +++ b/viscy/airtable/schemas.py @@ -0,0 +1,283 @@ +"""Pydantic schemas for Airtable records and model registry.""" + +from typing import Any, Literal + +from pydantic import BaseModel, Field, field_validator + + +class DatasetRecord(BaseModel): + """ + Pydantic model for a dataset record (FOV) in Airtable. + + This represents a single field of view (FOV) from an HCS plate, + with metadata about the biological sample and data location. + + Note + ---- + FOV_ID is a computed field in Airtable (formula: Dataset + "_" + Well ID + "_" + FOV) + and is automatically generated. You can read it but cannot set it during creation. + """ + + # Required fields for creation (but optional when reading from Airtable due to legacy data) + dataset_name: str = Field( + ..., + description="Name of the dataset/plate this FOV belongs to", + alias="Dataset", + ) + well_id: str = Field( + ..., + description="Well identifier (e.g., 'B_3' or 'B/3')", + alias="Well ID", + ) + fov_name: str = Field( + ..., description="FOV index within well (e.g., '0', '1', '2')", alias="FOV" + ) + data_path: str | None = Field( + None, + description="Full path to FOV data (e.g., '/hpc/data/plate.zarr/B/3/0'). Required for new records.", + alias="Data path", + ) + + # Computed field (read-only, auto-generated by Airtable) + fov_id: str | None = Field( + None, + description="Computed FOV identifier (Dataset_WellID_FOV) - auto-generated by Airtable", + alias="FOV_ID", + exclude=True, # Don't include in to_airtable_dict() by default + ) + + # Optional metadata fields + cell_type: str | None = Field( + None, description="Cell type (e.g., 'RPE1', 'A549')", alias="Cell type" + ) + cell_state: str | None = Field( + None, + description="Cell state (e.g., 'healthy', 'infected')", + alias="Cell state", + ) + cell_line: list[str] | None = Field( + None, + description="Cell line identifiers (Airtable array field)", + alias="Cell line", + ) + organelle: str | None = Field( + None, description="Target organelle (e.g., 'SEC61B')", alias="Organelle" + ) + channel_0: str | None = Field( + None, description="Channel 0 name (e.g., 'Phase')", alias="Channel-0" + ) + channel_1: str | None = Field( + None, description="Channel 1 name (e.g., 'GFP')", alias="Channel-1" + ) + channel_2: str | None = Field( + None, description="Channel 2 name (e.g., 'RFP')", alias="Channel-2" + ) + fluorescence_modality: str | None = Field( + None, + description="Fluorescence modality (e.g., 'confocal')", + alias="Fluorescence modality", + ) + organellebox_infectomics: bool | None = Field( + None, + description="Part of OrganelleBox Infectomics dataset", + alias="OrganelleBox Infectomics", + ) + + # Airtable record ID (not used for creation, but for retrieval) + record_id: str | None = Field(None, description="Airtable record ID", alias="id") + + model_config = { + "populate_by_name": True, # Allow both field name and alias + "str_strip_whitespace": True, # Strip whitespace from strings + } + + @field_validator("dataset_name", "well_id", "fov_name") + @classmethod + def no_empty_strings(cls, v: str) -> str: + """Ensure required string fields are not empty.""" + if not v or not v.strip(): + raise ValueError("Field cannot be empty") + return v.strip() + + @field_validator("fov_id") + @classmethod + def validate_fov_id(cls, v: str | None) -> str | None: + """Validate FOV ID if provided (it's computed, so usually None on creation).""" + if v is not None and (not v or not v.strip()): + raise ValueError("FOV ID cannot be empty string if provided") + return v.strip() if v else None + + @field_validator("data_path") + @classmethod + def validate_data_path(cls, v: str | None) -> str | None: + """Validate data path format if provided.""" + if v is None: + return None + if not v.strip(): + raise ValueError("Data path cannot be empty string") + # Don't check if path exists - it might be on a different machine + # Just ensure it's a reasonable path format + if not v.startswith("/"): + raise ValueError("Data path must be an absolute path (start with /)") + return v.strip() + + def to_airtable_dict(self) -> dict[str, Any]: + """ + Convert to Airtable-compatible dictionary. + + Returns + ------- + dict + Dictionary with Airtable field names (aliases), excluding None values + and the record_id field. + """ + # Use by_alias=True to get Airtable field names + # exclude_none=True to skip optional fields that weren't set + return self.model_dump( + by_alias=True, exclude_none=True, exclude={"id", "record_id"} + ) + + @classmethod + def from_airtable_record(cls, record: dict[str, Any]) -> "DatasetRecord": + """ + Create DatasetRecord from Airtable API response. + + Parameters + ---------- + record : dict + Airtable record with 'id' and 'fields' keys + + Returns + ------- + DatasetRecord + Parsed dataset record + """ + return cls(id=record["id"], **record["fields"]) + + def __str__(self) -> str: + """Human-readable string representation.""" + identifier = ( + self.fov_id or f"{self.dataset_name}_{self.well_id}_{self.fov_name}" + ) + path_str = self.data_path or "no path" + return f"Dataset({identifier} @ {path_str})" + + def __repr__(self) -> str: + """Detailed string representation.""" + return ( + f"DatasetRecord(dataset_name={self.dataset_name!r}, well_id={self.well_id!r}, " + f"fov_name={self.fov_name!r}, fov_id={self.fov_id!r})" + ) + + +class ModelRecord(BaseModel): + """ + Pydantic model for a registered model artifact. + + This represents a registered model in WandB with metadata about + training, architecture, and file locations. + """ + + # Core identification + model_name: str = Field(..., description="Model identifier") + model_type: Literal["contrastive", "segmentation", "vae", "translation"] = Field( + ..., description="Model category" + ) + version: str = Field(..., description="Version (e.g., 'v1', 'v2')") + + # File locations + checkpoint_path: str = Field(..., description="Absolute path to .ckpt file") + config_artifact_path: str | None = Field( + None, description="Path to downloaded config_fit.yml from artifact" + ) + + # Architecture (for quick reference) + architecture: str | None = Field( + None, description="Model architecture (e.g., '2.5D')" + ) + backbone: str | None = Field(None, description="Backbone (e.g., 'convnext_tiny')") + + # Training lineage + collection_name: str | None = Field(None, description="Training data collection") + wandb_run_id: str | None = Field(None, description="Training run ID") + wandb_artifact_url: str | None = Field(None, description="WandB artifact URL") + + # Metadata + description: str | None = Field(None, description="Model description") + trained_date: str | None = Field(None, description="Training date") + + model_config = { + "populate_by_name": True, + "str_strip_whitespace": True, + } + + @field_validator("checkpoint_path") + @classmethod + def validate_checkpoint_path(cls, v: str) -> str: + """Validate checkpoint path.""" + if not v.startswith("/"): + raise ValueError("Checkpoint path must be absolute") + if not v.endswith(".ckpt"): + raise ValueError("Checkpoint must be .ckpt file") + return v + + @field_validator("version") + @classmethod + def validate_version(cls, v: str) -> str: + """Validate version format.""" + if not v or not v.strip(): + raise ValueError("Version cannot be empty") + return v.strip() + + def to_wandb_metadata(self) -> dict[str, Any]: + """ + Convert to WandB artifact metadata dictionary. + + Returns + ------- + dict + Dictionary suitable for wandb.Artifact(metadata=...) + """ + return self.model_dump( + exclude_none=True, + exclude={"config_artifact_path"}, # Don't include temp paths + ) + + @classmethod + def from_wandb_artifact(cls, artifact: Any) -> "ModelRecord": + """ + Create ModelRecord from WandB artifact. + + Parameters + ---------- + artifact : wandb.Artifact + WandB artifact object + + Returns + ------- + ModelRecord + Parsed model record + """ + metadata = artifact.metadata.copy() + + # Add artifact-level fields + metadata["model_name"] = artifact.name + if hasattr(artifact, "version"): + metadata["version"] = artifact.version + if hasattr(artifact, "description") and artifact.description: + metadata["description"] = artifact.description + if hasattr(artifact, "url"): + metadata["wandb_artifact_url"] = artifact.url + + return cls(**metadata) + + def __str__(self) -> str: + """Human-readable string representation.""" + return f"Model({self.model_name}:{self.version} @ {self.checkpoint_path})" + + def __repr__(self) -> str: + """Detailed string representation.""" + return ( + f"ModelRecord(model_name={self.model_name!r}, model_type={self.model_type!r}, " + f"version={self.version!r})" + ) diff --git a/viscy/cli/wandb_utils.py b/viscy/cli/wandb_utils.py new file mode 100644 index 000000000..24715b041 --- /dev/null +++ b/viscy/cli/wandb_utils.py @@ -0,0 +1,60 @@ +"""WandB utilities for model registry.""" + +import tempfile +from pathlib import Path + +import wandb + +from viscy.airtable.schemas import ModelRecord + + +def download_model_artifact( + model_name: str, + version: str = "latest", + wandb_project: str = "viscy-model-registry", + download_dir: str | None = None, +) -> tuple[ModelRecord, Path]: + """ + Download model artifact from WandB. + + Parameters + ---------- + model_name : str + Model name (e.g., 'contrastive-rpe1') + version : str + Version or alias (e.g., 'v1', 'latest', 'production') + wandb_project : str + WandB project name + download_dir : str | None + Where to download artifact files (temp dir if None) + + Returns + ------- + tuple[ModelRecord, Path] + (model_record, path_to_downloaded_config) + """ + api = wandb.Api() + entity = api.default_entity + + # Fetch artifact + artifact = api.artifact(f"{entity}/{wandb_project}/{model_name}:{version}") + + # Parse metadata with Pydantic validation + model_record = ModelRecord.from_wandb_artifact(artifact) + + # Download artifact files (config_fit.yml) + if download_dir is None: + download_dir = tempfile.mkdtemp(prefix=f"viscy_model_{model_name}_") + + artifact_dir = Path(artifact.download(root=download_dir)) + config_path = artifact_dir / "config_fit.yml" + + if not config_path.exists(): + raise FileNotFoundError( + "Training config not found in artifact. " + "Model may have been registered without config." + ) + + model_record.config_artifact_path = str(config_path) + + return model_record, config_path diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 6c2d243d9..5a87a9abf 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -168,9 +168,22 @@ def _log_metrics( def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): grid = render_images(imgs, cmaps=["gray"] * 3) - self.logger.experiment.add_image( - key, grid, self.current_epoch, dataformats="HWC" - ) + + # Handle different logger types + if hasattr(self.logger, "experiment"): + # Check if TensorBoard logger + if hasattr(self.logger.experiment, "add_image"): + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + # Check if WandB logger + # FIXME this is just temporary fix to get the samples logged to WandB + elif hasattr(self.logger.experiment, "log"): + import wandb + + self.logger.experiment.log( + {key: wandb.Image(grid), "epoch": self.current_epoch} + ) def _log_step_samples(self, batch_idx, samples, stage: Literal["train", "val"]): """Common method for logging step samples"""