diff --git a/include/imex/Conversion/CMakeLists.txt b/include/imex/Conversion/CMakeLists.txt index db25c8fc0..a396bfaa3 100644 --- a/include/imex/Conversion/CMakeLists.txt +++ b/include/imex/Conversion/CMakeLists.txt @@ -9,3 +9,4 @@ add_subdirectory(DistToStandard) add_subdirectory(DropRegions) add_subdirectory(XeTileToXeGPU) add_subdirectory(XeGPUToVC) +add_subdirectory(VectorToXeGPU) diff --git a/include/imex/Conversion/Passes.h b/include/imex/Conversion/Passes.h index 416a09de5..978e1bfa9 100644 --- a/include/imex/Conversion/Passes.h +++ b/include/imex/Conversion/Passes.h @@ -26,6 +26,7 @@ #include #include #include +#include namespace imex { diff --git a/include/imex/Conversion/Passes.td b/include/imex/Conversion/Passes.td index 9f14e023f..78533156e 100644 --- a/include/imex/Conversion/Passes.td +++ b/include/imex/Conversion/Passes.td @@ -441,4 +441,34 @@ def ConvertXeGPUToVC : Pass<"convert-xegpu-to-vc", "::mlir::gpu::GPUModuleOp"> { let constructor = "imex::createConvertXeGPUToVCPass()"; } +//===----------------------------------------------------------------------===// +// VectorToXeGPU +//===----------------------------------------------------------------------===// + +def ConvertVectorToXeGPU: Pass<"convert-vector-to-xegpu", "::mlir::ModuleOp"> { + let summary = "Convert from the Vector dialect to the XeGPU dialect."; + let description = [{ + Convert Vector dialect operations into the XeGPU dialect operations. It aims at lowering `vector.transfer_read` and `vector.transfer_write` operations to `xegpu.load_nd` and `xegpu.store_nd` operations, creating the descriptors meanwhile. + + #### Input invariant + + %3 = vector.transfer_read %arg1[%0, %2], %arg2 : memref<512x640xf32>, vector<2x32xf32> + %4 = arith.cmpf ugt, %3, %arg3 : vector<2x32xf32> + %5 = arith.select %4, %3, %arg3 : vector<2x32xi1>, vector<2x32xf32> + vector.transfer_write %5, %arg4[%0, %2] : vector<2x32xf32>, memref<512x640xf32> + + #### Output IR + + %desc = xegpu.create_nd_tdesc %arg1[%0, %2] {mode = vc} : memref<512x640xf32> -> !xegpu.tensor_desc<2x32xf32> + %3 = xegpu.load_nd %desc {mode = vc}: !xegpu.tensor_desc<2x32xf32> -> vector<2x32xf32> + %4 = arith.cmpf ugt, %3, %arg3 : vector<2x32xf32> + %5 = arith.select %4, %3, %arg3 : vector<2x32xi1>, vector<2x32xf32> + %desc2 = xegpu.create_nd_tdesc %arg4[%0, %2] {mode = vc} : memref<512x640xf32> -> !xegpu.tensor_desc<2x32xf32> + xegpu.store_nd %5, %desc2 {mode = vc} : vector<2x32xf32>, !xegpu.tensor_desc<32xf32> + }]; + let constructor = "::imex::createConvertVectorToXeGPUPass()"; + let dependentDialects = ["::mlir::xegpu::XeGPUDialect"]; + let options = []; +} + #endif // _IMEX_CONVERSION_PASSES_TD_INCLUDED_ diff --git a/include/imex/Conversion/VectorToXeGPU/CMakeLists.txt b/include/imex/Conversion/VectorToXeGPU/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/include/imex/Conversion/VectorToXeGPU/VectorToXeGPU.h b/include/imex/Conversion/VectorToXeGPU/VectorToXeGPU.h new file mode 100644 index 000000000..58cb880e6 --- /dev/null +++ b/include/imex/Conversion/VectorToXeGPU/VectorToXeGPU.h @@ -0,0 +1,37 @@ +//===- VectorToXeGPU.h - VectorToXeGPU conversion -------*- C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines the VectorToXeGPU conversion, converting the Vector +/// dialect to the XeGPU dialect. +/// +//===----------------------------------------------------------------------===// + +#ifndef _VectorToXeGPU_H_INCLUDED_ +#define _VectorToXeGPU_H_INCLUDED_ + +#include +#include + +namespace mlir { +class LLVMTypeConverter; +class MLIRContext; +class ModuleOp; +template +class OperationPass; +class RewritePatternSet; +} + +namespace imex { +/// Create a pass to convert the Vector dialect to the XeGPU dialect. +std::unique_ptr<::mlir::OperationPass<::mlir::ModuleOp>> createConvertVectorToXeGPUPass(); + +} // namespace imex + +#endif // _VectorToXeGPU_H_INCLUDED_ diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index a989a7991..43ed3a43c 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -7,3 +7,4 @@ add_subdirectory(GPUXToLLVM) add_subdirectory(XeGPUToSPIRV) add_subdirectory(XeTileToXeGPU) add_subdirectory(XeGPUToVC) +add_subdirectory(VectorToXeGPU) diff --git a/lib/Conversion/VectorToXeGPU/CMakeLists.txt b/lib/Conversion/VectorToXeGPU/CMakeLists.txt new file mode 100644 index 000000000..dc5e84c93 --- /dev/null +++ b/lib/Conversion/VectorToXeGPU/CMakeLists.txt @@ -0,0 +1,12 @@ +add_imex_conversion_library(IMEXVectorToXeGPU + VectorToXeGPU.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/imex/Conversion/VectorToXeGPU + + DEPENDS + IMEXConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRXeGPUDialect +) diff --git a/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp new file mode 100644 index 000000000..002985b39 --- /dev/null +++ b/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -0,0 +1,224 @@ +//===- VectorToXeGPU.cpp - VectorToXeGPU conversion -------*- C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the VectorToXeGPU conversion, converting the Vector +/// dialect to the XeGPU dialect. +/// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "../PassDetail.h" +#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/VectorInterfaces.h" +#include "mlir/Support/LogicalResult.h" + +using namespace mlir; + +namespace imex { + +namespace { + +class MyPatternRewriter : public PatternRewriter { +public: + MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {} + + /// Override the necessary PatternRewriter hooks here. +}; + +struct MyTarget : public ConversionTarget { + MyTarget(MLIRContext &ctx) : ConversionTarget(ctx) { + + /// Mark `cf.br` and `cf.cond_br` as illegal. + addIllegalOp(); //, vector::TransferWriteOp + } +}; + +// ******************************* +// ***** Individual patterns ***** +// ******************************* + +// Goal: vector.transfer_read -> xegpu.create_nd_tdesc + xegpu.load_nd +// E.g. translate +// %3 = vector.transfer_read %arg1[%0, %2], %arg2 : memref<512x640xf32>, +// vector<1x32xf32> to %desc = xegpu.create_nd_tdesc %arg1[%0, %2] {mode = vc} +// : memref<512x640xf32> -> !xegpu.tensor_desc<32xf32> +// to +// %4 = xegpu.load_nd %3 {mode = vc}: !xegpu.tensor_desc<32xf32> -> +// vector<32xf32> +// %5 = vector.shape_cast %4 : vector<1x32xf32> to vector<32xf32> + +struct TransferReadOpConverter + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::TransferReadOp read, + PatternRewriter &rewriter) const override { + auto ctx = read->getContext(); + auto resultTile = read.getResult(); + auto resTileType = resultTile.getType(); + auto resTileShape = resTileType.getShape(); + auto rank = resTileType.getRank(); + auto source = read.getSource(); + + ArrayRef loadShape; + if (rank == 1) + loadShape = {1, resTileShape[0]}; + else + loadShape = resTileShape; + auto loadType = VectorType::get(loadShape, resTileType.getElementType()); + auto tDescTy = + xegpu::TensorDescType::get(loadShape, resTileType.getElementType()); + mlir::SmallVector tDescOffsets{read->getOperand(1), + read->getOperand(2)}; + rewriter.setInsertionPoint(read); + mlir::Value desc; + if (auto MemRefTypedSource = + mlir::cast>(source)) { + desc = rewriter.create( + read.getLoc(), tDescTy, MemRefTypedSource, tDescOffsets); + } else { + return mlir::failure(); + } + + mlir::IntegerAttr vnniAxisAttr; + mlir::DenseI64ArrayAttr transposeAttr; + mlir::IntegerAttr transposeBitWidthAttr; + auto CACHED = mlir::xegpu::CachePolicy::CACHED; + auto L1 = mlir::xegpu::CachePolicyAttr::get(ctx, CACHED); + auto L2 = mlir::xegpu::CachePolicyAttr::get(ctx, CACHED); + auto L3 = mlir::xegpu::CachePolicyAttr::get(ctx, CACHED); + Operation *payload = rewriter.create( + read.getLoc(), loadType, desc, vnniAxisAttr, transposeAttr, + transposeBitWidthAttr, L1, L2, L3); + + if (rank == 1) { + // xegpu currently don't support 1d vector load. We need to cast it to 2d + auto cast = rewriter.create( + read.getLoc(), resTileType, payload->getResults()); + if (auto map = read.getPermutationMap(); map.isSingleConstant()) { + SmallVector mask(resTileShape[0], + map.getSingleConstantResult()); + payload = + rewriter.create(read.getLoc(), cast, cast, mask); + } else { + AffineExpr d0, d1; + bindDims(read.getContext(), d0, d1); + auto mp = AffineMap::get(map.getNumDims(), 0, {d1}, read.getContext()); + // (d0, d1) -> (d1) + if (map != mp) { + // Unsupported permutation map + return ::mlir::failure(); + } + payload = cast; + } + } + rewriter.replaceOp(read, payload->getResults()); + + return ::mlir::success(); + } +}; + +// vector.transfer_write %5, %arg4[%0, %2] : vector<1x32xf32>, +// memref<512x640xf32> to %5 = vector.shape_cast %4 : vector<32xf32> to +// vector<1x32xf32> %desc2 = xegpu.create_nd_tdesc %arg4[%0, %2] {mode = vc} : +// memref<512x640xf32> -> !xegpu.tensor_desc<1x32xf32> xegpu.store_nd %5, %desc2 +// {mode = vc} : vector<1x32xf32>, !xegpu.tensor_desc<1x32xf32> + +struct TransferWriteOpConverter + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::TransferWriteOp write, + PatternRewriter &rewriter) const override { + auto ctx = write->getContext(); + auto resultTile = write->getOperand(0); //%5 + auto source = write.getSource(); // memref<512x640xi32> + auto resTileType = dyn_cast(resultTile.getType()); + auto resTileShape = resTileType.getShape(); + auto rank = resTileType.getRank(); + auto intermediateType = + VectorType::get({1, resTileShape[0]}, resTileType.getElementType()); + + ArrayRef loadShape; + if (rank == 1) + loadShape = {1, resTileShape[0]}; + else + loadShape = resTileShape; + auto tDescTy = + xegpu::TensorDescType::get(loadShape, resTileType.getElementType()); + mlir::SmallVector tDescOffsets{write->getOperand(2), + write->getOperand(3)}; + rewriter.setInsertionPoint(write); + mlir::Value payload = write.getOperand(0); + if (rank == 1) { + payload = rewriter.create( + write.getLoc(), intermediateType, write->getOperand(0)); + } + mlir::Value desc; + if (auto MemRefTypedSource = + mlir::cast>(source)) { + desc = rewriter.create( + write.getLoc(), tDescTy /*resultTy*/, MemRefTypedSource /*source*/, + tDescOffsets /*offsets*/); + } else { + return mlir::failure(); + } + + auto WRITE_BACK = mlir::xegpu::CachePolicy::WRITE_BACK; + auto L1 = mlir::xegpu::CachePolicyAttr::get(ctx, WRITE_BACK); + auto L2 = mlir::xegpu::CachePolicyAttr::get(ctx, WRITE_BACK); + auto L3 = mlir::xegpu::CachePolicyAttr::get(ctx, WRITE_BACK); + rewriter.create(write.getLoc(), payload, desc, L1, L2, + L3); + rewriter.eraseOp(write); + + return ::mlir::success(); + } +}; + +// ******************************* +// ***** Pass infrastructure ***** +// ******************************* + +// Full Pass +struct ConvertVectorToXeGPUPass // convert Vector to XeGPU + : public ::imex::ConvertVectorToXeGPUBase { + ConvertVectorToXeGPUPass() = default; + + void runOnOperation() override { + auto *ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); + + patterns.insert(ctx); + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)); + } +}; + +} // namespace + +/// Populate the given list with patterns that convert Vector to XeGPU + +/// Create a pass that convert Vector to XeGPU +std::unique_ptr<::mlir::OperationPass<::mlir::ModuleOp>> +createConvertVectorToXeGPUPass() { + return std::make_unique(); +} + +} // namespace imex diff --git a/lib/Utils/XeCommon.cpp b/lib/Utils/XeCommon.cpp index bfc4b8ab3..f02413ec9 100644 --- a/lib/Utils/XeCommon.cpp +++ b/lib/Utils/XeCommon.cpp @@ -112,6 +112,8 @@ encodeVectorType(mlir::ConversionPatternRewriter &rewriter, } else if (elemType == rewriter.getBF16Type()) { str += "i32"; elemType = rewriter.getI32Type(); + } else if (elemType == rewriter.getI32Type()) { + str += "i32"; } else assert(0 && "add more support"); auto newType = mlir::VectorType::get(size, elemType); diff --git a/test/Conversion/VectorToXeGPU/gemm_3x3.mlir b/test/Conversion/VectorToXeGPU/gemm_3x3.mlir new file mode 100644 index 000000000..a161d6d7a --- /dev/null +++ b/test/Conversion/VectorToXeGPU/gemm_3x3.mlir @@ -0,0 +1,114 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/vector-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/vector-to-llvm.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +#map = affine_map<(d0, d1) -> (0)> +module @gemm attributes {gpu.container_module} { + memref.global "private" constant @__constant_3x3xi32_1 : memref<3x3xi32> = dense<1> + memref.global "private" constant @__constant_3x3xi32_0 : memref<3x3xi32> = dense<[[10, 11, 12], [13, 14, 15], [16, 17, 18]]> + memref.global "private" constant @__constant_3x3xi32 : memref<3x3xi32> = dense<[[1, 1, 1], [1, 1, 2], [3, 3, 3]]> + func.func @main() { + %0 = memref.get_global @__constant_3x3xi32 : memref<3x3xi32> + %1 = memref.get_global @__constant_3x3xi32_0 : memref<3x3xi32> + %2 = memref.get_global @__constant_3x3xi32_1 : memref<3x3xi32> + %3 = call @test(%0, %1, %2) : (memref<3x3xi32>, memref<3x3xi32>, memref<3x3xi32>) -> memref<3x3xi32> + %cast = memref.cast %3 : memref<3x3xi32> to memref<*xi32> + call @printMemrefI32(%cast) : (memref<*xi32>) -> () + // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} + // CHECK-NEXT: [40, 43, 46] + // CHECK-NEXT: [56, 60, 64] + // CHECK-NEXT: [118, 127, 136] + return + } + func.func private @printMemrefI32(memref<*xi32>) + func.func @test(%arg0: memref<3x3xi32>, %arg1: memref<3x3xi32>, %arg2: memref<3x3xi32>) -> memref<3x3xi32> { + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c3 = arith.constant 3 : index + %c0 = arith.constant 0 : index + %memref = gpu.alloc host_shared () : memref<3x3xi32> + memref.copy %arg2, %memref : memref<3x3xi32> to memref<3x3xi32> + %memref_0 = gpu.alloc host_shared () : memref<3x3xi32> + memref.copy %arg1, %memref_0 : memref<3x3xi32> to memref<3x3xi32> + %memref_1 = gpu.alloc host_shared () : memref<3x3xi32> + memref.copy %arg0, %memref_1 : memref<3x3xi32> to memref<3x3xi32> + %memref_2 = gpu.alloc host_shared () : memref<3x3xi32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c3, %c1, %c1) threads in (%c1, %c1, %c1) args(%c0 : index, %c16 : index, %c0 : index, %memref_2 : memref<3x3xi32>) + %memref_3 = gpu.alloc host_shared () : memref<3x3xi32> + memref.copy %memref_2, %memref_3 : memref<3x3xi32> to memref<3x3xi32> + gpu.launch_func @test_kernel_0::@test_kernel blocks in (%c3, %c1, %c1) threads in (%c1, %c1, %c1) args(%c0 : index, %c16 : index, %c0 : index, %memref_1 : memref<3x3xi32>, %c0 : index, %memref_0 : memref<3x3xi32>, %memref_3 : memref<3x3xi32>) + %memref_4 = gpu.alloc host_shared () : memref<3x3xi32> + gpu.launch_func @test_kernel_1::@test_kernel blocks in (%c3, %c1, %c1) threads in (%c1, %c1, %c1) args(%c0 : index, %c16 : index, %c0 : index, %memref_3 : memref<3x3xi32>, %memref : memref<3x3xi32>, %memref_4 : memref<3x3xi32>) + gpu.dealloc %memref_2 : memref<3x3xi32> + gpu.dealloc %memref_3 : memref<3x3xi32> + %alloc = memref.alloc() : memref<3x3xi32> + memref.copy %memref_4, %alloc : memref<3x3xi32> to memref<3x3xi32> + gpu.dealloc %memref_4 : memref<3x3xi32> + gpu.dealloc %memref_1 : memref<3x3xi32> + gpu.dealloc %memref_0 : memref<3x3xi32> + gpu.dealloc %memref : memref<3x3xi32> + return %alloc : memref<3x3xi32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<3x3xi32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense<0> : vector<16xi32> + %0 = gpu.block_id x + %1 = gpu.thread_id x + %2 = arith.addi %arg0, %0 : index + %3 = arith.muli %arg1, %1 : index + %4 = arith.addi %arg2, %3 : index + vector.transfer_write %cst, %arg3[%2, %4] : vector<16xi32>, memref<3x3xi32> + gpu.return + } + } + gpu.module @test_kernel_0 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<3x3xi32>, %arg4: index, %arg5: memref<3x3xi32>, %arg6: memref<3x3xi32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0_i32 = arith.constant 0 : i32 + %0 = gpu.block_id x + %1 = gpu.thread_id x + %2 = arith.addi %arg0, %0 : index + %3 = arith.muli %arg1, %1 : index + %4 = arith.addi %arg2, %3 : index + %5 = vector.transfer_read %arg3[%2, %arg4], %c0_i32 {permutation_map = #map} : memref<3x3xi32>, vector<16xi32> + %6 = vector.transfer_read %arg5[%arg4, %4], %c0_i32 : memref<3x3xi32>, vector<16xi32> + %7 = vector.transfer_read %arg6[%2, %4], %c0_i32 : memref<3x3xi32>, vector<16xi32> + %8 = arith.muli %5, %6 : vector<16xi32> + %9 = arith.addi %7, %8 : vector<16xi32> + vector.transfer_write %9, %arg6[%2, %4] : vector<16xi32>, memref<3x3xi32> + %10 = vector.transfer_read %arg3[%2, %c1], %c0_i32 {permutation_map = #map} : memref<3x3xi32>, vector<16xi32> + %11 = vector.transfer_read %arg5[%c1, %4], %c0_i32 : memref<3x3xi32>, vector<16xi32> + %12 = vector.transfer_read %arg6[%2, %4], %c0_i32 : memref<3x3xi32>, vector<16xi32> + %13 = arith.muli %10, %11 : vector<16xi32> + %14 = arith.addi %12, %13 : vector<16xi32> + vector.transfer_write %14, %arg6[%2, %4] : vector<16xi32>, memref<3x3xi32> + %15 = vector.transfer_read %arg3[%2, %c2], %c0_i32 {permutation_map = #map} : memref<3x3xi32>, vector<16xi32> + %16 = vector.transfer_read %arg5[%c2, %4], %c0_i32 : memref<3x3xi32>, vector<16xi32> + %17 = vector.transfer_read %arg6[%2, %4], %c0_i32 : memref<3x3xi32>, vector<16xi32> + %18 = arith.muli %15, %16 : vector<16xi32> + %19 = arith.addi %17, %18 : vector<16xi32> + vector.transfer_write %19, %arg6[%2, %4] : vector<16xi32>, memref<3x3xi32> + gpu.return + } + } + gpu.module @test_kernel_1 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<3x3xi32>, %arg4: memref<3x3xi32>, %arg5: memref<3x3xi32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0_i32 = arith.constant 0 : i32 + %0 = gpu.block_id x + %1 = gpu.thread_id x + %2 = arith.addi %arg0, %0 : index + %3 = arith.muli %arg1, %1 : index + %4 = arith.addi %arg2, %3 : index + %5 = vector.transfer_read %arg3[%2, %4], %c0_i32 : memref<3x3xi32>, vector<16xi32> + %6 = vector.transfer_read %arg4[%2, %4], %c0_i32 : memref<3x3xi32>, vector<16xi32> + %7 = arith.addi %5, %6 : vector<16xi32> + vector.transfer_write %7, %arg5[%2, %4] : vector<16xi32>, memref<3x3xi32> + gpu.return + } + } +} \ No newline at end of file diff --git a/test/Conversion/VectorToXeGPU/vector-to-llvm.pp b/test/Conversion/VectorToXeGPU/vector-to-llvm.pp new file mode 100644 index 000000000..9a8698589 --- /dev/null +++ b/test/Conversion/VectorToXeGPU/vector-to-llvm.pp @@ -0,0 +1,18 @@ +builtin.module( + convert-vector-to-xegpu + imex-convert-gpu-to-spirv{enable-vc-intrinsic=true} + spirv.module(spirv-lower-abi-attrs + spirv-update-vce) + func.func(llvm-request-c-wrappers) + serialize-spirv + convert-gpu-to-gpux + convert-scf-to-cf + convert-cf-to-llvm + convert-arith-to-llvm + convert-func-to-llvm + convert-math-to-llvm + convert-gpux-to-llvm + expand-strided-metadata + lower-affine + finalize-memref-to-llvm + reconcile-unrealized-casts) diff --git a/test/Conversion/VectorToXeGPU/vector_1d.mlir b/test/Conversion/VectorToXeGPU/vector_1d.mlir new file mode 100644 index 000000000..909f38b3a --- /dev/null +++ b/test/Conversion/VectorToXeGPU/vector_1d.mlir @@ -0,0 +1,25 @@ +// RUN: imex-opt --split-input-file --convert-vector-to-xegpu %s -verify-diagnostics -o -| FileCheck %s + +gpu.module @forward_kernel { + gpu.func @forward_kernel(%arg0: index, %arg1: memref<512x640xf32>, %arg2: f32, %arg3: vector<32xf32>, %arg4: memref<512x640xf32>) { + // CHECK: %[[COL:.*]] = gpu.block_id x + %0 = gpu.block_id x + %1 = gpu.thread_id x + // CHECK: %[[ROW:.*]] = arith.muli %{{.*}}, %arg0 : index + %2 = arith.muli %1, %arg0 : index + // CHECK: %[[TDESC0:.*]] = xegpu.create_nd_tdesc %arg1[%[[COL]], %[[ROW]]] : memref<512x640xf32> + // CHECK-NEXT: %[[LOAD0:.*]] = xegpu.load_nd %[[TDESC0]] {{.*}} -> vector<1x32xf32> + // CHECK-NEXT: %[[CAST0:.*]] = vector.shape_cast %[[LOAD0]] : vector<1x32xf32> to vector<32xf32> + %3 = vector.transfer_read %arg1[%0, %2], %arg2 : memref<512x640xf32>, vector<32xf32> + // CHECK: %[[CMP:.*]] = arith.cmpf ugt, %[[CAST0]], %arg3 : vector<32xf32> + %4 = arith.cmpf ugt, %3, %arg3 : vector<32xf32> + // CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %[[CAST0]], %arg3 : vector<32xi1>, vector<32xf32> + %5 = arith.select %4, %3, %arg3 : vector<32xi1>, vector<32xf32> + // CHECK: %[[CAST1:.*]] = vector.shape_cast %[[SELECT]] : vector<32xf32> to vector<1x32xf32> + // CHECK: %[[TDESC1:.*]] = xegpu.create_nd_tdesc %arg4[%[[COL]], %[[ROW]]] : memref<512x640xf32> + // CHECK-NEXT: xegpu.store_nd %[[CAST1]], %[[TDESC1]] {{.*}} : vector<1x32xf32> + vector.transfer_write %5, %arg4[%0, %2] : vector<32xf32>, memref<512x640xf32> + gpu.return + } +} + diff --git a/test/Conversion/VectorToXeGPU/vector_nd.mlir b/test/Conversion/VectorToXeGPU/vector_nd.mlir new file mode 100644 index 000000000..8ce048611 --- /dev/null +++ b/test/Conversion/VectorToXeGPU/vector_nd.mlir @@ -0,0 +1,24 @@ +// RUN: imex-opt --split-input-file --convert-vector-to-xegpu %s -verify-diagnostics -o -| FileCheck %s + +gpu.module @forward_kernel { + gpu.func @forward_kernel(%arg0: index, %arg1: memref<512x640xf32>, %arg2: f32, %arg3: vector<2x32xf32>, %arg4: memref<512x640xf32>) { + // CHECK: %[[COL:.*]] = gpu.block_id x + %0 = gpu.block_id x + %1 = gpu.thread_id x + // CHECK: %[[ROW:.*]] = arith.muli %{{.*}}, %arg0 : index + %2 = arith.muli %1, %arg0 : index + // CHECK: %[[TDESC0:.*]] = xegpu.create_nd_tdesc %arg1[%[[COL]], %[[ROW]]] : memref<512x640xf32> + // CHECK-NEXT: %[[LOAD0:.*]] = xegpu.load_nd %[[TDESC0]] {{.*}} -> vector<2x32xf32> + %3 = vector.transfer_read %arg1[%0, %2], %arg2 : memref<512x640xf32>, vector<2x32xf32> + // CHECK: %[[CMP:.*]] = arith.cmpf ugt, %[[LOAD0]], %arg3 : vector<2x32xf32> + %4 = arith.cmpf ugt, %3, %arg3 : vector<2x32xf32> + // CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LOAD0]], %arg3 : vector<2x32xi1>, vector<2x32xf32> + %5 = arith.select %4, %3, %arg3 : vector<2x32xi1>, vector<2x32xf32> + // CHECK: %[[TDESC1:.*]] = xegpu.create_nd_tdesc %arg4[%[[COL]], %[[ROW]]] : memref<512x640xf32> + // CHECK-NEXT: xegpu.store_nd %[[SELECT]], %[[TDESC1]] {{.*}} : vector<2x32xf32> + vector.transfer_write %5, %arg4[%0, %2] : vector<2x32xf32>, memref<512x640xf32> + gpu.return + } +} + +