Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions examples/airtable/example_config_with_airtable.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Example config showing Airtable integration with Lightning training
# Usage: viscy fit -c examples/airtable/example_config_with_airtable.yml

seed_everything: true

trainer:
accelerator: gpu
devices: 1
max_epochs: 100
check_val_every_n_epoch: 1
log_every_n_steps: 50

# Add Airtable logging callback
callbacks:
- class_path: viscy.representation.airtable_callback.AirtableLoggingCallback
init_args:
base_id: "appXXXXXXXXXXXXXX" # Replace with your Airtable base ID
dataset_id: "recYYYYYYYYYYYYYY" # Replace with your dataset record ID
model_name: null # Auto-generate from model class and timestamp
log_metrics: false # Set to true to store metrics in Airtable (otherwise use TensorBoard)

model:
class_path: viscy.representation.contrastive.ContrastiveModule
init_args:
# Your model config here
backbone: resnet50
embedding_len: 256

data:
class_path: viscy.data.triplet.TripletDataModule
init_args:
data_path: /hpc/data/your_plate.zarr
tracks_path: /hpc/tracks/your_tracks/
source_channel: [Phase]
z_range: [0, 5]
initial_yx_patch_size: [512, 512]
final_yx_patch_size: [224, 224]
split_ratio: 0.8
batch_size: 16
num_workers: 8

# FOV selection from Airtable dataset definition
fit_include_wells: ["B3", "B4", "C3"]
fit_exclude_fovs: []

# Data augmentation
augmentations:
- class_path: viscy.transforms.RandAffined
init_args:
keys: [Phase]
prob: 0.8
rotate_range: [3.14, 0.0, 0.0]
scale_range: [0.1, 0.1, 0.1]
- class_path: viscy.transforms.RandGaussianNoised
init_args:
keys: [Phase]
prob: 0.5
mean: 0.0
std: 0.1

# Dataset metadata (for reference, not used by training)
dataset_metadata:
airtable_id: "recYYYYYYYYYYYYYY"
name: "RPE1_infection_v2"
version: "v2"
description: "RPE1 cells, infection experiment, wells B3-C3"
121 changes: 121 additions & 0 deletions examples/airtable/filter_n_create_dataset_tag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Filter datasets using pandas and create collection tags."""

# %%

from viscy.airtable.database import AirtableManager

# BASE_ID = os.getenv("AIRTABLE_BASE_ID")
BASE_ID = "app8vqaoWyOwa0sB5"
airtable_db = AirtableManager(base_id=BASE_ID)

# %%
# EXAMPLE 1: Get all dataset records as DataFrame and explore
print("=" * 70)
print("Getting all dataset records as DataFrame")
print("=" * 70)

df_datasets = airtable_db.list_datasets()
print(f"\nTotal dataset records: {len(df_datasets)}")
print("\nDataFrame columns:")
print(df_datasets.columns.tolist())
print("\nFirst few rows:")
print(df_datasets.head())

# %%
# EXAMPLE 2: Filter by dataset and specific wells using pandas
print("\n" + "=" * 70)
print("Filter: Dataset, Wells B_3 and B_4")
print("=" * 70)

# Get all dataset records as DataFrame
df = airtable_db.list_datasets()

# Filter with pandas - simple and powerful!
filtered = df[
(df["Dataset"] == "2024_11_07_A549_SEC61_DENV")
& (df["Well ID"].isin(["B/1", "B/2"]))
]

print(f"\nTotal dataset records after filtering: {len(filtered)}")
print("\nBreakdown by well:")
print(filtered.groupby("Well ID").size())

# Create collection from filtered dataset records
fov_ids = filtered["FOV_ID"].tolist()

try:
collection_id = airtable_db.create_collection_from_datasets(
collection_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2",
fov_ids=fov_ids,
version="0.0.2", # Semantic versioning
purpose="training",
description="Dataset records from wells B_3 and B_4",
)
print(f"\n✓ Created collection: {collection_id}")
print(f" Contains {len(fov_ids)} dataset records")
except ValueError as e:
print(f"\n⚠ {e}")

# %%
# Delete the collection entry demo
airtable_db.delete_collection(collection_id)
print(f"Deleted collection: {collection_id}")

# %%
# EXAMPLE 3: Group by dataset and show summary
print("\n" + "=" * 70)
print("Group by dataset and show summary")
print("=" * 70)

df_all = airtable_db.list_datasets()

grouped = df_all.groupby("Dataset")

for dataset_name, group in grouped:
print(f"\n{dataset_name}:")
print(f" Total records: {len(group)}")
print(f" Wells: {sorted(group['Well ID'].unique())}")

# %%
# EXAMPLE 4: Filter by multiple wells
print("\n" + "=" * 70)
print("Filter: Multiple specific wells")
print("=" * 70)

df = airtable_db.list_datasets()

# Filter for specific wells from a dataset
filtered = df[
(df["Dataset"] == "2024_11_07_A549_SEC61_DENV")
& (df["Well ID"].isin(["B/3", "B/4", "C/3", "C/4"]))
]

print(f"\nDataset records matching criteria: {len(filtered)}")
print("\nBy well:")
print(filtered.groupby("Well ID").size())

print("\nFOV IDs:")
for fov_id in filtered["FOV_ID"]:
print(f" {fov_id}")

# %%
# EXAMPLE 5: Summary statistics
print("\n" + "=" * 70)
print("Summary Statistics")
print("=" * 70)

df = airtable_db.list_datasets()

print("\nDataset records per source dataset:")
print(df.groupby("Dataset").size())

print("\nWells with most dataset records:")
print(df.groupby("Well ID").size().sort_values(ascending=False).head(10))

print("\nTotal unique wells:")
print(f"{df['Well ID'].nunique()} wells")

print("\nTotal unique FOV IDs:")
print(f"{df['FOV_ID'].nunique()} FOV IDs")

# %%
118 changes: 118 additions & 0 deletions examples/airtable/get_dataset_paths_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Example usage of get_dataset_paths with Collections and CollectionDataset dataclasses."""

# %%
from viscy.airtable.database import AirtableManager

BASE_ID = "app8vqaoWyOwa0sB5"
airtable_db = AirtableManager(base_id=BASE_ID)

# %%
# Fetch collection from Airtable
collection = airtable_db.get_dataset_paths(
collection_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2",
version="v1",
)

# %%
# Collections properties
print("=== Collections ===")
print(f"collection.name: {collection.name}")
print(f"collection.version: {collection.version}")
print(f"len(collection): {len(collection)} HCS plate(s)")
print(f"collection.total_fovs: {collection.total_fovs} FOVs")

# %%
# Iterate over CollectionDataset objects (one per HCS plate)
print("\n=== CollectionDataset ===")
for ds in collection:
print(f"ds.data_path: {ds.data_path}")
print(f"ds.tracks_path: {ds.tracks_path}")
print(f"len(ds): {len(ds)} FOVs")
print(f"ds.fov_names: {ds.fov_names[:3]}...")
print(f"ds.fov_paths: {ds.fov_paths[:2]}...")
print(f"ds.exists(): {ds.exists()}")

# %%
# Validate paths exist (raises FileNotFoundError if not)
collection.validate()
print("\nAll paths validated successfully!")


# %%
# List available collections
print("=== Available Collections ===")
df = airtable_db.list_collections()
print(df[["name", "version", "purpose"]].dropna(subset=["name"]).to_string())

# %%
# =============================================================================
# Create TripletDataModule from collection using factory function
# =============================================================================
from viscy.airtable.factory import create_triplet_datamodule_from_collection

# Create data module from collection
dm = create_triplet_datamodule_from_collection(
collection=collection,
source_channel=["Phase3D"],
z_range=(20, 21),
batch_size=1,
num_workers=1,
initial_yx_patch_size=(160, 160),
final_yx_patch_size=(160, 160),
return_negative=False,
time_interval=1,
)

# %%
# Setup and inspect the data module
dm.setup("fit")
print("\n=== TripletDataModule from Collections ===")
print(f"Data module type: {type(dm).__name__}")
print(f"Train samples: {len(dm.train_dataset)}")
print(f"Val samples: {len(dm.val_dataset)}")

# %%
# =============================================================================
# Alternative: CollectionTripletDataModule (Lightning Config Compatible)
# =============================================================================
from viscy.airtable.factory import CollectionTripletDataModule

# This class is designed for Lightning CLI and config files
# but can also be used directly in Python
dm_class = CollectionTripletDataModule(
base_id=BASE_ID,
collection_name="2024_11_07_A549_SEC61_DENV_wells_B1_B2",
collection_version="v1",
source_channel=["Phase3D"],
z_range=(20, 21),
batch_size=1,
num_workers=1,
initial_yx_patch_size=(160, 160),
final_yx_patch_size=(160, 160),
return_negative=False,
time_interval=1,
)

dm_class.setup("fit")
print("\n=== CollectionTripletDataModule (Class) ===")
print(f"Data module type: {type(dm_class).__name__}")
print(f"Train samples: {len(dm_class.train_dataset)}")
print(f"Val samples: {len(dm_class.val_dataset)}")

# %% Visualize some of the images
import matplotlib.pyplot as plt
import torch

img_stack = []
for idx, batch in enumerate(dm.train_dataloader()):
img_stack.append(batch["anchor"][0, 0, 0])
if idx >= 9:
break
img_stack = torch.stack(img_stack)
# %%
# Make subplot with 10 images
fig, axs = plt.subplots(2, 5, figsize=(10, 4))
for i in range(10):
axs[i // 5, i % 5].imshow(img_stack[i], cmap="gray")
axs[i // 5, i % 5].axis("off")
plt.show()
Loading
Loading