[mlir][memref] Fold subview into GPU subgroup MMA load/store ops

This commits adds support for folding subview into GPU subgroup
MMA load/store ops.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D146150
This commit is contained in:
Lei Zhang 2023-03-15 17:49:27 +00:00
parent 28a0d0e85a
commit 59e4fbfcd0
4 changed files with 76 additions and 2 deletions

View File

@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
MLIRArithTransforms
MLIRBufferizationDialect
MLIRFuncDialect
MLIRGPUOps
MLIRInferTypeOpInterface
MLIRLoopLikeInterface
MLIRMemRefDialect

View File

@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@ -216,6 +217,14 @@ static Value getMemRefOperand(vector::TransferWriteOp op) {
return op.getSource();
}
static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
return op.getSrcMemref();
}
static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
return op.getDstMemref();
}
/// Given the permutation map of the original
/// `vector.transfer_read`/`vector.transfer_write` operations compute the
/// permutation map to use after the subview is folded with it.
@ -407,6 +416,11 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
transferReadOp.getPadding(),
/*mask=*/Value(), transferReadOp.getInBoundsAttr());
})
.Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
op, op.getType(), subViewOp.getSource(), sourceIndices,
op.getLeadDimension(), op.getTransposeAttr());
})
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
}
@ -502,11 +516,11 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case([&](AffineStoreOp op) {
rewriter.replaceOpWithNewOp<AffineStoreOp>(
storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices);
op, op.getValue(), subViewOp.getSource(), sourceIndices);
})
.Case([&](memref::StoreOp op) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices,
op, op.getValue(), subViewOp.getSource(), sourceIndices,
op.getNontemporal());
})
.Case([&](vector::TransferWriteOp op) {
@ -516,6 +530,11 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
op.getPermutationMap()),
op.getInBoundsAttr());
})
.Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
op, op.getSrc(), subViewOp.getSource(), sourceIndices,
op.getLeadDimension(), op.getTransposeAttr());
})
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
}
@ -590,9 +609,11 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
patterns.add<LoadOpOfSubViewOpFolder<AffineLoadOp>,
LoadOpOfSubViewOpFolder<memref::LoadOp>,
LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
StoreOpOfSubViewOpFolder<AffineStoreOp>,
StoreOpOfSubViewOpFolder<memref::StoreOp>,
StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
LoadOpOfExpandShapeOpFolder<AffineLoadOp>,
LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
StoreOpOfExpandShapeOpFolder<AffineStoreOp>,

View File

@ -524,3 +524,54 @@ func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index,
memref.store %arg5, %0[%arg3, %arg4] {nontemporal=true}: memref<4x4xf32, strided<[64, 3], offset: ?>>
return
}
// -----
func.func @fold_gpu_subgroup_mma_load_matrix_1d(%src: memref<?xvector<4xf32>>, %offset: index, %i: index) -> !gpu.mma_matrix<16x16xf16, "COp"> {
%subview = memref.subview %src[%offset] [81920] [1] : memref<?xvector<4xf32>> to memref<81920xvector<4xf32>, strided<[1], offset: ?>>
%matrix = gpu.subgroup_mma_load_matrix %subview[%i] {leadDimension = 160 : index} : memref<81920xvector<4xf32>, strided<[1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
return %matrix: !gpu.mma_matrix<16x16xf16, "COp">
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK: func.func @fold_gpu_subgroup_mma_load_matrix_1d
// CHECK-SAME: (%[[SRC:.+]]: memref<?xvector<4xf32>>, %[[OFFSET:.+]]: index, %[[I:.+]]: index)
// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP]](%[[I]])[%[[OFFSET]]]
// CHECK: %[[LOAD:.+]] = gpu.subgroup_mma_load_matrix %[[SRC]][%[[APPLY]]] {leadDimension = 160 : index} : memref<?xvector<4xf32>> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: return %[[LOAD]]
// -----
func.func @fold_gpu_subgroup_mma_store_matrix_1d(%dst: memref<?xvector<4xf32>>, %offset: index, %i: index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) {
%subview = memref.subview %dst[%offset] [81920] [1] : memref<?xvector<4xf32>> to memref<81920xvector<4xf32>, strided<[1], offset: ?>>
gpu.subgroup_mma_store_matrix %matrix, %subview[%i] {leadDimension = 160 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<81920xvector<4xf32>, strided<[1], offset: ?>>
return
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK: func.func @fold_gpu_subgroup_mma_store_matrix_1d
// CHECK-SAME: (%[[DST:.+]]: memref<?xvector<4xf32>>, %[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[VAL:.+]]: !gpu.mma_matrix<16x16xf16, "COp">)
// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP]](%[[I0]])[%[[OFFSET]]]
// CHECK: gpu.subgroup_mma_store_matrix %[[VAL]], %[[DST]][%[[APPLY]]] {leadDimension = 160 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<?xvector<4xf32>>
// -----
// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
// CHECK-SAME: %[[SRC:.+]]: memref<128x128xf32>
func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> !gpu.mma_matrix<16x16xf16, "COp"> {
%subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[64, 1], offset: ?>>
// CHECK: gpu.subgroup_mma_load_matrix %[[SRC]][{{.+}}] {leadDimension = 32 : index} : memref<128x128xf32> -> !gpu.mma_matrix<16x16xf16, "COp">
%matrix = gpu.subgroup_mma_load_matrix %subview[%arg3, %arg4] {leadDimension = 32 : index} : memref<64x32xf32, strided<[64, 1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
return %matrix : !gpu.mma_matrix<16x16xf16, "COp">
}
// -----
// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
// CHECK-SAME: %[[DST:.+]]: memref<128x128xf32>
func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) {
%subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[64, 1], offset: ?>>
// CHECK: gpu.subgroup_mma_store_matrix %{{.+}}, %[[DST]][{{.+}}] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf32>
gpu.subgroup_mma_store_matrix %matrix, %subview[%arg3, %arg4] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<64x32xf32, strided<[64, 1], offset: ?>>
return
}

View File

@ -9954,6 +9954,7 @@ cc_library(
":ControlFlowDialect",
":DialectUtils",
":FuncDialect",
":GPUDialect",
":IR",
":InferTypeOpInterface",
":LoopLikeInterface",