-
Notifications
You must be signed in to change notification settings - Fork 6
Remove the need for to_homogeneous(dataset.train_node_ids)
for Labeled homogeneous inputs
#115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
64fac2d
463f60f
d88276d
df28cfd
f8bf0bd
2a59b95
bf5f1c7
e43a93a
a80eea9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,13 +1,15 @@ | ||||||
"""Utils for Neighbor loaders.""" | ||||||
from collections import abc | ||||||
from copy import deepcopy | ||||||
from typing import Union | ||||||
from typing import Optional, Union | ||||||
|
||||||
import torch | ||||||
from torch_geometric.data import Data, HeteroData | ||||||
from torch_geometric.typing import EdgeType | ||||||
|
||||||
from gigl.common.logger import Logger | ||||||
from gigl.types.graph import is_label_edge_type | ||||||
from gigl.src.common.types.graph_data import NodeType | ||||||
from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE, is_label_edge_type | ||||||
|
||||||
logger = Logger() | ||||||
|
||||||
|
@@ -118,3 +120,84 @@ def strip_label_edges(data: HeteroData) -> HeteroData: | |||||
del data.num_sampled_edges[edge_type] | ||||||
|
||||||
return data | ||||||
|
||||||
|
||||||
# Allowed inputs for node samplers. | ||||||
# If None is provded, then all nodes in the graph will be sampled. | ||||||
# And the graph must be homogeneous. | ||||||
# If a single tensor is provided, it is assumed to be a tensor of node IDs. | ||||||
# And the graph must be homogeneous, or labled homogeneous. | ||||||
# If a tuple is provided, the first element is the node type and the second element is the tensor of node IDs. | ||||||
# If a dict is provided, the keys are node types and the values are tensors of node IDs. | ||||||
# If a dict is provided, the graph must be heterogeneous, and there must be only one key/value pair in the dict. | ||||||
# We allow dicts to be passed in as a convenenience for users who have a heterogeneous graph with only one supervision edge type. | ||||||
NodeSamplerInput = Optional[ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is used as input for multiple methods, lets define it in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also call this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally we aim to work towards supporting the pyg api as close as possible here. |
||||||
Union[ | ||||||
torch.Tensor, tuple[NodeType, torch.Tensor], abc.Mapping[NodeType, torch.Tensor] | ||||||
] | ||||||
] | ||||||
|
||||||
|
||||||
def resolve_node_sampler_input_from_user_input( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar to renaming suggestion above: def parse_input_nodes(input_nodes: gigl.typing.InputNodes) -> ... ? |
||||||
input_nodes: NodeSamplerInput, | ||||||
dataset_nodes: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]], | ||||||
) -> tuple[Optional[NodeType], torch.Tensor, bool]: | ||||||
"""Resolves the input nodes for a node sampler. | ||||||
This function takes the user input for input nodes and resolves it to a consistent format. | ||||||
|
||||||
See the comment above NodeSamplerInput for the allowed inputs. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion for generating better documentation i.e. this will link to the class directly.
Suggested change
|
||||||
|
||||||
Args: | ||||||
input_nodes (NodeSamplerInput): The input nodes provided by the user. | ||||||
dataset_nodes (Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]): The nodes in the dataset. | ||||||
|
||||||
Returns: | ||||||
tuple[NodeType, torch.Tensor, bool]: A tuple containing: | ||||||
- node_type (NodeType): The type of the nodes. | ||||||
- node_ids (torch.Tensor): The tensor of node IDs. | ||||||
- is_labeled_homogeneous (bool): Whether the dataset is a labeled homogeneous graph. | ||||||
""" | ||||||
is_labeled_homoogeneous = False | ||||||
if isinstance(input_nodes, torch.Tensor): | ||||||
node_ids = input_nodes | ||||||
|
||||||
# If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting, | ||||||
# if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE. | ||||||
if isinstance(dataset_nodes, dict): | ||||||
if ( | ||||||
len(dataset_nodes) == 1 | ||||||
and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset_nodes | ||||||
): | ||||||
node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE | ||||||
is_labeled_homoogeneous = True | ||||||
else: | ||||||
raise ValueError( | ||||||
f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset_nodes.keys()}" | ||||||
) | ||||||
else: | ||||||
node_type = None | ||||||
elif isinstance(input_nodes, abc.Mapping): | ||||||
if len(input_nodes) != 1: | ||||||
raise ValueError( | ||||||
f"If input_nodes is provided as a mapping, it must contain exactly one key/value pair. Received: {input_nodes}. This may happen if you call Loader(node_ids=dataset.node_ids) with a heterogeneous dataset." | ||||||
) | ||||||
node_type, node_ids = next(iter(input_nodes.items())) | ||||||
is_labeled_homoogeneous = node_type == DEFAULT_HOMOGENEOUS_NODE_TYPE | ||||||
Comment on lines
+179
to
+185
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI I'm ok with not adding this change - we can just add the util if there's pushback here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pushback. |
||||||
elif isinstance(input_nodes, tuple): | ||||||
node_type, node_ids = input_nodes | ||||||
elif input_nodes is None: | ||||||
if dataset_nodes is None: | ||||||
raise ValueError("If input_nodes is None, the dataset must have node ids.") | ||||||
if isinstance(dataset_nodes, torch.Tensor): | ||||||
node_type = None | ||||||
node_ids = dataset_nodes | ||||||
elif isinstance(dataset_nodes, dict): | ||||||
raise ValueError( | ||||||
f"Input nodes must be provided for a heterogeneous graph. Received: {dataset_nodes}" | ||||||
) | ||||||
|
||||||
return ( | ||||||
node_type, | ||||||
node_ids, | ||||||
is_labeled_homoogeneous, | ||||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might have to run
make format
- unused imports