From a49a650949c07fd87cdb33eca0f15a2a6ad10c56 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 17:04:46 +0000 Subject: [PATCH 1/3] Add graph-store PPR E2E wiring --- Makefile | 8 ++ .../e2e_hom_cora_sup_gs_ppr_task_config.yaml | 75 +++++++++++++++++++ .../graph_store/homogeneous_inference.py | 58 +++++++------- .../graph_store/homogeneous_training.py | 29 +++++-- gigl/utils/sampling.py | 41 ++++++++++ tests/e2e_tests/e2e_tests.yaml | 3 + 6 files changed, 182 insertions(+), 32 deletions(-) create mode 100644 examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml diff --git a/Makefile b/Makefile index 93ab75ffc..dab742500 100644 --- a/Makefile +++ b/Makefile @@ -270,6 +270,14 @@ run_hom_cora_sup_gs_e2e_test: --test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \ --test_names="hom_cora_sup_gs_test" +run_hom_cora_sup_gs_ppr_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} +run_hom_cora_sup_gs_ppr_e2e_test: compile_gigl_kubeflow_pipeline +run_hom_cora_sup_gs_ppr_e2e_test: + uv run python tests/e2e_tests/e2e_test.py \ + --compiled_pipeline_path=$(compiled_pipeline_path) \ + --test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \ + --test_names="hom_cora_sup_gs_ppr_test" + run_het_dblp_sup_gs_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} run_het_dblp_sup_gs_e2e_test: compile_gigl_kubeflow_pipeline run_het_dblp_sup_gs_e2e_test: diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml new file mode 100644 index 000000000..46c508819 --- /dev/null +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml @@ -0,0 +1,75 @@ +# This config runs homogeneous CORA supervised training and inference in Graph Store mode +# with PPR sampling. It intentionally reuses the standard graph-store training/inference +# entrypoints, changing only the sampler args and keeping the loop short for E2E coverage. +graphMetadata: + edgeTypes: + - dstNodeType: paper + relation: cites + srcNodeType: paper + nodeTypes: + - paper +datasetConfig: + dataPreprocessorConfig: + dataPreprocessorConfigClsPath: gigl.src.mocking.mocking_assets.passthrough_preprocessor_config_for_mocked_assets.PassthroughPreprocessorConfigForMockedAssets + dataPreprocessorArgs: + mocked_dataset_name: 'cora_homogeneous_node_anchor_edge_features_user_defined_labels' +trainerConfig: + trainerArgs: + log_every_n_batch: "1" + num_neighbors: "[10, 10]" + sampler_type: "ppr" + ppr_alpha: "0.5" + ppr_eps: "0.0001" + ppr_max_nodes: "20" + ppr_neighbors_per_hop: "100" + ppr_max_fetch_iterations: "2" + sampling_workers_per_process: "2" + main_batch_size: "8" + random_batch_size: "8" + num_max_train_batches: "4" + num_val_batches: "4" + val_every_n_batch: "1" + command: python -m examples.link_prediction.graph_store.homogeneous_training + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" + splitter_kwargs: >- + { + "sampling_direction": "in", + "should_convert_labels_to_edges": True, + "num_val": 0.25, + "num_test": 0.25 + } + num_server_sessions: "1" +inferencerConfig: + inferencerArgs: + log_every_n_batch: "1" + num_neighbors: "[10, 10]" + sampler_type: "ppr" + ppr_alpha: "0.5" + ppr_eps: "0.0001" + ppr_max_nodes: "20" + ppr_neighbors_per_hop: "100" + ppr_max_fetch_iterations: "2" + sampling_workers_per_inference_process: "2" + inferenceBatchSize: 256 + command: python -m examples.link_prediction.graph_store.homogeneous_inference + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + num_server_sessions: "1" +sharedConfig: + shouldSkipInference: false + shouldSkipModelEvaluation: true +taskMetadata: + nodeAnchorBasedLinkPredictionTaskMetadata: + supervisionEdgeTypes: + - dstNodeType: paper + relation: cites + srcNodeType: paper +featureFlags: + should_run_glt_backend: 'True' + data_preprocessor_num_shards: '2' diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index 34bc2672e..5faa84b72 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -87,7 +87,7 @@ import sys import time from dataclasses import dataclass -from typing import Union +from typing import Optional, Union import torch import torch.multiprocessing as mp @@ -101,6 +101,7 @@ from gigl.common.utils.gcs import GcsUtils from gigl.distributed.graph_store.compute import init_compute_process from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils import get_graph_store_info from gigl.env.distributed import GraphStoreInfo from gigl.nn import LinkPredictionGNN @@ -110,16 +111,10 @@ from gigl.src.common.utils.bq import BqUtils from gigl.src.common.utils.model import load_state_dict_from_uri from gigl.src.inference.lib.assets import InferenceAssets -from gigl.utils.sampling import parse_fanout +from gigl.utils.sampling import parse_fanout, parse_sampler_options logger = Logger() -# Default number of inference processes per machine incase one isnt provided in inference args -# i.e. `local_world_size` is not provided, and we can't infer automatically. -# If there are GPUs attached to the machine, we automatically infer to setting -# LOCAL_WORLD_SIZE == # of gpus on the machine. -DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE = 4 - @dataclass(frozen=True) class InferenceProcessArgs: @@ -143,6 +138,7 @@ class InferenceProcessArgs: inference_batch_size (int): Batch size to use for inference. num_neighbors (Union[list[int], dict[EdgeType, list[int]]]): Fanout for subgraph sampling, where the ith item corresponds to the number of items to sample for the ith hop. + sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling. sampling_workers_per_inference_process (int): Number of sampling workers per inference process. sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for @@ -169,6 +165,7 @@ class InferenceProcessArgs: # Inference configuration inference_batch_size: int num_neighbors: Union[list[int], dict[EdgeType, list[int]]] + sampler_options: Optional[SamplerOptions] sampling_workers_per_inference_process: int sampling_worker_shared_channel_size: str log_every_n_batch: int @@ -242,6 +239,7 @@ def _inference_process( # For large-scale settings, consider setting this field to 30-60 seconds to ensure dataloaders # don't compete for memory during initialization, causing OOM process_start_gap_seconds=0, + sampler_options=args.sampler_options, ) # Initialize a LinkPredictionGNN model and load parameters from # the saved model. @@ -455,25 +453,23 @@ def _run_example_inference( if arg_local_world_size is not None: local_world_size = int(arg_local_world_size) logger.info(f"Using local_world_size from inferencer_args: {local_world_size}") - if torch.cuda.is_available() and local_world_size != torch.cuda.device_count(): - logger.warning( - f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. " - "This may lead to unexpected failures with NCCL communication incase GPUs are being used for " - + "training/inference. Consider setting local_world_size to the number of GPUs." - ) else: - if torch.cuda.is_available() and torch.cuda.device_count() > 0: - # If GPUs are available, we set the local_world_size to the number of GPUs - local_world_size = torch.cuda.device_count() - logger.info( - f"Detected {local_world_size} GPUs. Thus, setting local_world_size to {local_world_size}" - ) - else: - # If no GPUs are available, we set the local_world_size to the number of inference processes per machine - logger.info( - f"No GPUs detected. Thus, setting local_world_size to `{DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE}`" - ) - local_world_size = DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE + local_world_size = cluster_info.num_processes_per_compute + logger.info( + f"Using local_world_size from cluster_info.num_processes_per_compute: {local_world_size}" + ) + if local_world_size != cluster_info.num_processes_per_compute: + raise ValueError( + f"Graph Store local_world_size={local_world_size} must match " + f"cluster_info.num_processes_per_compute=" + f"{cluster_info.num_processes_per_compute}" + ) + if torch.cuda.is_available() and local_world_size != torch.cuda.device_count(): + logger.warning( + f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. " + "This may lead to unexpected failures with NCCL communication incase GPUs are being used for " + + "training/inference. Consider setting local_world_size to the number of GPUs." + ) if cluster_info.compute_node_rank == 0: gcs_utils = GcsUtils() @@ -494,6 +490,7 @@ def _run_example_inference( # Parses the fanout as a string. For the homogeneous case, the fanouts should be specified # as a string of a list of integers, such as "[10, 10]". num_neighbors = parse_fanout(inferencer_args.get("num_neighbors", "[10, 10]")) + sampler_options = parse_sampler_options(inferencer_args) # While the ideal value for `sampling_workers_per_inference_process` has been identified to be # between `2` and `4`, this may need some tuning depending on the pipeline. We default this @@ -516,6 +513,14 @@ def _run_example_inference( log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "50")) + logger.info( + f"Got inference args local_world_size={local_world_size}, " + f"num_neighbors={num_neighbors}, sampler_options={sampler_options}, " + f"sampling_workers_per_inference_process={sampling_workers_per_inference_process}, " + f"sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, " + f"log_every_n_batch={log_every_n_batch}" + ) + # When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine. inference_args = InferenceProcessArgs( local_world_size=local_world_size, @@ -528,6 +533,7 @@ def _run_example_inference( edge_feature_dim=edge_feature_dim, inference_batch_size=inference_batch_size, num_neighbors=num_neighbors, + sampler_options=sampler_options, sampling_workers_per_inference_process=sampling_workers_per_inference_process, sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, log_every_n_batch=log_every_n_batch, diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 04340f99a..c7ae356cc 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -143,6 +143,7 @@ shutdown_compute_process, ) from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils import get_available_device, get_graph_store_info from gigl.env.distributed import GraphStoreInfo from gigl.nn import LinkPredictionGNN, RetrievalLoss @@ -158,7 +159,7 @@ from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict from gigl.utils.iterator import InfiniteIterator -from gigl.utils.sampling import parse_fanout +from gigl.utils.sampling import parse_fanout, parse_sampler_options logger = Logger() @@ -191,6 +192,7 @@ def _setup_dataloaders( split: Literal["train", "val", "test"], cluster_info: GraphStoreInfo, num_neighbors: list[int] | dict[EdgeType, list[int]], + sampler_options: Optional[SamplerOptions], sampling_workers_per_process: int, main_batch_size: int, random_batch_size: int, @@ -205,6 +207,7 @@ def _setup_dataloaders( split (Literal["train", "val", "test"]): The current split which we are loading data for. cluster_info (GraphStoreInfo): Cluster topology info for graph store mode. num_neighbors: Fanout for subgraph sampling. + sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling. sampling_workers_per_process (int): Number of sampling workers per training/testing process. main_batch_size (int): Batch size for main dataloader with query and labeled nodes. random_batch_size (int): Batch size for random negative dataloader. @@ -240,6 +243,7 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + sampler_options=sampler_options, ) logger.info(f"---Rank {rank} finished setting up main loader for split={split}") @@ -266,6 +270,7 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + sampler_options=sampler_options, ) logger.info( @@ -375,6 +380,7 @@ class TrainingProcessArgs: sampling_workers_per_process (int): Number of sampling workers per training/testing process. sampling_worker_shared_channel_size (str): Shared-memory buffer size for the channel during sampling. process_start_gap_seconds (int): Time to sleep between dataloader initializations. + sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling. main_batch_size (int): Batch size for main dataloader. random_batch_size (int): Batch size for random negative dataloader. learning_rate (float): Learning rate for the optimizer. @@ -400,6 +406,7 @@ class TrainingProcessArgs: # Sampling config num_neighbors: list[int] | dict[EdgeType, list[int]] + sampler_options: Optional[SamplerOptions] sampling_workers_per_process: int sampling_worker_shared_channel_size: str process_start_gap_seconds: int @@ -463,6 +470,7 @@ def _training_process( split="train", cluster_info=args.cluster_info, num_neighbors=args.num_neighbors, + sampler_options=args.sampler_options, sampling_workers_per_process=args.sampling_workers_per_process, main_batch_size=args.main_batch_size, random_batch_size=args.random_batch_size, @@ -481,6 +489,7 @@ def _training_process( split="val", cluster_info=args.cluster_info, num_neighbors=args.num_neighbors, + sampler_options=args.sampler_options, sampling_workers_per_process=args.sampling_workers_per_process, main_batch_size=args.main_batch_size, random_batch_size=args.random_batch_size, @@ -637,6 +646,7 @@ def _training_process( split="test", cluster_info=args.cluster_info, num_neighbors=args.num_neighbors, + sampler_options=args.sampler_options, sampling_workers_per_process=args.sampling_workers_per_process, main_batch_size=args.main_batch_size, random_batch_size=args.random_batch_size, @@ -837,13 +847,17 @@ def _run_example_training( # Training Hyperparameters trainer_args = dict(gbml_config_pb_wrapper.trainer_config.trainer_args) - if torch.cuda.is_available(): - default_local_world_size = torch.cuda.device_count() - else: - default_local_world_size = 2 local_world_size = int( - trainer_args.get("local_world_size", str(default_local_world_size)) + trainer_args.get( + "local_world_size", str(cluster_info.num_processes_per_compute) + ) ) + if local_world_size != cluster_info.num_processes_per_compute: + raise ValueError( + f"Graph Store local_world_size={local_world_size} must match " + f"cluster_info.num_processes_per_compute=" + f"{cluster_info.num_processes_per_compute}" + ) if torch.cuda.is_available(): if local_world_size > torch.cuda.device_count(): @@ -853,6 +867,7 @@ def _run_example_training( fanout = trainer_args.get("num_neighbors", "[10, 10]") num_neighbors = parse_fanout(fanout) + sampler_options = parse_sampler_options(trainer_args) sampling_workers_per_process: int = int( trainer_args.get("sampling_workers_per_process", "4") @@ -880,6 +895,7 @@ def _run_example_training( logger.info( f"Got training args local_world_size={local_world_size}, \ num_neighbors={num_neighbors}, \ + sampler_options={sampler_options}, \ sampling_workers_per_process={sampling_workers_per_process}, \ main_batch_size={main_batch_size}, \ random_batch_size={random_batch_size}, \ @@ -931,6 +947,7 @@ def _run_example_training( node_feature_dim=node_feature_dim, edge_feature_dim=edge_feature_dim, num_neighbors=num_neighbors, + sampler_options=sampler_options, sampling_workers_per_process=sampling_workers_per_process, sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, diff --git a/gigl/utils/sampling.py b/gigl/utils/sampling.py index 5d0ed6a44..e2c6996e5 100644 --- a/gigl/utils/sampling.py +++ b/gigl/utils/sampling.py @@ -1,10 +1,12 @@ import ast +from collections.abc import Mapping from dataclasses import dataclass from typing import Any, Optional, Union import torch from gigl.common.logger import Logger +from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.src.common.types.graph_data import EdgeType, NodeType logger = Logger() @@ -88,6 +90,45 @@ def parse_fanout(fanout_str: str) -> Union[list[int], dict[EdgeType, list[int]]] ) +def _parse_optional_int(value: Optional[str]) -> Optional[int]: + if value is None: + return None + normalized = value.strip().lower() + if normalized in {"", "none", "null"}: + return None + return int(value) + + +def parse_sampler_options(args: Mapping[str, str]) -> Optional[SamplerOptions]: + sampler_type = args.get("sampler_type", "khop").strip().lower().replace("-", "_") + if sampler_type == "": + sampler_type = "khop" + + if sampler_type in {"khop", "k_hop", "neighbor", "neighbor_sampler"}: + return None + + if sampler_type != "ppr": + raise ValueError( + f"Unsupported sampler_type={sampler_type}. Expected one of: khop, ppr." + ) + + max_ppr_nodes = args.get("ppr_max_nodes") + if max_ppr_nodes is None: + max_ppr_nodes = args.get("ppr_max_ppr_nodes", "50") + + num_neighbors_per_hop = args.get("ppr_neighbors_per_hop") + if num_neighbors_per_hop is None: + num_neighbors_per_hop = args.get("ppr_num_neighbors_per_hop", "1000") + + return PPRSamplerOptions( + alpha=float(args.get("ppr_alpha", "0.5")), + eps=float(args.get("ppr_eps", "0.0001")), + max_ppr_nodes=int(max_ppr_nodes), + num_neighbors_per_hop=int(num_neighbors_per_hop), + max_fetch_iterations=_parse_optional_int(args.get("ppr_max_fetch_iterations")), + ) + + @dataclass(frozen=True) class ABLPInputNodes: """Represents ABLP (Anchor Based Link Prediction) input for a single storage server. diff --git a/tests/e2e_tests/e2e_tests.yaml b/tests/e2e_tests/e2e_tests.yaml index 61fc4f311..6d09d8213 100644 --- a/tests/e2e_tests/e2e_tests.yaml +++ b/tests/e2e_tests/e2e_tests.yaml @@ -22,6 +22,9 @@ tests: hom_cora_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" + hom_cora_sup_gs_ppr_test: + task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml" + resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" het_dblp_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" From 2f35f22983baf28aa2c13c0213604f6feb6b6b1e Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 20:46:38 +0000 Subject: [PATCH 2/3] Configure graph-store PPR sampler options inline --- .../e2e_hom_cora_sup_gs_ppr_task_config.yaml | 28 +++++++------ .../graph_store/homogeneous_inference.py | 10 +++-- .../graph_store/homogeneous_training.py | 10 +++-- gigl/utils/sampling.py | 41 ------------------- 4 files changed, 30 insertions(+), 59 deletions(-) diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml index 46c508819..ad1ab8a4a 100644 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml @@ -17,12 +17,14 @@ trainerConfig: trainerArgs: log_every_n_batch: "1" num_neighbors: "[10, 10]" - sampler_type: "ppr" - ppr_alpha: "0.5" - ppr_eps: "0.0001" - ppr_max_nodes: "20" - ppr_neighbors_per_hop: "100" - ppr_max_fetch_iterations: "2" + ppr_sampler_options: >- + { + "alpha": 0.5, + "eps": 0.0001, + "max_ppr_nodes": 20, + "num_neighbors_per_hop": 100, + "max_fetch_iterations": 2 + } sampling_workers_per_process: "2" main_batch_size: "8" random_batch_size: "8" @@ -47,12 +49,14 @@ inferencerConfig: inferencerArgs: log_every_n_batch: "1" num_neighbors: "[10, 10]" - sampler_type: "ppr" - ppr_alpha: "0.5" - ppr_eps: "0.0001" - ppr_max_nodes: "20" - ppr_neighbors_per_hop: "100" - ppr_max_fetch_iterations: "2" + ppr_sampler_options: >- + { + "alpha": 0.5, + "eps": 0.0001, + "max_ppr_nodes": 20, + "num_neighbors_per_hop": 100, + "max_fetch_iterations": 2 + } sampling_workers_per_inference_process: "2" inferenceBatchSize: 256 command: python -m examples.link_prediction.graph_store.homogeneous_inference diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index 5faa84b72..16333c75f 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -83,6 +83,7 @@ import argparse import gc +import json import os import sys import time @@ -101,7 +102,7 @@ from gigl.common.utils.gcs import GcsUtils from gigl.distributed.graph_store.compute import init_compute_process from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset -from gigl.distributed.sampler_options import SamplerOptions +from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.distributed.utils import get_graph_store_info from gigl.env.distributed import GraphStoreInfo from gigl.nn import LinkPredictionGNN @@ -111,7 +112,7 @@ from gigl.src.common.utils.bq import BqUtils from gigl.src.common.utils.model import load_state_dict_from_uri from gigl.src.inference.lib.assets import InferenceAssets -from gigl.utils.sampling import parse_fanout, parse_sampler_options +from gigl.utils.sampling import parse_fanout logger = Logger() @@ -490,7 +491,10 @@ def _run_example_inference( # Parses the fanout as a string. For the homogeneous case, the fanouts should be specified # as a string of a list of integers, such as "[10, 10]". num_neighbors = parse_fanout(inferencer_args.get("num_neighbors", "[10, 10]")) - sampler_options = parse_sampler_options(inferencer_args) + sampler_options: Optional[SamplerOptions] = None + sampler_options_args = inferencer_args.get("ppr_sampler_options") + if sampler_options_args is not None and sampler_options_args.strip(): + sampler_options = PPRSamplerOptions(**json.loads(sampler_options_args)) # While the ideal value for `sampling_workers_per_inference_process` has been identified to be # between `2` and `4`, this may need some tuning depending on the pipeline. We default this diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index c7ae356cc..cdfd6d93b 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -119,6 +119,7 @@ import argparse import gc +import json import os import statistics import sys @@ -143,7 +144,7 @@ shutdown_compute_process, ) from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset -from gigl.distributed.sampler_options import SamplerOptions +from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.distributed.utils import get_available_device, get_graph_store_info from gigl.env.distributed import GraphStoreInfo from gigl.nn import LinkPredictionGNN, RetrievalLoss @@ -159,7 +160,7 @@ from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict from gigl.utils.iterator import InfiniteIterator -from gigl.utils.sampling import parse_fanout, parse_sampler_options +from gigl.utils.sampling import parse_fanout logger = Logger() @@ -867,7 +868,10 @@ def _run_example_training( fanout = trainer_args.get("num_neighbors", "[10, 10]") num_neighbors = parse_fanout(fanout) - sampler_options = parse_sampler_options(trainer_args) + sampler_options: Optional[SamplerOptions] = None + sampler_options_args = trainer_args.get("ppr_sampler_options") + if sampler_options_args is not None and sampler_options_args.strip(): + sampler_options = PPRSamplerOptions(**json.loads(sampler_options_args)) sampling_workers_per_process: int = int( trainer_args.get("sampling_workers_per_process", "4") diff --git a/gigl/utils/sampling.py b/gigl/utils/sampling.py index e2c6996e5..5d0ed6a44 100644 --- a/gigl/utils/sampling.py +++ b/gigl/utils/sampling.py @@ -1,12 +1,10 @@ import ast -from collections.abc import Mapping from dataclasses import dataclass from typing import Any, Optional, Union import torch from gigl.common.logger import Logger -from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.src.common.types.graph_data import EdgeType, NodeType logger = Logger() @@ -90,45 +88,6 @@ def parse_fanout(fanout_str: str) -> Union[list[int], dict[EdgeType, list[int]]] ) -def _parse_optional_int(value: Optional[str]) -> Optional[int]: - if value is None: - return None - normalized = value.strip().lower() - if normalized in {"", "none", "null"}: - return None - return int(value) - - -def parse_sampler_options(args: Mapping[str, str]) -> Optional[SamplerOptions]: - sampler_type = args.get("sampler_type", "khop").strip().lower().replace("-", "_") - if sampler_type == "": - sampler_type = "khop" - - if sampler_type in {"khop", "k_hop", "neighbor", "neighbor_sampler"}: - return None - - if sampler_type != "ppr": - raise ValueError( - f"Unsupported sampler_type={sampler_type}. Expected one of: khop, ppr." - ) - - max_ppr_nodes = args.get("ppr_max_nodes") - if max_ppr_nodes is None: - max_ppr_nodes = args.get("ppr_max_ppr_nodes", "50") - - num_neighbors_per_hop = args.get("ppr_neighbors_per_hop") - if num_neighbors_per_hop is None: - num_neighbors_per_hop = args.get("ppr_num_neighbors_per_hop", "1000") - - return PPRSamplerOptions( - alpha=float(args.get("ppr_alpha", "0.5")), - eps=float(args.get("ppr_eps", "0.0001")), - max_ppr_nodes=int(max_ppr_nodes), - num_neighbors_per_hop=int(num_neighbors_per_hop), - max_fetch_iterations=_parse_optional_int(args.get("ppr_max_fetch_iterations")), - ) - - @dataclass(frozen=True) class ABLPInputNodes: """Represents ABLP (Anchor Based Link Prediction) input for a single storage server. From 188525fb745d310ba3975f9480e7b658b7fcc069 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 29 May 2026 21:44:34 +0000 Subject: [PATCH 3/3] Clarify graph-store PPR sampler args --- .../configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml | 8 ++++++++ .../link_prediction/graph_store/homogeneous_inference.py | 4 ++-- .../link_prediction/graph_store/homogeneous_training.py | 4 ++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml index ad1ab8a4a..1e440dc7f 100644 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_ppr_task_config.yaml @@ -17,6 +17,10 @@ trainerConfig: trainerArgs: log_every_n_batch: "1" num_neighbors: "[10, 10]" + # Parsed in the graph-store training entrypoint and passed directly as + # kwargs to PPRSamplerOptions in gigl/distributed/sampler_options.py. + # Presence of ppr_sampler_options selects PPR; otherwise this example uses + # k-hop sampling configured by num_neighbors. ppr_sampler_options: >- { "alpha": 0.5, @@ -49,6 +53,10 @@ inferencerConfig: inferencerArgs: log_every_n_batch: "1" num_neighbors: "[10, 10]" + # Parsed in the graph-store inference entrypoint and passed directly as + # kwargs to PPRSamplerOptions in gigl/distributed/sampler_options.py. + # Presence of ppr_sampler_options selects PPR; otherwise this example uses + # k-hop sampling configured by num_neighbors. ppr_sampler_options: >- { "alpha": 0.5, diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index 16333c75f..26dbad8e9 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -82,8 +82,8 @@ """ import argparse +import ast import gc -import json import os import sys import time @@ -494,7 +494,7 @@ def _run_example_inference( sampler_options: Optional[SamplerOptions] = None sampler_options_args = inferencer_args.get("ppr_sampler_options") if sampler_options_args is not None and sampler_options_args.strip(): - sampler_options = PPRSamplerOptions(**json.loads(sampler_options_args)) + sampler_options = PPRSamplerOptions(**ast.literal_eval(sampler_options_args)) # While the ideal value for `sampling_workers_per_inference_process` has been identified to be # between `2` and `4`, this may need some tuning depending on the pipeline. We default this diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index cdfd6d93b..8f601399f 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -118,8 +118,8 @@ """ import argparse +import ast import gc -import json import os import statistics import sys @@ -871,7 +871,7 @@ def _run_example_training( sampler_options: Optional[SamplerOptions] = None sampler_options_args = trainer_args.get("ppr_sampler_options") if sampler_options_args is not None and sampler_options_args.strip(): - sampler_options = PPRSamplerOptions(**json.loads(sampler_options_args)) + sampler_options = PPRSamplerOptions(**ast.literal_eval(sampler_options_args)) sampling_workers_per_process: int = int( trainer_args.get("sampling_workers_per_process", "4")