[mlir] Argument and result attribute handling during inlining.

The revision adds the handleArgument and handleResult handlers that
allow users of the inlining interface to implement argument and result
conversions that take argument and result attributes into account. The
motivating use cases for this revision are taken from the LLVM dialect
inliner, which has to copy arguments that are marked as byval and that
also has to consider zeroext / signext when converting integers.

All type conversions are currently handled by the
materializeCallConversion hook. It runs before isLegalToInline and
supports only the introduction of a single cast operation since it may
have to rollback. The new handlers run shortly before and after
inlining and cannot fail. As a result, they can introduce more complex
ir such as copying a struct argument. At the moment, the new hooks
cannot be used to perform type conversions since all type conversions
have to be done using the materializeCallConversion. A follow up
revision will either relax this constraint or drop
materializeCallConversion in favor of the new and more flexible
handlers.

The revision also extends the CallableOpInterface to provide access
to the argument and result attributes if available.

Reviewed By: rriddle, Dinistro

Differential Revision: https://reviews.llvm.org/D145582
This commit is contained in:
Tobias Gysi 2023-03-22 08:38:55 +01:00
parent 9c16eef1ec
commit f809eb4db2
19 changed files with 459 additions and 1 deletions

View File

@ -731,6 +731,8 @@ interface section goes as follows:
* `CallableOpInterface` - Used to represent the target callee of call.
- `Region * getCallableRegion()`
- `ArrayRef<Type> getCallableResults()`
- `ArrayAttr getCallableArgAttrs()`
- `ArrayAttr getCallableResAttrs()`
##### RegionKindInterfaces

View File

@ -169,6 +169,18 @@ Region *FuncOp::getCallableRegion() { return &getBody(); }
/// executed.
ArrayRef<Type> FuncOp::getCallableResults() { return getType().getResults(); }
/// Returns the argument attributes for all callable region arguments or
/// null if there are none.
ArrayAttr FuncOp::getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
/// Returns the result attributes for all callable region results or
/// null if there are none.
ArrayAttr FuncOp::getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
// ....
/// Return the callee of the generic call operation, this is required by the

View File

@ -307,6 +307,18 @@ llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
return getFunctionType().getResults();
}
/// Returns the argument attributes for all callable region arguments or
/// null if there are none.
ArrayAttr FuncOp::getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
/// Returns the result attributes for all callable region results or
/// null if there are none.
ArrayAttr FuncOp::getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -307,6 +307,18 @@ llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
return getFunctionType().getResults();
}
/// Returns the argument attributes for all callable region arguments or
/// null if there are none.
ArrayAttr FuncOp::getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
/// Returns the result attributes for all callable region results or
/// null if there are none.
ArrayAttr FuncOp::getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -307,6 +307,18 @@ llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
return getFunctionType().getResults();
}
/// Returns the argument attributes for all callable region arguments or
/// null if there are none.
ArrayAttr FuncOp::getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
/// Returns the result attributes for all callable region results or
/// null if there are none.
ArrayAttr FuncOp::getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -336,6 +336,18 @@ llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
return getFunctionType().getResults();
}
/// Returns the argument attributes for all callable region arguments or
/// null if there are none.
ArrayAttr FuncOp::getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
/// Returns the result attributes for all callable region results or
/// null if there are none.
ArrayAttr FuncOp::getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -168,6 +168,18 @@ def Async_FuncOp : Async_Op<"func",
ArrayRef<Type> getCallableResults() { return getFunctionType()
.getResults(); }
/// Returns the argument attributes for all callable region arguments or
/// null if there are none.
::mlir::ArrayAttr getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
/// Returns the result attributes for all callable region results or
/// null if there are none.
::mlir::ArrayAttr getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//

View File

@ -299,6 +299,18 @@ def FuncOp : Func_Op<"func", [
/// executed.
ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
/// Returns the argument attributes for all callable region arguments or
/// null if there are none.
::mlir::ArrayAttr getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
/// Returns the result attributes for all callable region results or
/// null if there are none.
::mlir::ArrayAttr getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//

View File

@ -1583,6 +1583,10 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return getFunctionType().getReturnTypes(); }
//===------------------------------------------------------------------===//
// CallableOpInterface
//===------------------------------------------------------------------===//
/// Returns the callable region, which is the function body. If the function
/// is external, returns null.
Region *getCallableRegion();
@ -1596,6 +1600,17 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
return getFunctionType().getReturnTypes();
}
/// Returns the argument attributes for all callable region arguments or
/// null if there are none.
::mlir::ArrayAttr getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
/// Returns the result attributes for all callable region results or
/// null if there are none.
::mlir::ArrayAttr getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
}];
let hasCustomAssemblyFormat = 1;

View File

@ -73,6 +73,18 @@ def MLProgram_FuncOp : MLProgram_Op<"func", [
/// executed.
ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
/// Returns the argument attributes for all callable region arguments or
/// null if there are none.
::mlir::ArrayAttr getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
/// Returns the result attributes for all callable region results or
/// null if there are none.
::mlir::ArrayAttr getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
@ -422,6 +434,18 @@ def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [
/// executed.
ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
/// Returns the argument attributes for all callable region arguments or
/// null if there are none.
::mlir::ArrayAttr getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
/// Returns the result attributes for all callable region results or
/// null if there are none.
::mlir::ArrayAttr getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//

View File

@ -1149,6 +1149,18 @@ def Shape_FuncOp : Shape_Op<"func",
return getFunctionType().getResults();
}
/// Returns the argument attributes for all callable region arguments or
/// null if there are none.
::mlir::ArrayAttr getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
/// Returns the result attributes for all callable region results or
/// null if there are none.
::mlir::ArrayAttr getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//

View File

@ -394,6 +394,12 @@ def NamedSequenceOp : TransformDialectOp<"named_sequence",
::llvm::ArrayRef<::mlir::Type> getCallableResults() {
return getFunctionType().getResults();
}
::mlir::ArrayAttr getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
::mlir::ArrayAttr getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
}];
}

View File

@ -84,6 +84,18 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
}],
"::llvm::ArrayRef<::mlir::Type>", "getCallableResults"
>,
InterfaceMethod<[{
Returns the argument attributes for all callable region arguments or
null if there are none.
}],
"::mlir::ArrayAttr", "getCallableArgAttrs"
>,
InterfaceMethod<[{
Returns the result attributes for all callable region results or null
if there are none.
}],
"::mlir::ArrayAttr", "getCallableResAttrs"
>
];
}

View File

@ -13,6 +13,7 @@
#ifndef MLIR_TRANSFORMS_INLININGUTILS_H
#define MLIR_TRANSFORMS_INLININGUTILS_H
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Region.h"
@ -141,6 +142,40 @@ public:
return nullptr;
}
/// Hook to transform the call arguments before using them to replace the
/// callee arguments. It returns the transformation result or `argument`
/// itself if the hook did not change anything. The type of the returned value
/// has to match `targetType`, and the `argumentAttrs` dictionary is non-null
/// even if no attribute is present. The hook is called after converting the
/// callsite argument types using the materializeCallConversion callback, and
/// right before inlining the callee region. Any operations created using the
/// provided `builder` are inserted right before the inlined callee region.
/// Example use cases are the insertion of copies for by value arguments, or
/// integer conversions that require signedness information.
virtual Value handleArgument(OpBuilder &builder, Operation *call,
Operation *callable, Value argument,
Type targetType,
DictionaryAttr argumentAttrs) const {
return argument;
}
/// Hook to transform the callee results before using them to replace the call
/// results. It returns the transformation result or the `result` itself if
/// the hook did not change anything. The type of the returned values has to
/// match `targetType`, and the `resultAttrs` dictionary is non-null even if
/// no attribute is present. The hook is called right before handling
/// terminators, and obtains the callee result before converting its type
/// using the `materializeCallConversion` callback. Any operations created
/// using the provided `builder` are inserted right after the inlined callee
/// region. Example use cases are the insertion of copies for by value results
/// or integer conversions that require signedness information.
/// NOTE: This hook is invoked after inlining the `callable` region.
virtual Value handleResult(OpBuilder &builder, Operation *call,
Operation *callable, Value result, Type targetType,
DictionaryAttr resultAttrs) const {
return result;
}
/// Process a set of blocks that have been inlined for a call. This callback
/// is invoked before inlined terminator operations have been processed.
virtual void processInlinedCallBlocks(
@ -183,6 +218,15 @@ public:
virtual void handleTerminator(Operation *op, Block *newDest) const;
virtual void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const;
virtual Value handleArgument(OpBuilder &builder, Operation *call,
Operation *callable, Value argument,
Type targetType,
DictionaryAttr argumentAttrs) const;
virtual Value handleResult(OpBuilder &builder, Operation *call,
Operation *callable, Value result, Type targetType,
DictionaryAttr resultAttrs) const;
virtual void processInlinedCallBlocks(
Operation *call, iterator_range<Region::iterator> inlinedBlocks) const;
};

View File

@ -2469,6 +2469,16 @@ ArrayRef<Type> spirv::FuncOp::getCallableResults() {
return getFunctionType().getResults();
}
// CallableOpInterface
::mlir::ArrayAttr spirv::FuncOp::getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
// CallableOpInterface
::mlir::ArrayAttr spirv::FuncOp::getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
//===----------------------------------------------------------------------===//
// spirv.FunctionCall
//===----------------------------------------------------------------------===//

View File

@ -103,6 +103,26 @@ void InlinerInterface::handleTerminator(Operation *op,
handler->handleTerminator(op, valuesToRepl);
}
Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call,
Operation *callable, Value argument,
Type targetType,
DictionaryAttr argumentAttrs) const {
auto *handler = getInterfaceFor(callable);
assert(handler && "expected valid dialect handler");
return handler->handleArgument(builder, call, callable, argument, targetType,
argumentAttrs);
}
Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call,
Operation *callable, Value result,
Type targetType,
DictionaryAttr resultAttrs) const {
auto *handler = getInterfaceFor(callable);
assert(handler && "expected valid dialect handler");
return handler->handleResult(builder, call, callable, result, targetType,
resultAttrs);
}
void InlinerInterface::processInlinedCallBlocks(
Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
auto *handler = getInterfaceFor(call);
@ -141,6 +161,71 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
// Inline Methods
//===----------------------------------------------------------------------===//
static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
CallOpInterface call,
CallableOpInterface callable,
IRMapping &mapper) {
// Unpack the argument attributes if there are any.
SmallVector<DictionaryAttr> argAttrs(
callable.getCallableRegion()->getNumArguments(),
builder.getDictionaryAttr({}));
if (ArrayAttr arrayAttr = callable.getCallableArgAttrs()) {
assert(arrayAttr.size() == argAttrs.size());
for (auto [idx, attr] : llvm::enumerate(arrayAttr))
argAttrs[idx] = cast<DictionaryAttr>(attr);
}
// Run the argument attribute handler for the given argument and attribute.
for (auto [blockArg, argAttr] :
llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
Value newArgument = interface.handleArgument(builder, call, callable,
mapper.lookup(blockArg),
blockArg.getType(), argAttr);
assert(newArgument.getType() == blockArg.getType() &&
"expected the handled argument type to match the target type");
// Update the mapping to point the new argument returned by the handler.
mapper.map(blockArg, newArgument);
}
}
static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
CallOpInterface call, CallableOpInterface callable,
ValueRange results) {
// Unpack the result attributes if there are any.
SmallVector<DictionaryAttr> resAttrs(results.size(),
builder.getDictionaryAttr({}));
if (ArrayAttr arrayAttr = callable.getCallableResAttrs()) {
assert(arrayAttr.size() == resAttrs.size());
for (auto [idx, attr] : llvm::enumerate(arrayAttr))
resAttrs[idx] = cast<DictionaryAttr>(attr);
}
// Run the result attribute handler for the given result and attribute.
SmallVector<DictionaryAttr> resultAttributes;
for (auto [result, resAttr] : llvm::zip(results, resAttrs)) {
// Store the original result users before running the handler.
DenseSet<Operation *> resultUsers;
for (Operation *user : result.getUsers())
resultUsers.insert(user);
// TODO: Use the type of the call result to replace once the hook can be
// used for type conversions. At the moment, all type conversions have to be
// done using materializeCallConversion.
Type targetType = result.getType();
Value newResult = interface.handleResult(builder, call, callable, result,
targetType, resAttr);
assert(newResult.getType() == targetType &&
"expected the handled result type to match the target type");
// Replace the result uses except for the ones introduce by the handler.
result.replaceUsesWithIf(newResult, [&](OpOperand &operand) {
return resultUsers.count(operand.getOwner());
});
}
}
static LogicalResult
inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
Block::iterator inlinePoint, IRMapping &mapper,
@ -166,6 +251,12 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
mapper))
return failure();
// Run the argument attribute handler before inlining the callable region.
OpBuilder builder(inlineBlock, inlinePoint);
auto callable = dyn_cast<CallableOpInterface>(src->getParentOp());
if (call && callable)
handleArgumentImpl(interface, builder, call, callable, mapper);
// Check to see if the region is being cloned, or moved inline. In either
// case, move the new blocks after the 'insertBlock' to improve IR
// readability.
@ -199,8 +290,14 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
// Handle the case where only a single block was inlined.
if (std::next(newBlocks.begin()) == newBlocks.end()) {
// Run the result attribute handler on the terminator operands.
Operation *firstBlockTerminator = firstNewBlock->getTerminator();
builder.setInsertionPoint(firstBlockTerminator);
if (call && callable)
handleResultImpl(interface, builder, call, callable,
firstBlockTerminator->getOperands());
// Have the interface handle the terminator of this block.
auto *firstBlockTerminator = firstNewBlock->getTerminator();
interface.handleTerminator(firstBlockTerminator,
llvm::to_vector<6>(resultsToReplace));
firstBlockTerminator->erase();
@ -218,6 +315,12 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
resultToRepl.value().getLoc()));
}
// Run the result attribute handler on the post insertion block arguments.
builder.setInsertionPointToStart(postInsertBlock);
if (call && callable)
handleResultImpl(interface, builder, call, callable,
postInsertBlock->getArguments());
/// Handle the terminators for each of the new blocks.
for (auto &newBlock : newBlocks)
interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);

View File

@ -226,3 +226,40 @@ func.func @func_with_block_args_location_callee2(%arg0 : i32) {
call @func_with_block_args_location(%arg0) : (i32) -> ()
return
}
// Check that we can handle argument and result attributes.
test.conversion_func_op @handle_attr_callee_fn_multi_arg(%arg0 : i16, %arg1 : i16 {"test.handle_argument"}) -> (i16 {"test.handle_result"}, i16) {
%0 = arith.addi %arg0, %arg1 : i16
%1 = arith.subi %arg0, %arg1 : i16
"test.return"(%0, %1) : (i16, i16) -> ()
}
test.conversion_func_op @handle_attr_callee_fn(%arg0 : i32 {"test.handle_argument"}) -> (i32 {"test.handle_result"}) {
"test.return"(%arg0) : (i32) -> ()
}
// CHECK-LABEL: func @inline_handle_attr_call
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
func.func @inline_handle_attr_call(%arg0 : i16, %arg1 : i16) -> (i16, i16) {
// CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[ARG1]]) : (i16) -> i16
// CHECK: %[[SUM:.*]] = arith.addi %[[ARG0]], %[[CHANGE_INPUT]]
// CHECK: %[[DIFF:.*]] = arith.subi %[[ARG0]], %[[CHANGE_INPUT]]
// CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[SUM]]) : (i16) -> i16
// CHECK-NEXT: return %[[CHANGE_RESULT]], %[[DIFF]]
%res0, %res1 = "test.conversion_call_op"(%arg0, %arg1) { callee=@handle_attr_callee_fn_multi_arg } : (i16, i16) -> (i16, i16)
return %res0, %res1 : i16, i16
}
// CHECK-LABEL: func @inline_convert_and_handle_attr_call
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
func.func @inline_convert_and_handle_attr_call(%arg0 : i16) -> (i16) {
// CHECK: %[[CAST_INPUT:.*]] = "test.cast"(%[[ARG0]]) : (i16) -> i32
// CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[CAST_INPUT]]) : (i32) -> i32
// CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[CHANGE_INPUT]]) : (i32) -> i32
// CHECK: %[[CAST_RESULT:.*]] = "test.cast"(%[[CHANGE_RESULT]]) : (i32) -> i16
// CHECK: return %[[CAST_RESULT]]
%res = "test.conversion_call_op"(%arg0) { callee=@handle_attr_callee_fn } : (i16) -> (i16)
return %res : i16
}

View File

@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
@ -354,6 +355,24 @@ struct TestInlinerInterface : public DialectInlinerInterface {
return builder.create<TestCastOp>(conversionLoc, resultType, input);
}
Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
Value argument, Type targetType,
DictionaryAttr argumentAttrs) const final {
if (!argumentAttrs.contains("test.handle_argument"))
return argument;
return builder.create<TestTypeChangerOp>(call->getLoc(), targetType,
argument);
}
Value handleResult(OpBuilder &builder, Operation *call, Operation *callable,
Value result, Type targetType,
DictionaryAttr resultAttrs) const final {
if (!resultAttrs.contains("test.handle_result"))
return result;
return builder.create<TestTypeChangerOp>(call->getLoc(), targetType,
result);
}
void processInlinedCallBlocks(
Operation *call,
iterator_range<Region::iterator> inlinedBlocks) const final {
@ -650,6 +669,29 @@ LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
//===----------------------------------------------------------------------===//
// ConversionFuncOp
//===----------------------------------------------------------------------===//
ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void ConversionFuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
}
//===----------------------------------------------------------------------===//
// TestFoldToCallOp
//===----------------------------------------------------------------------===//

View File

@ -14,6 +14,7 @@ include "TestInterfaces.td"
include "mlir/Dialect/DLTI/DLTIBase.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/PatternBase.td"
@ -482,6 +483,66 @@ def ConversionCallOp : TEST_Op<"conversion_call_op",
}];
}
def ConversionFuncOp : TEST_Op<"conversion_func_op", [CallableOpInterface,
FunctionOpInterface]> {
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region AnyRegion:$body);
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
// CallableOpInterface
//===------------------------------------------------------------------===//
/// Returns the region on the current operation that is callable. This may
/// return null in the case of an external callable object, e.g. an external
/// function.
::mlir::Region *getCallableRegion() {
return isExternal() ? nullptr : &getBody();
}
/// Returns the results types that the callable region produces when
/// executed.
::mlir::ArrayRef<::mlir::Type> getCallableResults() {
return getFunctionType().getResults();
}
/// Returns the argument attributes for all callable region arguments or
/// null if there are none.
::mlir::ArrayAttr getCallableArgAttrs() {
return getArgAttrs().value_or(nullptr);
}
/// Returns the result attributes for all callable region results or
/// null if there are none.
::mlir::ArrayAttr getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this async function.
::mlir::ArrayRef<::mlir::Type> getArgumentTypes() {
return getFunctionType().getInputs();
}
/// Returns the result types of this async function.
::mlir::ArrayRef<::mlir::Type> getResultTypes() {
return getFunctionType().getResults();
}
/// Returns the number of results of this async function
unsigned getNumResults() {return getResultTypes().size();}
}];
let hasCustomAssemblyFormat = 1;
}
def FunctionalRegionOp : TEST_Op<"functional_region_op",
[CallableOpInterface]> {
let regions = (region AnyRegion:$body);
@ -492,6 +553,12 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
::llvm::ArrayRef<::mlir::Type> getCallableResults() {
return getType().cast<::mlir::FunctionType>().getResults();
}
::mlir::ArrayAttr getCallableArgAttrs() {
return nullptr;
}
::mlir::ArrayAttr getCallableResAttrs() {
return nullptr;
}
}];
}