-
Notifications
You must be signed in to change notification settings - Fork 6
Remove the need for to_homogeneous(dataset.train_node_ids)
for Labeled homogeneous inputs
#115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
/unit_test |
/e2e_test |
/integration_test |
GiGL Automation@ 18:32:23UTC : 🔄 |
GiGL Automation@ 18:32:27UTC : 🔄 @ 20:00:36UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 18:32:32UTC : 🔄 @ 19:20:47UTC : ✅ Workflow completed successfully. |
/unit_test |
GiGL Automation@ 20:14:09UTC : 🔄 |
to_homogeneous(dataset.train_node_ids)
for Labeled homogeneous inputs
elif isinstance(input_nodes, abc.Mapping): | ||
if len(input_nodes) != 1: | ||
raise ValueError( | ||
f"If input_nodes is provided as a mapping, it must contain exactly one key/value pair. Received: {input_nodes}. This may happen if you call Loader(node_ids=dataset.node_ids) with a heterogeneous dataset." | ||
) | ||
node_type, node_ids = next(iter(input_nodes.items())) | ||
is_labeled_homoogeneous = node_type == DEFAULT_HOMOGENEOUS_NODE_TYPE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI I'm ok with not adding this change - we can just add the util if there's pushback here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pushback.
This can happen outside this function for a cleaner interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Kyle!
|
||
|
||
@dataclass(frozen=True) | ||
class _ResolvedNodeSamplerInput: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer not to add anotherNodeSamplerInput
derivative if we can avoid it for the sake of reducing the complexity of our codebase -- is the value here primarily the is_labeled_homogeneous
field? Is there any way we can just have resolve_node_sampler_input_from_user_input
return a Tuple[NodeSamplerInput, bool] or even a Tuple[NodeType, torch.Tensor, bool]
, since it seems like the only place this is used in in the __init__
of the ABLPLoader?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I wanted to do this to avoid the three sized tuple.
I understand the apprehension here about the new dataclass though - do you think renaming it could help? Something like _ParsedInputs
or equivalent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LGTM provided comments are addressed
/unit_test |
GiGL Automation@ 20:15:15UTC : 🔄 @ 20:52:38UTC : ✅ Workflow completed successfully. |
input_nodes: Optional[ | ||
Union[ | ||
torch.Tensor, | ||
tuple[NodeType, torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might have to run make format
- unused imports
# If a dict is provided, the keys are node types and the values are tensors of node IDs. | ||
# If a dict is provided, the graph must be heterogeneous, and there must be only one key/value pair in the dict. | ||
# We allow dicts to be passed in as a convenenience for users who have a heterogeneous graph with only one supervision edge type. | ||
NodeSamplerInput = Optional[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is used as input for multiple methods, lets define it in gigl.typing
or gigl.types
Probably can define all our types there - similar to pyg, torch respectively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also call this InputNodes
instead?:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we aim to work towards supporting the pyg api as close as possible here.
"""Resolves the input nodes for a node sampler. | ||
This function takes the user input for input nodes and resolves it to a consistent format. | ||
|
||
See the comment above NodeSamplerInput for the allowed inputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion for generating better documentation i.e. this will link to the class directly.
See the comment above NodeSamplerInput for the allowed inputs. | |
See the comment above :py:obj:`gigl.distributed.utils.neighborloader.NodeSamplerInput` for the allowed inputs. |
elif isinstance(input_nodes, abc.Mapping): | ||
if len(input_nodes) != 1: | ||
raise ValueError( | ||
f"If input_nodes is provided as a mapping, it must contain exactly one key/value pair. Received: {input_nodes}. This may happen if you call Loader(node_ids=dataset.node_ids) with a heterogeneous dataset." | ||
) | ||
node_type, node_ids = next(iter(input_nodes.items())) | ||
is_labeled_homoogeneous = node_type == DEFAULT_HOMOGENEOUS_NODE_TYPE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pushback.
This can happen outside this function for a cleaner interface
] | ||
|
||
|
||
def resolve_node_sampler_input_from_user_input( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to renaming suggestion above:
def parse_input_nodes(input_nodes: gigl.typing.InputNodes) -> ...
?
Headline change here is to remove the need to do:
As demonstrated in
_run_cora_supervised
1 and_run_distributed_neighbor_loader_labeled_homogeneous
2.Also:
neighborloader.resolve_node_sampler_input_from_user_input
NodeSamplerInput
type alias to be shared between DistABLPLoader and DistNeighborLoader.Added unit tests for new util.