[mlir][memref] Bufferize memref.tensor_store op
This change adds the BufferizableOpInterface implementation for memref.tensor_store. Differential Revision: https://reviews.llvm.org/D144080
This commit is contained in:
parent
01581e28ad
commit
c645eb0d03
|
@ -0,0 +1,21 @@
|
||||||
|
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||||
|
#define MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
class DialectRegistry;
|
||||||
|
|
||||||
|
namespace memref {
|
||||||
|
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||||
|
} // namespace memref
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H
|
|
@ -45,6 +45,7 @@
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
|
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
|
||||||
|
#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
|
||||||
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
|
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
|
||||||
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
||||||
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
||||||
|
@ -131,6 +132,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
||||||
registry);
|
registry);
|
||||||
linalg::registerBufferizableOpInterfaceExternalModels(registry);
|
linalg::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
linalg::registerTilingInterfaceExternalModels(registry);
|
linalg::registerTilingInterfaceExternalModels(registry);
|
||||||
|
memref::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
|
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
|
||||||
scf::registerBufferizableOpInterfaceExternalModels(registry);
|
scf::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
shape::registerBufferizableOpInterfaceExternalModels(registry);
|
shape::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::bufferization;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// Bufferization of memref.tensor_store. Replace with memref.copy.
|
||||||
|
struct TensorStoreOpInterface
|
||||||
|
: public BufferizableOpInterface::ExternalModel<TensorStoreOpInterface,
|
||||||
|
memref::TensorStoreOp> {
|
||||||
|
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
|
||||||
|
const AnalysisState &state) const {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
|
const AnalysisState &state) const {
|
||||||
|
assert(opOperand.getOperandNumber() == 0 && "expected src operand");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||||
|
const AnalysisState &state) const {
|
||||||
|
// The memref operand is written but not the tensor operand.
|
||||||
|
assert(opOperand.getOperandNumber() == 0 && "expected src operand");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
|
const BufferizationOptions &options) const {
|
||||||
|
auto tensorStoreOp = cast<memref::TensorStoreOp>(op);
|
||||||
|
auto srcBuffer = getBuffer(rewriter, tensorStoreOp.getTensor(), options);
|
||||||
|
if (failed(srcBuffer))
|
||||||
|
return failure();
|
||||||
|
if (failed(options.createMemCpy(rewriter, op->getLoc(), *srcBuffer,
|
||||||
|
tensorStoreOp.getMemref())))
|
||||||
|
return failure();
|
||||||
|
rewriter.eraseOp(tensorStoreOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void mlir::memref::registerBufferizableOpInterfaceExternalModels(
|
||||||
|
DialectRegistry ®istry) {
|
||||||
|
registry.addExtension(+[](MLIRContext *ctx, MemRefDialect *dialect) {
|
||||||
|
TensorStoreOp::attachInterface<TensorStoreOpInterface>(*ctx);
|
||||||
|
});
|
||||||
|
}
|
|
@ -1,4 +1,5 @@
|
||||||
add_mlir_dialect_library(MLIRMemRefTransforms
|
add_mlir_dialect_library(MLIRMemRefTransforms
|
||||||
|
BufferizableOpInterfaceImpl.cpp
|
||||||
ComposeSubView.cpp
|
ComposeSubView.cpp
|
||||||
ExpandOps.cpp
|
ExpandOps.cpp
|
||||||
ExpandStridedMetadata.cpp
|
ExpandStridedMetadata.cpp
|
||||||
|
@ -20,6 +21,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
|
||||||
MLIRAffineUtils
|
MLIRAffineUtils
|
||||||
MLIRArithDialect
|
MLIRArithDialect
|
||||||
MLIRArithTransforms
|
MLIRArithTransforms
|
||||||
|
MLIRBufferizationDialect
|
||||||
MLIRFuncDialect
|
MLIRFuncDialect
|
||||||
MLIRInferTypeOpInterface
|
MLIRInferTypeOpInterface
|
||||||
MLIRLoopLikeInterface
|
MLIRLoopLikeInterface
|
||||||
|
|
|
@ -38,6 +38,34 @@ transform.sequence failures(propagate) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensor_pad_constant(
|
||||||
|
// CHECK-SAME: %[[t:.*]]: tensor<?x10xindex>
|
||||||
|
// CHECK: %[[src:.*]] = bufferization.to_memref %[[t]]
|
||||||
|
// CHECK: %[[alloc:.*]] = memref.alloc
|
||||||
|
// CHECK: %[[subview:.*]] = memref.subview %[[alloc]]
|
||||||
|
// CHECK: memref.copy %[[src]], %[[subview]]
|
||||||
|
// CHECK: bufferization.to_tensor %[[alloc]] restrict writable
|
||||||
|
func.func @tensor_pad_constant(%t: tensor<?x10xindex>, %l2: index, %h1: index,
|
||||||
|
%h2: index) -> tensor<?x?xindex> {
|
||||||
|
%0 = tensor.pad %t low[5, %l2] high[%h1, %h2] {
|
||||||
|
^bb0(%arg0: index, %arg1: index):
|
||||||
|
%c = arith.constant 50 : index
|
||||||
|
tensor.yield %c : index
|
||||||
|
} : tensor<?x10xindex> to tensor<?x?xindex>
|
||||||
|
return %0 : tensor<?x?xindex>
|
||||||
|
}
|
||||||
|
|
||||||
|
transform.sequence failures(propagate) {
|
||||||
|
^bb1(%arg1: !pdl.operation):
|
||||||
|
%0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation
|
||||||
|
%1 = transform.get_result %0[0] : (!pdl.operation) -> !transform.any_value
|
||||||
|
%2 = transform.structured.bufferize_to_allocation %1
|
||||||
|
// Make sure that One-Shot Bufferize can bufferize the rest.
|
||||||
|
transform.bufferization.one_shot_bufferize %arg1
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @materialization_of_bbarg(
|
// CHECK-LABEL: func @materialization_of_bbarg(
|
||||||
// CHECK-SAME: %[[t:.*]]: tensor<?x10xindex>
|
// CHECK-SAME: %[[t:.*]]: tensor<?x10xindex>
|
||||||
// CHECK: %[[c0:.*]] = arith.constant 0 : index
|
// CHECK: %[[c0:.*]] = arith.constant 0 : index
|
||||||
|
@ -59,3 +87,26 @@ transform.sequence failures(propagate) {
|
||||||
%1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!pdl.operation) -> !transform.any_value
|
%1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!pdl.operation) -> !transform.any_value
|
||||||
%2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4}
|
%2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @materialization_of_bbarg(
|
||||||
|
// CHECK-SAME: %[[t:.*]]: tensor<?x10xindex>
|
||||||
|
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]]
|
||||||
|
// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}) : memref<?x10xindex, 4>
|
||||||
|
// CHECK: memref.copy %[[m]], %[[alloc]]
|
||||||
|
// CHECK: %[[r:.*]] = memref.load %[[alloc]]
|
||||||
|
// CHECK: return %[[r]]
|
||||||
|
func.func @materialization_of_bbarg(%t: tensor<?x10xindex>, %idx: index) -> index {
|
||||||
|
%r = tensor.extract %t[%idx, %idx] : tensor<?x10xindex>
|
||||||
|
return %r : index
|
||||||
|
}
|
||||||
|
|
||||||
|
transform.sequence failures(propagate) {
|
||||||
|
^bb1(%arg1: !pdl.operation):
|
||||||
|
%0 = transform.structured.match ops{["tensor.extract"]} in %arg1 : (!pdl.operation) -> !pdl.operation
|
||||||
|
%1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!pdl.operation) -> !transform.any_value
|
||||||
|
%2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4}
|
||||||
|
// Make sure that One-Shot Bufferize can bufferize the rest.
|
||||||
|
transform.bufferization.one_shot_bufferize %arg1
|
||||||
|
}
|
||||||
|
|
11
mlir/test/Dialect/MemRef/bufferize.mlir
Normal file
11
mlir/test/Dialect/MemRef/bufferize.mlir
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
// RUN: mlir-opt -one-shot-bufferize %s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensor_store(
|
||||||
|
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>, %[[m:.*]]: memref<?xf32>
|
||||||
|
// CHECK: %[[src:.*]] = bufferization.to_memref %[[t]]
|
||||||
|
// CHECK: memref.copy %[[src]], %[[m]]
|
||||||
|
// CHECK: return
|
||||||
|
func.func @tensor_store(%t: tensor<?xf32>, %m: memref<?xf32>) {
|
||||||
|
memref.tensor_store %t, %m : memref<?xf32>
|
||||||
|
return
|
||||||
|
}
|
|
@ -9904,6 +9904,7 @@ cc_library(
|
||||||
":ArithDialect",
|
":ArithDialect",
|
||||||
":ArithTransforms",
|
":ArithTransforms",
|
||||||
":ArithUtils",
|
":ArithUtils",
|
||||||
|
":BufferizationDialect",
|
||||||
":ControlFlowDialect",
|
":ControlFlowDialect",
|
||||||
":DialectUtils",
|
":DialectUtils",
|
||||||
":FuncDialect",
|
":FuncDialect",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user