[mlir][memref] Fold subview into GPU subgroup MMA load/store ops
This commits adds support for folding subview into GPU subgroup MMA load/store ops. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D146150
This commit is contained in:
parent
28a0d0e85a
commit
59e4fbfcd0
|
@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
|
|||
MLIRArithTransforms
|
||||
MLIRBufferizationDialect
|
||||
MLIRFuncDialect
|
||||
MLIRGPUOps
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRLoopLikeInterface
|
||||
MLIRMemRefDialect
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
|
@ -216,6 +217,14 @@ static Value getMemRefOperand(vector::TransferWriteOp op) {
|
|||
return op.getSource();
|
||||
}
|
||||
|
||||
static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
|
||||
return op.getSrcMemref();
|
||||
}
|
||||
|
||||
static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
|
||||
return op.getDstMemref();
|
||||
}
|
||||
|
||||
/// Given the permutation map of the original
|
||||
/// `vector.transfer_read`/`vector.transfer_write` operations compute the
|
||||
/// permutation map to use after the subview is folded with it.
|
||||
|
@ -407,6 +416,11 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
|
|||
transferReadOp.getPadding(),
|
||||
/*mask=*/Value(), transferReadOp.getInBoundsAttr());
|
||||
})
|
||||
.Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
|
||||
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
|
||||
op, op.getType(), subViewOp.getSource(), sourceIndices,
|
||||
op.getLeadDimension(), op.getTransposeAttr());
|
||||
})
|
||||
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
|
||||
return success();
|
||||
}
|
||||
|
@ -502,11 +516,11 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
|
|||
llvm::TypeSwitch<Operation *, void>(storeOp)
|
||||
.Case([&](AffineStoreOp op) {
|
||||
rewriter.replaceOpWithNewOp<AffineStoreOp>(
|
||||
storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices);
|
||||
op, op.getValue(), subViewOp.getSource(), sourceIndices);
|
||||
})
|
||||
.Case([&](memref::StoreOp op) {
|
||||
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
||||
storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices,
|
||||
op, op.getValue(), subViewOp.getSource(), sourceIndices,
|
||||
op.getNontemporal());
|
||||
})
|
||||
.Case([&](vector::TransferWriteOp op) {
|
||||
|
@ -516,6 +530,11 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
|
|||
op.getPermutationMap()),
|
||||
op.getInBoundsAttr());
|
||||
})
|
||||
.Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
|
||||
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
|
||||
op, op.getSrc(), subViewOp.getSource(), sourceIndices,
|
||||
op.getLeadDimension(), op.getTransposeAttr());
|
||||
})
|
||||
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
|
||||
return success();
|
||||
}
|
||||
|
@ -590,9 +609,11 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
|
|||
patterns.add<LoadOpOfSubViewOpFolder<AffineLoadOp>,
|
||||
LoadOpOfSubViewOpFolder<memref::LoadOp>,
|
||||
LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
|
||||
LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
|
||||
StoreOpOfSubViewOpFolder<AffineStoreOp>,
|
||||
StoreOpOfSubViewOpFolder<memref::StoreOp>,
|
||||
StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
|
||||
StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
|
||||
LoadOpOfExpandShapeOpFolder<AffineLoadOp>,
|
||||
LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
|
||||
StoreOpOfExpandShapeOpFolder<AffineStoreOp>,
|
||||
|
|
|
@ -524,3 +524,54 @@ func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index,
|
|||
memref.store %arg5, %0[%arg3, %arg4] {nontemporal=true}: memref<4x4xf32, strided<[64, 3], offset: ?>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fold_gpu_subgroup_mma_load_matrix_1d(%src: memref<?xvector<4xf32>>, %offset: index, %i: index) -> !gpu.mma_matrix<16x16xf16, "COp"> {
|
||||
%subview = memref.subview %src[%offset] [81920] [1] : memref<?xvector<4xf32>> to memref<81920xvector<4xf32>, strided<[1], offset: ?>>
|
||||
%matrix = gpu.subgroup_mma_load_matrix %subview[%i] {leadDimension = 160 : index} : memref<81920xvector<4xf32>, strided<[1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
return %matrix: !gpu.mma_matrix<16x16xf16, "COp">
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
|
||||
// CHECK: func.func @fold_gpu_subgroup_mma_load_matrix_1d
|
||||
// CHECK-SAME: (%[[SRC:.+]]: memref<?xvector<4xf32>>, %[[OFFSET:.+]]: index, %[[I:.+]]: index)
|
||||
// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP]](%[[I]])[%[[OFFSET]]]
|
||||
// CHECK: %[[LOAD:.+]] = gpu.subgroup_mma_load_matrix %[[SRC]][%[[APPLY]]] {leadDimension = 160 : index} : memref<?xvector<4xf32>> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
// CHECK: return %[[LOAD]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fold_gpu_subgroup_mma_store_matrix_1d(%dst: memref<?xvector<4xf32>>, %offset: index, %i: index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) {
|
||||
%subview = memref.subview %dst[%offset] [81920] [1] : memref<?xvector<4xf32>> to memref<81920xvector<4xf32>, strided<[1], offset: ?>>
|
||||
gpu.subgroup_mma_store_matrix %matrix, %subview[%i] {leadDimension = 160 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<81920xvector<4xf32>, strided<[1], offset: ?>>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
|
||||
// CHECK: func.func @fold_gpu_subgroup_mma_store_matrix_1d
|
||||
// CHECK-SAME: (%[[DST:.+]]: memref<?xvector<4xf32>>, %[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[VAL:.+]]: !gpu.mma_matrix<16x16xf16, "COp">)
|
||||
// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP]](%[[I0]])[%[[OFFSET]]]
|
||||
// CHECK: gpu.subgroup_mma_store_matrix %[[VAL]], %[[DST]][%[[APPLY]]] {leadDimension = 160 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<?xvector<4xf32>>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
|
||||
// CHECK-SAME: %[[SRC:.+]]: memref<128x128xf32>
|
||||
func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> !gpu.mma_matrix<16x16xf16, "COp"> {
|
||||
%subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[64, 1], offset: ?>>
|
||||
// CHECK: gpu.subgroup_mma_load_matrix %[[SRC]][{{.+}}] {leadDimension = 32 : index} : memref<128x128xf32> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
%matrix = gpu.subgroup_mma_load_matrix %subview[%arg3, %arg4] {leadDimension = 32 : index} : memref<64x32xf32, strided<[64, 1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
return %matrix : !gpu.mma_matrix<16x16xf16, "COp">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
|
||||
// CHECK-SAME: %[[DST:.+]]: memref<128x128xf32>
|
||||
func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) {
|
||||
%subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[64, 1], offset: ?>>
|
||||
// CHECK: gpu.subgroup_mma_store_matrix %{{.+}}, %[[DST]][{{.+}}] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf32>
|
||||
gpu.subgroup_mma_store_matrix %matrix, %subview[%arg3, %arg4] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<64x32xf32, strided<[64, 1], offset: ?>>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -9954,6 +9954,7 @@ cc_library(
|
|||
":ControlFlowDialect",
|
||||
":DialectUtils",
|
||||
":FuncDialect",
|
||||
":GPUDialect",
|
||||
":IR",
|
||||
":InferTypeOpInterface",
|
||||
":LoopLikeInterface",
|
||||
|
|
Loading…
Reference in New Issue
Block a user