-
Notifications
You must be signed in to change notification settings - Fork 52
Provide CUDA support for the binarize() operation on sparse matrices #601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
# here just to allow CUDA extension to overload this function with correct type casting | ||
binarize(x, T::DataType) = binarize(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# here just to allow CUDA extension to overload this function with correct type casting | |
binarize(x, T::DataType) = binarize(x) | |
binarize(x, T::DataType) = map(y -> y > 0 ? one(T) : zero(T), x) |
or something similar
The segmentation fault in GNNGGraphs CUDA CI should be investigated |
Co-authored-by: Carlo Lucibello <[email protected]>
function GNNGraphs.binarize(Mat::CUSPARSE.CuSparseMatrixCSC) | ||
bin_vals = fill!(similar(nonzeros(Mat), Bool), true) | ||
return CUSPARSE.CuSparseMatrixCSC(Mat.colPtr, rowvals(Mat), bin_vals, size(Mat)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function GNNGraphs.binarize(Mat::CUSPARSE.CuSparseMatrixCSC) | |
bin_vals = fill!(similar(nonzeros(Mat), Bool), true) | |
return CUSPARSE.CuSparseMatrixCSC(Mat.colPtr, rowvals(Mat), bin_vals, size(Mat)) | |
end |
This could be removed in favor of a single method with signature `binarize(Mat::CUSPARSE.CuSparseMatrixCSC, T::DataType = Bool). Same for the main method in GNNGraphs I guess.
Made binarize() work on CUDA with sparse adjmat of type CUSPACE.CuSparseMatrixCSC.
Also, added a specialization of binarize() that takes an eltype as an additional arg and directly creates a binarized adjmat of that, avoiding potentially costly subsequent conversions.
However, note that for now this implementation won't be called during the fwd pass, as the CUDA propagate specialization falls back on the default gather/scatter approach, which doesn't use the adjecency matrix.
All tests on CPU and GPU correctly passed, except for 1 broken in "ChebConv GPU", but it is also broken on master.