[mlir] Argument and result attribute handling during inlining.
The revision adds the handleArgument and handleResult handlers that allow users of the inlining interface to implement argument and result conversions that take argument and result attributes into account. The motivating use cases for this revision are taken from the LLVM dialect inliner, which has to copy arguments that are marked as byval and that also has to consider zeroext / signext when converting integers. All type conversions are currently handled by the materializeCallConversion hook. It runs before isLegalToInline and supports only the introduction of a single cast operation since it may have to rollback. The new handlers run shortly before and after inlining and cannot fail. As a result, they can introduce more complex ir such as copying a struct argument. At the moment, the new hooks cannot be used to perform type conversions since all type conversions have to be done using the materializeCallConversion. A follow up revision will either relax this constraint or drop materializeCallConversion in favor of the new and more flexible handlers. The revision also extends the CallableOpInterface to provide access to the argument and result attributes if available. Reviewed By: rriddle, Dinistro Differential Revision: https://reviews.llvm.org/D145582
This commit is contained in:
parent
9c16eef1ec
commit
f809eb4db2
|
@ -731,6 +731,8 @@ interface section goes as follows:
|
|||
* `CallableOpInterface` - Used to represent the target callee of call.
|
||||
- `Region * getCallableRegion()`
|
||||
- `ArrayRef<Type> getCallableResults()`
|
||||
- `ArrayAttr getCallableArgAttrs()`
|
||||
- `ArrayAttr getCallableResAttrs()`
|
||||
|
||||
##### RegionKindInterfaces
|
||||
|
||||
|
|
|
@ -169,6 +169,18 @@ Region *FuncOp::getCallableRegion() { return &getBody(); }
|
|||
/// executed.
|
||||
ArrayRef<Type> FuncOp::getCallableResults() { return getType().getResults(); }
|
||||
|
||||
/// Returns the argument attributes for all callable region arguments or
|
||||
/// null if there are none.
|
||||
ArrayAttr FuncOp::getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
/// Returns the result attributes for all callable region results or
|
||||
/// null if there are none.
|
||||
ArrayAttr FuncOp::getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
// ....
|
||||
|
||||
/// Return the callee of the generic call operation, this is required by the
|
||||
|
|
|
@ -307,6 +307,18 @@ llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
|
|||
return getFunctionType().getResults();
|
||||
}
|
||||
|
||||
/// Returns the argument attributes for all callable region arguments or
|
||||
/// null if there are none.
|
||||
ArrayAttr FuncOp::getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
/// Returns the result attributes for all callable region results or
|
||||
/// null if there are none.
|
||||
ArrayAttr FuncOp::getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -307,6 +307,18 @@ llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
|
|||
return getFunctionType().getResults();
|
||||
}
|
||||
|
||||
/// Returns the argument attributes for all callable region arguments or
|
||||
/// null if there are none.
|
||||
ArrayAttr FuncOp::getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
/// Returns the result attributes for all callable region results or
|
||||
/// null if there are none.
|
||||
ArrayAttr FuncOp::getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -307,6 +307,18 @@ llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
|
|||
return getFunctionType().getResults();
|
||||
}
|
||||
|
||||
/// Returns the argument attributes for all callable region arguments or
|
||||
/// null if there are none.
|
||||
ArrayAttr FuncOp::getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
/// Returns the result attributes for all callable region results or
|
||||
/// null if there are none.
|
||||
ArrayAttr FuncOp::getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -336,6 +336,18 @@ llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
|
|||
return getFunctionType().getResults();
|
||||
}
|
||||
|
||||
/// Returns the argument attributes for all callable region arguments or
|
||||
/// null if there are none.
|
||||
ArrayAttr FuncOp::getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
/// Returns the result attributes for all callable region results or
|
||||
/// null if there are none.
|
||||
ArrayAttr FuncOp::getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -168,6 +168,18 @@ def Async_FuncOp : Async_Op<"func",
|
|||
ArrayRef<Type> getCallableResults() { return getFunctionType()
|
||||
.getResults(); }
|
||||
|
||||
/// Returns the argument attributes for all callable region arguments or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
/// Returns the result attributes for all callable region results or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// FunctionOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
|
|
@ -299,6 +299,18 @@ def FuncOp : Func_Op<"func", [
|
|||
/// executed.
|
||||
ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
|
||||
|
||||
/// Returns the argument attributes for all callable region arguments or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
/// Returns the result attributes for all callable region results or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// FunctionOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
|
|
@ -1583,6 +1583,10 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
|
|||
/// Returns the result types of this function.
|
||||
ArrayRef<Type> getResultTypes() { return getFunctionType().getReturnTypes(); }
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// CallableOpInterface
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the callable region, which is the function body. If the function
|
||||
/// is external, returns null.
|
||||
Region *getCallableRegion();
|
||||
|
@ -1596,6 +1600,17 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
|
|||
return getFunctionType().getReturnTypes();
|
||||
}
|
||||
|
||||
/// Returns the argument attributes for all callable region arguments or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
/// Returns the result attributes for all callable region results or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
|
|
@ -73,6 +73,18 @@ def MLProgram_FuncOp : MLProgram_Op<"func", [
|
|||
/// executed.
|
||||
ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
|
||||
|
||||
/// Returns the argument attributes for all callable region arguments or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
/// Returns the result attributes for all callable region results or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// FunctionOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
@ -422,6 +434,18 @@ def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [
|
|||
/// executed.
|
||||
ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
|
||||
|
||||
/// Returns the argument attributes for all callable region arguments or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
/// Returns the result attributes for all callable region results or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// FunctionOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
|
|
@ -1149,6 +1149,18 @@ def Shape_FuncOp : Shape_Op<"func",
|
|||
return getFunctionType().getResults();
|
||||
}
|
||||
|
||||
/// Returns the argument attributes for all callable region arguments or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
/// Returns the result attributes for all callable region results or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// FunctionOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
|
|
@ -394,6 +394,12 @@ def NamedSequenceOp : TransformDialectOp<"named_sequence",
|
|||
::llvm::ArrayRef<::mlir::Type> getCallableResults() {
|
||||
return getFunctionType().getResults();
|
||||
}
|
||||
::mlir::ArrayAttr getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
::mlir::ArrayAttr getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -84,6 +84,18 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
|
|||
}],
|
||||
"::llvm::ArrayRef<::mlir::Type>", "getCallableResults"
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Returns the argument attributes for all callable region arguments or
|
||||
null if there are none.
|
||||
}],
|
||||
"::mlir::ArrayAttr", "getCallableArgAttrs"
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Returns the result attributes for all callable region results or null
|
||||
if there are none.
|
||||
}],
|
||||
"::mlir::ArrayAttr", "getCallableResAttrs"
|
||||
>
|
||||
];
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#ifndef MLIR_TRANSFORMS_INLININGUTILS_H
|
||||
#define MLIR_TRANSFORMS_INLININGUTILS_H
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/DialectInterface.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Region.h"
|
||||
|
@ -141,6 +142,40 @@ public:
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
/// Hook to transform the call arguments before using them to replace the
|
||||
/// callee arguments. It returns the transformation result or `argument`
|
||||
/// itself if the hook did not change anything. The type of the returned value
|
||||
/// has to match `targetType`, and the `argumentAttrs` dictionary is non-null
|
||||
/// even if no attribute is present. The hook is called after converting the
|
||||
/// callsite argument types using the materializeCallConversion callback, and
|
||||
/// right before inlining the callee region. Any operations created using the
|
||||
/// provided `builder` are inserted right before the inlined callee region.
|
||||
/// Example use cases are the insertion of copies for by value arguments, or
|
||||
/// integer conversions that require signedness information.
|
||||
virtual Value handleArgument(OpBuilder &builder, Operation *call,
|
||||
Operation *callable, Value argument,
|
||||
Type targetType,
|
||||
DictionaryAttr argumentAttrs) const {
|
||||
return argument;
|
||||
}
|
||||
|
||||
/// Hook to transform the callee results before using them to replace the call
|
||||
/// results. It returns the transformation result or the `result` itself if
|
||||
/// the hook did not change anything. The type of the returned values has to
|
||||
/// match `targetType`, and the `resultAttrs` dictionary is non-null even if
|
||||
/// no attribute is present. The hook is called right before handling
|
||||
/// terminators, and obtains the callee result before converting its type
|
||||
/// using the `materializeCallConversion` callback. Any operations created
|
||||
/// using the provided `builder` are inserted right after the inlined callee
|
||||
/// region. Example use cases are the insertion of copies for by value results
|
||||
/// or integer conversions that require signedness information.
|
||||
/// NOTE: This hook is invoked after inlining the `callable` region.
|
||||
virtual Value handleResult(OpBuilder &builder, Operation *call,
|
||||
Operation *callable, Value result, Type targetType,
|
||||
DictionaryAttr resultAttrs) const {
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Process a set of blocks that have been inlined for a call. This callback
|
||||
/// is invoked before inlined terminator operations have been processed.
|
||||
virtual void processInlinedCallBlocks(
|
||||
|
@ -183,6 +218,15 @@ public:
|
|||
virtual void handleTerminator(Operation *op, Block *newDest) const;
|
||||
virtual void handleTerminator(Operation *op,
|
||||
ArrayRef<Value> valuesToRepl) const;
|
||||
|
||||
virtual Value handleArgument(OpBuilder &builder, Operation *call,
|
||||
Operation *callable, Value argument,
|
||||
Type targetType,
|
||||
DictionaryAttr argumentAttrs) const;
|
||||
virtual Value handleResult(OpBuilder &builder, Operation *call,
|
||||
Operation *callable, Value result, Type targetType,
|
||||
DictionaryAttr resultAttrs) const;
|
||||
|
||||
virtual void processInlinedCallBlocks(
|
||||
Operation *call, iterator_range<Region::iterator> inlinedBlocks) const;
|
||||
};
|
||||
|
|
|
@ -2469,6 +2469,16 @@ ArrayRef<Type> spirv::FuncOp::getCallableResults() {
|
|||
return getFunctionType().getResults();
|
||||
}
|
||||
|
||||
// CallableOpInterface
|
||||
::mlir::ArrayAttr spirv::FuncOp::getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
// CallableOpInterface
|
||||
::mlir::ArrayAttr spirv::FuncOp::getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.FunctionCall
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -103,6 +103,26 @@ void InlinerInterface::handleTerminator(Operation *op,
|
|||
handler->handleTerminator(op, valuesToRepl);
|
||||
}
|
||||
|
||||
Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call,
|
||||
Operation *callable, Value argument,
|
||||
Type targetType,
|
||||
DictionaryAttr argumentAttrs) const {
|
||||
auto *handler = getInterfaceFor(callable);
|
||||
assert(handler && "expected valid dialect handler");
|
||||
return handler->handleArgument(builder, call, callable, argument, targetType,
|
||||
argumentAttrs);
|
||||
}
|
||||
|
||||
Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call,
|
||||
Operation *callable, Value result,
|
||||
Type targetType,
|
||||
DictionaryAttr resultAttrs) const {
|
||||
auto *handler = getInterfaceFor(callable);
|
||||
assert(handler && "expected valid dialect handler");
|
||||
return handler->handleResult(builder, call, callable, result, targetType,
|
||||
resultAttrs);
|
||||
}
|
||||
|
||||
void InlinerInterface::processInlinedCallBlocks(
|
||||
Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
|
||||
auto *handler = getInterfaceFor(call);
|
||||
|
@ -141,6 +161,71 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
|
|||
// Inline Methods
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
|
||||
CallOpInterface call,
|
||||
CallableOpInterface callable,
|
||||
IRMapping &mapper) {
|
||||
// Unpack the argument attributes if there are any.
|
||||
SmallVector<DictionaryAttr> argAttrs(
|
||||
callable.getCallableRegion()->getNumArguments(),
|
||||
builder.getDictionaryAttr({}));
|
||||
if (ArrayAttr arrayAttr = callable.getCallableArgAttrs()) {
|
||||
assert(arrayAttr.size() == argAttrs.size());
|
||||
for (auto [idx, attr] : llvm::enumerate(arrayAttr))
|
||||
argAttrs[idx] = cast<DictionaryAttr>(attr);
|
||||
}
|
||||
|
||||
// Run the argument attribute handler for the given argument and attribute.
|
||||
for (auto [blockArg, argAttr] :
|
||||
llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
|
||||
Value newArgument = interface.handleArgument(builder, call, callable,
|
||||
mapper.lookup(blockArg),
|
||||
blockArg.getType(), argAttr);
|
||||
assert(newArgument.getType() == blockArg.getType() &&
|
||||
"expected the handled argument type to match the target type");
|
||||
|
||||
// Update the mapping to point the new argument returned by the handler.
|
||||
mapper.map(blockArg, newArgument);
|
||||
}
|
||||
}
|
||||
|
||||
static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
|
||||
CallOpInterface call, CallableOpInterface callable,
|
||||
ValueRange results) {
|
||||
// Unpack the result attributes if there are any.
|
||||
SmallVector<DictionaryAttr> resAttrs(results.size(),
|
||||
builder.getDictionaryAttr({}));
|
||||
if (ArrayAttr arrayAttr = callable.getCallableResAttrs()) {
|
||||
assert(arrayAttr.size() == resAttrs.size());
|
||||
for (auto [idx, attr] : llvm::enumerate(arrayAttr))
|
||||
resAttrs[idx] = cast<DictionaryAttr>(attr);
|
||||
}
|
||||
|
||||
// Run the result attribute handler for the given result and attribute.
|
||||
SmallVector<DictionaryAttr> resultAttributes;
|
||||
for (auto [result, resAttr] : llvm::zip(results, resAttrs)) {
|
||||
// Store the original result users before running the handler.
|
||||
DenseSet<Operation *> resultUsers;
|
||||
for (Operation *user : result.getUsers())
|
||||
resultUsers.insert(user);
|
||||
|
||||
// TODO: Use the type of the call result to replace once the hook can be
|
||||
// used for type conversions. At the moment, all type conversions have to be
|
||||
// done using materializeCallConversion.
|
||||
Type targetType = result.getType();
|
||||
|
||||
Value newResult = interface.handleResult(builder, call, callable, result,
|
||||
targetType, resAttr);
|
||||
assert(newResult.getType() == targetType &&
|
||||
"expected the handled result type to match the target type");
|
||||
|
||||
// Replace the result uses except for the ones introduce by the handler.
|
||||
result.replaceUsesWithIf(newResult, [&](OpOperand &operand) {
|
||||
return resultUsers.count(operand.getOwner());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
|
||||
Block::iterator inlinePoint, IRMapping &mapper,
|
||||
|
@ -166,6 +251,12 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
|
|||
mapper))
|
||||
return failure();
|
||||
|
||||
// Run the argument attribute handler before inlining the callable region.
|
||||
OpBuilder builder(inlineBlock, inlinePoint);
|
||||
auto callable = dyn_cast<CallableOpInterface>(src->getParentOp());
|
||||
if (call && callable)
|
||||
handleArgumentImpl(interface, builder, call, callable, mapper);
|
||||
|
||||
// Check to see if the region is being cloned, or moved inline. In either
|
||||
// case, move the new blocks after the 'insertBlock' to improve IR
|
||||
// readability.
|
||||
|
@ -199,8 +290,14 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
|
|||
|
||||
// Handle the case where only a single block was inlined.
|
||||
if (std::next(newBlocks.begin()) == newBlocks.end()) {
|
||||
// Run the result attribute handler on the terminator operands.
|
||||
Operation *firstBlockTerminator = firstNewBlock->getTerminator();
|
||||
builder.setInsertionPoint(firstBlockTerminator);
|
||||
if (call && callable)
|
||||
handleResultImpl(interface, builder, call, callable,
|
||||
firstBlockTerminator->getOperands());
|
||||
|
||||
// Have the interface handle the terminator of this block.
|
||||
auto *firstBlockTerminator = firstNewBlock->getTerminator();
|
||||
interface.handleTerminator(firstBlockTerminator,
|
||||
llvm::to_vector<6>(resultsToReplace));
|
||||
firstBlockTerminator->erase();
|
||||
|
@ -218,6 +315,12 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
|
|||
resultToRepl.value().getLoc()));
|
||||
}
|
||||
|
||||
// Run the result attribute handler on the post insertion block arguments.
|
||||
builder.setInsertionPointToStart(postInsertBlock);
|
||||
if (call && callable)
|
||||
handleResultImpl(interface, builder, call, callable,
|
||||
postInsertBlock->getArguments());
|
||||
|
||||
/// Handle the terminators for each of the new blocks.
|
||||
for (auto &newBlock : newBlocks)
|
||||
interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
|
||||
|
|
|
@ -226,3 +226,40 @@ func.func @func_with_block_args_location_callee2(%arg0 : i32) {
|
|||
call @func_with_block_args_location(%arg0) : (i32) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Check that we can handle argument and result attributes.
|
||||
test.conversion_func_op @handle_attr_callee_fn_multi_arg(%arg0 : i16, %arg1 : i16 {"test.handle_argument"}) -> (i16 {"test.handle_result"}, i16) {
|
||||
%0 = arith.addi %arg0, %arg1 : i16
|
||||
%1 = arith.subi %arg0, %arg1 : i16
|
||||
"test.return"(%0, %1) : (i16, i16) -> ()
|
||||
}
|
||||
test.conversion_func_op @handle_attr_callee_fn(%arg0 : i32 {"test.handle_argument"}) -> (i32 {"test.handle_result"}) {
|
||||
"test.return"(%arg0) : (i32) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @inline_handle_attr_call
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
|
||||
func.func @inline_handle_attr_call(%arg0 : i16, %arg1 : i16) -> (i16, i16) {
|
||||
|
||||
// CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[ARG1]]) : (i16) -> i16
|
||||
// CHECK: %[[SUM:.*]] = arith.addi %[[ARG0]], %[[CHANGE_INPUT]]
|
||||
// CHECK: %[[DIFF:.*]] = arith.subi %[[ARG0]], %[[CHANGE_INPUT]]
|
||||
// CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[SUM]]) : (i16) -> i16
|
||||
// CHECK-NEXT: return %[[CHANGE_RESULT]], %[[DIFF]]
|
||||
%res0, %res1 = "test.conversion_call_op"(%arg0, %arg1) { callee=@handle_attr_callee_fn_multi_arg } : (i16, i16) -> (i16, i16)
|
||||
return %res0, %res1 : i16, i16
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @inline_convert_and_handle_attr_call
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
|
||||
func.func @inline_convert_and_handle_attr_call(%arg0 : i16) -> (i16) {
|
||||
|
||||
// CHECK: %[[CAST_INPUT:.*]] = "test.cast"(%[[ARG0]]) : (i16) -> i32
|
||||
// CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[CAST_INPUT]]) : (i32) -> i32
|
||||
// CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[CHANGE_INPUT]]) : (i32) -> i32
|
||||
// CHECK: %[[CAST_RESULT:.*]] = "test.cast"(%[[CHANGE_RESULT]]) : (i32) -> i16
|
||||
// CHECK: return %[[CAST_RESULT]]
|
||||
%res = "test.conversion_call_op"(%arg0) { callee=@handle_attr_callee_fn } : (i16) -> (i16)
|
||||
return %res : i16
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/ExtensibleDialect.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -354,6 +355,24 @@ struct TestInlinerInterface : public DialectInlinerInterface {
|
|||
return builder.create<TestCastOp>(conversionLoc, resultType, input);
|
||||
}
|
||||
|
||||
Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
|
||||
Value argument, Type targetType,
|
||||
DictionaryAttr argumentAttrs) const final {
|
||||
if (!argumentAttrs.contains("test.handle_argument"))
|
||||
return argument;
|
||||
return builder.create<TestTypeChangerOp>(call->getLoc(), targetType,
|
||||
argument);
|
||||
}
|
||||
|
||||
Value handleResult(OpBuilder &builder, Operation *call, Operation *callable,
|
||||
Value result, Type targetType,
|
||||
DictionaryAttr resultAttrs) const final {
|
||||
if (!resultAttrs.contains("test.handle_result"))
|
||||
return result;
|
||||
return builder.create<TestTypeChangerOp>(call->getLoc(), targetType,
|
||||
result);
|
||||
}
|
||||
|
||||
void processInlinedCallBlocks(
|
||||
Operation *call,
|
||||
iterator_range<Region::iterator> inlinedBlocks) const final {
|
||||
|
@ -650,6 +669,29 @@ LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConversionFuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
auto buildFuncType =
|
||||
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
|
||||
function_interface_impl::VariadicFlag,
|
||||
std::string &) { return builder.getFunctionType(argTypes, results); };
|
||||
|
||||
return function_interface_impl::parseFunctionOp(
|
||||
parser, result, /*allowVariadic=*/false,
|
||||
getFunctionTypeAttrName(result.name), buildFuncType,
|
||||
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
|
||||
}
|
||||
|
||||
void ConversionFuncOp::print(OpAsmPrinter &p) {
|
||||
function_interface_impl::printFunctionOp(
|
||||
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
|
||||
getArgAttrsAttrName(), getResAttrsAttrName());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestFoldToCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -14,6 +14,7 @@ include "TestInterfaces.td"
|
|||
include "mlir/Dialect/DLTI/DLTIBase.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
include "mlir/IR/FunctionInterfaces.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/PatternBase.td"
|
||||
|
@ -482,6 +483,66 @@ def ConversionCallOp : TEST_Op<"conversion_call_op",
|
|||
}];
|
||||
}
|
||||
|
||||
def ConversionFuncOp : TEST_Op<"conversion_func_op", [CallableOpInterface,
|
||||
FunctionOpInterface]> {
|
||||
let arguments = (ins SymbolNameAttr:$sym_name,
|
||||
TypeAttrOf<FunctionType>:$function_type,
|
||||
OptionalAttr<DictArrayAttr>:$arg_attrs,
|
||||
OptionalAttr<DictArrayAttr>:$res_attrs,
|
||||
OptionalAttr<StrAttr>:$sym_visibility);
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
//===------------------------------------------------------------------===//
|
||||
// CallableOpInterface
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the region on the current operation that is callable. This may
|
||||
/// return null in the case of an external callable object, e.g. an external
|
||||
/// function.
|
||||
::mlir::Region *getCallableRegion() {
|
||||
return isExternal() ? nullptr : &getBody();
|
||||
}
|
||||
|
||||
/// Returns the results types that the callable region produces when
|
||||
/// executed.
|
||||
::mlir::ArrayRef<::mlir::Type> getCallableResults() {
|
||||
return getFunctionType().getResults();
|
||||
}
|
||||
|
||||
/// Returns the argument attributes for all callable region arguments or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableArgAttrs() {
|
||||
return getArgAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
/// Returns the result attributes for all callable region results or
|
||||
/// null if there are none.
|
||||
::mlir::ArrayAttr getCallableResAttrs() {
|
||||
return getResAttrs().value_or(nullptr);
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// FunctionOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the argument types of this async function.
|
||||
::mlir::ArrayRef<::mlir::Type> getArgumentTypes() {
|
||||
return getFunctionType().getInputs();
|
||||
}
|
||||
|
||||
/// Returns the result types of this async function.
|
||||
::mlir::ArrayRef<::mlir::Type> getResultTypes() {
|
||||
return getFunctionType().getResults();
|
||||
}
|
||||
|
||||
/// Returns the number of results of this async function
|
||||
unsigned getNumResults() {return getResultTypes().size();}
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def FunctionalRegionOp : TEST_Op<"functional_region_op",
|
||||
[CallableOpInterface]> {
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
@ -492,6 +553,12 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
|
|||
::llvm::ArrayRef<::mlir::Type> getCallableResults() {
|
||||
return getType().cast<::mlir::FunctionType>().getResults();
|
||||
}
|
||||
::mlir::ArrayAttr getCallableArgAttrs() {
|
||||
return nullptr;
|
||||
}
|
||||
::mlir::ArrayAttr getCallableResAttrs() {
|
||||
return nullptr;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user