From 3cf7f22498254c60067322a28cce2268768bba0b Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 15 Feb 2023 16:52:49 +0000 Subject: [PATCH] [mlir][vectorToGPU] Fix type used when folding transpose into read op Pick the right result type when folding transpose op into a read Differential Revision: https://reviews.llvm.org/D144113 --- .../Conversion/VectorToGPU/VectorToGPU.cpp | 4 +-- .../VectorToGPU/vector-to-mma-ops.mlir | 30 +++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index d8703597b281..ece17510e136 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -444,13 +444,13 @@ struct CombineTransferReadOpTranspose final PatternRewriter &rewriter) const override { // Look through integer extend ops. Value source = op.getVector(); - auto resultType = op.getVectorType(); + Type resultType = op.getType(); Operation *extOp; if ((extOp = source.getDefiningOp()) || (extOp = source.getDefiningOp())) { source = extOp->getOperand(0); resultType = - VectorType::get(resultType.getShape(), + VectorType::get(resultType.cast().getShape(), source.getType().cast().getElementType()); } diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir index 0ba9eb40483b..08f7e12cf55d 100644 --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -408,3 +408,33 @@ func.func @matmul_mixed_signedness_int8(%arg0: memref<16x16xi8>, %arg1: memref<1 vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> return } + +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0) -> (d0, 0)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: func @matmul_mixed_signedness_int8 +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 32 : index} : memref<16x32xi8> -> !gpu.mma_matrix<16x32xui8, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 32 : index} : memref<16x32xi8> -> !gpu.mma_matrix<32x16xsi8, "BOp"> +// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi32> -> !gpu.mma_matrix<16x16xi32, "COp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x32xui8, "AOp">, !gpu.mma_matrix<32x16xsi8, "BOp"> -> !gpu.mma_matrix<16x16xi32, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xi32, "COp">, memref<16x16xi32> +func.func @matmul_mixed_signedness_int8(%arg0: memref<16x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<16x16xi32>) { + %cst_0 = arith.constant dense<0> : vector<16x16xi8> + %c0 = arith.constant 0 : index + %cst_i8 = arith.constant 0 : i8 + %cst_i32 = arith.constant 0 : i32 + %Ar = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x32xi8>, vector<16x32xi8> + %Br = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x32xi8>, vector<16x32xi8> + %C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32> + %Ae = arith.extui %Ar : vector<16x32xi8> to vector<16x32xi32> + %Be = arith.extsi %Br : vector<16x32xi8> to vector<16x32xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %Ae, %Be, %C : vector<16x32xi32>, vector<16x32xi32> into vector<16x16xi32> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> + return +} \ No newline at end of file