-
Notifications
You must be signed in to change notification settings - Fork 16
Add graph-store PPR E2E wiring #655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: mkolodner-sc/ppr_gs_memory
Are you sure you want to change the base?
Changes from all commits
a49a650
8c1dd36
851ed8b
ab6aecd
ee5806b
98bb3f9
f0e3275
a24e32a
2f35f22
188525f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| # This config runs homogeneous CORA supervised training and inference in Graph Store mode | ||
| # with PPR sampling. It intentionally reuses the standard graph-store training/inference | ||
| # entrypoints, changing only the sampler args and keeping the loop short for E2E coverage. | ||
| graphMetadata: | ||
| edgeTypes: | ||
| - dstNodeType: paper | ||
| relation: cites | ||
| srcNodeType: paper | ||
| nodeTypes: | ||
| - paper | ||
| datasetConfig: | ||
| dataPreprocessorConfig: | ||
| dataPreprocessorConfigClsPath: gigl.src.mocking.mocking_assets.passthrough_preprocessor_config_for_mocked_assets.PassthroughPreprocessorConfigForMockedAssets | ||
| dataPreprocessorArgs: | ||
| mocked_dataset_name: 'cora_homogeneous_node_anchor_edge_features_user_defined_labels' | ||
| trainerConfig: | ||
| trainerArgs: | ||
| log_every_n_batch: "1" | ||
| num_neighbors: "[10, 10]" | ||
| # Parsed in the graph-store training entrypoint and passed directly as | ||
| # kwargs to PPRSamplerOptions in gigl/distributed/sampler_options.py. | ||
| # Presence of ppr_sampler_options selects PPR; otherwise this example uses | ||
| # k-hop sampling configured by num_neighbors. | ||
| ppr_sampler_options: >- | ||
| { | ||
| "alpha": 0.5, | ||
| "eps": 0.0001, | ||
| "max_ppr_nodes": 20, | ||
| "num_neighbors_per_hop": 100, | ||
| "max_fetch_iterations": 2 | ||
| } | ||
| sampling_workers_per_process: "2" | ||
| main_batch_size: "8" | ||
| random_batch_size: "8" | ||
| num_max_train_batches: "4" | ||
| num_val_batches: "4" | ||
| val_every_n_batch: "1" | ||
| command: python -m examples.link_prediction.graph_store.homogeneous_training | ||
| graphStoreStorageConfig: | ||
| command: python -m examples.link_prediction.graph_store.storage_main | ||
| storageArgs: | ||
| sample_edge_direction: "in" | ||
| splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" | ||
| splitter_kwargs: >- | ||
| { | ||
| "sampling_direction": "in", | ||
| "should_convert_labels_to_edges": True, | ||
| "num_val": 0.25, | ||
| "num_test": 0.25 | ||
| } | ||
| num_server_sessions: "1" | ||
| inferencerConfig: | ||
| inferencerArgs: | ||
| log_every_n_batch: "1" | ||
| num_neighbors: "[10, 10]" | ||
| # Parsed in the graph-store inference entrypoint and passed directly as | ||
| # kwargs to PPRSamplerOptions in gigl/distributed/sampler_options.py. | ||
| # Presence of ppr_sampler_options selects PPR; otherwise this example uses | ||
| # k-hop sampling configured by num_neighbors. | ||
| ppr_sampler_options: >- | ||
| { | ||
| "alpha": 0.5, | ||
| "eps": 0.0001, | ||
| "max_ppr_nodes": 20, | ||
| "num_neighbors_per_hop": 100, | ||
| "max_fetch_iterations": 2 | ||
| } | ||
| sampling_workers_per_inference_process: "2" | ||
| inferenceBatchSize: 256 | ||
| command: python -m examples.link_prediction.graph_store.homogeneous_inference | ||
| graphStoreStorageConfig: | ||
| command: python -m examples.link_prediction.graph_store.storage_main | ||
| storageArgs: | ||
| sample_edge_direction: "in" | ||
| num_server_sessions: "1" | ||
| sharedConfig: | ||
| shouldSkipInference: false | ||
| shouldSkipModelEvaluation: true | ||
| taskMetadata: | ||
| nodeAnchorBasedLinkPredictionTaskMetadata: | ||
| supervisionEdgeTypes: | ||
| - dstNodeType: paper | ||
| relation: cites | ||
| srcNodeType: paper | ||
| featureFlags: | ||
| should_run_glt_backend: 'True' | ||
| data_preprocessor_num_shards: '2' | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -82,12 +82,13 @@ | |
| """ | ||
|
|
||
| import argparse | ||
| import ast | ||
| import gc | ||
| import os | ||
| import sys | ||
| import time | ||
| from dataclasses import dataclass | ||
| from typing import Union | ||
| from typing import Optional, Union | ||
|
|
||
| import torch | ||
| import torch.multiprocessing as mp | ||
|
|
@@ -101,6 +102,7 @@ | |
| from gigl.common.utils.gcs import GcsUtils | ||
| from gigl.distributed.graph_store.compute import init_compute_process | ||
| from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset | ||
| from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions | ||
| from gigl.distributed.utils import get_graph_store_info | ||
| from gigl.env.distributed import GraphStoreInfo | ||
| from gigl.nn import LinkPredictionGNN | ||
|
|
@@ -114,12 +116,6 @@ | |
|
|
||
| logger = Logger() | ||
|
|
||
| # Default number of inference processes per machine incase one isnt provided in inference args | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Graph Store mode, the source of truth should be cluster_info.num_processes_per_compute, not a local CPU/GPU heuristic. The previous fallback could make inference spawn a different number of compute processes than storage expected, causing storage rendezvous failures like “only N/M clients joined.” |
||
| # i.e. `local_world_size` is not provided, and we can't infer automatically. | ||
| # If there are GPUs attached to the machine, we automatically infer to setting | ||
| # LOCAL_WORLD_SIZE == # of gpus on the machine. | ||
| DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE = 4 | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class InferenceProcessArgs: | ||
|
|
@@ -143,6 +139,7 @@ class InferenceProcessArgs: | |
| inference_batch_size (int): Batch size to use for inference. | ||
| num_neighbors (Union[list[int], dict[EdgeType, list[int]]]): Fanout for subgraph sampling, | ||
| where the ith item corresponds to the number of items to sample for the ith hop. | ||
| sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling. | ||
| sampling_workers_per_inference_process (int): Number of sampling workers per inference | ||
| process. | ||
| sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for | ||
|
|
@@ -169,6 +166,7 @@ class InferenceProcessArgs: | |
| # Inference configuration | ||
| inference_batch_size: int | ||
| num_neighbors: Union[list[int], dict[EdgeType, list[int]]] | ||
| sampler_options: Optional[SamplerOptions] | ||
| sampling_workers_per_inference_process: int | ||
| sampling_worker_shared_channel_size: str | ||
| log_every_n_batch: int | ||
|
|
@@ -242,6 +240,7 @@ def _inference_process( | |
| # For large-scale settings, consider setting this field to 30-60 seconds to ensure dataloaders | ||
| # don't compete for memory during initialization, causing OOM | ||
| process_start_gap_seconds=0, | ||
| sampler_options=args.sampler_options, | ||
| ) | ||
| # Initialize a LinkPredictionGNN model and load parameters from | ||
| # the saved model. | ||
|
|
@@ -455,25 +454,23 @@ def _run_example_inference( | |
| if arg_local_world_size is not None: | ||
| local_world_size = int(arg_local_world_size) | ||
| logger.info(f"Using local_world_size from inferencer_args: {local_world_size}") | ||
| if torch.cuda.is_available() and local_world_size != torch.cuda.device_count(): | ||
| logger.warning( | ||
| f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. " | ||
| "This may lead to unexpected failures with NCCL communication incase GPUs are being used for " | ||
| + "training/inference. Consider setting local_world_size to the number of GPUs." | ||
| ) | ||
| else: | ||
| if torch.cuda.is_available() and torch.cuda.device_count() > 0: | ||
| # If GPUs are available, we set the local_world_size to the number of GPUs | ||
| local_world_size = torch.cuda.device_count() | ||
| logger.info( | ||
| f"Detected {local_world_size} GPUs. Thus, setting local_world_size to {local_world_size}" | ||
| ) | ||
| else: | ||
| # If no GPUs are available, we set the local_world_size to the number of inference processes per machine | ||
| logger.info( | ||
| f"No GPUs detected. Thus, setting local_world_size to `{DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE}`" | ||
| ) | ||
| local_world_size = DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE | ||
| local_world_size = cluster_info.num_processes_per_compute | ||
| logger.info( | ||
| f"Using local_world_size from cluster_info.num_processes_per_compute: {local_world_size}" | ||
| ) | ||
| if local_world_size != cluster_info.num_processes_per_compute: | ||
| raise ValueError( | ||
| f"Graph Store local_world_size={local_world_size} must match " | ||
| f"cluster_info.num_processes_per_compute=" | ||
| f"{cluster_info.num_processes_per_compute}" | ||
| ) | ||
| if torch.cuda.is_available() and local_world_size != torch.cuda.device_count(): | ||
| logger.warning( | ||
| f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. " | ||
| "This may lead to unexpected failures with NCCL communication incase GPUs are being used for " | ||
| + "training/inference. Consider setting local_world_size to the number of GPUs." | ||
| ) | ||
|
|
||
| if cluster_info.compute_node_rank == 0: | ||
| gcs_utils = GcsUtils() | ||
|
|
@@ -494,6 +491,10 @@ def _run_example_inference( | |
| # Parses the fanout as a string. For the homogeneous case, the fanouts should be specified | ||
| # as a string of a list of integers, such as "[10, 10]". | ||
| num_neighbors = parse_fanout(inferencer_args.get("num_neighbors", "[10, 10]")) | ||
| sampler_options: Optional[SamplerOptions] = None | ||
| sampler_options_args = inferencer_args.get("ppr_sampler_options") | ||
| if sampler_options_args is not None and sampler_options_args.strip(): | ||
| sampler_options = PPRSamplerOptions(**ast.literal_eval(sampler_options_args)) | ||
|
|
||
| # While the ideal value for `sampling_workers_per_inference_process` has been identified to be | ||
| # between `2` and `4`, this may need some tuning depending on the pipeline. We default this | ||
|
|
@@ -516,6 +517,14 @@ def _run_example_inference( | |
|
|
||
| log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "50")) | ||
|
|
||
| logger.info( | ||
| f"Got inference args local_world_size={local_world_size}, " | ||
| f"num_neighbors={num_neighbors}, sampler_options={sampler_options}, " | ||
| f"sampling_workers_per_inference_process={sampling_workers_per_inference_process}, " | ||
| f"sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, " | ||
| f"log_every_n_batch={log_every_n_batch}" | ||
| ) | ||
|
|
||
| # When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine. | ||
| inference_args = InferenceProcessArgs( | ||
| local_world_size=local_world_size, | ||
|
|
@@ -528,6 +537,7 @@ def _run_example_inference( | |
| edge_feature_dim=edge_feature_dim, | ||
| inference_batch_size=inference_batch_size, | ||
| num_neighbors=num_neighbors, | ||
| sampler_options=sampler_options, | ||
| sampling_workers_per_inference_process=sampling_workers_per_inference_process, | ||
| sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, | ||
| log_every_n_batch=log_every_n_batch, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.