diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 47bd30d66ef9..3ec2281b1f62 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8668,6 +8668,30 @@ def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [ }]; } +def Torch_AtenPixelUnshuffleOp : Torch_Op<"aten.pixel_unshuffle", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::pixel_unshuffle : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$downscale_factor + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPixelUnshuffleOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPixelUnshuffleOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenChannelShuffleOp : Torch_Op<"aten.channel_shuffle", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index a000b7ab2f98..27feeea627e4 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -158,6 +158,18 @@ LogicalResult getPermutedType(BaseTensorType inType, SmallVector permuteDims, Type &permutedType); +// Extracts shape as vector of int64_t from vector of Value +SmallVector getIntShapeFromValues(ArrayRef vals); + +// Converts a vector of Value (shape dimensions) into a ValueTensorType +// Each `Value` is expected to be a constant integer, and +// non-constant values are treated as unknown dimensions (using `kUnknownSize`). +ValueTensorType getTypeFromShape(ArrayRef vals, Type inOptionalDType); + +// Get the size of the dimension 'i' of a given tensor `inValue`. +Value getDimSize(PatternRewriter &rewriter, Location loc, Value inValue, + uint64_t dimIndex); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 0ccf0d2b68f8..1b712806025a 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7613,6 +7613,56 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %15 = torch.aten.append.t %6, %14 : !torch.list, !torch.int -> !torch.list\n" " return %6 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.pixel_unshuffle\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %int1 = torch.constant.int 1\n" +" %int-3 = torch.constant.int -3\n" +" %str = torch.constant.str \"AssertionError: width must be divisible by downscale_factor in pixel_unshuffle\"\n" +" %int-1 = torch.constant.int -1\n" +" %str_0 = torch.constant.str \"AssertionError: height must be divisible by downscale_factor in pixel_unshuffle\"\n" +" %int-2 = torch.constant.int -2\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: input must be at least rank-3 in pixel_unshuffle\"\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.mul.int %arg1, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.remainder.int %3, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.remainder.int %6, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.slice.t %arg0, %int0, %int-3, %int1 : !torch.list, !torch.int, !torch.int, !torch.int -> !torch.list\n" +" %10 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.mul.int %10, %2 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %9, %11 : !torch.list, !torch.int -> !torch.list\n" +" %13 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.floordiv.int %13, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.append.t %9, %14 : !torch.list, !torch.int -> !torch.list\n" +" %16 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.floordiv.int %16, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten.append.t %9, %17 : !torch.list, !torch.int -> !torch.list\n" +" return %9 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.channel_shuffle\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: input must be at least rank-3 in channel_shuffle\"\n" @@ -12380,6 +12430,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pixel_unshuffle\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.channel_shuffle\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 341eb6e95d08..532d474eee8c 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3537,30 +3537,6 @@ class DecomposeAten_LinalgDetOp : public OpRewritePattern { }; } // namespace -namespace { // Start of rearrangement ops utility functions -// Extracts shape as vector of int64_t from vector of Value -SmallVector getIntShapeFromValues(ArrayRef vals) { - SmallVector shape; - shape.reserve(vals.size()); - for (Value v : vals) { - int64_t cst_val; - if (matchPattern(v, m_TorchConstantInt(&cst_val))) { - shape.push_back(cst_val); - } else { - shape.push_back(kUnknownSize); - } - } - return shape; -} - -// Converts a vector of Value (shape dimensions) into a ValueTensorType -ValueTensorType getTypeFromShape(ArrayRef vals, Type inOptionalDType) { - SmallVector intShape = getIntShapeFromValues(vals); - return ValueTensorType::get(vals[0].getContext(), llvm::ArrayRef(intShape), - inOptionalDType); -} -} // namespace - // Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and // prims.collapse operations. // @@ -3610,18 +3586,9 @@ class DecomposeAtenPixelShuffleOp auto nLeadingDims = inRank - 3; - // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead - // of 'create': if the dimension size is known, then the AtenSizeIntOp is - // folded to a ConstantOp. - auto getDimSize = [&](uint64_t i) -> Value { - Value dim = - rewriter.create(loc, rewriter.getI64IntegerAttr(i)); - return rewriter.createOrFold(loc, inValue, dim); - }; - - auto inC = getDimSize(inRank - 3); - auto inH = getDimSize(inRank - 2); - auto inW = getDimSize(inRank - 1); + auto inC = getDimSize(rewriter, loc, inValue, inRank - 3); + auto inH = getDimSize(rewriter, loc, inValue, inRank - 2); + auto inW = getDimSize(rewriter, loc, inValue, inRank - 1); auto factor = op.getUpscaleFactor(); @@ -3710,6 +3677,148 @@ class DecomposeAtenPixelShuffleOp }; } // namespace +// Decompose aten.pixel_unshuffle into: prims.split_dim, aten.permute, and +// prims.collapse operations. +// +// We want to do the exact opposite of aten.pixel_shuffle +// +// 'r' is referred to as the 'downscale factor' or just 'factor' below. +// +// If input is a tensor of shape +// (*leading_dims, C, H*r, W*r), +// +// where leading_dims is of size N, then +// X = pixel_unshuffle(input, downscale_factor) +// +// gets replaced with +// X = input.split_dim(...) # shape (*leading_dims, C, H, r, W*r) +// X = X.split_dim(...) # shape (*leading_dims, C, H, r, W, r) +// X = X.permute(0, ..., N, N+2, N+4, N+1, N+3) +// # shape (*leading_dims, C, r, r, H, W) +// X = X.collapse(...) # shape (*leading_dims, C, r*r, H, W) +// X = X.collapse(...) # shape (*leading_dims, C*r*r, H, W) +// +namespace { +class DecomposeAtenPixelUnshuffleOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenPixelUnshuffleOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value inValue = op.getSelf(); + auto inType = cast(inValue.getType()); + auto maybeSizes = inType.getOptionalSizes(); + if (!maybeSizes) { + return rewriter.notifyMatchFailure( + op, "Expected input tensor to have known rank."); + } + auto inShape = maybeSizes.value(); + auto inRank = inShape.size(); + + // The input tensor must have at least 3 dimensions: (1) the channel + // dimension which gets bigger by 'factor*factor', (2) the H channel which + // gets smaller by 'factor' and (3) the W channel which get smaller by + // 'factor'. The total number of dimensions is 3 + N, where N is the number + // of leading dimensions, and N >= 0 so the input must have rank at least 3. + if (inRank < 3) + return rewriter.notifyMatchFailure( + op, "Expected input tensor to have rank greater than 2."); + + const auto inOptionalDType = inType.getOptionalDtype(); + + auto nLeadingDims = inRank - 3; + + auto inC = getDimSize(rewriter, loc, inValue, inRank - 3); + auto inH = getDimSize(rewriter, loc, inValue, inRank - 2); + auto inW = getDimSize(rewriter, loc, inValue, inRank - 1); + + auto factor = op.getDownscaleFactor(); + + Value factorSquared = + rewriter.createOrFold(loc, factor, factor); + + Value outC = rewriter.createOrFold(loc, inC, factorSquared); + + Value outH = rewriter.createOrFold(loc, inH, factor); + Value outW = rewriter.createOrFold(loc, inW, factor); + + SmallVector dimensionConstants; + dimensionConstants.reserve(inRank + 2); + for (unsigned i = 0; i < inRank + 2; ++i) { + dimensionConstants.push_back( + rewriter.create(loc, rewriter.getI64IntegerAttr(i))); + } + + SmallVector leadingDims; + leadingDims.reserve(nLeadingDims); + for (unsigned i = 0; i < nLeadingDims; ++i) { + Value leadingDimSize = rewriter.createOrFold( + loc, inValue, dimensionConstants[i]); + leadingDims.push_back(leadingDimSize); + } + + SmallVector prePermuteShape = leadingDims; + prePermuteShape.append({inC, outH, factor, outW, factor}); + + SmallVector postPermuteShape = leadingDims; + postPermuteShape.append({inC, factor, factor, outH, outW}); + + SmallVector partiallyCollapsedShape = leadingDims; + partiallyCollapsedShape.append({inC, factorSquared, outH, outW}); + + SmallVector outShape = leadingDims; + outShape.append({outC, outH, outW}); + + SmallVector permutation{dimensionConstants.begin(), + dimensionConstants.begin() + nLeadingDims}; + SmallVector permutationTail{0, 2, 4, 1, 3}; + for (uint64_t d : permutationTail) { + permutation.push_back(dimensionConstants[nLeadingDims + d]); + } + + Value permuteDimsOrder = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + permutation); + + SmallVector heightSplitShape = leadingDims; + heightSplitShape.append({inC, outH, factor, inW}); + + // Split input channel inH -> (outH, factor) + auto partiallyExpanded = + rewriter + .create( + loc, getTypeFromShape(heightSplitShape, inOptionalDType), + inValue, dimensionConstants[nLeadingDims + 1], outH) + .getResult(); + + // Split new dimension inW -> (outW, factor) + auto fullyExpanded = rewriter.create( + loc, getTypeFromShape(prePermuteShape, inOptionalDType), + partiallyExpanded, dimensionConstants[nLeadingDims + 3], outW); + + // Perform the permutation + auto permuted = rewriter.create( + loc, getTypeFromShape(postPermuteShape, inOptionalDType), fullyExpanded, + permuteDimsOrder); + + // Collapse final 2 dimension + auto partiallyCollapsed = rewriter.create( + loc, getTypeFromShape(partiallyCollapsedShape, inOptionalDType), + permuted, dimensionConstants[nLeadingDims + 1], + dimensionConstants[nLeadingDims + 2]); + + // Collapse back to original rank + rewriter.replaceOpWithNewOp( + op, op.getType(), partiallyCollapsed, dimensionConstants[nLeadingDims], + dimensionConstants[nLeadingDims + 1]); + + return success(); + } +}; +} // namespace + // Decompose aten.channel_shuffle into: prims.split_dim, aten.permute, and // prims.collapse operations. // @@ -3764,23 +3873,14 @@ class DecomposeAtenChannelShuffleOp auto numOfSpatialDims = inRank - 2; - // Get the size of the dimension 'i'. Note the use of 'createOrFold' - // instead of 'create': if the dimension size is known, then the - // AtenSizeIntOp is folded to a ConstantOp. - auto getDimSize = [&rewriter, &inValue, loc](uint64_t i) -> Value { - Value dim = - rewriter.create(loc, rewriter.getI64IntegerAttr(i)); - return rewriter.createOrFold(loc, inValue, dim); - }; - // The channel dimension is always the second dimension. PyTorch errors out // if the batch dimension (first dimension) is not present. See comment at // the top of this class for details. - auto inC = getDimSize(1); + auto inC = getDimSize(rewriter, loc, inValue, 1); SmallVector inSpatialDims; inSpatialDims.reserve(numOfSpatialDims); for (unsigned i = 2; i < (2 + numOfSpatialDims); ++i) { - inSpatialDims.push_back(getDimSize(i)); + inSpatialDims.push_back(getDimSize(rewriter, loc, inValue, i)); } auto groups = op.getGroups(); @@ -12859,6 +12959,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 7deacbae65a1..95f497bf0abf 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -421,6 +421,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 2d40a6cf2c19..aadd0b8dd137 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -327,7 +327,8 @@ bool Torch::isViewLikeOp(Operation *op) { AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp, PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp, AtenPixelShuffleOp, - AtenChannelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op); + AtenPixelUnshuffleOp, AtenChannelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>( + op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, @@ -709,3 +710,41 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { return rewriter.getI64Type(); return inputType; } + +// Extracts shape as vector of int64_t from vector of Value +// This function attempts to match each `Value` in the input list to a +// Torch constant integer. If successful, the constant value is used +// in the output shape vector. Otherwise, a sentinel value (`kUnknownSize`) +// is inserted to indicate an unknown dimension. +SmallVector Torch::getIntShapeFromValues(ArrayRef vals) { + SmallVector shape; + shape.reserve(vals.size()); + for (Value v : vals) { + int64_t cst_val; + if (matchPattern(v, m_TorchConstantInt(&cst_val))) { + shape.push_back(cst_val); + } else { + shape.push_back(kUnknownSize); + } + } + return shape; +} + +// Converts a vector of Value (shape dimensions) into a ValueTensorType +ValueTensorType Torch::getTypeFromShape(ArrayRef vals, + Type inOptionalDType) { + SmallVector intShape = getIntShapeFromValues(vals); + return ValueTensorType::get(vals[0].getContext(), llvm::ArrayRef(intShape), + inOptionalDType); +} + +// Returns the size of the dimension 'i' of a given tensor `inValue`. +// Note the use of 'createOrFold' +// instead of 'create': if the dimension size is statically known, then the +// AtenSizeIntOp is folded to a ConstantOp. +Value Torch::getDimSize(PatternRewriter &rewriter, Location loc, Value inValue, + uint64_t dimIndex) { + Value dim = + rewriter.create(loc, rewriter.getI64IntegerAttr(dimIndex)); + return rewriter.createOrFold(loc, inValue, dim); +} diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bee2c62d7202..b00618f66dea 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3127,6 +3127,11 @@ "PixelShuffleModuleSpatiallyDynamic_basic", "PixelShuffleModuleSpatiallyStatic_basic", "PixelShuffleModuleStaticRank3Int64_basic", + "PixelUnshuffleModuleStaticRank5Float32_basic", + "PixelUnshuffleModuleStaticRank3Int64_basic", + "PixelUnshuffleModuleFullDynamic_basic", + "PixelUnshuffleModuleSpatiallyDynamic_basic", + "PixelUnshuffleModuleSpatiallyStatic_basic", "ChannelShuffleBasic_basic", "ChannelShuffleUnitaryGroup_basic", "ChannelShuffle1D_basic", @@ -4738,6 +4743,11 @@ "PixelShuffleModuleSpatiallyStatic_basic", "PixelShuffleModuleStaticRank3Int64_basic", "PixelShuffleModuleStaticRank4Float32_basic", + "PixelUnshuffleModuleStaticRank5Float32_basic", + "PixelUnshuffleModuleStaticRank3Int64_basic", + "PixelUnshuffleModuleFullDynamic_basic", + "PixelUnshuffleModuleSpatiallyDynamic_basic", + "PixelUnshuffleModuleSpatiallyStatic_basic", "ChannelShuffleBasic_basic", "ChannelShuffleUnitaryGroup_basic", "ChannelShuffle1D_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a1276dd5039a..88e05be7608b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -839,6 +839,20 @@ def aten〇pixel_shuffle〡shape(self: List[int], upscale_factor: int) -> List[i out.append(self[-1] * upscale_factor) return out +def aten〇pixel_unshuffle〡shape(self: List[int], downscale_factor: int) -> List[int]: + + assert len(self) >= 3, "input must be at least rank-3 in pixel_unshuffle" + downscale_factor_squared = downscale_factor * downscale_factor + assert self[-2] % (downscale_factor) == 0, "height must be divisible by downscale_factor in pixel_unshuffle" + assert self[-1] % (downscale_factor) == 0, "width must be divisible by downscale_factor in pixel_unshuffle" + + out = self[0:-3] + out.append(self[-3] * downscale_factor_squared) + out.append(self[-2] // downscale_factor) + out.append(self[-1] // downscale_factor) + return out + + def aten〇channel_shuffle〡shape(self: List[int], groups: int) -> List[int]: assert len(self) >= 3, "input must be at least rank-3 in channel_shuffle" return self @@ -3069,6 +3083,11 @@ def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_facto self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 2, 2)], downscale_factor = 2)) +def aten〇pixel_unshuffle〡dtype(self_rank_dtype: Tuple[int, int], downscale_factor: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 4, 4, 5)], groups = 2)) def aten〇channel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], groups: int) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index c0878700e34d..0aa71a52ed4d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -719,6 +719,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)", has_folder=True) emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)") + emit("aten::pixel_unshuffle : (Tensor, int) -> (Tensor)") emit("aten::channel_shuffle : (Tensor, int) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True) emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index f60342c26646..b3099b497df7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1010,6 +1010,25 @@ def PixelShuffleModuleSpatiallyStatic_basic(module, tu: TestUtils): # ============================================================================== +class PixelUnshuffleModuleStaticRank4Float32(torch.nn.Module): + # Basic test case for PixelUnshuffle operation + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([1, 1, 12, 12], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 3) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleStaticRank4Float32()) +def PixelUnshuffleModuleStaticRank4Float32_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 12, 12)) + + +# ============================================================================== + + class ChannelShuffleBasic(torch.nn.Module): # Basic test case for ChannelShuffle operation. def __init__(self): @@ -1035,6 +1054,25 @@ def ChannelShuffleBasic_basic(module, tu: TestUtils): # ============================================================================== +class PixelUnshuffleModuleStaticRank5Float32(torch.nn.Module): + # Basic test case for PixelUnshuffle operation + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([4, 1, 8, 4, 4], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 2) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleStaticRank5Float32()) +def PixelUnshuffleModuleStaticRank5Float32_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 1, 8, 4, 4)) + + +# ============================================================================== + + class ChannelShuffleUnitaryGroup(torch.nn.Module): # Test case where group = 1. def __init__(self): @@ -1060,6 +1098,24 @@ def ChannelShuffleUnitaryGroup_basic(module, tu: TestUtils): # ============================================================================== +class PixelUnshuffleModuleStaticRank3Int64(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([1, 8, 8], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 2) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleStaticRank3Int64()) +def PixelUnshuffleModuleStaticRank3Int64_basic(module, tu: TestUtils): + module.forward(tu.randint(1, 8, 8, low=0, high=100)) + + +# ============================================================================== + + class ChannelShuffle1D(torch.nn.Module): def __init__(self): super().__init__() @@ -1084,6 +1140,24 @@ def ChannelShuffle1D_basic(module, tu: TestUtils): # ============================================================================== +class PixelUnshuffleModuleFullDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1, -1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 2) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleFullDynamic()) +def PixelUnshuffleModuleFullDynamic_basic(module, tu: TestUtils): + module.forward(tu.randint(1, 2, 6, 6, low=0, high=100)) + + +# ============================================================================== + + class ChannelShuffle4D(torch.nn.Module): def __init__(self): super().__init__() @@ -1108,6 +1182,24 @@ def ChannelShuffle4D_basic(module, tu: TestUtils): # ============================================================================== +class PixelUnshuffleModuleSpatiallyDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([4, 1, 6, -1, -1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 2) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleSpatiallyDynamic()) +def PixelUnshuffleModuleSpatiallyDynamic_basic(module, tu: TestUtils): + module.forward(tu.randint(4, 1, 6, 4, 6, low=0, high=100)) + + +# ============================================================================== + + class ChannelShuffleTrailingOnes(torch.nn.Module): # Test case where ChannelShuffle last dimensions are ones. def __init__(self): @@ -1133,6 +1225,24 @@ def ChannelShuffleTrailingOnes_basic(module, tu: TestUtils): # ============================================================================== +class PixelUnshuffleModuleSpatiallyStatic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1, 6, 3], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 3) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleSpatiallyStatic()) +def PixelUnshuffleModuleSpatiallyStatic_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 2, 3, 6, 3, low=0, high=100)) + + +# ============================================================================== + + class ChannelShuffleDynamicDims(torch.nn.Module): # Test case for dynamic dimensions in ChannelShuffle operation. def __init__(self): diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index f9509449422b..993dd8bf7b56 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -851,6 +851,61 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf // ----- +// CHECK-LABEL: func @pixel_unshuffle_static +// CHECK-DAG: %[[C2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[C0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[C1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[C3:.*]] = torch.constant.int 3 +// CHECK-DAG: %[[C4:.*]] = torch.constant.int 4 +// CHECK-DAG: %[[C5:.*]] = torch.constant.int 5 +// CHECK: %[[PERMLIST:.*]] = torch.prim.ListConstruct %[[C0]], %[[C1]], %[[C3]], %[[C5]], %[[C2]], %[[C4]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[EXPAND1:.*]] = torch.prims.split_dim %[[ARG0]], %[[C2]], %[[C2]] : !torch.vtensor<[1,8,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,4],f32> +// CHECK: %[[EXPAND2:.*]] = torch.prims.split_dim %[[EXPAND1]], %[[C4]], %[[C2]] : !torch.vtensor<[1,8,2,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,2,2],f32> +// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[EXPAND2]], %[[PERMLIST]] : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.list -> !torch.vtensor<[1,8,2,2,2,2],f32> +// CHECK: %[[COLLAPSE1:.*]] = torch.prims.collapse %[[PERMUTE]], %[[C2]], %[[C3]] : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,4,2,2],f32> +// CHECK: %[[COLLAPSE2:.*]] = torch.prims.collapse %[[COLLAPSE1]], %[[C1]], %[[C2]] : !torch.vtensor<[1,8,4,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,32,2,2],f32> +// CHECK: return %[[COLLAPSE2]] : !torch.vtensor<[1,32,2,2],f32> +func.func @pixel_unshuffle_static(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,32,2,2],f32> attributes {torch.assume_strict_symbolic_shapes} { + %int2 = torch.constant.int 2 + %0 = torch.aten.pixel_unshuffle %arg0, %int2 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,32,2,2],f32> + return %0 : !torch.vtensor<[1,32,2,2],f32> +} + + +// ----- + + +// CHECK-LABEL: func @pixel_unshuffle_fulldynamic +// CHECK-DAG: %[[C2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[C0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[C1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[C3:.*]] = torch.constant.int 3 +// CHECK-DAG: %[[C4:.*]] = torch.constant.int 4 +// CHECK-DAG: %[[C5:.*]] = torch.constant.int 5 +// CHECK: %[[INC:.*]] = torch.aten.size.int %[[ARG0]], %[[C1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[INH:.*]] = torch.aten.size.int %[[ARG0]], %[[C2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[INW:.*]] = torch.aten.size.int %[[ARG0]], %[[C3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[OUTC:.*]] = torch.aten.mul.int %[[INC]], %[[C4]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[OUTH:.*]] = torch.aten.floordiv.int %[[INH]], %[[C2]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[OUTW:.*]] = torch.aten.floordiv.int %[[INW]], %[[C2]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[SIZE0:.*]] = torch.aten.size.int %[[ARG0]], %[[C0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[PERMLIST:.*]] = torch.prim.ListConstruct %[[C0]], %[[C1]], %[[C3]], %[[C5]], %[[C2]], %[[C4]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[EXPAND1:.*]] = torch.prims.split_dim %[[ARG0]], %[[C2]], %[[OUTH]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,2,?],f32> +// CHECK: %[[EXPAND2:.*]] = torch.prims.split_dim %[[EXPAND1]], %[[C4]], %[[OUTW]] : !torch.vtensor<[?,?,?,2,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,2,?,2],f32> +// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[EXPAND2]], %[[PERMLIST]] : !torch.vtensor<[?,?,?,2,?,2],f32>, !torch.list -> !torch.vtensor<[?,?,2,2,?,?],f32> +// CHECK: %[[COLLAPSE1:.*]] = torch.prims.collapse %[[PERMUTE]], %[[C2]], %[[C3]] : !torch.vtensor<[?,?,2,2,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,4,?,?],f32> +// CHECK: %[[COLLAPSE2:.*]] = torch.prims.collapse %[[COLLAPSE1]], %[[C1]], %[[C2]] : !torch.vtensor<[?,?,4,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[COLLAPSE2]] : !torch.vtensor<[?,?,?,?],f32> +func.func @pixel_unshuffle_fulldynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.assume_strict_symbolic_shapes} { + %int2 = torch.constant.int 2 + %0 = torch.aten.pixel_unshuffle %arg0, %int2 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + + +// ----- + + // CHECK-LABEL: func @channel_shuffle func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,8,4,4],f32> attributes {torch.assume_strict_symbolic_shapes} { %int4 = torch.constant.int 4