Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions impl/ascend/functions/unique.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,8 @@ diopiError_t diopiUnique(diopiContextHandle_t ctx, diopiTensorHandle_t* out, dio
makeTensor(ctx, outTmpAt, {inputAt.numel()}, inputAt.dtype());
}

// allocate temp inverse tensor
diopiTensorHandle_t inverseTmp = nullptr;
AscendTensor inverseTmpAt(inverseTmp);
bool returnInverse = (indices != nullptr) ? true : false;
std::vector<int64_t> zeroShape = {0};
if (returnInverse || returnCounts) {
makeTensor(ctx, inverseTmpAt, inputAt.shape(), diopi_dtype_int64);
} else {
makeTensor(ctx, inverseTmpAt, zeroShape, diopi_dtype_int64);
}

// allocate temp counts tensor
diopiTensorHandle_t countsTmp = nullptr;
Expand All @@ -48,8 +40,23 @@ diopiError_t diopiUnique(diopiContextHandle_t ctx, diopiTensorHandle_t* out, dio
}

// call aclnnUnique2
auto params = ::impl::ascend::aclnn_adaptor::convertParams(input, sorted, returnInverse, returnCounts, outTmpAt, inverseTmpAt, countsTmpAt).params();
DIOPI_ASECND_CALL_ACLNN_TYPE_SYNC(aclnnUnique2, ctx, params);
std::tuple<aclTensor*, bool, bool, bool, aclTensor*, aclTensor*, aclTensor*> params;
if (returnInverse) {
params = ::impl::ascend::aclnn_adaptor::convertParams(input, sorted, returnInverse, returnCounts, outTmpAt, indices, countsTmpAt).params();
DIOPI_ASECND_CALL_ACLNN_TYPE_SYNC(aclnnUnique2, ctx, params);
} else {
// allocate temp inverse tensor
diopiTensorHandle_t inverseTmp = nullptr;
AscendTensor inverseTmpAt(inverseTmp);
makeTensor(ctx, inverseTmpAt, zeroShape, diopi_dtype_int64);
if (returnCounts) {
makeTensor(ctx, inverseTmpAt, inputAt.shape(), diopi_dtype_int64);
} else {
makeTensor(ctx, inverseTmpAt, zeroShape, diopi_dtype_int64);
}
params = ::impl::ascend::aclnn_adaptor::convertParams(input, sorted, returnInverse, returnCounts, outTmpAt, inverseTmpAt, countsTmpAt).params();
DIOPI_ASECND_CALL_ACLNN_TYPE_SYNC(aclnnUnique2, ctx, params);
}

// get true outShape by aclGetViewShape
int64_t* viewDims = nullptr;
Expand All @@ -65,11 +72,6 @@ diopiError_t diopiUnique(diopiContextHandle_t ctx, diopiTensorHandle_t* out, dio
AscendTensor outReshapeAt = reshape(ctx, outTmpAt, {viewDims, viewDims + viewDimNum});
*out = const_cast<diopiTensorHandle_t>(outReshapeAt.tensorHandle());

// fill indices tensor
if (returnInverse) {
indices = const_cast<diopiTensorHandle_t>(inverseTmpAt.tensorHandle());
}

// fill counts tensor
if (returnCounts) {
// get counts tensor shape, counts tensor is the 7th tensor in aclnnUnique2, index = 6
Expand Down
Loading