Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,14 @@ run_hom_cora_sup_gs_e2e_test:
--test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \
--test_names="hom_cora_sup_gs_test"

run_hom_cora_sup_gs_ppr_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH}
run_hom_cora_sup_gs_ppr_e2e_test: compile_gigl_kubeflow_pipeline
run_hom_cora_sup_gs_ppr_e2e_test:
uv run python tests/e2e_tests/e2e_test.py \
--compiled_pipeline_path=$(compiled_pipeline_path) \
--test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \
--test_names="hom_cora_sup_gs_ppr_test"

run_het_dblp_sup_gs_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH}
run_het_dblp_sup_gs_e2e_test: compile_gigl_kubeflow_pipeline
run_het_dblp_sup_gs_e2e_test:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# This config runs homogeneous CORA supervised training and inference in Graph Store mode
# with PPR sampling. It intentionally reuses the standard graph-store training/inference
# entrypoints, changing only the sampler args and keeping the loop short for E2E coverage.
graphMetadata:
edgeTypes:
- dstNodeType: paper
relation: cites
srcNodeType: paper
nodeTypes:
- paper
datasetConfig:
dataPreprocessorConfig:
dataPreprocessorConfigClsPath: gigl.src.mocking.mocking_assets.passthrough_preprocessor_config_for_mocked_assets.PassthroughPreprocessorConfigForMockedAssets
dataPreprocessorArgs:
mocked_dataset_name: 'cora_homogeneous_node_anchor_edge_features_user_defined_labels'
trainerConfig:
trainerArgs:
log_every_n_batch: "1"
num_neighbors: "[10, 10]"
sampler_type: "ppr"
ppr_alpha: "0.5"
ppr_eps: "0.0001"
ppr_max_nodes: "20"
ppr_neighbors_per_hop: "100"
ppr_max_fetch_iterations: "2"
sampling_workers_per_process: "2"
main_batch_size: "8"
random_batch_size: "8"
num_max_train_batches: "4"
num_val_batches: "4"
val_every_n_batch: "1"
command: python -m examples.link_prediction.graph_store.homogeneous_training
graphStoreStorageConfig:
command: python -m examples.link_prediction.graph_store.storage_main
storageArgs:
sample_edge_direction: "in"
splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter"
splitter_kwargs: >-
{
"sampling_direction": "in",
"should_convert_labels_to_edges": True,
"num_val": 0.25,
"num_test": 0.25
}
num_server_sessions: "1"
inferencerConfig:
inferencerArgs:
log_every_n_batch: "1"
num_neighbors: "[10, 10]"
sampler_type: "ppr"
ppr_alpha: "0.5"
ppr_eps: "0.0001"
ppr_max_nodes: "20"
ppr_neighbors_per_hop: "100"
ppr_max_fetch_iterations: "2"
sampling_workers_per_inference_process: "2"
inferenceBatchSize: 256
command: python -m examples.link_prediction.graph_store.homogeneous_inference
graphStoreStorageConfig:
command: python -m examples.link_prediction.graph_store.storage_main
storageArgs:
sample_edge_direction: "in"
num_server_sessions: "1"
sharedConfig:
shouldSkipInference: false
shouldSkipModelEvaluation: true
taskMetadata:
nodeAnchorBasedLinkPredictionTaskMetadata:
supervisionEdgeTypes:
- dstNodeType: paper
relation: cites
srcNodeType: paper
featureFlags:
should_run_glt_backend: 'True'
data_preprocessor_num_shards: '2'
58 changes: 32 additions & 26 deletions examples/link_prediction/graph_store/homogeneous_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
import sys
import time
from dataclasses import dataclass
from typing import Union
from typing import Optional, Union

import torch
import torch.multiprocessing as mp
Expand All @@ -101,6 +101,7 @@
from gigl.common.utils.gcs import GcsUtils
from gigl.distributed.graph_store.compute import init_compute_process
from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset
from gigl.distributed.sampler_options import SamplerOptions
from gigl.distributed.utils import get_graph_store_info
from gigl.env.distributed import GraphStoreInfo
from gigl.nn import LinkPredictionGNN
Expand All @@ -110,16 +111,10 @@
from gigl.src.common.utils.bq import BqUtils
from gigl.src.common.utils.model import load_state_dict_from_uri
from gigl.src.inference.lib.assets import InferenceAssets
from gigl.utils.sampling import parse_fanout
from gigl.utils.sampling import parse_fanout, parse_sampler_options

logger = Logger()

# Default number of inference processes per machine incase one isnt provided in inference args
# i.e. `local_world_size` is not provided, and we can't infer automatically.
# If there are GPUs attached to the machine, we automatically infer to setting
# LOCAL_WORLD_SIZE == # of gpus on the machine.
DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE = 4


@dataclass(frozen=True)
class InferenceProcessArgs:
Expand All @@ -143,6 +138,7 @@ class InferenceProcessArgs:
inference_batch_size (int): Batch size to use for inference.
num_neighbors (Union[list[int], dict[EdgeType, list[int]]]): Fanout for subgraph sampling,
where the ith item corresponds to the number of items to sample for the ith hop.
sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling.
sampling_workers_per_inference_process (int): Number of sampling workers per inference
process.
sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for
Expand All @@ -169,6 +165,7 @@ class InferenceProcessArgs:
# Inference configuration
inference_batch_size: int
num_neighbors: Union[list[int], dict[EdgeType, list[int]]]
sampler_options: Optional[SamplerOptions]
sampling_workers_per_inference_process: int
sampling_worker_shared_channel_size: str
log_every_n_batch: int
Expand Down Expand Up @@ -242,6 +239,7 @@ def _inference_process(
# For large-scale settings, consider setting this field to 30-60 seconds to ensure dataloaders
# don't compete for memory during initialization, causing OOM
process_start_gap_seconds=0,
sampler_options=args.sampler_options,
)
# Initialize a LinkPredictionGNN model and load parameters from
# the saved model.
Expand Down Expand Up @@ -455,25 +453,23 @@ def _run_example_inference(
if arg_local_world_size is not None:
local_world_size = int(arg_local_world_size)
logger.info(f"Using local_world_size from inferencer_args: {local_world_size}")
if torch.cuda.is_available() and local_world_size != torch.cuda.device_count():
logger.warning(
f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. "
"This may lead to unexpected failures with NCCL communication incase GPUs are being used for "
+ "training/inference. Consider setting local_world_size to the number of GPUs."
)
else:
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
# If GPUs are available, we set the local_world_size to the number of GPUs
local_world_size = torch.cuda.device_count()
logger.info(
f"Detected {local_world_size} GPUs. Thus, setting local_world_size to {local_world_size}"
)
else:
# If no GPUs are available, we set the local_world_size to the number of inference processes per machine
logger.info(
f"No GPUs detected. Thus, setting local_world_size to `{DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE}`"
)
local_world_size = DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE
local_world_size = cluster_info.num_processes_per_compute
logger.info(
f"Using local_world_size from cluster_info.num_processes_per_compute: {local_world_size}"
)
if local_world_size != cluster_info.num_processes_per_compute:
raise ValueError(
f"Graph Store local_world_size={local_world_size} must match "
f"cluster_info.num_processes_per_compute="
f"{cluster_info.num_processes_per_compute}"
)
if torch.cuda.is_available() and local_world_size != torch.cuda.device_count():
logger.warning(
f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. "
"This may lead to unexpected failures with NCCL communication incase GPUs are being used for "
+ "training/inference. Consider setting local_world_size to the number of GPUs."
)

if cluster_info.compute_node_rank == 0:
gcs_utils = GcsUtils()
Expand All @@ -494,6 +490,7 @@ def _run_example_inference(
# Parses the fanout as a string. For the homogeneous case, the fanouts should be specified
# as a string of a list of integers, such as "[10, 10]".
num_neighbors = parse_fanout(inferencer_args.get("num_neighbors", "[10, 10]"))
sampler_options = parse_sampler_options(inferencer_args)

# While the ideal value for `sampling_workers_per_inference_process` has been identified to be
# between `2` and `4`, this may need some tuning depending on the pipeline. We default this
Expand All @@ -516,6 +513,14 @@ def _run_example_inference(

log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "50"))

logger.info(
f"Got inference args local_world_size={local_world_size}, "
f"num_neighbors={num_neighbors}, sampler_options={sampler_options}, "
f"sampling_workers_per_inference_process={sampling_workers_per_inference_process}, "
f"sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, "
f"log_every_n_batch={log_every_n_batch}"
)

# When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine.
inference_args = InferenceProcessArgs(
local_world_size=local_world_size,
Expand All @@ -528,6 +533,7 @@ def _run_example_inference(
edge_feature_dim=edge_feature_dim,
inference_batch_size=inference_batch_size,
num_neighbors=num_neighbors,
sampler_options=sampler_options,
sampling_workers_per_inference_process=sampling_workers_per_inference_process,
sampling_worker_shared_channel_size=sampling_worker_shared_channel_size,
log_every_n_batch=log_every_n_batch,
Expand Down
29 changes: 23 additions & 6 deletions examples/link_prediction/graph_store/homogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
shutdown_compute_process,
)
from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset
from gigl.distributed.sampler_options import SamplerOptions
from gigl.distributed.utils import get_available_device, get_graph_store_info
from gigl.env.distributed import GraphStoreInfo
from gigl.nn import LinkPredictionGNN, RetrievalLoss
Expand All @@ -158,7 +159,7 @@
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict
from gigl.utils.iterator import InfiniteIterator
from gigl.utils.sampling import parse_fanout
from gigl.utils.sampling import parse_fanout, parse_sampler_options

logger = Logger()

Expand Down Expand Up @@ -191,6 +192,7 @@ def _setup_dataloaders(
split: Literal["train", "val", "test"],
cluster_info: GraphStoreInfo,
num_neighbors: list[int] | dict[EdgeType, list[int]],
sampler_options: Optional[SamplerOptions],
sampling_workers_per_process: int,
main_batch_size: int,
random_batch_size: int,
Expand All @@ -205,6 +207,7 @@ def _setup_dataloaders(
split (Literal["train", "val", "test"]): The current split which we are loading data for.
cluster_info (GraphStoreInfo): Cluster topology info for graph store mode.
num_neighbors: Fanout for subgraph sampling.
sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling.
sampling_workers_per_process (int): Number of sampling workers per training/testing process.
main_batch_size (int): Batch size for main dataloader with query and labeled nodes.
random_batch_size (int): Batch size for random negative dataloader.
Expand Down Expand Up @@ -240,6 +243,7 @@ def _setup_dataloaders(
channel_size=sampling_worker_shared_channel_size,
process_start_gap_seconds=process_start_gap_seconds,
shuffle=shuffle,
sampler_options=sampler_options,
)

logger.info(f"---Rank {rank} finished setting up main loader for split={split}")
Expand All @@ -266,6 +270,7 @@ def _setup_dataloaders(
channel_size=sampling_worker_shared_channel_size,
process_start_gap_seconds=process_start_gap_seconds,
shuffle=shuffle,
sampler_options=sampler_options,
)

logger.info(
Expand Down Expand Up @@ -375,6 +380,7 @@ class TrainingProcessArgs:
sampling_workers_per_process (int): Number of sampling workers per training/testing process.
sampling_worker_shared_channel_size (str): Shared-memory buffer size for the channel during sampling.
process_start_gap_seconds (int): Time to sleep between dataloader initializations.
sampler_options (Optional[SamplerOptions]): Sampler variant. None uses k-hop sampling.
main_batch_size (int): Batch size for main dataloader.
random_batch_size (int): Batch size for random negative dataloader.
learning_rate (float): Learning rate for the optimizer.
Expand All @@ -400,6 +406,7 @@ class TrainingProcessArgs:

# Sampling config
num_neighbors: list[int] | dict[EdgeType, list[int]]
sampler_options: Optional[SamplerOptions]
sampling_workers_per_process: int
sampling_worker_shared_channel_size: str
process_start_gap_seconds: int
Expand Down Expand Up @@ -463,6 +470,7 @@ def _training_process(
split="train",
cluster_info=args.cluster_info,
num_neighbors=args.num_neighbors,
sampler_options=args.sampler_options,
sampling_workers_per_process=args.sampling_workers_per_process,
main_batch_size=args.main_batch_size,
random_batch_size=args.random_batch_size,
Expand All @@ -481,6 +489,7 @@ def _training_process(
split="val",
cluster_info=args.cluster_info,
num_neighbors=args.num_neighbors,
sampler_options=args.sampler_options,
sampling_workers_per_process=args.sampling_workers_per_process,
main_batch_size=args.main_batch_size,
random_batch_size=args.random_batch_size,
Expand Down Expand Up @@ -637,6 +646,7 @@ def _training_process(
split="test",
cluster_info=args.cluster_info,
num_neighbors=args.num_neighbors,
sampler_options=args.sampler_options,
sampling_workers_per_process=args.sampling_workers_per_process,
main_batch_size=args.main_batch_size,
random_batch_size=args.random_batch_size,
Expand Down Expand Up @@ -837,13 +847,17 @@ def _run_example_training(
# Training Hyperparameters
trainer_args = dict(gbml_config_pb_wrapper.trainer_config.trainer_args)

if torch.cuda.is_available():
default_local_world_size = torch.cuda.device_count()
else:
default_local_world_size = 2
local_world_size = int(
trainer_args.get("local_world_size", str(default_local_world_size))
trainer_args.get(
"local_world_size", str(cluster_info.num_processes_per_compute)
)
)
if local_world_size != cluster_info.num_processes_per_compute:
raise ValueError(
f"Graph Store local_world_size={local_world_size} must match "
f"cluster_info.num_processes_per_compute="
f"{cluster_info.num_processes_per_compute}"
)

if torch.cuda.is_available():
if local_world_size > torch.cuda.device_count():
Expand All @@ -853,6 +867,7 @@ def _run_example_training(

fanout = trainer_args.get("num_neighbors", "[10, 10]")
num_neighbors = parse_fanout(fanout)
sampler_options = parse_sampler_options(trainer_args)

sampling_workers_per_process: int = int(
trainer_args.get("sampling_workers_per_process", "4")
Expand Down Expand Up @@ -880,6 +895,7 @@ def _run_example_training(
logger.info(
f"Got training args local_world_size={local_world_size}, \
num_neighbors={num_neighbors}, \
sampler_options={sampler_options}, \
sampling_workers_per_process={sampling_workers_per_process}, \
main_batch_size={main_batch_size}, \
random_batch_size={random_batch_size}, \
Expand Down Expand Up @@ -931,6 +947,7 @@ def _run_example_training(
node_feature_dim=node_feature_dim,
edge_feature_dim=edge_feature_dim,
num_neighbors=num_neighbors,
sampler_options=sampler_options,
sampling_workers_per_process=sampling_workers_per_process,
sampling_worker_shared_channel_size=sampling_worker_shared_channel_size,
process_start_gap_seconds=process_start_gap_seconds,
Expand Down
10 changes: 0 additions & 10 deletions gigl/distributed/base_dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,16 +466,6 @@ def create_mp_producer(
channel = BaseDistLoader.create_colocated_channel(worker_options)
if isinstance(sampler_options, PPRSamplerOptions):
degree_tensors = dataset.degree_tensor
if isinstance(degree_tensors, dict):
logger.info(
f"Pre-computed degree tensors for PPR sampling across "
f"{len(degree_tensors)} edge types."
)
else:
logger.info(
f"Pre-computed degree tensor for PPR sampling with "
f"{degree_tensors.size(0)} nodes."
)
else:
degree_tensors = None
return DistSamplingProducer(
Expand Down
Loading