From 8badd120d897b91a2e659a7680304f12c3ea932c Mon Sep 17 00:00:00 2001 From: nithinmanoj10 Date: Mon, 29 Jan 2024 22:59:14 +0530 Subject: [PATCH 1/5] =?UTF-8?q?=F0=9F=8E=89=20Started=20working=20on=20CSR?= =?UTF-8?q?Graph=20abstraction?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changed up the StaticGraph base class. This feature is still in progress --- benchmarking/gat/seastar/train.py | 150 ++++++++------ benchmarking/gcn/seastar/train.py | 90 +++++---- .../static-temporal-tgcn/seastar/train.py | 188 +++++++++++++----- stgraph/graph/static/StaticGraph.py | 94 +++------ stgraph/graph/static/StaticGraph_old.py | 82 ++++++++ stgraph/graph/static/csr/__init__.py | 0 stgraph/graph/static/{ => csr}/csr.cu | 0 stgraph/graph/static/{ => csr}/csr.so | Bin 8 files changed, 379 insertions(+), 225 deletions(-) create mode 100644 stgraph/graph/static/StaticGraph_old.py create mode 100644 stgraph/graph/static/csr/__init__.py rename stgraph/graph/static/{ => csr}/csr.cu (100%) rename stgraph/graph/static/{ => csr}/csr.so (100%) diff --git a/benchmarking/gat/seastar/train.py b/benchmarking/gat/seastar/train.py index aaa52fab..c003f788 100644 --- a/benchmarking/gat/seastar/train.py +++ b/benchmarking/gat/seastar/train.py @@ -13,7 +13,7 @@ import torch import torch.nn.functional as F import pynvml -from stgraph.graph.static.StaticGraph import StaticGraph +from stgraph.graph.static.StaticGraph_old import StaticGraph from stgraph.dataset.CoraDataLoader import CoraDataLoader from utils import EarlyStopping, accuracy import snoop @@ -25,8 +25,8 @@ def train(args): cora = CoraDataLoader(verbose=True) # To account for the initial CUDA Context object for pynvml - tmp = StaticGraph([(0,0)], [1], 1) - + tmp = StaticGraph([(0, 0)], [1], 1) + features = torch.FloatTensor(cora.get_all_features()) labels = torch.LongTensor(cora.get_all_targets()) train_mask = cora.get_train_mask() @@ -49,15 +49,15 @@ def train(args): assert train_mask.shape[0] == num_nodes - print('dataset {}'.format("Cora")) - print('# of edges : {}'.format(num_edges)) - print('# of nodes : {}'.format(num_nodes)) - print('# of features : {}'.format(num_feats)) + print("dataset {}".format("Cora")) + print("# of edges : {}".format(num_edges)) + print("# of nodes : {}".format(num_nodes)) + print("# of features : {}".format(num_feats)) features = torch.FloatTensor(features) labels = torch.LongTensor(labels) - if hasattr(torch, 'BoolTensor'): + if hasattr(torch, "BoolTensor"): train_mask = torch.BoolTensor(train_mask) else: @@ -74,17 +74,19 @@ def train(args): # create model heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] - model = GAT(g, - args.num_layers, - num_feats, - args.num_hidden, - n_classes, - heads, - F.elu, - args.in_drop, - args.attn_drop, - args.negative_slope, - args.residual) + model = GAT( + g, + args.num_layers, + num_feats, + args.num_hidden, + n_classes, + heads, + F.elu, + args.in_drop, + args.attn_drop, + args.negative_slope, + args.residual, + ) print(model) if args.early_stop: stopper = EarlyStopping(patience=100) @@ -94,7 +96,8 @@ def train(args): # use optimizer optimizer = torch.optim.Adam( - model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) # initialize graph dur = [] @@ -103,8 +106,8 @@ def train(args): Used_memory = 0 for epoch in range(args.num_epochs): - #print('epoch = ', epoch) - #print('mem0 = {}'.format(mem0)) + # print('epoch = ', epoch) + # print('mem0 = {}'.format(mem0)) torch.cuda.synchronize() tf = time.time() model.train() @@ -120,7 +123,7 @@ def train(args): torch.cuda.synchronize() loss.backward() optimizer.step() - t2 =time.time() + t2 = time.time() run_time_this_epoch = t2 - tf if epoch >= 3: @@ -131,56 +134,77 @@ def train(args): train_acc = accuracy(logits[train_mask], labels[train_mask]) - #log for each step - print('Epoch {:05d} | Time(s) {:.4f} | train_acc {:.6f} | Used_Memory {:.6f} mb'.format( - epoch, run_time_this_epoch, train_acc, (now_mem * 1.0 / (1024**2)) - )) + # log for each step + print( + "Epoch {:05d} | Time(s) {:.4f} | train_acc {:.6f} | Used_Memory {:.6f} mb".format( + epoch, run_time_this_epoch, train_acc, (now_mem * 1.0 / (1024**2)) + ) + ) if args.early_stop: - model.load_state_dict(torch.load('es_checkpoint.pt')) + model.load_state_dict(torch.load("es_checkpoint.pt")) - #OUTPUT we need - avg_run_time = avg_run_time *1. / record_time - Used_memory /= (1024**3) - print('^^^{:6f}^^^{:6f}'.format(Used_memory, avg_run_time)) + # OUTPUT we need + avg_run_time = avg_run_time * 1.0 / record_time + Used_memory /= 1024**3 + print("^^^{:6f}^^^{:6f}".format(Used_memory, avg_run_time)) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='GAT') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GAT") # COMMENT IF SNOOP IS TO BE ENABLED snoop.install(enabled=False) - parser.add_argument("--gpu", type=int, default=0, - help="which GPU to use. Set -1 to use CPU.") - parser.add_argument("--num_epochs", type=int, default=200, - help="number of training epochs") - parser.add_argument("--num_heads", type=int, default=8, - help="number of hidden attention heads") - parser.add_argument("--num_out_heads", type=int, default=1, - help="number of output attention heads") - parser.add_argument("--num_layers", type=int, default=1, - help="number of hidden layers") - parser.add_argument("--num_hidden", type=int, default=32, - help="number of hidden units") - parser.add_argument("--residual", action="store_true", default=False, - help="use residual connection") - parser.add_argument("--in_drop", type=float, default=.6, - help="input feature dropout") - parser.add_argument("--attn_drop", type=float, default=.6, - help="attention dropout") - parser.add_argument("--lr", type=float, default=0.005, - help="learning rate") - parser.add_argument('--weight_decay', type=float, default=5e-4, - help="weight decay") - parser.add_argument('--negative_slope', type=float, default=0.2, - help="the negative slope of leaky relu") - parser.add_argument('--early_stop', action='store_true', default=False, - help="indicates whether to use early stop or not") - parser.add_argument('--fastmode', action="store_true", default=False, - help="skip re-evaluate the validation set") + parser.add_argument( + "--gpu", type=int, default=0, help="which GPU to use. Set -1 to use CPU." + ) + parser.add_argument( + "--num_epochs", type=int, default=200, help="number of training epochs" + ) + parser.add_argument( + "--num_heads", type=int, default=8, help="number of hidden attention heads" + ) + parser.add_argument( + "--num_out_heads", type=int, default=1, help="number of output attention heads" + ) + parser.add_argument( + "--num_layers", type=int, default=1, help="number of hidden layers" + ) + parser.add_argument( + "--num_hidden", type=int, default=32, help="number of hidden units" + ) + parser.add_argument( + "--residual", action="store_true", default=False, help="use residual connection" + ) + parser.add_argument( + "--in_drop", type=float, default=0.6, help="input feature dropout" + ) + parser.add_argument( + "--attn_drop", type=float, default=0.6, help="attention dropout" + ) + parser.add_argument("--lr", type=float, default=0.005, help="learning rate") + parser.add_argument("--weight_decay", type=float, default=5e-4, help="weight decay") + parser.add_argument( + "--negative_slope", + type=float, + default=0.2, + help="the negative slope of leaky relu", + ) + parser.add_argument( + "--early_stop", + action="store_true", + default=False, + help="indicates whether to use early stop or not", + ) + parser.add_argument( + "--fastmode", + action="store_true", + default=False, + help="skip re-evaluate the validation set", + ) args = parser.parse_args() print(args) - + train(args) diff --git a/benchmarking/gcn/seastar/train.py b/benchmarking/gcn/seastar/train.py index b426d747..4deb1ce9 100644 --- a/benchmarking/gcn/seastar/train.py +++ b/benchmarking/gcn/seastar/train.py @@ -5,21 +5,21 @@ import pynvml import torch.nn as nn import torch.nn.functional as F -from stgraph.graph.static.StaticGraph import StaticGraph +from stgraph.graph.static.StaticGraph_old import StaticGraph from stgraph.dataset.CoraDataLoader import CoraDataLoader from utils import to_default_device, accuracy from model import GCN -def main(args): +def main(args): cora = CoraDataLoader(verbose=True) # To account for the initial CUDA Context object for pynvml - tmp = StaticGraph([(0,0)], [1], 1) - + tmp = StaticGraph([(0, 0)], [1], 1) + features = torch.FloatTensor(cora.get_all_features()) labels = torch.LongTensor(cora.get_all_targets()) - + train_mask = cora.get_train_mask() test_mask = cora.get_test_mask() @@ -47,7 +47,9 @@ def main(args): # A simple sanity check print("Measuerd Graph Size (pynvml): ", graph_mem, " B", flush=True) - print("Measuerd Graph Size (pynvml): ", (graph_mem)/(1024**2), " MB", flush=True) + print( + "Measuerd Graph Size (pynvml): ", (graph_mem) / (1024**2), " MB", flush=True + ) # normalization degs = torch.from_numpy(g.weighted_in_degrees()).type(torch.int32) @@ -58,23 +60,18 @@ def main(args): num_feats = features.shape[1] n_classes = int(max(labels) - min(labels) + 1) - print("Num Classes: ",n_classes) - - model = GCN(g, - num_feats, - args.num_hidden, - n_classes, - args.num_layers, - F.relu) - + print("Num Classes: ", n_classes) + + model = GCN(g, num_feats, args.num_hidden, n_classes, args.num_layers, F.relu) + if cuda: model.cuda() loss_fcn = torch.nn.CrossEntropyLoss() # use optimizer - optimizer = torch.optim.Adam(model.parameters(), - lr=args.lr, - weight_decay=args.weight_decay) + optimizer = torch.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) # initialize graph dur = [] @@ -106,40 +103,45 @@ def main(args): dur.append(run_time_this_epoch) train_acc = accuracy(logits[train_mask], labels[train_mask]) - print('Epoch {:05d} | Time(s) {:.4f} | train_acc {:.6f} | Used_Memory {:.6f} mb '.format( - epoch, run_time_this_epoch, train_acc, (now_mem * 1.0 / (1024**2)) - )) + print( + "Epoch {:05d} | Time(s) {:.4f} | train_acc {:.6f} | Used_Memory {:.6f} mb ".format( + epoch, run_time_this_epoch, train_acc, (now_mem * 1.0 / (1024**2)) + ) + ) - Used_memory /= (1024**3) - print('^^^{:6f}^^^{:6f}'.format(Used_memory, np.mean(dur))) + Used_memory /= 1024**3 + print("^^^{:6f}^^^{:6f}".format(Used_memory, np.mean(dur))) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='GCN') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GCN") # COMMENT IF SNOOP IS TO BE ENABLED snoop.install(enabled=False) - parser.add_argument("--dropout", type=float, default=0.5, - help="dropout probability") - parser.add_argument("--dataset", type=str, - help="Datset to train your model") - parser.add_argument("--gpu", type=int, default=0, - help="gpu") - parser.add_argument("--lr", type=float, default=1e-2, - help="learning rate") - parser.add_argument("--num_epochs", type=int, default=200, - help="number of training epochs") - parser.add_argument("--num_hidden", type=int, default=16, - help="number of hidden gcn units") - parser.add_argument("--num_layers", type=int, default=1, - help="number of hidden gcn layers") - parser.add_argument("--weight-decay", type=float, default=5e-4, - help="Weight for L2 loss") - parser.add_argument("--self-loop", action='store_true', - help="graph self-loop (default=False)") + parser.add_argument( + "--dropout", type=float, default=0.5, help="dropout probability" + ) + parser.add_argument("--dataset", type=str, help="Datset to train your model") + parser.add_argument("--gpu", type=int, default=0, help="gpu") + parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") + parser.add_argument( + "--num_epochs", type=int, default=200, help="number of training epochs" + ) + parser.add_argument( + "--num_hidden", type=int, default=16, help="number of hidden gcn units" + ) + parser.add_argument( + "--num_layers", type=int, default=1, help="number of hidden gcn layers" + ) + parser.add_argument( + "--weight-decay", type=float, default=5e-4, help="Weight for L2 loss" + ) + parser.add_argument( + "--self-loop", action="store_true", help="graph self-loop (default=False)" + ) parser.set_defaults(self_loop=False) args = parser.parse_args() print(args) - main(args) \ No newline at end of file + main(args) diff --git a/benchmarking/static-temporal-tgcn/seastar/train.py b/benchmarking/static-temporal-tgcn/seastar/train.py index 1328b742..6d71dd92 100644 --- a/benchmarking/static-temporal-tgcn/seastar/train.py +++ b/benchmarking/static-temporal-tgcn/seastar/train.py @@ -9,7 +9,7 @@ import os from model import STGraphTGCN -from stgraph.graph.static.StaticGraph import StaticGraph +from stgraph.graph.static.StaticGraph_old import StaticGraph from stgraph.dataset.WindmillOutputDataLoader import WindmillOutputDataLoader from stgraph.dataset.WikiMathDataLoader import WikiMathDataLoader @@ -23,8 +23,8 @@ from rich import inspect -def main(args): +def main(args): if torch.cuda.is_available(): print("🎉 CUDA is available") else: @@ -32,20 +32,63 @@ def main(args): quit() # Dummy object to account for CUDA context object - Graph = StaticGraph([(0,0)], [1], 1) - + Graph = StaticGraph([(0, 0)], [1], 1) + if args.dataset == "wiki": - dataloader = WikiMathDataLoader('static-temporal', 'wikivital_mathematics', args.feat_size, args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = WikiMathDataLoader( + "static-temporal", + "wikivital_mathematics", + args.feat_size, + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "windmill": - dataloader = WindmillOutputDataLoader('static-temporal', 'windmill_output', args.feat_size, args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = WindmillOutputDataLoader( + "static-temporal", + "windmill_output", + args.feat_size, + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "hungarycp": - dataloader = HungaryCPDataLoader('static-temporal', 'HungaryCP', args.feat_size, args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = HungaryCPDataLoader( + "static-temporal", + "HungaryCP", + args.feat_size, + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "pedalme": - dataloader = PedalMeDataLoader('static-temporal', 'pedalme', args.feat_size, args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = PedalMeDataLoader( + "static-temporal", + "pedalme", + args.feat_size, + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "metrla": - dataloader = METRLADataLoader('static-temporal', 'METRLA', args.feat_size, args.feat_size, args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = METRLADataLoader( + "static-temporal", + "METRLA", + args.feat_size, + args.feat_size, + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "monte": - dataloader = MontevideoBusDataLoader('static-temporal', 'montevideobus', args.feat_size, args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = MontevideoBusDataLoader( + "static-temporal", + "montevideobus", + args.feat_size, + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) else: print("😔 Unrecognized dataset") quit() @@ -53,19 +96,23 @@ def main(args): edge_list = dataloader.get_edges() edge_weight_list = dataloader.get_edge_weights() targets = dataloader.get_all_targets() - + pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(0) initial_used_gpu_mem = pynvml.nvmlDeviceGetMemoryInfo(handle).used G = StaticGraph(edge_list, edge_weight_list, dataloader.num_nodes) graph_mem = pynvml.nvmlDeviceGetMemoryInfo(handle).used - initial_used_gpu_mem - edge_weight = to_default_device(torch.unsqueeze(torch.FloatTensor(edge_weight_list), 1)) + edge_weight = to_default_device( + torch.unsqueeze(torch.FloatTensor(edge_weight_list), 1) + ) targets = to_default_device(torch.FloatTensor(np.array(targets))) num_hidden_units = args.num_hidden num_outputs = 1 - model = to_default_device(STGraphTGCN(args.feat_size, num_hidden_units, num_outputs)) + model = to_default_device( + STGraphTGCN(args.feat_size, num_hidden_units, num_outputs) + ) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Logging Output @@ -78,23 +125,26 @@ def main(args): backprop_every = args.backprop_every if backprop_every == 0: backprop_every = total_timestamps - + if total_timestamps % backprop_every == 0: - num_iter = int(total_timestamps/backprop_every) + num_iter = int(total_timestamps / backprop_every) else: - num_iter = int(total_timestamps/backprop_every) + 1 + num_iter = int(total_timestamps / backprop_every) + 1 # metrics dur = [] max_gpu = [] - table = BenchmarkTable(f"(STGraph Static-Temporal) TGCN on {dataloader.name} dataset", ["Epoch", "Time(s)", "MSE", "Used GPU Memory (Max MB)"]) - + table = BenchmarkTable( + f"(STGraph Static-Temporal) TGCN on {dataloader.name} dataset", + ["Epoch", "Time(s)", "MSE", "Used GPU Memory (Max MB)"], + ) + # normalization degs = torch.from_numpy(G.in_degrees()).type(torch.int32) norm = torch.pow(degs, -0.5) norm[torch.isinf(norm)] = 0 norm = to_default_device(norm) - G.set_ndata('norm', norm.unsqueeze(1)) + G.set_ndata("norm", norm.unsqueeze(1)) # train print("Training...\n") @@ -103,7 +153,7 @@ def main(args): torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats(0) model.train() - + t0 = time.time() gpu_mem_arr = [] cost_arr = [] @@ -112,20 +162,24 @@ def main(args): optimizer.zero_grad() cost = 0 hidden_state = None - y_hat = torch.randn((dataloader.num_nodes, args.feat_size), device=get_default_device()) + y_hat = torch.randn( + (dataloader.num_nodes, args.feat_size), device=get_default_device() + ) for k in range(backprop_every): t = index * backprop_every + k if t >= total_timestamps: break - y_out, y_hat, hidden_state = model(G, y_hat, edge_weight, hidden_state) - cost = cost + torch.mean((y_out-targets[t])**2) - + y_out, y_hat, hidden_state = model( + G, y_hat, edge_weight, hidden_state + ) + cost = cost + torch.mean((y_out - targets[t]) ** 2) + if cost == 0: break - - cost = cost / (backprop_every+1) + + cost = cost / (backprop_every + 1) cost.backward() optimizer.step() torch.cuda.synchronize() @@ -140,56 +194,82 @@ def main(args): dur.append(run_time_this_epoch) max_gpu.append(max(gpu_mem_arr)) - table.add_row([epoch, "{:.5f}".format(run_time_this_epoch), "{:.4f}".format(sum(cost_arr)/len(cost_arr)), "{:.4f}".format((max(gpu_mem_arr) * 1.0 / (1024**2)))]) + table.add_row( + [ + epoch, + "{:.5f}".format(run_time_this_epoch), + "{:.4f}".format(sum(cost_arr) / len(cost_arr)), + "{:.4f}".format((max(gpu_mem_arr) * 1.0 / (1024**2))), + ] + ) table.display() - print('Average Time taken: {:6f}'.format(np.mean(dur))) + print("Average Time taken: {:6f}".format(np.mean(dur))) return np.mean(dur), (max(max_gpu) * 1.0 / (1024**2)) - + except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): table.add_row(["OOM", "OOM", "OOM", "OOM"]) table.display() else: print("😔 Something went wrong") return "OOM", "OOM" + def write_results(args, time_taken, max_gpu): cutoff = "whole" if args.cutoff_time < sys.maxsize: cutoff = str(args.cutoff_time) file_name = f"stgraph_{args.dataset}_T{cutoff}_B{args.backprop_every}_H{args.num_hidden}_F{args.feat_size}" - df_data = pd.DataFrame([{'Filename': file_name, 'Time Taken (s)': time_taken, 'Max GPU Usage (MB)': max_gpu}]) - - if os.path.exists('../../results/static-temporal.csv'): - df = pd.read_csv('../../results/static-temporal.csv') + df_data = pd.DataFrame( + [ + { + "Filename": file_name, + "Time Taken (s)": time_taken, + "Max GPU Usage (MB)": max_gpu, + } + ] + ) + + if os.path.exists("../../results/static-temporal.csv"): + df = pd.read_csv("../../results/static-temporal.csv") df = pd.concat([df, df_data]) else: df = df_data - - df.to_csv('../../results/static-temporal.csv', sep=',', index=False, encoding='utf-8') + + df.to_csv( + "../../results/static-temporal.csv", sep=",", index=False, encoding="utf-8" + ) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='STGraph Static TGCN') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="STGraph Static TGCN") snoop.install(enabled=False) - parser.add_argument("--dataset", type=str, default="wiki", - help="Name of the Dataset (wiki, windmill, hungary_cp, pedalme, metrla, monte)") - parser.add_argument("--backprop-every", type=int, default=0, - help="Feature size of nodes") - parser.add_argument("--feat-size", type=int, default=8, - help="Feature size of nodes") - parser.add_argument("--num-hidden", type=int, default=100, - help="Number of hidden units") - parser.add_argument("--lr", type=float, default=1e-2, - help="learning rate") - parser.add_argument("--cutoff-time", type=int, default=sys.maxsize, - help="learning rate") - parser.add_argument("--num-epochs", type=int, default=1, - help="number of training epochs") + parser.add_argument( + "--dataset", + type=str, + default="wiki", + help="Name of the Dataset (wiki, windmill, hungary_cp, pedalme, metrla, monte)", + ) + parser.add_argument( + "--backprop-every", type=int, default=0, help="Feature size of nodes" + ) + parser.add_argument( + "--feat-size", type=int, default=8, help="Feature size of nodes" + ) + parser.add_argument( + "--num-hidden", type=int, default=100, help="Number of hidden units" + ) + parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") + parser.add_argument( + "--cutoff-time", type=int, default=sys.maxsize, help="learning rate" + ) + parser.add_argument( + "--num-epochs", type=int, default=1, help="number of training epochs" + ) args = parser.parse_args() - + print(args) time_taken, max_gpu = main(args) - write_results(args, time_taken, max_gpu) \ No newline at end of file + write_results(args, time_taken, max_gpu) diff --git a/stgraph/graph/static/StaticGraph.py b/stgraph/graph/static/StaticGraph.py index 74ed7348..06a69af7 100644 --- a/stgraph/graph/static/StaticGraph.py +++ b/stgraph/graph/static/StaticGraph.py @@ -1,82 +1,48 @@ -from abc import ABC, abstractmethod -import copy +from __future__ import annotations -import numpy as np - -from rich.console import Console - -console = Console() +import time +from abc import abstractmethod +from typing import Any from stgraph.graph.STGraphBase import STGraphBase -from stgraph.graph.static.csr import CSR - class StaticGraph(STGraphBase): - def __init__(self, edge_list, edge_weights, num_nodes): + def __init__( + self: StaticGraph, edge_list: list, edge_weights: list, num_nodes: int + ) -> None: super().__init__() self._num_nodes = num_nodes self._num_edges = len(set(edge_list)) - - # console.log("Building forward edge list") - self._prepare_edge_lst_fwd(edge_list) - # console.log("Creating forward graph") - self._forward_graph = CSR(self.fwd_edge_list, edge_weights, self._num_nodes, is_edge_reverse=True) - - # console.log("Building backward edge list") - self._prepare_edge_lst_bwd(self.fwd_edge_list) - # console.log("Creating backward graph") - self._backward_graph = CSR(self.bwd_edge_list, edge_weights, self._num_nodes) - - # console.log("Getting CSR ptrs") - self._get_graph_csr_ptrs() - - def _prepare_edge_lst_fwd(self, edge_list): - edge_list_for_t = edge_list - edge_list_for_t.sort(key = lambda x: (x[1],x[0])) - edge_list_for_t = [(edge_list_for_t[j][0],edge_list_for_t[j][1],j) for j in range(len(edge_list_for_t))] - self.fwd_edge_list = edge_list_for_t - - def _prepare_edge_lst_bwd(self, edge_list): - edge_list_for_t = copy.deepcopy(edge_list) - edge_list_for_t.sort() - self.bwd_edge_list = edge_list_for_t - - def _get_graph_csr_ptrs(self): - self.fwd_row_offset_ptr = self._forward_graph.row_offset_ptr - self.fwd_column_indices_ptr = self._forward_graph.column_indices_ptr - self.fwd_eids_ptr = self._forward_graph.eids_ptr - self.fwd_node_ids_ptr = self._forward_graph.node_ids_ptr + self._ndata = {} - self.bwd_row_offset_ptr = self._backward_graph.row_offset_ptr - self.bwd_column_indices_ptr = self._backward_graph.column_indices_ptr - self.bwd_eids_ptr = self._backward_graph.eids_ptr - self.bwd_node_ids_ptr = self._backward_graph.node_ids_ptr - - def get_num_nodes(self): + def get_num_nodes(self: StaticGraph) -> int: return self._num_nodes - - def get_num_edges(self): + + def get_num_edges(self: StaticGraph) -> int: return self._num_edges - - def get_ndata(self, field): + + def get_ndata(self: StaticGraph, field: str) -> Any: if field in self._ndata: return self._ndata[field] else: return None - def set_ndata(self, field, val): + def set_ndata(self: StaticGraph, field: str, val: Any) -> None: self._ndata[field] = val - - def graph_type(self): - # return "csr" - return "csr_unsorted" - - def in_degrees(self): - return np.array(self._forward_graph.out_degrees, dtype='int32') - - def out_degrees(self): - return np.array(self._forward_graph.in_degrees, dtype='int32') - - def weighted_in_degrees(self): - return np.array(self._forward_graph.weighted_out_degrees, dtype='int32') \ No newline at end of file + + @abstractmethod + def graph_type(self: StaticGraph) -> str: + pass + + @abstractmethod + def in_degrees(self: StaticGraph) -> Any: + pass + + @abstractmethod + def out_degrees(self: StaticGraph) -> Any: + pass + + @abstractmethod + def weighted_in_degrees(self: StaticGraph) -> Any: + pass diff --git a/stgraph/graph/static/StaticGraph_old.py b/stgraph/graph/static/StaticGraph_old.py new file mode 100644 index 00000000..74ed7348 --- /dev/null +++ b/stgraph/graph/static/StaticGraph_old.py @@ -0,0 +1,82 @@ +from abc import ABC, abstractmethod +import copy + +import numpy as np + +from rich.console import Console + +console = Console() + +from stgraph.graph.STGraphBase import STGraphBase + + +from stgraph.graph.static.csr import CSR + +class StaticGraph(STGraphBase): + def __init__(self, edge_list, edge_weights, num_nodes): + super().__init__() + self._num_nodes = num_nodes + self._num_edges = len(set(edge_list)) + + # console.log("Building forward edge list") + self._prepare_edge_lst_fwd(edge_list) + # console.log("Creating forward graph") + self._forward_graph = CSR(self.fwd_edge_list, edge_weights, self._num_nodes, is_edge_reverse=True) + + # console.log("Building backward edge list") + self._prepare_edge_lst_bwd(self.fwd_edge_list) + # console.log("Creating backward graph") + self._backward_graph = CSR(self.bwd_edge_list, edge_weights, self._num_nodes) + + # console.log("Getting CSR ptrs") + self._get_graph_csr_ptrs() + + def _prepare_edge_lst_fwd(self, edge_list): + edge_list_for_t = edge_list + edge_list_for_t.sort(key = lambda x: (x[1],x[0])) + edge_list_for_t = [(edge_list_for_t[j][0],edge_list_for_t[j][1],j) for j in range(len(edge_list_for_t))] + self.fwd_edge_list = edge_list_for_t + + def _prepare_edge_lst_bwd(self, edge_list): + edge_list_for_t = copy.deepcopy(edge_list) + edge_list_for_t.sort() + self.bwd_edge_list = edge_list_for_t + + def _get_graph_csr_ptrs(self): + self.fwd_row_offset_ptr = self._forward_graph.row_offset_ptr + self.fwd_column_indices_ptr = self._forward_graph.column_indices_ptr + self.fwd_eids_ptr = self._forward_graph.eids_ptr + self.fwd_node_ids_ptr = self._forward_graph.node_ids_ptr + + self.bwd_row_offset_ptr = self._backward_graph.row_offset_ptr + self.bwd_column_indices_ptr = self._backward_graph.column_indices_ptr + self.bwd_eids_ptr = self._backward_graph.eids_ptr + self.bwd_node_ids_ptr = self._backward_graph.node_ids_ptr + + def get_num_nodes(self): + return self._num_nodes + + def get_num_edges(self): + return self._num_edges + + def get_ndata(self, field): + if field in self._ndata: + return self._ndata[field] + else: + return None + + def set_ndata(self, field, val): + self._ndata[field] = val + + def graph_type(self): + # return "csr" + return "csr_unsorted" + + def in_degrees(self): + return np.array(self._forward_graph.out_degrees, dtype='int32') + + def out_degrees(self): + return np.array(self._forward_graph.in_degrees, dtype='int32') + + def weighted_in_degrees(self): + return np.array(self._forward_graph.weighted_out_degrees, dtype='int32') \ No newline at end of file diff --git a/stgraph/graph/static/csr/__init__.py b/stgraph/graph/static/csr/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/stgraph/graph/static/csr.cu b/stgraph/graph/static/csr/csr.cu similarity index 100% rename from stgraph/graph/static/csr.cu rename to stgraph/graph/static/csr/csr.cu diff --git a/stgraph/graph/static/csr.so b/stgraph/graph/static/csr/csr.so similarity index 100% rename from stgraph/graph/static/csr.so rename to stgraph/graph/static/csr/csr.so From a6fb1a1d698300243adb04c4492833641638dcbb Mon Sep 17 00:00:00 2001 From: nithinmanoj10 Date: Tue, 30 Jan 2024 21:49:41 +0530 Subject: [PATCH 2/5] =?UTF-8?q?=E2=9E=95=20Created=20CSRGraph.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added the inherited child class CSRGraph --- stgraph/graph/static/StaticGraph.py | 20 ++++---- stgraph/graph/static/csr/CSRGraph.py | 74 ++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 11 deletions(-) create mode 100644 stgraph/graph/static/csr/CSRGraph.py diff --git a/stgraph/graph/static/StaticGraph.py b/stgraph/graph/static/StaticGraph.py index 06a69af7..aa210a24 100644 --- a/stgraph/graph/static/StaticGraph.py +++ b/stgraph/graph/static/StaticGraph.py @@ -8,28 +8,26 @@ class StaticGraph(STGraphBase): - def __init__( - self: StaticGraph, edge_list: list, edge_weights: list, num_nodes: int - ) -> None: + def __init__(self: StaticGraph) -> None: super().__init__() - self._num_nodes = num_nodes - self._num_edges = len(set(edge_list)) - self._ndata = {} + self.num_nodes = 0 + self.num_edges = 0 + self.ndata = {} def get_num_nodes(self: StaticGraph) -> int: - return self._num_nodes + return self.num_nodes def get_num_edges(self: StaticGraph) -> int: - return self._num_edges + return self.num_edges def get_ndata(self: StaticGraph, field: str) -> Any: - if field in self._ndata: - return self._ndata[field] + if field in self.ndata: + return self.ndata[field] else: return None def set_ndata(self: StaticGraph, field: str, val: Any) -> None: - self._ndata[field] = val + self.ndata[field] = val @abstractmethod def graph_type(self: StaticGraph) -> str: diff --git a/stgraph/graph/static/csr/CSRGraph.py b/stgraph/graph/static/csr/CSRGraph.py new file mode 100644 index 00000000..e99bad70 --- /dev/null +++ b/stgraph/graph/static/csr/CSRGraph.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import copy +import numpy as np + +from stgraph.graph.static.StaticGraph import StaticGraph +from stgraph.graph.static.csr.csr import CSR + + +class CSRGraph(StaticGraph): + def __init__( + self: CSRGraph, edge_list: list, edge_weights: list, num_nodes: int + ) -> None: + super().__init__() + self.num_nodes = num_nodes + self.num_edges = len(set(edge_list)) + + self._fwd_edge_list = None + self._bwd_edge_list = None + self._fwd_row_offset_ptr = None + self._fwd_column_indices_ptr = None + self._fwd_eids_ptr = None + self._fwd_node_ids_ptr = None + self._bwd_row_offset_ptr = None + self._bwd_column_indices_ptr = None + self._bwd_eids_ptr = None + self._bwd_node_ids_ptr = None + + self._prepare_edge_lst_fwd(edge_list) + self._forward_graph = CSR( + self.fwd_edge_list, edge_weights, self._num_nodes, is_edge_reverse=True + ) + + self._prepare_edge_lst_bwd(self.fwd_edge_list) + self._backward_graph = CSR(self.bwd_edge_list, edge_weights, self._num_nodes) + + self._set_graph_csr_ptrs() + + def _prepare_edge_lst_fwd(self: CSRGraph, edge_list: list) -> None: + edge_list_for_t = edge_list + edge_list_for_t.sort(key=lambda x: (x[1], x[0])) + edge_list_for_t = [ + (edge_list_for_t[j][0], edge_list_for_t[j][1], j) + for j in range(len(edge_list_for_t)) + ] + self._fwd_edge_list = edge_list_for_t + + def _prepare_edge_lst_bwd(self: CSRGraph, edge_list: list) -> None: + edge_list_for_t = copy.deepcopy(edge_list) + edge_list_for_t.sort() + self._bwd_edge_list = edge_list_for_t + + def _set_graph_csr_ptrs(self: CSRGraph) -> None: + self._fwd_row_offset_ptr = self._forward_graph.row_offset_ptr + self._fwd_column_indices_ptr = self._forward_graph.column_indices_ptr + self._fwd_eids_ptr = self._forward_graph.eids_ptr + self._fwd_node_ids_ptr = self._forward_graph.node_ids_ptr + + self._bwd_row_offset_ptr = self._backward_graph.row_offset_ptr + self._bwd_column_indices_ptr = self._backward_graph.column_indices_ptr + self._bwd_eids_ptr = self._backward_graph.eids_ptr + self._bwd_node_ids_ptr = self._backward_graph.node_ids_ptr + + def graph_type(self: CSRGraph) -> str: + return "csr_unsorted" + + def in_degrees(self: CSRGraph) -> np.ndarray: + return np.array(self._forward_graph.out_degrees, dtype="int32") + + def out_degrees(self: CSRGraph) -> np.ndarray: + return np.array(self._forward_graph.in_degrees, dtype="int32") + + def weighted_in_degrees(self: CSRGraph) -> np.ndarray: + return np.array(self._forward_graph.weighted_out_degrees, dtype="int32") From 1407e6dd989ca6b99e6a5f3179f5a451d1fb8e4a Mon Sep 17 00:00:00 2001 From: nithinmanoj10 Date: Wed, 31 Jan 2024 22:46:39 +0530 Subject: [PATCH 3/5] =?UTF-8?q?=E2=9E=95=20Added=20all=20the=20graph=20typ?= =?UTF-8?q?es=20to=20the=20=5F=5Finit=5F=5F=20file?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added all the graph classes in the __init__ file by importing them there itself. And made some linting changes for the new version of ruff. --- stgraph/dataset/dynamic/england_covid_dataloader.py | 1 + stgraph/dataset/static/cora_dataloader.py | 1 + stgraph/dataset/stgraph_dataset.py | 3 +++ stgraph/dataset/temporal/hungarycp_dataloader.py | 1 + stgraph/dataset/temporal/metrla_dataloader.py | 1 + stgraph/dataset/temporal/montevideobus_dataloader.py | 1 + stgraph/dataset/temporal/pedalme_dataloader.py | 1 + stgraph/dataset/temporal/wikimath_dataloader.py | 1 + stgraph/dataset/temporal/windmilloutput_dataloader.py | 1 + stgraph/graph/__init__.py | 10 +++++++++- 10 files changed, 20 insertions(+), 1 deletion(-) diff --git a/stgraph/dataset/dynamic/england_covid_dataloader.py b/stgraph/dataset/dynamic/england_covid_dataloader.py index 13636fa3..dfe2b84d 100644 --- a/stgraph/dataset/dynamic/england_covid_dataloader.py +++ b/stgraph/dataset/dynamic/england_covid_dataloader.py @@ -53,6 +53,7 @@ class EnglandCovidDataLoader(STGraphDynamicDataset): The name of the dataset. gdata : dict Graph meta data. + """ def __init__( diff --git a/stgraph/dataset/static/cora_dataloader.py b/stgraph/dataset/static/cora_dataloader.py index 5e484f71..7f85b48c 100644 --- a/stgraph/dataset/static/cora_dataloader.py +++ b/stgraph/dataset/static/cora_dataloader.py @@ -61,6 +61,7 @@ class CoraDataLoader(STGraphStaticDataset): The name of the dataset. gdata : dict Graph meta data. + """ def __init__( diff --git a/stgraph/dataset/stgraph_dataset.py b/stgraph/dataset/stgraph_dataset.py index 8e958e46..bef40db3 100644 --- a/stgraph/dataset/stgraph_dataset.py +++ b/stgraph/dataset/stgraph_dataset.py @@ -67,6 +67,7 @@ def __init__(self: STGraphDataset) -> None: _load_dataset() Loads the dataset from cache + """ self.name = "" self.gdata = {} @@ -106,6 +107,7 @@ def _has_dataset_cache(self: STGraphDataset) -> bool: # The dataset is cached, continue cached operations else: # The dataset is not cached, continue load and save operations + """ user_home_dir = os.path.expanduser("~") stgraph_dir = user_home_dir + "/.stgraph" @@ -128,6 +130,7 @@ def _get_cache_file_path(self: STGraphDataset) -> str: ------- str The absolute path of the cached dataset file + """ user_home_dir = os.path.expanduser("~") stgraph_dir = user_home_dir + "/.stgraph" diff --git a/stgraph/dataset/temporal/hungarycp_dataloader.py b/stgraph/dataset/temporal/hungarycp_dataloader.py index e94e356f..8bbf5ada 100644 --- a/stgraph/dataset/temporal/hungarycp_dataloader.py +++ b/stgraph/dataset/temporal/hungarycp_dataloader.py @@ -58,6 +58,7 @@ class HungaryCPDataLoader(STGraphTemporalDataset): The name of the dataset. gdata : dict Graph meta data. + """ def __init__( diff --git a/stgraph/dataset/temporal/metrla_dataloader.py b/stgraph/dataset/temporal/metrla_dataloader.py index a454784d..fef0bc14 100644 --- a/stgraph/dataset/temporal/metrla_dataloader.py +++ b/stgraph/dataset/temporal/metrla_dataloader.py @@ -67,6 +67,7 @@ class METRLADataLoader(STGraphTemporalDataset): The name of the dataset. gdata : dict Graph meta data. + """ def __init__( diff --git a/stgraph/dataset/temporal/montevideobus_dataloader.py b/stgraph/dataset/temporal/montevideobus_dataloader.py index 5ad608d8..e51a0a8c 100644 --- a/stgraph/dataset/temporal/montevideobus_dataloader.py +++ b/stgraph/dataset/temporal/montevideobus_dataloader.py @@ -66,6 +66,7 @@ class MontevideoBusDataLoader(STGraphTemporalDataset): The name of the dataset. gdata : dict Graph meta data. + """ def __init__( diff --git a/stgraph/dataset/temporal/pedalme_dataloader.py b/stgraph/dataset/temporal/pedalme_dataloader.py index dbb3edf1..d3371f43 100644 --- a/stgraph/dataset/temporal/pedalme_dataloader.py +++ b/stgraph/dataset/temporal/pedalme_dataloader.py @@ -58,6 +58,7 @@ class PedalMeDataLoader(STGraphTemporalDataset): The name of the dataset. gdata : dict Graph meta data. + """ def __init__( diff --git a/stgraph/dataset/temporal/wikimath_dataloader.py b/stgraph/dataset/temporal/wikimath_dataloader.py index 35eabf0f..3776b454 100644 --- a/stgraph/dataset/temporal/wikimath_dataloader.py +++ b/stgraph/dataset/temporal/wikimath_dataloader.py @@ -65,6 +65,7 @@ class WikiMathDataLoader(STGraphTemporalDataset): The name of the dataset. gdata : dict Graph meta data. + """ def __init__( diff --git a/stgraph/dataset/temporal/windmilloutput_dataloader.py b/stgraph/dataset/temporal/windmilloutput_dataloader.py index a6382f4b..5221f306 100644 --- a/stgraph/dataset/temporal/windmilloutput_dataloader.py +++ b/stgraph/dataset/temporal/windmilloutput_dataloader.py @@ -82,6 +82,7 @@ class WindmillOutputDataLoader(STGraphTemporalDataset): The name of the dataset. gdata : dict Graph meta data. + """ def __init__( diff --git a/stgraph/graph/__init__.py b/stgraph/graph/__init__.py index a5aa46c6..85d5ab3a 100644 --- a/stgraph/graph/__init__.py +++ b/stgraph/graph/__init__.py @@ -1 +1,9 @@ -'''Graph Abstraction provided by STGraph''' \ No newline at end of file +"""Graph Abstraction provided by STGraph""" + +from stgraph.graph.STGraphBase import STGraphBase +from stgraph.graph.static.StaticGraph import StaticGraph +from stgraph.graph.static.csr.CSRGraph import CSRGraph +from stgraph.graph.dynamic.DynamicGraph import DynamicGraph +from stgraph.graph.dynamic.gpma.GPMAGraph import GPMAGraph +from stgraph.graph.dynamic.pcsr.PCSRGraph import PCSRGraph +from stgraph.graph.dynamic.naive.NaiveGraph import NaiveGraph From d98ec4b64a9484a2a9a5c0a66be54174861b2048 Mon Sep 17 00:00:00 2001 From: nithinmanoj10 Date: Wed, 31 Jan 2024 23:04:12 +0530 Subject: [PATCH 4/5] =?UTF-8?q?=F0=9F=93=9D=20Moved=20CSR=20Readme?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit And modified the script to build the CSR .so file --- stgraph/graph/build_static.sh | 4 ++-- stgraph/graph/static/{ => csr}/README.md | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename stgraph/graph/static/{ => csr}/README.md (100%) diff --git a/stgraph/graph/build_static.sh b/stgraph/graph/build_static.sh index 4976a9f4..19559f14 100644 --- a/stgraph/graph/build_static.sh +++ b/stgraph/graph/build_static.sh @@ -1,6 +1,6 @@ echo "🔨 Building csr" -cd static/ +cd static/csr /usr/local/cuda-11.7/bin/nvcc $(python3 -m pybind11 --includes) -shared -rdc=true --compiler-options '-fPIC' -D__CDPRT_SUPPRESS_SYNC_DEPRECATION_WARNING -o csr.so csr.cu echo "✅ csr build completed" -cd .. +cd ../.. echo "" \ No newline at end of file diff --git a/stgraph/graph/static/README.md b/stgraph/graph/static/csr/README.md similarity index 100% rename from stgraph/graph/static/README.md rename to stgraph/graph/static/csr/README.md From 047d120bb8a82065b6521cad4c78207375be9de9 Mon Sep 17 00:00:00 2001 From: nithinmanoj10 Date: Thu, 1 Feb 2024 22:00:53 +0530 Subject: [PATCH 5/5] =?UTF-8?q?=F0=9F=9A=A7=20GCN=20example=20work=20in=20?= =?UTF-8?q?progress?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Working on the example python script for running a GCN using the new dataloader abstraction and CSRGraph that was added --- examples/gnn/gcn/model.py | 27 +++++++ examples/gnn/gcn/train.py | 35 ++++++++ examples/gnn/gcn/train_old.py | 146 ++++++++++++++++++++++++++++++++++ examples/gnn/gcn/utils.py | 32 ++++++++ 4 files changed, 240 insertions(+) create mode 100644 examples/gnn/gcn/model.py create mode 100644 examples/gnn/gcn/train.py create mode 100644 examples/gnn/gcn/train_old.py create mode 100644 examples/gnn/gcn/utils.py diff --git a/examples/gnn/gcn/model.py b/examples/gnn/gcn/model.py new file mode 100644 index 00000000..f9554d9d --- /dev/null +++ b/examples/gnn/gcn/model.py @@ -0,0 +1,27 @@ +import torch.nn as nn +from stgraph.nn.pytorch.graph_conv import GraphConv + +class GCN(nn.Module): + def __init__(self, + g, + in_feats, + n_hidden, + n_classes, + n_layers, + activation): + super(GCN, self).__init__() + self.g = g + self.layers = nn.ModuleList() + # input layer + self.layers.append(GraphConv(in_feats, n_hidden, activation)) + # hidden layers + for i in range(n_layers - 1): + self.layers.append(GraphConv(n_hidden, n_hidden, activation)) + # output layer + self.layers.append(GraphConv(n_hidden, n_classes, None)) + + def forward(self, g, features): + h = features + for layer in self.layers: + h = layer(g, h) + return h \ No newline at end of file diff --git a/examples/gnn/gcn/train.py b/examples/gnn/gcn/train.py new file mode 100644 index 00000000..0784ada3 --- /dev/null +++ b/examples/gnn/gcn/train.py @@ -0,0 +1,35 @@ +import argparse + +from stgraph.dataset import CoraDataLoader +from torch import FloatTensor, LongTensor, BoolTensor +from torch.cuda import set_device + +from utils import generate_test_mask, generate_train_mask + + +def main(args): + cora = CoraDataLoader(verbose=True) + + features = FloatTensor(cora.get_all_features()) + labels = LongTensor(cora.get_all_targets()) + + train_mask = BoolTensor(generate_train_mask(len(features), 0.6)) + test_mask = BoolTensor(generate_test_mask(len(features), 0.6)) + + if args.gpu < 0: + cuda = False + else: + cuda = True + set_device(args.gpu) + features = features.cuda() + labels = labels.cuda() + train_mask = train_mask.cuda() + test_mask = test_mask.cuda() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GCN") + + parser.add_argument("--gpu", type=int, default=0, help="Current GPU device number") + +main("hi") diff --git a/examples/gnn/gcn/train_old.py b/examples/gnn/gcn/train_old.py new file mode 100644 index 00000000..70ef96fa --- /dev/null +++ b/examples/gnn/gcn/train_old.py @@ -0,0 +1,146 @@ +import argparse, time +import numpy as np +import torch +import snoop +import pynvml +import torch.nn.functional as F +from stgraph.graph import CSRGraph +from stgraph.dataset import CoraDataLoader +from utils import to_default_device, accuracy +from model import GCN + + +def main(args): + cora = CoraDataLoader(verbose=True) + + # To account for the initial CUDA Context object for pynvml + tmp = CSRGraph([(0, 0)], [1], 1) + + features = torch.FloatTensor(cora.get_all_features()) + labels = torch.LongTensor(cora.get_all_targets()) + + train_mask = cora.get_train_mask() + test_mask = cora.get_test_mask() + + train_mask = torch.BoolTensor(train_mask) + test_mask = torch.BoolTensor(test_mask) + + if args.gpu < 0: + cuda = False + else: + cuda = True + torch.cuda.set_device(args.gpu) + features = features.cuda() + labels = labels.cuda() + train_mask = train_mask.cuda() + test_mask = test_mask.cuda() + + print("Features Shape: ", features.shape, flush=True) + edge_weight = [1 for _ in range(len(cora.get_edges()))] + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + initial_used_gpu_mem = pynvml.nvmlDeviceGetMemoryInfo(handle).used + g = StaticGraph(cora.get_edges(), edge_weight, features.shape[0]) + graph_mem = pynvml.nvmlDeviceGetMemoryInfo(handle).used - initial_used_gpu_mem + + # A simple sanity check + print("Measuerd Graph Size (pynvml): ", graph_mem, " B", flush=True) + print( + "Measuerd Graph Size (pynvml): ", (graph_mem) / (1024**2), " MB", flush=True + ) + + # normalization + degs = torch.from_numpy(g.weighted_in_degrees()).type(torch.int32) + norm = torch.pow(degs, -0.5) + norm[torch.isinf(norm)] = 0 + norm = to_default_device(norm) + g.set_ndata("norm", norm.unsqueeze(1)) + + num_feats = features.shape[1] + n_classes = int(max(labels) - min(labels) + 1) + print("Num Classes: ", n_classes) + + model = GCN(g, num_feats, args.num_hidden, n_classes, args.num_layers, F.relu) + + if cuda: + model.cuda() + loss_fcn = torch.nn.CrossEntropyLoss() + + # use optimizer + optimizer = torch.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + + # initialize graph + dur = [] + Used_memory = 0 + + for epoch in range(args.num_epochs): + torch.cuda.reset_peak_memory_stats(0) + model.train() + if cuda: + torch.cuda.synchronize() + t0 = time.time() + + # forward + logits = model(g, features) + loss = loss_fcn(logits[train_mask], labels[train_mask]) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + now_mem = torch.cuda.max_memory_allocated(0) + graph_mem + Used_memory = max(now_mem, Used_memory) + + if cuda: + torch.cuda.synchronize() + + run_time_this_epoch = time.time() - t0 + + if epoch >= 3: + dur.append(run_time_this_epoch) + + train_acc = accuracy(logits[train_mask], labels[train_mask]) + print( + "Epoch {:05d} | Time(s) {:.4f} | train_acc {:.6f} | Used_Memory {:.6f} mb ".format( + epoch, run_time_this_epoch, train_acc, (now_mem * 1.0 / (1024**2)) + ) + ) + + Used_memory /= 1024**3 + print("^^^{:6f}^^^{:6f}".format(Used_memory, np.mean(dur))) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GCN") + + # COMMENT IF SNOOP IS TO BE ENABLED + snoop.install(enabled=False) + + parser.add_argument( + "--dropout", type=float, default=0.5, help="dropout probability" + ) + parser.add_argument("--dataset", type=str, help="Datset to train your model") + parser.add_argument("--gpu", type=int, default=0, help="gpu") + parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") + parser.add_argument( + "--num_epochs", type=int, default=200, help="number of training epochs" + ) + parser.add_argument( + "--num_hidden", type=int, default=16, help="number of hidden gcn units" + ) + parser.add_argument( + "--num_layers", type=int, default=1, help="number of hidden gcn layers" + ) + parser.add_argument( + "--weight-decay", type=float, default=5e-4, help="Weight for L2 loss" + ) + parser.add_argument( + "--self-loop", action="store_true", help="graph self-loop (default=False)" + ) + parser.set_defaults(self_loop=False) + args = parser.parse_args() + print(args) + + main(args) diff --git a/examples/gnn/gcn/utils.py b/examples/gnn/gcn/utils.py new file mode 100644 index 00000000..0f2d52c9 --- /dev/null +++ b/examples/gnn/gcn/utils.py @@ -0,0 +1,32 @@ +import torch + + +def accuracy(logits, labels): + _, indices = torch.max(logits, dim=1) + correct = torch.sum(indices == labels) + return correct.item() * 1.0 / len(labels) + + +# GPU | CPU +def get_default_device(): + if torch.cuda.is_available(): + return torch.device("cuda:0") + else: + return torch.device("cpu") + + +def to_default_device(data): + if isinstance(data, (list, tuple)): + return [to_default_device(x, get_default_device()) for x in data] + + return data.to(get_default_device(), non_blocking=True) + + +def generate_train_mask(size: int, train_test_split: int) -> list: + cutoff = size * train_test_split + return [1 if i < cutoff else 0 for i in range(size)] + + +def generate_test_mask(size: int, train_test_split: int) -> list: + cutoff = size * train_test_split + return [0 if i < cutoff else 1 for i in range(size)]