A reproducible framework for protein function prediction using sequence and GO-derived embeddings, with controlled data splitting, model sweeps, ensemble evaluation, and benchmark reporting.
Important: protein-embedding-classifier does not generate embeddings. It consumes precomputed embeddings generated by external models (for example ESM3c, ProtT5, Ankh3, and GeOKG) and uses them as feature vectors for supervised classification.
protein-embedding-classifier is a research-oriented pipeline for supervised protein function classification from precomputed embeddings.
- Build train/validation/test/zero-shot partitions from accession-level data.
- Load multiple embedding views (sequence and GO embeddings).
- Train and tune multiple classifiers per embedding view.
- Compare single-model and ensemble strategies under consistent evaluation boundaries.
- Export reproducibility and reporting artifacts suitable for scientific analysis.
Predicting protein function is constrained by sequence diversity, taxonomic drift, and label incompleteness. A robust evaluation workflow must:
- Separate model selection from final reporting.
- Quantify generalization on both in-distribution and strict holdout data.
- Prevent leakage across split boundaries.
Embeddings provide dense, transferable representations that can capture biochemical and evolutionary signal. This repository supports:
- Sequence embedding views from database-backed sources.
- GO embedding views from file-based sources.
- Comparable downstream classifiers over the same accession-level split definitions.
Different model/embedding combinations capture complementary errors. The framework includes soft-voting and majority-voting variants to test whether combining predictors improves robustness relative to best single models.
- Data ingestion
ProteinLoaderretrieves accession list + metadata (e.g., organism).LabelLoaderloads labels from CSV or SQL and aligns labels to available DB accessions.
- Embedding loading
SequenceEmbeddingLoaderloads sequence embeddings by model/layer from DB.GOEmbeddingLoaderloads BP/MF/CC embeddings from CSV and concatenates ontology vectors.EmbeddingBundlematerializes split-specific matrices for each embedding view.
- Label handling (single vs multilabel)
ProblemSpecification.from_labelsinfersbinary,multiclass, ormultilabel.- Multilabel targets are binarized with
MultiLabelBinarizer.
- SplitManager
IndependentValidationTrainTestSplitapplies validation, train/test, and zero-shot logic with overlap/coverage checks.
- Training pipeline
TrainingServicetrains per(classifier, embedding)and computes metrics on validation (and optionally test).
- Ensemble logic
SoftVotingServicehandles weighted soft voting and majority variants using persisted model artifacts.
- Benchmark orchestration
Pipeline.run_benchmark_stepcompares best single vs ensemble variants across seeds/ablations and exports summary artifacts.
[DB + CSV Inputs]
|
v
[ProteinLoader] -----> accession list + metadata
|
+----> [LabelLoader] -----> aligned labels + missing report
|
v
[IndependentValidationTrainTestSplit]
|--> train_ids
|--> val_ids
|--> test_ids
|--> zero_shot_ids (strict holdout)
v
[DatasetBundle]
|
+----> [SequenceEmbeddingLoader (DB)]
+----> [GOEmbeddingLoader (CSV)]
|
v
[EmbeddingService]
|
v
[EmbeddingBundle]
|
+--> [TrainingService / SweepService]
| |
| +--> validation metrics
| +--> final test metrics (optional)
|
+--> [SoftVotingService]
| |
| +--> uniform / validation-weighted / trainable / majority
|
+--> [Benchmark Step]
|
+--> benchmark_summary.csv/json
+--> benchmark_multiseed_summary.csv/json
+--> benchmark_ablation_summary.csv
+--> benchmark_weights_analysis.json
The implemented lifecycle is:
- Dataset loading
- Build accession universe and aligned labels.
- Validation split selection
- Select validation IDs first (
random,organism, orcsv).
- Select validation IDs first (
- Train/Test splitting
- Split the remaining IDs using
randomorcross_validation.
- Split the remaining IDs using
- Model training
- Train per classifier and embedding view.
- Validation scoring
- Compute validation metrics for model ranking and hyperparameter selection.
- Ensemble weight learning (optional)
- Fit ensemble weights using validation probabilities only.
- Final test evaluation
- Report test metrics on held-out test split.
- Zero-shot evaluation (strict holdout)
- Evaluate separately on zero-shot IDs if present.
- Validation is used for model selection and ensemble weighting.
- Test is used for performance reporting after model selection.
- Zero-shot is never used for training, threshold tuning, or weight learning.
The independent split configuration supports:
validation.strategy:random | organism | csvtrain_test.strategy:random | cross_validationzero_shot.strategy:random | organism | csv
- Select validation split.
- Split remaining IDs into train/test.
- Select zero-shot IDs from full accession universe.
- Remove zero-shot IDs from train/validation/test.
- Enforce partition integrity and coverage.
IndependentValidationTrainTestSplit enforces:
- No overlap between zero-shot and train/validation/test.
- Coverage consistency: all aligned IDs must belong to exactly one of train/val/test/zero-shot.
- Duplicate accession assignments across split values in CSV are rejected.
Pipeline-level benchmark checks also assert that:
train ∩ zero_shot = ∅validation ∩ zero_shot = ∅test ∩ zero_shot = ∅
The framework evaluates these core variants:
- Best single model
- Selects the single
(classifier, embedding)with highest validation F1.
- Selects the single
- Uniform soft voting
- Equal weights across selected embedding models.
- Validation-weighted soft voting
- Weights proportional to validation performance.
- Trainable weight soft voting
- Learns non-negative normalized weights via validation-driven optimization.
Additional majority-voting modes are also implemented (majority_global, majority_by_embedding, majority_by_classifier) for benchmark comparisons.
- Uniform: fixed
1 / n_models. - Validation-score-based: macro-F1-derived scores normalized into weights.
- Trainable: optimization over validation probabilities (Dirichlet sampling + objective on validation F1).
Validation probabilities and labels are the only inputs for ensemble weight fitting.
Zero-shot is treated as deployment-like unseen data. Using it for weighting would contaminate generalization assessment and violate strict holdout principles.
Pipeline.run_benchmark_step:
- Loads best persisted model artifacts from the latest sweep run.
- Rebuilds dataset/embeddings for each configured seed.
- Computes metrics for:
- best single baseline,
- configured ensemble variants.
- Computes deltas relative to best single.
- Optionally aggregates across seeds and ablations.
For each variant and seed:
- Validation metrics: accuracy/precision/recall/F1.
- Test metrics: accuracy/precision/recall/F1.
- Zero-shot metrics (if non-empty split).
Delta vs Best Single (Test).Delta vs Best Single (Zero-Shot).- Generalization gap: validation F1 minus test F1.
Under the latest sweep run results/ directory:
benchmark_summary.csvbenchmark_summary.jsonbenchmark_multiseed_summary.csvbenchmark_multiseed_summary.jsonbenchmark_ablation_summary.csvbenchmark_weights_analysis.json
When multiple seeds are configured, aggregated outputs report mean/std summaries (including zero-shot F1 statistics where available).
Zero-shot evaluation is a strict holdout mechanism designed to probe out-of-distribution behavior.
- Test model behavior on data excluded from all learning and selection stages.
- Complement standard test performance with a stronger generalization stress test.
Real-world deployment often faces proteins from unseen organisms or unseen split strata. Zero-shot simulates this distribution shift.
Zero-shot IDs are removed from train/validation/test and must remain unused for:
- model fitting,
- threshold tuning,
- ensemble weight fitting.
The pipeline includes explicit reproducibility mechanisms:
- Seed handling
- Benchmark supports per-seed reruns (
seedsin benchmark config). - Split randomness is re-seeded through cloned dataset split config.
- Benchmark supports per-seed reruns (
- Config snapshot
- Sweep run stores resolved pipeline/training configs and run metadata.
- Artifact reuse
- Benchmark and ensemble consume persisted model artifacts from sweep/final training outputs.
- Cross-validation consistency
train_test.strategy: cross_validationwith explicitn_splits,fold_index,random_stateensures deterministic fold selection.
- Metadata provenance
- Run metadata includes git commit and package versions.
- Benchmark records artifact hashes for integrity tracking.
Configuration is YAML-based and split into:
config/pipeline.yaml(dataset/split/step orchestration)config/embeddings.yaml(embedding sources and model toggles)config/training/training_config.yaml(wandb/final training/reporting)config/model_sweep/*.yaml(classifier-specific hyperparameter spaces)
dataset:
db_config_path: config/db.yaml
label_loader:
source: file
file_path: /path/to/dataset.csv
accession_column: uniprot_id
label_column: data_class
artifacts_dir: artifacts
split:
validation:
strategy: csv
csv:
csv_path: /path/to/dataset.csv
accession_column: uniprot_id
split_column: data_split
validation_values: [val, validation]
train_test:
strategy: cross_validation
cross_validation:
n_splits: 5
fold_index: 0
random_state: 42
zero_shot:
strategy: csv
csv:
csv_path: /path/to/dataset.csv
accession_column: uniprot_id
split_column: data_split
validation_values: [zs, zero_shot]
training_config_path: config/training/training_config.yaml
experiment:
main_seed: 42
global_benchmark:
n_seeds: 10Install dependencies:
poetry installRun classifier sweeps:
poetry run pec --step sweep
# Optional filters:
# poetry run pec --step sweep --classifier xgb
# poetry run pec --step sweep --embedding_name GeOKGRun benchmark comparison on latest sweep artifacts:
poetry run pec --step benchmarkRun ensemble inference on latest sweep artifacts:
poetry run pec --step ensembleRun full multi-seed global benchmark orchestration:
poetry run pec --step global_benchmark
# Optional CLI override:
# poetry run pec --step global_benchmark --n_seeds 5--step sweep- Builds dataset + embeddings, runs classifier-specific sweep trials, stores best configs, and optionally performs final retraining with persisted model artifacts.
--step benchmark- Loads persisted best models, evaluates single vs ensemble variants on validation/test/zero-shot, computes deltas, writes summary artifacts.
--step ensemble- Loads selected persisted models, fits ensemble weighting on validation split, predicts on test split, and saves ensemble artifacts.
--step global_benchmark- Deterministically generates seeds from
experiment.main_seed, executessweep -> ensemble -> benchmarkper seed, and writes results underresults/global_benchmark/executions/run_seed_<seed>/.
- Deterministically generates seeds from
Primary output root is defined by reporting.output_root (default resolves to ../../pec_data).
Typical layout:
pec_data/
dataset/
default_dataset.csv
logs/
<run_name>.log
sweep/
<run_prefix>_<timestamp>/
configs/
resolved_pipeline.yaml
resolved_training.yaml
run_metadata.json
models/
*.pkl / *.pt
*.metadata.json
ensemble_model.pkl (if ensemble step)
reports/
sweep_results_full.csv
best_per_classifier.csv
best_classifier_per_embedding.csv
final_test_results.csv
predictions/
predictions_test.csv
results/
benchmark_summary.csv
benchmark_summary.json
benchmark_multiseed_summary.csv
benchmark_multiseed_summary.json
benchmark_ablation_summary.csv
benchmark_weights_analysis.json
- Add config entries in
config/embeddings.yaml. - If a new source type is required, implement loader logic in
core/embedding_loading.py. - Ensure outputs map accession -> vector and integrate into
EmbeddingService.load_embeddings().
- Register constructor logic in
core/training/model_factory.py. - Add sweep config under
config/model_sweep/. - Register default sweep path in
Pipeline._resolve_classifier_sweeps.
- Implement a new
WeightingStrategysubclass incore/ensemble/soft_voting_service.py. - Extend strategy factory wiring (
create_weighting_strategyand config mapping path). - Add benchmark variant specification if comparative evaluation is needed.
- Implement strategy under
data/splits/. - Wire it through
Pipeline._build_split_strategy. - Add configuration schema in
config/pipeline.yamland tests for overlap/coverage guarantees.
This framework is designed to support paper-grade experimentation:
- Comparative evaluation
- Standardized scoring across multiple embeddings and classifiers.
- Ablation studies
- Benchmark ablations via selection subsets (
embeddings/classifiers).
- Benchmark ablations via selection subsets (
- Multi-seed experiments
- Seed-level runs with aggregated mean/std reporting.
- Robust reporting
- Explicit separation of validation, test, and zero-shot roles.
- Leakage checks and reproducibility metadata integrated into the run artifacts.
This section documents the formal statistical comparison used for paper-quality benchmarking.
- Single-seed results can be unstable due to split randomness and model stochasticity.
- Multi-seed evaluation estimates performance variability and supports inferential statistics.
--step global_benchmarkruns repeated experiments and stores per-seed benchmark outputs under:results/global_benchmark/executions/run_seed_<seed>/
dataset
↓
sweep
↓
ensemble
↓
benchmark
↓
global_benchmark (multi-seed)
↓
statistical analysis
After multi-seed aggregation, the pipeline runs statistical analysis automatically and writes:
results/global_benchmark/aggregated/model_embedding_benchmark.csvresults/global_benchmark/aggregated/ensemble_strategy_benchmark.csvresults/global_benchmark/aggregated/ranking_tables.csvresults/global_benchmark/predictions/model_predictions/seed_<seed>/*.csvresults/global_benchmark/predictions/ensemble_predictions/seed_<seed>/*.csvresults/global_benchmark/metadata/experiment_seeds.jsonresults/global_benchmark/metadata/experiment_config_snapshot.yamlresults/global_benchmark/statistics/friedman_results.jsonresults/global_benchmark/statistics/nemenyi_results.csvresults/global_benchmark/statistics/model_rankings.csvresults/global_benchmark/statistics/critical_difference_diagram.png(optional, when plotting dependencies are available)
- Input: a complete matrix with rows=
seedsand columns=model configurations. - Metric: benchmark
test F1per configuration per seed. - Null hypothesis: all compared configurations have equal performance distributions.
Interpretation example:
If the Friedman test p-value < 0.05, we reject the null hypothesis that all models perform equally. We then apply the Nemenyi post-hoc test to determine which model pairs differ significantly.
- Executed when Friedman is significant (
p < 0.05). - Performs pairwise comparisons between configurations.
- Reports pairwise p-values and critical-difference-based significance.
- Statistical comparison is based on ranks per seed, not only raw metric magnitudes.
- This reduces sensitivity to absolute scale shifts between runs.
- Lower average rank indicates better overall performance across seeds.
avg_rankinmodel_rankings.csv:- lower = better overall configuration.
significantinnemenyi_results.csv:Truemeans the rank difference exceeds the Nemenyi critical difference at the configured alpha.
friedman_results.json:p_value < 0.05indicates at least one configuration differs significantly from the others.
- Sequence and GO embeddings are actively loaded by
EmbeddingService. structurePEappears in configuration but is not currently integrated into the active embedding loading path.- Benchmark and ensemble steps operate on persisted sweep/final-training artifacts from the latest run directory.