Skip to content

Remove the need for to_homogeneous(dataset.train_node_ids) for Labeled homogeneous inputs #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 33 additions & 43 deletions python/gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from gigl.distributed.distributed_neighborloader import DEFAULT_NUM_CPU_THREADS
from gigl.distributed.sampler import ABLPNodeSamplerInput
from gigl.distributed.utils.neighborloader import (
NodeSamplerInput,
labeled_to_homogeneous,
patch_fanout_for_sampling,
resolve_node_sampler_input_from_user_input,
shard_nodes_by_process,
strip_label_edges,
)
Expand All @@ -45,12 +47,7 @@ def __init__(
self,
dataset: DistLinkPredictionDataset,
num_neighbors: Union[list[int], dict[EdgeType, list[int]]],
input_nodes: Optional[
Union[
torch.Tensor,
tuple[NodeType, torch.Tensor],
]
] = None,
input_nodes: NodeSamplerInput = None,
# TODO(kmonte): Support multiple supervision edge types.
supervision_edge_type: Optional[EdgeType] = None,
num_workers: int = 1,
Expand Down Expand Up @@ -118,11 +115,9 @@ def __init__(
context (DistributedContext): Distributed context information of the current process.
local_process_rank (int): The local rank of the current process within a node.
local_process_world_size (int): The total number of processes within a node.
input_nodes (Optional[torch.Tensor, tuple[NodeType, torch.Tensor]]):
input_nodes (NodeSamplerInput):
Indices of seed nodes to start sampling from.
If set to `None` for homogeneous settings, all nodes will be considered.
In heterogeneous graphs, this flag must be passed in as a tuple that holds
the node type and node indices. (default: `None`)
See documentation for `gigl.distributed.utils.neighborloader.NodeSamplerInput` for more details.
num_workers (int): How many workers to use (subprocesses to spwan) for
distributed neighbor sampling of the current process. (default: ``1``).
batch_size (int, optional): how many samples per batch to load
Expand Down Expand Up @@ -235,50 +230,41 @@ def __init__(
f"The dataset must be heterogeneous for ABLP. Recieved dataset with graph of type: {type(dataset.graph)}"
)
self._is_input_heterogeneous: bool = False
if isinstance(input_nodes, tuple):
if supervision_edge_type is None:
raise ValueError(
"When using heterogeneous ABLP, you must provide supervision_edge_types."
)
self._is_input_heterogeneous = True
anchor_node_type, anchor_node_ids = input_nodes
# TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if
# this assumption is no longer valid and/or is too opinionated
assert (
supervision_edge_type[0] == anchor_node_type
), f"Label EdgeType are currently expected to be provided in outward edge direction as tuple (`anchor_node_type`,`relation`,`supervision_node_type`), \
got supervision edge type {supervision_edge_type} with anchor node type {anchor_node_type}"
supervision_node_type = supervision_edge_type[2]
if dataset.edge_dir == "in":
supervision_edge_type = reverse_edge_type(supervision_edge_type)
(
anchor_node_type,
anchor_node_ids,
self._is_labeled_homogeneous,
) = resolve_node_sampler_input_from_user_input(
input_nodes=input_nodes,
dataset_nodes=dataset.node_ids,
)

elif isinstance(input_nodes, torch.Tensor):
if (
anchor_node_type is None
or anchor_node_type == DEFAULT_HOMOGENEOUS_NODE_TYPE
):
if supervision_edge_type is not None:
raise ValueError(
f"Expected supervision edge type to be None for homogeneous input nodes, got {supervision_edge_type}"
)
anchor_node_ids = input_nodes
anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
supervision_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE
supervision_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
elif input_nodes is None:
if dataset.node_ids is None:
raise ValueError(
"Dataset must have node ids if input_nodes are not provided."
)
if isinstance(dataset.node_ids, abc.Mapping):
else:
if supervision_edge_type is None:
raise ValueError(
f"input_nodes must be provided for heterogeneous datasets, received node_ids of type: {dataset.node_ids.keys()}"
"When using heterogeneous ABLP, you must provide supervision_edge_type."
)
if supervision_edge_type is not None:
self._is_input_heterogeneous = True
# TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if
# this assumption is no longer valid and/or is too opinionated
if supervision_edge_type[0] != anchor_node_type:
raise ValueError(
f"Expected supervision edge type to be None for homogeneous input nodes, got {supervision_edge_type}"
f"Label EdgeType are currently expected to be provided in outward edge direction as tuple (`anchor_node_type`,`relation`,`supervision_node_type`), \
got supervision edge type {supervision_edge_type} with anchor node type {anchor_node_type}"
)

anchor_node_ids = dataset.node_ids
anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
supervision_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE
supervision_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
supervision_node_type = supervision_edge_type[2]
if dataset.edge_dir == "in":
supervision_edge_type = reverse_edge_type(supervision_edge_type)

missing_edge_types = set([supervision_edge_type]) - set(dataset.graph.keys())
if missing_edge_types:
Expand Down Expand Up @@ -542,6 +528,10 @@ def _set_labels(
local_node_to_global_node: torch.Tensor
# shape [N], where N is the number of nodes in the subgraph, and local_node_to_global_node[i] gives the global node id for local node id `i`
if isinstance(data, HeteroData):
if self._supervision_edge_type is None:
raise ValueError(
"When using heterogeneous ABLP, you must provide supervision_edge_type."
)
supervision_node_type = (
self._supervision_edge_type[0]
if self.edge_dir == "in"
Expand Down
74 changes: 28 additions & 46 deletions python/gigl/distributed/distributed_neighborloader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections import Counter, abc
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union

import torch
from graphlearn_torch.channel import SampleMessage
from graphlearn_torch.distributed import DistLoader, MpDistSamplingWorkerOptions
from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType
from graphlearn_torch.sampler import NodeSamplerInput as GLTNodeSamplerInput
from graphlearn_torch.sampler import SamplingConfig, SamplingType
from torch_geometric.data import Data, HeteroData
from torch_geometric.typing import EdgeType

Expand All @@ -14,18 +15,14 @@
from gigl.distributed.dist_context import DistributedContext
from gigl.distributed.dist_link_prediction_dataset import DistLinkPredictionDataset
from gigl.distributed.utils.neighborloader import (
NodeSamplerInput,
labeled_to_homogeneous,
patch_fanout_for_sampling,
resolve_node_sampler_input_from_user_input,
shard_nodes_by_process,
strip_label_edges,
)
from gigl.src.common.types.graph_data import (
NodeType, # TODO (mkolodner-sc): Change to use torch_geometric.typing
)
from gigl.types.graph import (
DEFAULT_HOMOGENEOUS_EDGE_TYPE,
DEFAULT_HOMOGENEOUS_NODE_TYPE,
)
from gigl.types.graph import DEFAULT_HOMOGENEOUS_EDGE_TYPE

logger = Logger()

Expand All @@ -38,9 +35,7 @@ def __init__(
self,
dataset: DistLinkPredictionDataset,
num_neighbors: Union[List[int], Dict[EdgeType, List[int]]],
input_nodes: Optional[
Union[torch.Tensor, Tuple[NodeType, torch.Tensor]]
] = None,
input_nodes: NodeSamplerInput = None,
num_workers: int = 1,
batch_size: int = 1,
context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this
Expand Down Expand Up @@ -70,12 +65,9 @@ def __init__(
context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process.
local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node.
local_process_world_size (deprecated - will be removed soon)(int): Required if context provided. The total number of processes within a node.
input_nodes (torch.Tensor or Tuple[str, torch.Tensor]): The
indices of seed nodes to start sampling from.
It is of type `torch.LongTensor` for homogeneous graphs.
If set to `None` for homogeneous settings, all nodes will be considered.
In heterogeneous graphs, this flag must be passed in as a tuple that holds
the node type and node indices. (default: `None`)
input_nodes (NodeSamplerInput):
Indices of seed nodes to start sampling from.
See documentation for `gigl.distributed.utils.neighborloader.NodeSamplerInput` for more details.
num_workers (int): How many workers to use (subprocesses to spwan) for
distributed neighbor sampling of the current process. (default: ``1``).
batch_size (int, optional): how many samples per batch to load
Expand Down Expand Up @@ -222,38 +214,28 @@ def __init__(
)

# Determines if the node ids passed in are heterogeneous or homogeneous.
self._is_labeled_heterogeneous = False
if isinstance(input_nodes, torch.Tensor):
node_ids = input_nodes

# If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting,
# if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE.
if isinstance(dataset.node_ids, abc.Mapping):
if (
len(dataset.node_ids) == 1
and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids
):
node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
self._is_labeled_heterogeneous = True
num_neighbors = patch_fanout_for_sampling(
dataset.get_edge_types(), num_neighbors
)
else:
raise ValueError(
f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}"
)
else:
node_type = None
else:
node_type, node_ids = input_nodes

(
anchor_node_type,
anchor_node_ids,
self._is_labeled_homogeneous,
) = resolve_node_sampler_input_from_user_input(
input_nodes=input_nodes,
dataset_nodes=dataset.node_ids,
)
if self._is_labeled_homogeneous:
# If the dataset is labeled homogeneous, we need to patch the fanout for sampling.
num_neighbors = patch_fanout_for_sampling(
dataset.get_edge_types(), num_neighbors
)
curr_process_nodes = shard_nodes_by_process(
input_nodes=node_ids,
input_nodes=anchor_node_ids,
local_process_rank=local_rank,
local_process_world_size=local_world_size,
)

input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type)
input_data = GLTNodeSamplerInput(
node=curr_process_nodes, input_type=anchor_node_type
)

# Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize
# the memory overhead and CPU contention.
Expand Down Expand Up @@ -343,6 +325,6 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]:
if isinstance(data, HeteroData):
data = strip_label_edges(data)

if self._is_labeled_heterogeneous:
if self._is_labeled_homogeneous:
data = labeled_to_homogeneous(DEFAULT_HOMOGENEOUS_EDGE_TYPE, data)
return data
87 changes: 85 additions & 2 deletions python/gigl/distributed/utils/neighborloader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Utils for Neighbor loaders."""
from collections import abc
from copy import deepcopy
from typing import Union
from typing import Optional, Union

import torch
from torch_geometric.data import Data, HeteroData
from torch_geometric.typing import EdgeType

from gigl.common.logger import Logger
from gigl.types.graph import is_label_edge_type
from gigl.src.common.types.graph_data import NodeType
from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE, is_label_edge_type

logger = Logger()

Expand Down Expand Up @@ -118,3 +120,84 @@ def strip_label_edges(data: HeteroData) -> HeteroData:
del data.num_sampled_edges[edge_type]

return data


# Allowed inputs for node samplers.
# If None is provded, then all nodes in the graph will be sampled.
# And the graph must be homogeneous.
# If a single tensor is provided, it is assumed to be a tensor of node IDs.
# And the graph must be homogeneous, or labled homogeneous.
# If a tuple is provided, the first element is the node type and the second element is the tensor of node IDs.
# If a dict is provided, the keys are node types and the values are tensors of node IDs.
# If a dict is provided, the graph must be heterogeneous, and there must be only one key/value pair in the dict.
# We allow dicts to be passed in as a convenenience for users who have a heterogeneous graph with only one supervision edge type.
NodeSamplerInput = Optional[
Union[
torch.Tensor, tuple[NodeType, torch.Tensor], abc.Mapping[NodeType, torch.Tensor]
]
]


def resolve_node_sampler_input_from_user_input(
input_nodes: NodeSamplerInput,
dataset_nodes: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]],
) -> tuple[Optional[NodeType], torch.Tensor, bool]:
"""Resolves the input nodes for a node sampler.
This function takes the user input for input nodes and resolves it to a consistent format.

See the comment above NodeSamplerInput for the allowed inputs.

Args:
input_nodes (NodeSamplerInput): The input nodes provided by the user.
dataset_nodes (Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]): The nodes in the dataset.

Returns:
tuple[NodeType, torch.Tensor, bool]: A tuple containing:
- node_type (NodeType): The type of the nodes.
- node_ids (torch.Tensor): The tensor of node IDs.
- is_labeled_homogeneous (bool): Whether the dataset is a labeled homogeneous graph.
"""
is_labeled_homoogeneous = False
if isinstance(input_nodes, torch.Tensor):
node_ids = input_nodes

# If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting,
# if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE.
if isinstance(dataset_nodes, dict):
if (
len(dataset_nodes) == 1
and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset_nodes
):
node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
is_labeled_homoogeneous = True
else:
raise ValueError(
f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset_nodes.keys()}"
)
else:
node_type = None
elif isinstance(input_nodes, abc.Mapping):
if len(input_nodes) != 1:
raise ValueError(
f"If input_nodes is provided as a mapping, it must contain exactly one key/value pair. Received: {input_nodes}. This may happen if you call Loader(node_ids=dataset.node_ids) with a heterogeneous dataset."
)
node_type, node_ids = next(iter(input_nodes.items()))
is_labeled_homoogeneous = node_type == DEFAULT_HOMOGENEOUS_NODE_TYPE
Comment on lines +179 to +185
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I'm ok with not adding this change - we can just add the util if there's pushback here.

elif isinstance(input_nodes, tuple):
node_type, node_ids = input_nodes
elif input_nodes is None:
if dataset_nodes is None:
raise ValueError("If input_nodes is None, the dataset must have node ids.")
if isinstance(dataset_nodes, torch.Tensor):
node_type = None
node_ids = dataset_nodes
elif isinstance(dataset_nodes, dict):
raise ValueError(
f"Input nodes must be provided for a heterogeneous graph. Received: {dataset_nodes}"
)

return (
node_type,
node_ids,
is_labeled_homoogeneous,
)
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _run_distributed_neighbor_loader_labeled_homogeneous(
assert isinstance(dataset.node_ids, abc.Mapping)
loader = DistNeighborLoader(
dataset=dataset,
input_nodes=to_homogeneous(dataset.node_ids),
input_nodes=dataset.node_ids,
num_neighbors=[2, 2],
context=context,
local_process_rank=0,
Expand Down Expand Up @@ -253,7 +253,7 @@ def _run_cora_supervised(
loader = DistABLPLoader(
dataset=dataset,
num_neighbors=[2, 2],
input_nodes=to_homogeneous(dataset.train_node_ids),
input_nodes=dataset.train_node_ids,
pin_memory_device=torch.device("cpu"),
)
count = 0
Expand Down
Loading