Skip to content

Conversation

@kmontemayor2-sc
Copy link
Collaborator

Scope of work done

Add server-side util so we can remotely fetch the training input.

Since this is kind of a big PR not adding the client-side equivalent in this one :P

Again this is server-side code, and it's really meant to be called by users.

Where is the documentation for this feature?: N/A

Did you add automated tests or write a test plan?

Updated Changelog.md? NO

Ready for code review?: NO

Copy link
Collaborator

@mkolodner-sc mkolodner-sc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Kyle! Did a pass here and left some comments/questions



def get_training_input(
split: Union[Literal["train", "val", "test"], str],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this sufficient? What if we want "all" nodes (i.e. dataset.node_ids)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That wouldn't be "training"input" would it? We can add another function to do that in the future (get_all_nodes?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be for the random negative loader for link prediction training, which would need some dataset.node_ids or equivalent.

We can add another function to do that in the future (get_all_nodes?)

We wouldn't need a whole different function, we could just specify some 'all' split and if its that we use _dataset.node_ids.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm, I do think that the tuple[Tensor, Tensor, Tensor | None] is important for the ABLP.

I guess I could rename this to get_ablp_input or something? Would that ameliorate your concerns?

Copy link
Collaborator

@mkolodner-sc mkolodner-sc Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline, renaming will be fine here for this function, and in a follow-up we will refactor the get_node_ids_on_rank utility so that it can be used to split, making it extendable for the SNC use case, and can be called in this function to reduce the duplicity. Can we add a TODO here in the meantime?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure added todos :)



def get_training_input(
split: Union[Literal["train", "val", "test"], str],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be for the random negative loader for link prediction training, which would need some dataset.node_ids or equivalent.

We can add another function to do that in the future (get_all_nodes?)

We wouldn't need a whole different function, we could just specify some 'all' split and if its that we use _dataset.node_ids.

f"Anchor nodes must be a torch.Tensor or a dict[NodeType, torch.Tensor], got {type(anchors)}"
)

anchors_for_rank = shard_nodes_by_process(anchor_nodes, rank, world_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember we had a discussion at some point of potentially moving this up to be user-facing instead of done under-the-hood by a GiGL utility/class. Would be curious to get your thoughts on this decision here as well

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmmm, we did and I think we agreed we should do less of this - though we still do it in GS mode.

I went with this approach in GS mode so we minimize network traffic.

Since this is how we are currently doing this for RemoteDistDataset and friends, should we keep this approach for now and re-visit in the future, for all node fetching?

Adding some get_for_rank bool flag is probably sufficient here? WDYT about that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine either way if you'd prefer to do this now or a follow-up. If in follow-up, let's just make sure we add a TODO here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same qq.
My understanding was that there was some agreement where the client should have complicated logic of knowing what data to fetch from where.

Now whether the user prompts the client code to fetch exactly what they want, or uses this sort of utility on the client to implicitly fetch some split of nodes, the logic lives inside the client and the server is dumb and just fetches the data requested.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see storage layer as a sort of a "db" specific to graph sampling.

If complex code to decide what data to provide to what client brings forward some smells, speicifcally I think it:

  • Breaks determinism (The operations are implicit, and outside of control from client)
  • Makes queries non-portable i.e. we are stuck with defining queries inside of the server vs client controlled
  • Couples the query patterns to runtime topology; which may change in the future and the coupling will induce extra eng effort int the future to circumvent
  • Potentially break ability to do replica/retry strategies (if needed in future) as the client doesn't know what data to expect and alignment between the two will be difficult unless client hosts logic.

If we do decide to go for this route I think we should prove out that it is for some sane reason i.e. actually proving out the network traffic argument. Would love elaboration here.

Copy link
Collaborator

@mkolodner-sc mkolodner-sc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Kyle for the work!

class TestRemoteDataset(unittest.TestCase):
def setUp(self) -> None:
"""Reset the global dataset before each test."""
"""Reset the global dataset and initialize process group before each test."""
Copy link
Collaborator

@svij-sc svij-sc Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and initialize process group before each test.

This comment doesn't seem accurate i.e. should we be initializing here in this method?

)


def destroy_test_process_group() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Lazy abstraction to a function, esp cause only used in one place.

The split ratios are calculated as:
- num_val = len(val_user_ids) / total_users
- num_test = len(test_user_ids) / total_users
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little confused reading this; isn't the user already passing in val_user_ids? why do we need to recompute the val split if one is being provided?

Creates a dataset with:
- USER nodes: [0, 1, 2, 3, 4]
- STORY nodes: [0, 1, 2, 3, 4]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mega nit: STORY --> ITEM
More generic term for OSS


# Set up edge partition books and edge indices
edge_partition_book = {
_USER_TO_STORY: torch.zeros(5, dtype=torch.int64),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

magic number 5.
Is this length of edges?

Similar magic numbers below re edge index.

Creates a dataset with:
- USER nodes: [0, 1, 2, 3, 4]
- STORY nodes: [0, 1, 2, 3, 4]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make these arguments? i.e. user_node_ids, item_node_ids, ?
Or atleast default arguments/constants that can be referenced below to make the function more modular?

f"Anchor nodes must be a torch.Tensor or a dict[NodeType, torch.Tensor], got {type(anchors)}"
)

anchors_for_rank = shard_nodes_by_process(anchor_nodes, rank, world_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same qq.
My understanding was that there was some agreement where the client should have complicated logic of knowing what data to fetch from where.

Now whether the user prompts the client code to fetch exactly what they want, or uses this sort of utility on the client to implicitly fetch some split of nodes, the logic lives inside the client and the server is dumb and just fetches the data requested.

f"Anchor nodes must be a torch.Tensor or a dict[NodeType, torch.Tensor], got {type(anchors)}"
)

anchors_for_rank = shard_nodes_by_process(anchor_nodes, rank, world_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see storage layer as a sort of a "db" specific to graph sampling.

If complex code to decide what data to provide to what client brings forward some smells, speicifcally I think it:

  • Breaks determinism (The operations are implicit, and outside of control from client)
  • Makes queries non-portable i.e. we are stuck with defining queries inside of the server vs client controlled
  • Couples the query patterns to runtime topology; which may change in the future and the coupling will induce extra eng effort int the future to circumvent
  • Potentially break ability to do replica/retry strategies (if needed in future) as the client doesn't know what data to expect and alignment between the two will be difficult unless client hosts logic.

If we do decide to go for this route I think we should prove out that it is for some sane reason i.e. actually proving out the network traffic argument. Would love elaboration here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants