Skip to content

smestern/pypatchOTDA

Repository files navigation

patchOTDA

What is this?

The aim of this package is to facilitate the integration of patch clamp electrophysiology datasets. Due to the sensitivity of patch clamp electrophysiological recordings to a variety of different extraneous variables, for example, temperature, solution, region, etc.

patchOTDA is a small python package that wraps several optimal transport based domain adaptation packages. The package aims to help intermediate users integrate two datasets by following simple OOP conventions.

End users are encouraged to check out the streamlit app: https://patchotda.streamlit.app/. This app allows you to integrate your dataset with a reference dataset from the Allen Institute.

Quickstart

Install

The package is not currently available on pip but can be installed by pulling the git repo

pip install git+https://github.com/smestern/pypatchOTDA.git

Should install the packages and their dependencies. To use SKADA, and FUGW transporters, the user will need to install these additional dependencies manually

pip install git+https://github.com/scikit-adaptation/skada
pip install unbalancedgw

Basic usage

The basic patchOTDA object is a wrapper around the POT domain adaptation method(s). Here we implement a few error catches and manipulations to allow the user to easily plug and play their data. To begin with, you can initiate the patchOTDA object

import patchOTDA.domainAdapt as pOTDA
from patchOTDA.datasets import MMS_DATA
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
import matplotlib.pyplot as plt
pOTDA.timeout = 10
da = pOTDA.PatchClampOTDA()

# Load data
Xs = MMS_DATA['CTKE_M1']['ephys']
Xt = MMS_DATA['VISp_Viewer']['ephys']
print(MMS_DATA.keys())
#make sure the features are the same
Xs = Xs.loc[:,MMS_DATA['joint_feats']].to_numpy()
Xt = Xt.loc[:,MMS_DATA['joint_feats']].to_numpy()

plt.plot(Xt[:,0],Xt[:,1],'o', c='k', label='VISp_Viewer')
plt.plot(Xs[:,0],Xs[:,1],'o', c='r', label='CTKE_M1')
plt.xlabel(MMS_DATA['joint_feats'][0])
plt.ylabel(MMS_DATA['joint_feats'][1])
plt.legend()
dict_keys(['CTKE_M1', 'VISp_Viewer', 'Query1', 'Query2', 'Query3', 'joint_feats'])





<matplotlib.legend.Legend at 0x77c4045ea8d0>

png

Preprocessing

Scaling is a needed step, as the OT methods can be unstable to converge if the scales are large.
In most cases its good practice to scale and impute (remove NaNs) prior to applying a domain adaptation method.
In standard ML practice, it's a good idea to apply preprocessing steps to each dataset independently.
This is also applicable here, however, sometimes, we want to apply our scaler to the same dataset -> if we are using a LinearOTmapping etc. so we can properly learn the difference between the datasets.

# Preprocess data
scaler = StandardScaler()
imputer = SimpleImputer(strategy='mean')
Xs = imputer.fit_transform(Xs)
Xt = imputer.transform(Xt)
Xs = scaler.fit_transform(Xs)
Xt = scaler.transform(Xt)

Transforming your data

The transporter follows the fit, fit_transform conventions as introduced in the POT package

# Fit model
da.fit(Xs, Xt)
#transport
Xs_transp = da.transform(Xs, Xt)
plt.plot(Xt[:,0],Xt[:,1],'o', c='k', label='VISp')
plt.plot(Xs_transp[:,0],Xs_transp[:,1],'x', c='b', label='CTKE_M1 (transp)')
plt.plot(Xs[:,0],Xs[:,1],'o', c='r', label='CTKE_M1 (original)', alpha=0.1)
plt.xlabel(MMS_DATA['joint_feats'][0])
plt.ylabel(MMS_DATA['joint_feats'][1])
plt.legend()
/home/smestern/miniconda3/envs/patchotda/lib/python3.11/site-packages/ot/lp/_network_simplex.py:332: UserWarning: numItermax reached before optimality. Try to increase numItermax.
  result_code_string = check_result(result_code)





<matplotlib.legend.Legend at 0x77c4043af290>

png

Tuning the transporter

patchOTDA includes automatic hyperparameter tuning powered by nevergrad. The tune() method searches over regularization parameters (and optionally across different OT methods) to find the best transport plan for your data.

Two tuning strategies are available:

Method Description When to use
'unidirectional' Transports Xs→Xt and scores the result with an error function (gw_dist for unsupervised, rf_clf_dist for supervised). Default. Works well when you have a clear source→target direction.
'bidirectional' Round-trip reconstruction: Xt→Xs→Xt. Measures how well the transport preserves structure. Requires flexible_transporter=True. When you want the tuner to also search across OT methods.

Key parameters:

  • n_iter – Total number of hyperparameter evaluations.
  • n_jobs – Parallelism (default -1 uses all cores).
  • supervised – If True and labels Ys/Yt are provided, uses a classifier-based error metric.
  • error_func – Custom error function with signature (Xs, Xt, Ys, Yt) → float. Overrides the default.
# Unsupervised tuning with a fixed transporter
# tune() searches over regularization parameters for the default EMDLaplaceTransport
da_tuned = pOTDA.PatchClampOTDA()
da_tuned.tune(Xs, Xt, n_iter=10, n_jobs=2, method='unidirectional', verbose=True)

# After tuning, fit and transform with the best parameters
Xs_tuned = da_tuned.fit_transform(Xs, Xt)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Before tuning (default params)
axes[0].plot(Xt[:,0], Xt[:,1], 'o', c='k', label='VISp_Viewer')
axes[0].plot(Xs[:,0], Xs[:,1], 'o', c='r', alpha=0.3, label='CTKE_M1 (original)')
axes[0].plot(Xs_transp[:,0], Xs_transp[:,1], 'x', c='r', label='CTKE_M1 (default)')
axes[0].set_title('Default parameters')
axes[0].set_xlabel(MMS_DATA['joint_feats'][0])
axes[0].set_ylabel(MMS_DATA['joint_feats'][1])
axes[0].legend()

# After tuning
axes[1].plot(Xt[:,0], Xt[:,1], 'o', c='k', label='VISp_Viewer')
axes[1].plot(Xs[:,0], Xs[:,1], 'o', c='r', alpha=0.3, label='CTKE_M1 (original)')
axes[1].plot(Xs_transp[:,0], Xs_transp[:,1], 'x', c='b', label='CTKE_M1 (default)')
axes[1].plot(Xs_tuned[:,0], Xs_tuned[:,1], 'x', c='g', label='CTKE_M1 (tuned)')
axes[1].set_title('Tuned parameters')
axes[1].set_xlabel(MMS_DATA['joint_feats'][0])
axes[1].set_ylabel(MMS_DATA['joint_feats'][1])
axes[1].legend()

plt.tight_layout()
print(f"Best parameters found: {da_tuned.best_}")
INFO:root:Setting verbosity to INFO
INFO:patchOTDA.domainAdapt:Using default error function for unsupervised learning
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.0037214547623536285
INFO:patchOTDA.domainAdapt:round 1/5 (2/10 evals) best score: 0.0037214547623536285
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.0037214547623536285
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.0004574917283814664
INFO:patchOTDA.domainAdapt:round 2/5 (4/10 evals) best score: 0.0004574917283814664
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.0004343028676499454
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.014632278916336902
INFO:patchOTDA.domainAdapt:round 3/5 (6/10 evals) best score: 0.0004343028676499454
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.001137250131825023
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.0035514806021804185
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.000476606028069165
INFO:patchOTDA.domainAdapt:round 4/5 (8/10 evals) best score: 0.000476606028069165
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.00047924405361193155
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.00047924405361193155
INFO:patchOTDA.domainAdapt:round 5/5 (10/10 evals) best score: 0.00047924405361193155
INFO:patchOTDA.domainAdapt:best kwargs:
INFO:patchOTDA.domainAdapt:{'reg_type': 1.5013107289081747e-09, 'reg_lap': 4.328761281083057e-07, 'reg_src': 2.0309176209047394e-08, 'metric': 'sqeuclidean', 'norm': 'median', 'similarity_param': 842.8947368421053, 'transporter': <class 'ot.da.EMDLaplaceTransport'>}
/home/smestern/miniconda3/envs/patchotda/lib/python3.11/site-packages/ot/lp/_network_simplex.py:332: UserWarning: numItermax reached before optimality. Try to increase numItermax.
  result_code_string = check_result(result_code)


Best parameters found: {'reg_type': 1.5013107289081747e-09, 'reg_lap': 4.328761281083057e-07, 'reg_src': 2.0309176209047394e-08, 'metric': 'sqeuclidean', 'norm': 'median', 'similarity_param': 842.8947368421053}

png

Flexible transporter tuning

Setting flexible_transporter=True lets the tuner search across multiple OT methods (e.g., Sinkhorn, EMD, EMDLaplace) in addition to their hyperparameters. This is useful when you're not sure which transport method is best for your data.

# Flexible tuning — search across OT methods AND their parameters
da_flex = pOTDA.PatchClampOTDA(flexible_transporter=True)
da_flex.tune(Xs, Xt, n_iter=10, n_jobs=2, method='unidirectional', verbose=True)

Xs_flex = da_flex.fit_transform(Xs, Xt)

plt.figure(figsize=(6, 5))
plt.plot(Xt[:,0], Xt[:,1], 'o', c='k', label='VISp_Viewer')
plt.plot(Xs[:,0], Xs[:,1], 'o', c='r', alpha=0.3, label='CTKE_M1 (original)')
plt.plot(Xs_flex[:,0], Xs_flex[:,1], 'x', c='g', label='CTKE_M1 (flex tuned)')
plt.title('Flexible transporter tuning')
plt.xlabel(MMS_DATA['joint_feats'][0])
plt.ylabel(MMS_DATA['joint_feats'][1])
plt.legend()

print(f"Best parameters found: {da_flex.best_}")
print(f"Selected transporter: {da_flex.inittransporter}")
INFO:patchOTDA.domainAdapt:Initialized PatchClampOTDA with transporter: <class 'ot.da.EMDLaplaceTransport'>
INFO:root:Setting verbosity to INFO
INFO:patchOTDA.domainAdapt:Using default error function for unsupervised learning
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.0004627458159726853
INFO:patchOTDA.domainAdapt:round 1/5 (2/10 evals) best score: 0.0004627458159726853
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.00372145476235363
INFO:patchOTDA.domainAdapt:round 2/5 (4/10 evals) best score: 0.00372145476235363
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.0037214547623536363
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:round 3/5 (6/10 evals) best score: 0.0037214547623536363
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.0037214547623536363
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.000479244053611932
INFO:patchOTDA.domainAdapt:round 4/5 (8/10 evals) best score: 0.000479244053611932
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.0037214547623536363
INFO:patchOTDA.domainAdapt:Gromov-Wasserstein distance: 0.000479244053611932
INFO:patchOTDA.domainAdapt:round 5/5 (10/10 evals) best score: 0.000479244053611932
INFO:patchOTDA.domainAdapt:best kwargs:
INFO:patchOTDA.domainAdapt:{'reg_type': 6.074424725526647e-08, 'reg_lap': 3.1187165404930406e-07, 'reg_src': 0.004369678030757057, 'metric': 'sqeuclidean', 'norm': 'max', 'similarity_param': 148.01980198019803, 'transporter': <class 'ot.da.EMDLaplaceTransport'>}
/home/smestern/miniconda3/envs/patchotda/lib/python3.11/site-packages/ot/da.py:1366: RuntimeWarning: invalid value encountered in divide
  transp_Xs_ = nx.dot(K, self.xt_) / nx.sum(K, axis=1)[:, None]
/home/smestern/miniconda3/envs/patchotda/lib/python3.11/site-packages/ot/da.py:1366: RuntimeWarning: invalid value encountered in divide
  transp_Xs_ = nx.dot(K, self.xt_) / nx.sum(K, axis=1)[:, None]
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
/home/smestern/miniconda3/envs/patchotda/lib/python3.11/site-packages/ot/lp/_network_simplex.py:332: UserWarning: Problem unbounded
  result_code_string = check_result(result_code)
/home/smestern/miniconda3/envs/patchotda/lib/python3.11/site-packages/ot/lp/_network_simplex.py:332: UserWarning: numItermax reached before optimality. Try to increase numItermax.
  result_code_string = check_result(result_code)


Best parameters found: {'reg_type': 6.074424725526647e-08, 'reg_lap': 3.1187165404930406e-07, 'reg_src': 0.004369678030757057, 'metric': 'sqeuclidean', 'norm': 'max', 'similarity_param': 148.01980198019803}
Selected transporter: <class 'ot.da.EMDLaplaceTransport'>

png

Custom error functions

You can also pass your own error function to tune(). The function must accept (Xs, Xt, Ys, Yt) and return a scalar distance (lower is better). Several built-in metrics are available under pOTDA.metrics:

Metric Type Description
gw_dist Unsupervised Gromov-Wasserstein distance between transported and target distributions
rf_clf_dist Supervised 1 − balanced accuracy of a Random Forest trained on Xt predicting Ys
normalized_mse Unsupervised Normalized mean squared error
# Example: tuning with a custom error function
import numpy as np

def my_custom_error(Xs, Xt, Ys, Yt):
    """Custom error: mean absolute difference of column-wise standard deviations."""
    if np.all(np.isnan(Xs)) or np.all(Xs == 0):
        return 9e5  # penalty for degenerate solutions
    return np.mean(np.abs(np.nanstd(Xs, axis=0) - np.nanstd(Xt, axis=0)))

da_custom = pOTDA.PatchClampOTDA()
da_custom.tune(Xs, Xt, n_iter=10, n_jobs=2, method='unidirectional',
               error_func=my_custom_error, verbose=True)

Xs_custom = da_custom.fit_transform(Xs, Xt)
print(f"Best parameters: {da_custom.best_}")
INFO:patchOTDA.domainAdapt:Initialized PatchClampOTDA with transporter: <class 'ot.da.EMDLaplaceTransport'>
INFO:root:Setting verbosity to INFO
INFO:patchOTDA.domainAdapt:Using user provided error function
INFO:patchOTDA.domainAdapt:round 1/5 (2/10 evals) best score: 0.08340754943063107
INFO:patchOTDA.domainAdapt:round 2/5 (4/10 evals) best score: 0.04395830422290997
INFO:patchOTDA.domainAdapt:round 3/5 (6/10 evals) best score: 0.04614006066698104
INFO:patchOTDA.domainAdapt:Computing Gromov-Wasserstein distance for error function
INFO:patchOTDA.domainAdapt:round 4/5 (8/10 evals) best score: 0.0827881833917348
INFO:patchOTDA.domainAdapt:round 5/5 (10/10 evals) best score: 0.05331896717248962
INFO:patchOTDA.domainAdapt:best kwargs:
INFO:patchOTDA.domainAdapt:{'reg_type': 1.9684194472866142e-05, 'reg_lap': 5.336699231206318e-08, 'reg_src': 5.878016072274918e-10, 'metric': 'sqeuclidean', 'norm': 'max', 'similarity_param': 325.4210526315789, 'transporter': <class 'ot.da.EMDLaplaceTransport'>}
/home/smestern/miniconda3/envs/patchotda/lib/python3.11/site-packages/ot/lp/_network_simplex.py:332: UserWarning: Problem unbounded
  result_code_string = check_result(result_code)
/home/smestern/miniconda3/envs/patchotda/lib/python3.11/site-packages/ot/lp/_network_simplex.py:332: UserWarning: numItermax reached before optimality. Try to increase numItermax.
  result_code_string = check_result(result_code)


Best parameters: {'reg_type': 1.9684194472866142e-05, 'reg_lap': 5.336699231206318e-08, 'reg_src': 5.878016072274918e-10, 'metric': 'sqeuclidean', 'norm': 'max', 'similarity_param': 325.4210526315789}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors