Skip to content

PyTorch implementation of Graph Matching Networks, adapted specifically for Linguistic Trees, e.g., Graph Matching with Bi-level Noisy Correspondence (COMMON, ICCV 2023), Graph Matching Networks for Learning the Similarity of Graph Structured Objects (GMN, ICML 2019).

License

Notifications You must be signed in to change notification settings

jlunder00/Tree-Matching-Networks

Β 
Β 

Repository files navigation

Tree Matching Networks for NLI

LICENSE

This repository contains an adaptation of Graph Matching Networks (GMN) for linguistic dependency trees, focused on natural language inference (NLI) and semantic similarity tasks. The idea is to represent sentences as dependency trees to capture structural information, then apply graph neural network techniques to learn relationships between sentence pairs.

Overview

This project extends Graph Matching Networks to operate on linguistic trees and includes two main components:

  1. TMN_DataGen: A package for generating and processing dependency trees from raw text
  2. Tree-Matching-Networks: The model implementation for training and inference

Project Structure

.
β”œβ”€β”€ GMN/                   # Original Graph Matching Networks code
β”œβ”€β”€ LinguisticTrees/       # My tree adaptations and training code
β”‚   β”œβ”€β”€ configs/           # Configuration files
β”‚   β”œβ”€β”€ data/              # Data loading and processing
β”‚   β”œβ”€β”€ models/            # Model architecture
β”‚   β”œβ”€β”€ training/          # Training and evaluation code
β”‚   └── experiments/       # Training and evaluation scripts
└── scripts/               # Demo and utility scripts

Installation

  1. First, install TMN_DataGen:

    git clone https://github.com/jlunder00/TMN_DataGen.git
    cd TMN_DataGen
    pip install .
  2. Then, install this repository:

    git clone https://github.com/jlunder00/Tree-Matching-Networks.git
    cd Tree-Matching-Networks
    pip install .

Required External Resources

Before using the models, you'll need:

  1. SpaCy Model: For dependency parsing.

    python -m spacy download en_core_web_sm # or en_core_web_lg/md/trf
  2. Word2Vec Vocabulary: For word boundary correction.

  3. Embedding Cache: Create directory for caching word embeddings:

    mkdir -p /path/to/embedding_cache
    • Set path in configuration files

Quick Start

Running the Demo

Try out the model with the demo script:

python -m Tree_Matching_Networks.scripts.demo \
  --checkpoint /path/to/best_entailment_model_checkpoint/checkpoints/best_model.pt \
  --config /path/to/custom/config.yaml  \
  --input input.tsv \
  --spacy_model en_core_web_sm

Note that occasionally the provided config that comes with a checkpoint may not work in the demo script.
Providing a config override to an appropriately configured custom config or one such config from Tree_Matching_Networks/LinguisticTrees/configs/experiment_configs/ can resolve this issue.

See Demo Instructions for more details.

Data Processing

To process your own data, use TMN_DataGen:

python -m TMN_DataGen.run process \
  --input_path your_data.jsonl \
  --out_dir processed_data/your_dataset \
  --dataset_type snli \
  --spacy_model en_core_web_sm

See TMN_DataGen README for more details.

Training

Train a model on processed data:

python -m Tree_Matching_Networks.LinguisticTrees.experiments.train_aggregative \
  --config Tree_Matching_Networks/LinguisticTrees/configs/experiment_configs/aggregative_config.yaml

See LinguisticTrees README for more configuration options.

Evaluation

Evaluate a trained model:

python -m Tree_Matching_Networks.LinguisticTrees.experiments.eval_aggregated \
  --checkpoint /path/to/checkpoint \
  --output_dir evaluation_results

Key Features

  • Tree-Based Representation: Leverages dependency trees to capture sentence structure
  • Cross-Graph Attention: Compares sentences using graph matching techniques
  • Flexible Model Configuration: Supports different tasks and training approaches
  • Contrastive Learning: Pretrain on large datasets for better transfer
  • Multiple NLP Tasks: Supports entailment, similarity, and binary classification

Model Architecture

My approach adapts Graph Matching Networks to work with linguistic trees:

  1. Text Processing: Convert sentences to dependency trees using SpaCy/DiaParser
  2. Feature Extraction: Embed words and dependency relations
  3. Graph Propagation: Use message passing to capture tree structure
  4. Graph Matching: Apply cross-graph attention to compare tree pairs
  5. Aggregation: Pool sentence trees into text-level representations

License

This project is MIT licensed, as found in the LICENSE file.

Acknowledgments

This project builds upon:

Yujia Li, Chenjie Gu, Thomas Dullien, Oriol Vinyals, Pushmeet Kohli. Graph Matching Networks for Learning the Similarity of Graph Structured Objects. ICML 2019. [paper]

Yijie Lin, Mouxing Yang, Jun Yu, Peng Hu, Changqing Zhang, Xi Peng. Graph Matching with Bi-level Noisy Correspondence. ICCV, 2023. [paper]

About

PyTorch implementation of Graph Matching Networks, adapted specifically for Linguistic Trees, e.g., Graph Matching with Bi-level Noisy Correspondence (COMMON, ICCV 2023), Graph Matching Networks for Learning the Similarity of Graph Structured Objects (GMN, ICML 2019).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.9%
  • C++ 1.5%
  • Other 0.6%