[mlir][memref] Add runtime verification for memref::CastOp
Verify unranked -> ranked casts and casts of dynamic sizes/offset/strides to static ones. Differential Revision: https://reviews.llvm.org/D138671
This commit is contained in:
parent
bd87b84a02
commit
5eee80ce5e
|
@ -14,9 +14,125 @@
|
|||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Generate an error message string for the given op and the specified error.
|
||||
static std::string generateErrorMessage(Operation *op, const std::string &msg) {
|
||||
std::string buffer;
|
||||
llvm::raw_string_ostream stream(buffer);
|
||||
OpPrintingFlags flags;
|
||||
stream << "ERROR: Runtime op verification failed\n";
|
||||
op->print(stream, flags);
|
||||
stream << "\n^ " << msg;
|
||||
stream << "\nLocation: ";
|
||||
op->getLoc().print(stream);
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace memref {
|
||||
namespace {
|
||||
struct CastOpInterface
|
||||
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
|
||||
CastOp> {
|
||||
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
|
||||
Location loc) const {
|
||||
auto castOp = cast<CastOp>(op);
|
||||
auto srcType = castOp.getSource().getType().cast<BaseMemRefType>();
|
||||
|
||||
// Nothing to check if the result is an unranked memref.
|
||||
auto resultType = castOp.getType().dyn_cast<MemRefType>();
|
||||
if (!resultType)
|
||||
return;
|
||||
|
||||
if (srcType.isa<UnrankedMemRefType>()) {
|
||||
// Check rank.
|
||||
Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
|
||||
Value resultRank =
|
||||
builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
|
||||
Value isSameRank = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, srcRank, resultRank);
|
||||
builder.create<cf::AssertOp>(loc, isSameRank,
|
||||
generateErrorMessage(op, "rank mismatch"));
|
||||
}
|
||||
|
||||
// Get source offset and strides. We do not have an op to get offsets and
|
||||
// strides from unranked memrefs, so cast the source to a type with fully
|
||||
// dynamic layout, from which we can then extract the offset and strides.
|
||||
// (Rank was already verified.)
|
||||
int64_t dynamicOffset = ShapedType::kDynamic;
|
||||
SmallVector<int64_t> dynamicShape(resultType.getRank(),
|
||||
ShapedType::kDynamic);
|
||||
auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
|
||||
dynamicOffset, dynamicShape);
|
||||
auto dynStridesType =
|
||||
MemRefType::get(dynamicShape, resultType.getElementType(),
|
||||
stridedLayout, resultType.getMemorySpace());
|
||||
Value helperCast =
|
||||
builder.create<CastOp>(loc, dynStridesType, castOp.getSource());
|
||||
auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast);
|
||||
|
||||
// Check dimension sizes.
|
||||
for (const auto &it : llvm::enumerate(resultType.getShape())) {
|
||||
// Static dim size -> static/dynamic dim size does not need verification.
|
||||
if (auto rankedSrcType = srcType.dyn_cast<MemRefType>())
|
||||
if (!rankedSrcType.isDynamicDim(it.index()))
|
||||
continue;
|
||||
|
||||
// Static/dynamic dim size -> dynamic dim size does not need verification.
|
||||
if (resultType.isDynamicDim(it.index()))
|
||||
continue;
|
||||
|
||||
Value srcDimSz =
|
||||
builder.create<DimOp>(loc, castOp.getSource(), it.index());
|
||||
Value resultDimSz =
|
||||
builder.create<arith::ConstantIndexOp>(loc, it.value());
|
||||
Value isSameSz = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
|
||||
builder.create<cf::AssertOp>(
|
||||
loc, isSameSz,
|
||||
generateErrorMessage(op, "size mismatch of dim " +
|
||||
std::to_string(it.index())));
|
||||
}
|
||||
|
||||
// Get result offset and strides.
|
||||
int64_t resultOffset;
|
||||
SmallVector<int64_t> resultStrides;
|
||||
if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
|
||||
return;
|
||||
|
||||
// Check offset.
|
||||
if (resultOffset != ShapedType::kDynamic) {
|
||||
// Static/dynamic offset -> dynamic offset does not need verification.
|
||||
Value srcOffset = metadataOp.getResult(1);
|
||||
Value resultOffsetVal =
|
||||
builder.create<arith::ConstantIndexOp>(loc, resultOffset);
|
||||
Value isSameOffset = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
|
||||
builder.create<cf::AssertOp>(loc, isSameOffset,
|
||||
generateErrorMessage(op, "offset mismatch"));
|
||||
}
|
||||
|
||||
// Check strides.
|
||||
for (const auto &it : llvm::enumerate(resultStrides)) {
|
||||
// Static/dynamic stride -> dynamic stride does not need verification.
|
||||
if (it.value() == ShapedType::kDynamic)
|
||||
continue;
|
||||
|
||||
Value srcStride =
|
||||
metadataOp.getResult(2 + resultType.getRank() + it.index());
|
||||
Value resultStrideVal =
|
||||
builder.create<arith::ConstantIndexOp>(loc, it.value());
|
||||
Value isSameStride = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
|
||||
builder.create<cf::AssertOp>(
|
||||
loc, isSameStride,
|
||||
generateErrorMessage(op, "stride mismatch of dim " +
|
||||
std::to_string(it.index())));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpandShapeOpInterface
|
||||
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
|
||||
ExpandShapeOp> {
|
||||
|
@ -53,7 +169,8 @@ struct ExpandShapeOpInterface
|
|||
builder.create<arith::ConstantIndexOp>(loc, 0));
|
||||
builder.create<cf::AssertOp>(
|
||||
loc, isModZero,
|
||||
"static result dims in reassoc group do not divide src dim evenly");
|
||||
generateErrorMessage(op, "static result dims in reassoc group do not "
|
||||
"divide src dim evenly"));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -64,6 +181,7 @@ struct ExpandShapeOpInterface
|
|||
void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
|
||||
CastOp::attachInterface<CastOpInterface>(*ctx);
|
||||
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
|
||||
|
||||
// Load additional dialects of which ops may get created.
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
// CHECK-DAG: %[[dim:.*]] = memref.dim %[[m]], %[[c0]]
|
||||
// CHECK: %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]]
|
||||
// CHECK: %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]]
|
||||
// CHECK: cf.assert %[[cmpi]], "static result dims in reassoc group do not divide src dim evenly"
|
||||
// CHECK: cf.assert %[[cmpi]], "ERROR: Runtime op verification failed
|
||||
func.func @expand_shape(%m: memref<?xf32>) -> memref<?x5xf32> {
|
||||
%0 = memref.expand_shape %m [[0, 1]] : memref<?xf32> into memref<?x5xf32>
|
||||
return %0 : memref<?x5xf32>
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
// RUN: mlir-opt %s -generate-runtime-verification -convert-memref-to-llvm \
|
||||
// RUN: -test-cf-assert \
|
||||
// RUN: -convert-func-to-llvm -reconcile-unrealized-casts | \
|
||||
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext 2>&1 | \
|
||||
// RUN: FileCheck %s
|
||||
|
||||
func.func @cast_to_static_dim(%m: memref<?xf32>) -> memref<10xf32> {
|
||||
%0 = memref.cast %m : memref<?xf32> to memref<10xf32>
|
||||
return %0 : memref<10xf32>
|
||||
}
|
||||
|
||||
func.func @cast_to_ranked(%m: memref<*xf32>) -> memref<f32> {
|
||||
%0 = memref.cast %m : memref<*xf32> to memref<f32>
|
||||
return %0 : memref<f32>
|
||||
}
|
||||
|
||||
func.func @cast_to_static_strides(%m: memref<?xf32, strided<[?], offset: ?>>)
|
||||
-> memref<?xf32, strided<[9], offset: 5>> {
|
||||
%0 = memref.cast %m : memref<?xf32, strided<[?], offset: ?>>
|
||||
to memref<?xf32, strided<[9], offset: 5>>
|
||||
return %0 : memref<?xf32, strided<[9], offset: 5>>
|
||||
}
|
||||
|
||||
func.func @valid_cast(%m: memref<*xf32>) -> memref<?xf32> {
|
||||
%0 = memref.cast %m : memref<*xf32> to memref<?xf32>
|
||||
return %0 : memref<?xf32>
|
||||
}
|
||||
|
||||
func.func @main() {
|
||||
// All casts inside the called functions are invalid at runtime, except for
|
||||
// the last one.
|
||||
%alloc = memref.alloc() : memref<5xf32>
|
||||
|
||||
// CHECK: ERROR: Runtime op verification failed
|
||||
// CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32> to memref<10xf32>
|
||||
// CHECK-NEXT: ^ size mismatch of dim 0
|
||||
// CHECK-NEXT: Location: loc({{.*}})
|
||||
%1 = memref.cast %alloc : memref<5xf32> to memref<?xf32>
|
||||
func.call @cast_to_static_dim(%1) : (memref<?xf32>) -> (memref<10xf32>)
|
||||
|
||||
// CHECK-NEXT: ERROR: Runtime op verification failed
|
||||
// CHECK-NEXT: memref.cast %{{.*}} : memref<*xf32> to memref<f32>
|
||||
// CHECK-NEXT: ^ rank mismatch
|
||||
// CHECK-NEXT: Location: loc({{.*}})
|
||||
%3 = memref.cast %alloc : memref<5xf32> to memref<*xf32>
|
||||
func.call @cast_to_ranked(%3) : (memref<*xf32>) -> (memref<f32>)
|
||||
|
||||
// CHECK-NEXT: ERROR: Runtime op verification failed
|
||||
// CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[9], offset: 5>>
|
||||
// CHECK-NEXT: ^ offset mismatch
|
||||
// CHECK-NEXT: Location: loc({{.*}})
|
||||
|
||||
// CHECK-NEXT: ERROR: Runtime op verification failed
|
||||
// CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[9], offset: 5>>
|
||||
// CHECK-NEXT: ^ stride mismatch of dim 0
|
||||
// CHECK-NEXT: Location: loc({{.*}})
|
||||
%4 = memref.cast %alloc
|
||||
: memref<5xf32> to memref<?xf32, strided<[?], offset: ?>>
|
||||
func.call @cast_to_static_strides(%4)
|
||||
: (memref<?xf32, strided<[?], offset: ?>>)
|
||||
-> (memref<?xf32, strided<[9], offset: 5>>)
|
||||
|
||||
// A last cast that actually succeeds.
|
||||
// CHECK-NOT: ERROR: Runtime op verification failed
|
||||
func.call @valid_cast(%3) : (memref<*xf32>) -> (memref<?xf32>)
|
||||
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue
Block a user