Skip to content

alexdorocode/protein-embedding-classifier

Repository files navigation

protein-embedding-classifier

A reproducible framework for protein function prediction using sequence and GO-derived embeddings, with controlled data splitting, model sweeps, ensemble evaluation, and benchmark reporting.

Important: protein-embedding-classifier does not generate embeddings. It consumes precomputed embeddings generated by external models (for example ESM3c, ProtT5, Ankh3, and GeOKG) and uses them as feature vectors for supervised classification.


1. Project Overview

protein-embedding-classifier is a research-oriented pipeline for supervised protein function classification from precomputed embeddings.

High-level system purpose

  • Build train/validation/test/zero-shot partitions from accession-level data.
  • Load multiple embedding views (sequence and GO embeddings).
  • Train and tune multiple classifiers per embedding view.
  • Compare single-model and ensemble strategies under consistent evaluation boundaries.
  • Export reproducibility and reporting artifacts suitable for scientific analysis.

Biological motivation

Predicting protein function is constrained by sequence diversity, taxonomic drift, and label incompleteness. A robust evaluation workflow must:

  • Separate model selection from final reporting.
  • Quantify generalization on both in-distribution and strict holdout data.
  • Prevent leakage across split boundaries.

Why embedding-based models

Embeddings provide dense, transferable representations that can capture biochemical and evolutionary signal. This repository supports:

  • Sequence embedding views from database-backed sources.
  • GO embedding views from file-based sources.
  • Comparable downstream classifiers over the same accession-level split definitions.

Why ensemble methods

Different model/embedding combinations capture complementary errors. The framework includes soft-voting and majority-voting variants to test whether combining predictors improves robustness relative to best single models.


2. System Architecture

Components

  • Data ingestion
    • ProteinLoader retrieves accession list + metadata (e.g., organism).
    • LabelLoader loads labels from CSV or SQL and aligns labels to available DB accessions.
  • Embedding loading
    • SequenceEmbeddingLoader loads sequence embeddings by model/layer from DB.
    • GOEmbeddingLoader loads BP/MF/CC embeddings from CSV and concatenates ontology vectors.
    • EmbeddingBundle materializes split-specific matrices for each embedding view.
  • Label handling (single vs multilabel)
    • ProblemSpecification.from_labels infers binary, multiclass, or multilabel.
    • Multilabel targets are binarized with MultiLabelBinarizer.
  • SplitManager
    • IndependentValidationTrainTestSplit applies validation, train/test, and zero-shot logic with overlap/coverage checks.
  • Training pipeline
    • TrainingService trains per (classifier, embedding) and computes metrics on validation (and optionally test).
  • Ensemble logic
    • SoftVotingService handles weighted soft voting and majority variants using persisted model artifacts.
  • Benchmark orchestration
    • Pipeline.run_benchmark_step compares best single vs ensemble variants across seeds/ablations and exports summary artifacts.

Text architecture diagram

[DB + CSV Inputs]
      |
      v
[ProteinLoader] -----> accession list + metadata
      |
      +----> [LabelLoader] -----> aligned labels + missing report
      |
      v
[IndependentValidationTrainTestSplit]
      |--> train_ids
      |--> val_ids
      |--> test_ids
      |--> zero_shot_ids (strict holdout)
      v
[DatasetBundle]
      |
      +----> [SequenceEmbeddingLoader (DB)]
      +----> [GOEmbeddingLoader (CSV)]
                 |
                 v
            [EmbeddingService]
                 |
                 v
            [EmbeddingBundle]
                 |
                 +--> [TrainingService / SweepService]
                 |       |
                 |       +--> validation metrics
                 |       +--> final test metrics (optional)
                 |
                 +--> [SoftVotingService]
                 |       |
                 |       +--> uniform / validation-weighted / trainable / majority
                 |
                 +--> [Benchmark Step]
                         |
                         +--> benchmark_summary.csv/json
                         +--> benchmark_multiseed_summary.csv/json
                         +--> benchmark_ablation_summary.csv
                         +--> benchmark_weights_analysis.json

3. Experimental Design Logic

The implemented lifecycle is:

  1. Dataset loading
    • Build accession universe and aligned labels.
  2. Validation split selection
    • Select validation IDs first (random, organism, or csv).
  3. Train/Test splitting
    • Split the remaining IDs using random or cross_validation.
  4. Model training
    • Train per classifier and embedding view.
  5. Validation scoring
    • Compute validation metrics for model ranking and hyperparameter selection.
  6. Ensemble weight learning (optional)
    • Fit ensemble weights using validation probabilities only.
  7. Final test evaluation
    • Report test metrics on held-out test split.
  8. Zero-shot evaluation (strict holdout)
    • Evaluate separately on zero-shot IDs if present.

Evaluation boundaries

  • Validation is used for model selection and ensemble weighting.
  • Test is used for performance reporting after model selection.
  • Zero-shot is never used for training, threshold tuning, or weight learning.

4. Split Strategy System

The independent split configuration supports:

  • validation.strategy: random | organism | csv
  • train_test.strategy: random | cross_validation
  • zero_shot.strategy: random | organism | csv

Order of application

  1. Select validation split.
  2. Split remaining IDs into train/test.
  3. Select zero-shot IDs from full accession universe.
  4. Remove zero-shot IDs from train/validation/test.
  5. Enforce partition integrity and coverage.

Leakage safeguards philosophy

IndependentValidationTrainTestSplit enforces:

  • No overlap between zero-shot and train/validation/test.
  • Coverage consistency: all aligned IDs must belong to exactly one of train/val/test/zero-shot.
  • Duplicate accession assignments across split values in CSV are rejected.

Pipeline-level benchmark checks also assert that:

  • train ∩ zero_shot = ∅
  • validation ∩ zero_shot = ∅
  • test ∩ zero_shot = ∅

5. Ensemble Decision System

The framework evaluates these core variants:

  1. Best single model
    • Selects the single (classifier, embedding) with highest validation F1.
  2. Uniform soft voting
    • Equal weights across selected embedding models.
  3. Validation-weighted soft voting
    • Weights proportional to validation performance.
  4. Trainable weight soft voting
    • Learns non-negative normalized weights via validation-driven optimization.

Additional majority-voting modes are also implemented (majority_global, majority_by_embedding, majority_by_classifier) for benchmark comparisons.

How weights are computed

  • Uniform: fixed 1 / n_models.
  • Validation-score-based: macro-F1-derived scores normalized into weights.
  • Trainable: optimization over validation probabilities (Dirichlet sampling + objective on validation F1).

When validation is used

Validation probabilities and labels are the only inputs for ensemble weight fitting.

Why zero-shot is excluded from weighting

Zero-shot is treated as deployment-like unseen data. Using it for weighting would contaminate generalization assessment and violate strict holdout principles.


6. Benchmark Step

What --step benchmark does

Pipeline.run_benchmark_step:

  • Loads best persisted model artifacts from the latest sweep run.
  • Rebuilds dataset/embeddings for each configured seed.
  • Computes metrics for:
    • best single baseline,
    • configured ensemble variants.
  • Computes deltas relative to best single.
  • Optionally aggregates across seeds and ablations.

What is computed

For each variant and seed:

  • Validation metrics: accuracy/precision/recall/F1.
  • Test metrics: accuracy/precision/recall/F1.
  • Zero-shot metrics (if non-empty split).
  • Delta vs Best Single (Test).
  • Delta vs Best Single (Zero-Shot).
  • Generalization gap: validation F1 minus test F1.

Exported outputs

Under the latest sweep run results/ directory:

  • benchmark_summary.csv
  • benchmark_summary.json
  • benchmark_multiseed_summary.csv
  • benchmark_multiseed_summary.json
  • benchmark_ablation_summary.csv
  • benchmark_weights_analysis.json

Multi-seed aggregation

When multiple seeds are configured, aggregated outputs report mean/std summaries (including zero-shot F1 statistics where available).


7. Zero-Shot Evaluation Philosophy

Zero-shot evaluation is a strict holdout mechanism designed to probe out-of-distribution behavior.

Purpose

  • Test model behavior on data excluded from all learning and selection stages.
  • Complement standard test performance with a stronger generalization stress test.

Why it simulates deployment

Real-world deployment often faces proteins from unseen organisms or unseen split strata. Zero-shot simulates this distribution shift.

Isolation requirement

Zero-shot IDs are removed from train/validation/test and must remain unused for:

  • model fitting,
  • threshold tuning,
  • ensemble weight fitting.

8. Reproducibility Guarantees

The pipeline includes explicit reproducibility mechanisms:

  • Seed handling
    • Benchmark supports per-seed reruns (seeds in benchmark config).
    • Split randomness is re-seeded through cloned dataset split config.
  • Config snapshot
    • Sweep run stores resolved pipeline/training configs and run metadata.
  • Artifact reuse
    • Benchmark and ensemble consume persisted model artifacts from sweep/final training outputs.
  • Cross-validation consistency
    • train_test.strategy: cross_validation with explicit n_splits, fold_index, random_state ensures deterministic fold selection.
  • Metadata provenance
    • Run metadata includes git commit and package versions.
    • Benchmark records artifact hashes for integrity tracking.

9. Configuration System

Configuration is YAML-based and split into:

  • config/pipeline.yaml (dataset/split/step orchestration)
  • config/embeddings.yaml (embedding sources and model toggles)
  • config/training/training_config.yaml (wandb/final training/reporting)
  • config/model_sweep/*.yaml (classifier-specific hyperparameter spaces)

Example configuration snippet

dataset:
  db_config_path: config/db.yaml
  label_loader:
    source: file
    file_path: /path/to/dataset.csv
    accession_column: uniprot_id
    label_column: data_class
    artifacts_dir: artifacts
  split:
    validation:
      strategy: csv
      csv:
        csv_path: /path/to/dataset.csv
        accession_column: uniprot_id
        split_column: data_split
        validation_values: [val, validation]
    train_test:
      strategy: cross_validation
      cross_validation:
        n_splits: 5
        fold_index: 0
        random_state: 42
    zero_shot:
      strategy: csv
      csv:
        csv_path: /path/to/dataset.csv
        accession_column: uniprot_id
        split_column: data_split
        validation_values: [zs, zero_shot]

training_config_path: config/training/training_config.yaml

experiment:
  main_seed: 42
  global_benchmark:
    n_seeds: 10

10. CLI Usage

Install dependencies:

poetry install

Run classifier sweeps:

poetry run pec --step sweep
# Optional filters:
# poetry run pec --step sweep --classifier xgb
# poetry run pec --step sweep --embedding_name GeOKG

Run benchmark comparison on latest sweep artifacts:

poetry run pec --step benchmark

Run ensemble inference on latest sweep artifacts:

poetry run pec --step ensemble

Run full multi-seed global benchmark orchestration:

poetry run pec --step global_benchmark
# Optional CLI override:
# poetry run pec --step global_benchmark --n_seeds 5

What each step does internally

  • --step sweep
    • Builds dataset + embeddings, runs classifier-specific sweep trials, stores best configs, and optionally performs final retraining with persisted model artifacts.
  • --step benchmark
    • Loads persisted best models, evaluates single vs ensemble variants on validation/test/zero-shot, computes deltas, writes summary artifacts.
  • --step ensemble
    • Loads selected persisted models, fits ensemble weighting on validation split, predicts on test split, and saves ensemble artifacts.
  • --step global_benchmark
    • Deterministically generates seeds from experiment.main_seed, executes sweep -> ensemble -> benchmark per seed, and writes results under results/global_benchmark/executions/run_seed_<seed>/.

11. Output Artifacts

Primary output root is defined by reporting.output_root (default resolves to ../../pec_data).

Typical layout:

pec_data/
  dataset/
    default_dataset.csv
  logs/
    <run_name>.log
  sweep/
    <run_prefix>_<timestamp>/
      configs/
        resolved_pipeline.yaml
        resolved_training.yaml
        run_metadata.json
      models/
        *.pkl / *.pt
        *.metadata.json
        ensemble_model.pkl (if ensemble step)
      reports/
        sweep_results_full.csv
        best_per_classifier.csv
        best_classifier_per_embedding.csv
        final_test_results.csv
      predictions/
        predictions_test.csv
      results/
        benchmark_summary.csv
        benchmark_summary.json
        benchmark_multiseed_summary.csv
        benchmark_multiseed_summary.json
        benchmark_ablation_summary.csv
        benchmark_weights_analysis.json

12. Extending the System

Add a new embedding

  1. Add config entries in config/embeddings.yaml.
  2. If a new source type is required, implement loader logic in core/embedding_loading.py.
  3. Ensure outputs map accession -> vector and integrate into EmbeddingService.load_embeddings().

Add a new classifier

  1. Register constructor logic in core/training/model_factory.py.
  2. Add sweep config under config/model_sweep/.
  3. Register default sweep path in Pipeline._resolve_classifier_sweeps.

Add a new ensemble strategy

  1. Implement a new WeightingStrategy subclass in core/ensemble/soft_voting_service.py.
  2. Extend strategy factory wiring (create_weighting_strategy and config mapping path).
  3. Add benchmark variant specification if comparative evaluation is needed.

Add a new split strategy

  1. Implement strategy under data/splits/.
  2. Wire it through Pipeline._build_split_strategy.
  3. Add configuration schema in config/pipeline.yaml and tests for overlap/coverage guarantees.

13. Scientific Use Case

This framework is designed to support paper-grade experimentation:

  • Comparative evaluation
    • Standardized scoring across multiple embeddings and classifiers.
  • Ablation studies
    • Benchmark ablations via selection subsets (embeddings / classifiers).
  • Multi-seed experiments
    • Seed-level runs with aggregated mean/std reporting.
  • Robust reporting
    • Explicit separation of validation, test, and zero-shot roles.
    • Leakage checks and reproducibility metadata integrated into the run artifacts.

14. Statistical Model Comparison

This section documents the formal statistical comparison used for paper-quality benchmarking.

Why multiple seeds are used

  • Single-seed results can be unstable due to split randomness and model stochasticity.
  • Multi-seed evaluation estimates performance variability and supports inferential statistics.
  • --step global_benchmark runs repeated experiments and stores per-seed benchmark outputs under:
    • results/global_benchmark/executions/run_seed_<seed>/

Global benchmark workflow

dataset
  ↓
sweep
  ↓
ensemble
  ↓
benchmark
  ↓
global_benchmark (multi-seed)
  ↓
statistical analysis

Statistical artifacts

After multi-seed aggregation, the pipeline runs statistical analysis automatically and writes:

  • results/global_benchmark/aggregated/model_embedding_benchmark.csv
  • results/global_benchmark/aggregated/ensemble_strategy_benchmark.csv
  • results/global_benchmark/aggregated/ranking_tables.csv
  • results/global_benchmark/predictions/model_predictions/seed_<seed>/*.csv
  • results/global_benchmark/predictions/ensemble_predictions/seed_<seed>/*.csv
  • results/global_benchmark/metadata/experiment_seeds.json
  • results/global_benchmark/metadata/experiment_config_snapshot.yaml
  • results/global_benchmark/statistics/friedman_results.json
  • results/global_benchmark/statistics/nemenyi_results.csv
  • results/global_benchmark/statistics/model_rankings.csv
  • results/global_benchmark/statistics/critical_difference_diagram.png (optional, when plotting dependencies are available)

What the Friedman test evaluates

  • Input: a complete matrix with rows=seeds and columns=model configurations.
  • Metric: benchmark test F1 per configuration per seed.
  • Null hypothesis: all compared configurations have equal performance distributions.

Interpretation example:

If the Friedman test p-value < 0.05, we reject the null hypothesis that all models perform equally. We then apply the Nemenyi post-hoc test to determine which model pairs differ significantly.

What the Nemenyi post-hoc test evaluates

  • Executed when Friedman is significant (p < 0.05).
  • Performs pairwise comparisons between configurations.
  • Reports pairwise p-values and critical-difference-based significance.

Why average rank is used

  • Statistical comparison is based on ranks per seed, not only raw metric magnitudes.
  • This reduces sensitivity to absolute scale shifts between runs.
  • Lower average rank indicates better overall performance across seeds.

Rank and significance interpretation

  • avg_rank in model_rankings.csv:
    • lower = better overall configuration.
  • significant in nemenyi_results.csv:
    • True means the rank difference exceeds the Nemenyi critical difference at the configured alpha.
  • friedman_results.json:
    • p_value < 0.05 indicates at least one configuration differs significantly from the others.

Notes on Current Implementation Scope

  • Sequence and GO embeddings are actively loaded by EmbeddingService.
  • structurePE appears in configuration but is not currently integrated into the active embedding loading path.
  • Benchmark and ensemble steps operate on persisted sweep/final-training artifacts from the latest run directory.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages