Skip to content

Conversation

kmontemayor2-sc
Copy link
Collaborator

@kmontemayor2-sc kmontemayor2-sc commented Jun 26, 2025

Headline change here is to remove the need to do:

loader = DistABLPLoader(input_nodes=to_homogeneous(dataset.train_node_ids), ...)

As demonstrated in _run_cora_supervised 1 and _run_distributed_neighbor_loader_labeled_homogeneous 2.

Also:

  • Breakout node input parsing logic to shared util neighborloader.resolve_node_sampler_input_from_user_input
  • Add NodeSamplerInput type alias to be shared between DistABLPLoader and DistNeighborLoader.

Added unit tests for new util.

@kmontemayor2-sc
Copy link
Collaborator Author

/unit_test

@kmontemayor2-sc
Copy link
Collaborator Author

/e2e_test

@kmontemayor2-sc
Copy link
Collaborator Author

/integration_test

Copy link
Contributor

GiGL Automation

@ 18:32:23UTC : 🔄 Unit Test started.

Copy link
Contributor

github-actions bot commented Jun 26, 2025

GiGL Automation

@ 18:32:27UTC : 🔄 E2E Test started.

@ 20:00:36UTC : ✅ Workflow completed successfully.

Copy link
Contributor

github-actions bot commented Jun 26, 2025

GiGL Automation

@ 18:32:32UTC : 🔄 Integration Test started.

@ 19:20:47UTC : ✅ Workflow completed successfully.

@kmontemayor2-sc
Copy link
Collaborator Author

/unit_test

Copy link
Contributor

GiGL Automation

@ 20:14:09UTC : 🔄 Unit Test started.

@kmontemayor2-sc kmontemayor2-sc changed the title wip to allow dict inputs Remove the need for to_homogeneous(dataset.train_node_ids) for Labeled homogeneous inputs Jun 26, 2025
Comment on lines +183 to +189
elif isinstance(input_nodes, abc.Mapping):
if len(input_nodes) != 1:
raise ValueError(
f"If input_nodes is provided as a mapping, it must contain exactly one key/value pair. Received: {input_nodes}. This may happen if you call Loader(node_ids=dataset.node_ids) with a heterogeneous dataset."
)
node_type, node_ids = next(iter(input_nodes.items()))
is_labeled_homoogeneous = node_type == DEFAULT_HOMOGENEOUS_NODE_TYPE
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I'm ok with not adding this change - we can just add the util if there's pushback here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pushback.
This can happen outside this function for a cleaner interface

Copy link
Collaborator

@mkolodner-sc mkolodner-sc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Kyle!



@dataclass(frozen=True)
class _ResolvedNodeSamplerInput:
Copy link
Collaborator

@mkolodner-sc mkolodner-sc Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer not to add anotherNodeSamplerInput derivative if we can avoid it for the sake of reducing the complexity of our codebase -- is the value here primarily the is_labeled_homogeneous field? Is there any way we can just have resolve_node_sampler_input_from_user_input return a Tuple[NodeSamplerInput, bool] or even a Tuple[NodeType, torch.Tensor, bool], since it seems like the only place this is used in in the __init__ of the ABLPLoader?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I wanted to do this to avoid the three sized tuple.

I understand the apprehension here about the new dataclass though - do you think renaming it could help? Something like _ParsedInputs or equivalent?

Copy link
Collaborator

@mkolodner-sc mkolodner-sc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM provided comments are addressed

@kmontemayor2-sc
Copy link
Collaborator Author

/unit_test

Copy link
Contributor

github-actions bot commented Jul 1, 2025

GiGL Automation

@ 20:15:15UTC : 🔄 Unit Test started.

@ 20:52:38UTC : ✅ Workflow completed successfully.

input_nodes: Optional[
Union[
torch.Tensor,
tuple[NodeType, torch.Tensor],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might have to run make format - unused imports

# If a dict is provided, the keys are node types and the values are tensors of node IDs.
# If a dict is provided, the graph must be heterogeneous, and there must be only one key/value pair in the dict.
# We allow dicts to be passed in as a convenenience for users who have a heterogeneous graph with only one supervision edge type.
NodeSamplerInput = Optional[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is used as input for multiple methods, lets define it in gigl.typing or gigl.types
Probably can define all our types there - similar to pyg, torch respectively.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we aim to work towards supporting the pyg api as close as possible here.

"""Resolves the input nodes for a node sampler.
This function takes the user input for input nodes and resolves it to a consistent format.

See the comment above NodeSamplerInput for the allowed inputs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion for generating better documentation i.e. this will link to the class directly.

Suggested change
See the comment above NodeSamplerInput for the allowed inputs.
See the comment above :py:obj:`gigl.distributed.utils.neighborloader.NodeSamplerInput` for the allowed inputs.

Comment on lines +183 to +189
elif isinstance(input_nodes, abc.Mapping):
if len(input_nodes) != 1:
raise ValueError(
f"If input_nodes is provided as a mapping, it must contain exactly one key/value pair. Received: {input_nodes}. This may happen if you call Loader(node_ids=dataset.node_ids) with a heterogeneous dataset."
)
node_type, node_ids = next(iter(input_nodes.items()))
is_labeled_homoogeneous = node_type == DEFAULT_HOMOGENEOUS_NODE_TYPE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pushback.
This can happen outside this function for a cleaner interface

]


def resolve_node_sampler_input_from_user_input(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to renaming suggestion above:

def parse_input_nodes(input_nodes: gigl.typing.InputNodes) -> ...

?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants