diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md index 6bb507013863..b51adec4fc4f 100644 --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -731,6 +731,8 @@ interface section goes as follows: * `CallableOpInterface` - Used to represent the target callee of call. - `Region * getCallableRegion()` - `ArrayRef getCallableResults()` + - `ArrayAttr getCallableArgAttrs()` + - `ArrayAttr getCallableResAttrs()` ##### RegionKindInterfaces diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md index 77a52163774f..f462274fa592 100644 --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -169,6 +169,18 @@ Region *FuncOp::getCallableRegion() { return &getBody(); } /// executed. ArrayRef FuncOp::getCallableResults() { return getType().getResults(); } +/// Returns the argument attributes for all callable region arguments or +/// null if there are none. +ArrayAttr FuncOp::getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); +} + +/// Returns the result attributes for all callable region results or +/// null if there are none. +ArrayAttr FuncOp::getCallableResAttrs() { + return getResAttrs().value_or(nullptr); +} + // .... /// Return the callee of the generic call operation, this is required by the diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index 17a42d69c8f4..f5258eb5cff1 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -307,6 +307,18 @@ llvm::ArrayRef FuncOp::getCallableResults() { return getFunctionType().getResults(); } +/// Returns the argument attributes for all callable region arguments or +/// null if there are none. +ArrayAttr FuncOp::getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); +} + +/// Returns the result attributes for all callable region results or +/// null if there are none. +ArrayAttr FuncOp::getCallableResAttrs() { + return getResAttrs().value_or(nullptr); +} + //===----------------------------------------------------------------------===// // GenericCallOp //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index 77ceb636e17f..a959969c0449 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -307,6 +307,18 @@ llvm::ArrayRef FuncOp::getCallableResults() { return getFunctionType().getResults(); } +/// Returns the argument attributes for all callable region arguments or +/// null if there are none. +ArrayAttr FuncOp::getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); +} + +/// Returns the result attributes for all callable region results or +/// null if there are none. +ArrayAttr FuncOp::getCallableResAttrs() { + return getResAttrs().value_or(nullptr); +} + //===----------------------------------------------------------------------===// // GenericCallOp //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp index 77ceb636e17f..a959969c0449 100644 --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -307,6 +307,18 @@ llvm::ArrayRef FuncOp::getCallableResults() { return getFunctionType().getResults(); } +/// Returns the argument attributes for all callable region arguments or +/// null if there are none. +ArrayAttr FuncOp::getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); +} + +/// Returns the result attributes for all callable region results or +/// null if there are none. +ArrayAttr FuncOp::getCallableResAttrs() { + return getResAttrs().value_or(nullptr); +} + //===----------------------------------------------------------------------===// // GenericCallOp //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index 188b94fc2dfe..d332411b63bb 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -336,6 +336,18 @@ llvm::ArrayRef FuncOp::getCallableResults() { return getFunctionType().getResults(); } +/// Returns the argument attributes for all callable region arguments or +/// null if there are none. +ArrayAttr FuncOp::getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); +} + +/// Returns the result attributes for all callable region results or +/// null if there are none. +ArrayAttr FuncOp::getCallableResAttrs() { + return getResAttrs().value_or(nullptr); +} + //===----------------------------------------------------------------------===// // GenericCallOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td index 2cf5ee810b7a..30147b8b6a30 100644 --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -168,6 +168,18 @@ def Async_FuncOp : Async_Op<"func", ArrayRef getCallableResults() { return getFunctionType() .getResults(); } + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + //===------------------------------------------------------------------===// // FunctionOpInterface Methods //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td index 45ec8a9e0b7e..1a06d6533b2d 100644 --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -299,6 +299,18 @@ def FuncOp : Func_Op<"func", [ /// executed. ArrayRef getCallableResults() { return getFunctionType().getResults(); } + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + //===------------------------------------------------------------------===// // FunctionOpInterface Methods //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index c2bb2f34a463..1bbc32f3d291 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1583,6 +1583,10 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ /// Returns the result types of this function. ArrayRef getResultTypes() { return getFunctionType().getReturnTypes(); } + //===------------------------------------------------------------------===// + // CallableOpInterface + //===------------------------------------------------------------------===// + /// Returns the callable region, which is the function body. If the function /// is external, returns null. Region *getCallableRegion(); @@ -1596,6 +1600,17 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ return getFunctionType().getReturnTypes(); } + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } }]; let hasCustomAssemblyFormat = 1; diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td index db6c7733130c..7984b9744513 100644 --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td @@ -73,6 +73,18 @@ def MLProgram_FuncOp : MLProgram_Op<"func", [ /// executed. ArrayRef getCallableResults() { return getFunctionType().getResults(); } + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + //===------------------------------------------------------------------===// // FunctionOpInterface Methods //===------------------------------------------------------------------===// @@ -422,6 +434,18 @@ def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [ /// executed. ArrayRef getCallableResults() { return getFunctionType().getResults(); } + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + //===------------------------------------------------------------------===// // FunctionOpInterface Methods //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index ae84b07acab2..47918b46dddc 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -1149,6 +1149,18 @@ def Shape_FuncOp : Shape_Op<"func", return getFunctionType().getResults(); } + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + //===------------------------------------------------------------------===// // FunctionOpInterface Methods //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 3ffc3f71433c..46dea7454635 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -394,6 +394,12 @@ def NamedSequenceOp : TransformDialectOp<"named_sequence", ::llvm::ArrayRef<::mlir::Type> getCallableResults() { return getFunctionType().getResults(); } + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } }]; } diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td index 96540675f833..cd37222cbc27 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -84,6 +84,18 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> { }], "::llvm::ArrayRef<::mlir::Type>", "getCallableResults" >, + InterfaceMethod<[{ + Returns the argument attributes for all callable region arguments or + null if there are none. + }], + "::mlir::ArrayAttr", "getCallableArgAttrs" + >, + InterfaceMethod<[{ + Returns the result attributes for all callable region results or null + if there are none. + }], + "::mlir::ArrayAttr", "getCallableResAttrs" + > ]; } diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h index 241983ef8c3d..63aba6a08e39 100644 --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -13,6 +13,7 @@ #ifndef MLIR_TRANSFORMS_INLININGUTILS_H #define MLIR_TRANSFORMS_INLININGUTILS_H +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectInterface.h" #include "mlir/IR/Location.h" #include "mlir/IR/Region.h" @@ -141,6 +142,40 @@ public: return nullptr; } + /// Hook to transform the call arguments before using them to replace the + /// callee arguments. It returns the transformation result or `argument` + /// itself if the hook did not change anything. The type of the returned value + /// has to match `targetType`, and the `argumentAttrs` dictionary is non-null + /// even if no attribute is present. The hook is called after converting the + /// callsite argument types using the materializeCallConversion callback, and + /// right before inlining the callee region. Any operations created using the + /// provided `builder` are inserted right before the inlined callee region. + /// Example use cases are the insertion of copies for by value arguments, or + /// integer conversions that require signedness information. + virtual Value handleArgument(OpBuilder &builder, Operation *call, + Operation *callable, Value argument, + Type targetType, + DictionaryAttr argumentAttrs) const { + return argument; + } + + /// Hook to transform the callee results before using them to replace the call + /// results. It returns the transformation result or the `result` itself if + /// the hook did not change anything. The type of the returned values has to + /// match `targetType`, and the `resultAttrs` dictionary is non-null even if + /// no attribute is present. The hook is called right before handling + /// terminators, and obtains the callee result before converting its type + /// using the `materializeCallConversion` callback. Any operations created + /// using the provided `builder` are inserted right after the inlined callee + /// region. Example use cases are the insertion of copies for by value results + /// or integer conversions that require signedness information. + /// NOTE: This hook is invoked after inlining the `callable` region. + virtual Value handleResult(OpBuilder &builder, Operation *call, + Operation *callable, Value result, Type targetType, + DictionaryAttr resultAttrs) const { + return result; + } + /// Process a set of blocks that have been inlined for a call. This callback /// is invoked before inlined terminator operations have been processed. virtual void processInlinedCallBlocks( @@ -183,6 +218,15 @@ public: virtual void handleTerminator(Operation *op, Block *newDest) const; virtual void handleTerminator(Operation *op, ArrayRef valuesToRepl) const; + + virtual Value handleArgument(OpBuilder &builder, Operation *call, + Operation *callable, Value argument, + Type targetType, + DictionaryAttr argumentAttrs) const; + virtual Value handleResult(OpBuilder &builder, Operation *call, + Operation *callable, Value result, Type targetType, + DictionaryAttr resultAttrs) const; + virtual void processInlinedCallBlocks( Operation *call, iterator_range inlinedBlocks) const; }; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index f6865b410709..bb3ad91ce620 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2469,6 +2469,16 @@ ArrayRef spirv::FuncOp::getCallableResults() { return getFunctionType().getResults(); } +// CallableOpInterface +::mlir::ArrayAttr spirv::FuncOp::getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); +} + +// CallableOpInterface +::mlir::ArrayAttr spirv::FuncOp::getCallableResAttrs() { + return getResAttrs().value_or(nullptr); +} + //===----------------------------------------------------------------------===// // spirv.FunctionCall //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index f9dc69caea47..8856fd59abf9 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -103,6 +103,26 @@ void InlinerInterface::handleTerminator(Operation *op, handler->handleTerminator(op, valuesToRepl); } +Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call, + Operation *callable, Value argument, + Type targetType, + DictionaryAttr argumentAttrs) const { + auto *handler = getInterfaceFor(callable); + assert(handler && "expected valid dialect handler"); + return handler->handleArgument(builder, call, callable, argument, targetType, + argumentAttrs); +} + +Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call, + Operation *callable, Value result, + Type targetType, + DictionaryAttr resultAttrs) const { + auto *handler = getInterfaceFor(callable); + assert(handler && "expected valid dialect handler"); + return handler->handleResult(builder, call, callable, result, targetType, + resultAttrs); +} + void InlinerInterface::processInlinedCallBlocks( Operation *call, iterator_range inlinedBlocks) const { auto *handler = getInterfaceFor(call); @@ -141,6 +161,71 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src, // Inline Methods //===----------------------------------------------------------------------===// +static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, + CallOpInterface call, + CallableOpInterface callable, + IRMapping &mapper) { + // Unpack the argument attributes if there are any. + SmallVector argAttrs( + callable.getCallableRegion()->getNumArguments(), + builder.getDictionaryAttr({})); + if (ArrayAttr arrayAttr = callable.getCallableArgAttrs()) { + assert(arrayAttr.size() == argAttrs.size()); + for (auto [idx, attr] : llvm::enumerate(arrayAttr)) + argAttrs[idx] = cast(attr); + } + + // Run the argument attribute handler for the given argument and attribute. + for (auto [blockArg, argAttr] : + llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) { + Value newArgument = interface.handleArgument(builder, call, callable, + mapper.lookup(blockArg), + blockArg.getType(), argAttr); + assert(newArgument.getType() == blockArg.getType() && + "expected the handled argument type to match the target type"); + + // Update the mapping to point the new argument returned by the handler. + mapper.map(blockArg, newArgument); + } +} + +static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, + CallOpInterface call, CallableOpInterface callable, + ValueRange results) { + // Unpack the result attributes if there are any. + SmallVector resAttrs(results.size(), + builder.getDictionaryAttr({})); + if (ArrayAttr arrayAttr = callable.getCallableResAttrs()) { + assert(arrayAttr.size() == resAttrs.size()); + for (auto [idx, attr] : llvm::enumerate(arrayAttr)) + resAttrs[idx] = cast(attr); + } + + // Run the result attribute handler for the given result and attribute. + SmallVector resultAttributes; + for (auto [result, resAttr] : llvm::zip(results, resAttrs)) { + // Store the original result users before running the handler. + DenseSet resultUsers; + for (Operation *user : result.getUsers()) + resultUsers.insert(user); + + // TODO: Use the type of the call result to replace once the hook can be + // used for type conversions. At the moment, all type conversions have to be + // done using materializeCallConversion. + Type targetType = result.getType(); + + Value newResult = interface.handleResult(builder, call, callable, result, + targetType, resAttr); + assert(newResult.getType() == targetType && + "expected the handled result type to match the target type"); + + // Replace the result uses except for the ones introduce by the handler. + result.replaceUsesWithIf(newResult, [&](OpOperand &operand) { + return resultUsers.count(operand.getOwner()); + }); + } +} + static LogicalResult inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper, @@ -166,6 +251,12 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, mapper)) return failure(); + // Run the argument attribute handler before inlining the callable region. + OpBuilder builder(inlineBlock, inlinePoint); + auto callable = dyn_cast(src->getParentOp()); + if (call && callable) + handleArgumentImpl(interface, builder, call, callable, mapper); + // Check to see if the region is being cloned, or moved inline. In either // case, move the new blocks after the 'insertBlock' to improve IR // readability. @@ -199,8 +290,14 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, // Handle the case where only a single block was inlined. if (std::next(newBlocks.begin()) == newBlocks.end()) { + // Run the result attribute handler on the terminator operands. + Operation *firstBlockTerminator = firstNewBlock->getTerminator(); + builder.setInsertionPoint(firstBlockTerminator); + if (call && callable) + handleResultImpl(interface, builder, call, callable, + firstBlockTerminator->getOperands()); + // Have the interface handle the terminator of this block. - auto *firstBlockTerminator = firstNewBlock->getTerminator(); interface.handleTerminator(firstBlockTerminator, llvm::to_vector<6>(resultsToReplace)); firstBlockTerminator->erase(); @@ -218,6 +315,12 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, resultToRepl.value().getLoc())); } + // Run the result attribute handler on the post insertion block arguments. + builder.setInsertionPointToStart(postInsertBlock); + if (call && callable) + handleResultImpl(interface, builder, call, callable, + postInsertBlock->getArguments()); + /// Handle the terminators for each of the new blocks. for (auto &newBlock : newBlocks) interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir index b102c210f056..f7eaa478cdbb 100644 --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -226,3 +226,40 @@ func.func @func_with_block_args_location_callee2(%arg0 : i32) { call @func_with_block_args_location(%arg0) : (i32) -> () return } + +// Check that we can handle argument and result attributes. +test.conversion_func_op @handle_attr_callee_fn_multi_arg(%arg0 : i16, %arg1 : i16 {"test.handle_argument"}) -> (i16 {"test.handle_result"}, i16) { + %0 = arith.addi %arg0, %arg1 : i16 + %1 = arith.subi %arg0, %arg1 : i16 + "test.return"(%0, %1) : (i16, i16) -> () +} +test.conversion_func_op @handle_attr_callee_fn(%arg0 : i32 {"test.handle_argument"}) -> (i32 {"test.handle_result"}) { + "test.return"(%arg0) : (i32) -> () +} + +// CHECK-LABEL: func @inline_handle_attr_call +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +func.func @inline_handle_attr_call(%arg0 : i16, %arg1 : i16) -> (i16, i16) { + + // CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[ARG1]]) : (i16) -> i16 + // CHECK: %[[SUM:.*]] = arith.addi %[[ARG0]], %[[CHANGE_INPUT]] + // CHECK: %[[DIFF:.*]] = arith.subi %[[ARG0]], %[[CHANGE_INPUT]] + // CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[SUM]]) : (i16) -> i16 + // CHECK-NEXT: return %[[CHANGE_RESULT]], %[[DIFF]] + %res0, %res1 = "test.conversion_call_op"(%arg0, %arg1) { callee=@handle_attr_callee_fn_multi_arg } : (i16, i16) -> (i16, i16) + return %res0, %res1 : i16, i16 +} + +// CHECK-LABEL: func @inline_convert_and_handle_attr_call +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +func.func @inline_convert_and_handle_attr_call(%arg0 : i16) -> (i16) { + + // CHECK: %[[CAST_INPUT:.*]] = "test.cast"(%[[ARG0]]) : (i16) -> i32 + // CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[CAST_INPUT]]) : (i32) -> i32 + // CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[CHANGE_INPUT]]) : (i32) -> i32 + // CHECK: %[[CAST_RESULT:.*]] = "test.cast"(%[[CHANGE_RESULT]]) : (i32) -> i16 + // CHECK: return %[[CAST_RESULT]] + %res = "test.conversion_call_op"(%arg0) { callee=@handle_attr_callee_fn } : (i16) -> (i16) + return %res : i16 +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 97c77b0eb489..36e2b9882be4 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" @@ -354,6 +355,24 @@ struct TestInlinerInterface : public DialectInlinerInterface { return builder.create(conversionLoc, resultType, input); } + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, Type targetType, + DictionaryAttr argumentAttrs) const final { + if (!argumentAttrs.contains("test.handle_argument")) + return argument; + return builder.create(call->getLoc(), targetType, + argument); + } + + Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, + Value result, Type targetType, + DictionaryAttr resultAttrs) const final { + if (!resultAttrs.contains("test.handle_result")) + return result; + return builder.create(call->getLoc(), targetType, + result); + } + void processInlinedCallBlocks( Operation *call, iterator_range inlinedBlocks) const final { @@ -650,6 +669,29 @@ LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +//===----------------------------------------------------------------------===// +// ConversionFuncOp +//===----------------------------------------------------------------------===// + +ParseResult ConversionFuncOp::parse(OpAsmParser &parser, + OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void ConversionFuncOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + //===----------------------------------------------------------------------===// // TestFoldToCallOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 3f642b8a87ea..e747d4bddfd7 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -14,6 +14,7 @@ include "TestInterfaces.td" include "mlir/Dialect/DLTI/DLTIBase.td" include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/IR/EnumAttr.td" +include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/PatternBase.td" @@ -482,6 +483,66 @@ def ConversionCallOp : TEST_Op<"conversion_call_op", }]; } +def ConversionFuncOp : TEST_Op<"conversion_func_op", [CallableOpInterface, + FunctionOpInterface]> { + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, + OptionalAttr:$sym_visibility); + let regions = (region AnyRegion:$body); + + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // CallableOpInterface + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { + return isExternal() ? nullptr : &getBody(); + } + + /// Returns the results types that the callable region produces when + /// executed. + ::mlir::ArrayRef<::mlir::Type> getCallableResults() { + return getFunctionType().getResults(); + } + + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this async function. + ::mlir::ArrayRef<::mlir::Type> getArgumentTypes() { + return getFunctionType().getInputs(); + } + + /// Returns the result types of this async function. + ::mlir::ArrayRef<::mlir::Type> getResultTypes() { + return getFunctionType().getResults(); + } + + /// Returns the number of results of this async function + unsigned getNumResults() {return getResultTypes().size();} + }]; + + let hasCustomAssemblyFormat = 1; +} + def FunctionalRegionOp : TEST_Op<"functional_region_op", [CallableOpInterface]> { let regions = (region AnyRegion:$body); @@ -492,6 +553,12 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op", ::llvm::ArrayRef<::mlir::Type> getCallableResults() { return getType().cast<::mlir::FunctionType>().getResults(); } + ::mlir::ArrayAttr getCallableArgAttrs() { + return nullptr; + } + ::mlir::ArrayAttr getCallableResAttrs() { + return nullptr; + } }]; }