diff --git a/python/gigl/distributed/dist_ablp_neighborloader.py b/python/gigl/distributed/dist_ablp_neighborloader.py index 952ace8f..717110d3 100644 --- a/python/gigl/distributed/dist_ablp_neighborloader.py +++ b/python/gigl/distributed/dist_ablp_neighborloader.py @@ -294,6 +294,8 @@ def __init__( supervision_node_type=supervision_node_type, ) + sampler_input.share_memory() + sampling_config = SamplingConfig( sampling_type=SamplingType.NODE, num_neighbors=num_neighbors, @@ -442,47 +444,9 @@ def _set_labels( Returns: Union[Data, HeteroData]: torch_geometric HeteroData/Data object with the filtered edge fields and labels set as properties of the instance """ - local_node_to_global_node: torch.Tensor - # shape [N], where N is the number of nodes in the subgraph, and local_node_to_global_node[i] gives the global node id for local node id `i` - if isinstance(data, HeteroData): - supervision_node_type = ( - self._supervision_edge_type[0] - if self.edge_dir == "in" - else self._supervision_edge_type[2] - ) - local_node_to_global_node = data[supervision_node_type].node - else: - local_node_to_global_node = data.node - - output_positive_labels: dict[int, torch.Tensor] = {} - output_negative_labels: dict[int, torch.Tensor] = {} - - for local_anchor_node_id in range(positive_labels.size(0)): - positive_mask = ( - local_node_to_global_node.unsqueeze(1) - == positive_labels[local_anchor_node_id] - ) # shape [N, P], where N is the number of nodes and P is the number of positive labels for the current anchor node - - # Gets the indexes of the items in local_node_to_global_node which match any of the positive labels for the current anchor node - output_positive_labels[local_anchor_node_id] = torch.nonzero(positive_mask)[ - :, 0 - ].to(self.to_device) - # Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the positive labels for the current anchor node - - if negative_labels is not None: - negative_mask = ( - local_node_to_global_node.unsqueeze(1) - == negative_labels[local_anchor_node_id] - ) # shape [N, M], where N is the number of nodes and M is the number of negative labels for the current anchor node - - # Gets the indexes of the items in local_node_to_global_node which match any of the negative labels for the current anchor node - output_negative_labels[local_anchor_node_id] = torch.nonzero( - negative_mask - )[:, 0].to(self.to_device) - # Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the negative labels for the current anchor node - data.y_positive = output_positive_labels + data.y_positive = positive_labels if negative_labels is not None: - data.y_negative = output_negative_labels + data.y_negative = negative_labels data = remove_labeled_edge_types(data) diff --git a/python/gigl/distributed/dist_neighbor_sampler.py b/python/gigl/distributed/dist_neighbor_sampler.py index 677a2058..a1b15e5b 100644 --- a/python/gigl/distributed/dist_neighbor_sampler.py +++ b/python/gigl/distributed/dist_neighbor_sampler.py @@ -62,19 +62,19 @@ async def _sample_from_nodes( combined_seeds = (input_seeds, positive_seeds, negative_seeds) else: combined_seeds = (input_seeds, positive_seeds) - input_nodes = {input_type: torch.cat(combined_seeds, dim=0)} + input_nodes = {input_type: torch.unique(torch.cat(combined_seeds, dim=0))} # Otherwise, they need to be passed as two separate node types to the inducer.init_node() function. else: if negative_seeds is None: input_nodes = { input_type: input_seeds, - supervision_node_type: positive_seeds, + supervision_node_type: torch.unique(positive_seeds), } else: input_nodes = { input_type: input_seeds, - supervision_node_type: torch.cat( - (positive_seeds, negative_seeds), dim=0 + supervision_node_type: torch.unique( + torch.cat((positive_seeds, negative_seeds), dim=0) ), } output: NeighborOutput diff --git a/python/gigl/distributed/sampler.py b/python/gigl/distributed/sampler.py index 09b03a03..a0c44e8a 100644 --- a/python/gigl/distributed/sampler.py +++ b/python/gigl/distributed/sampler.py @@ -51,3 +51,17 @@ def __getitem__(self, index: Union[torch.Tensor, Any]) -> "ABLPNodeSamplerInput" else None, supervision_node_type=self.supervision_node_type, ) + + def share_memory(self): + self.node.share_memory_() + self.positive_labels.share_memory_() + if self.negative_labels is not None: + self.negative_labels.share_memory_() + return self + + def to(self, device: torch.device): + self.node = self.node.to(device) + self.positive_labels = self.positive_labels.to(device) + if self.negative_labels is not None: + self.negative_labels = self.negative_labels.to(device) + return self