[mlir] Move SymbolOpInterfaces "classof" check to a proper "extraClassOf" interface field
SymbolOpInterface overrides the base classof to provide support for optionally implementing the interface. This is currently placed in the extraClassDeclarations, but that is kind of awkard given that it requires underlying knowledge of how the base classof is implemented. This commit adds a proper "extraClassOf" field to allow interfaces to implement this, which abstracts away the default classof logic. Differential Revision: https://reviews.llvm.org/D140197
This commit is contained in:
parent
3e731af912
commit
5cdc2bbc75
|
@ -2048,6 +2048,13 @@ class Interface<string name> {
|
|||
// An optional code block containing extra declarations to place in both
|
||||
// the interface and trait declaration.
|
||||
code extraSharedClassDeclaration = "";
|
||||
|
||||
// An optional code block for adding additional "classof" logic. This can
|
||||
// be used to better enable "optional" interfaces, where an entity only
|
||||
// implements the interface if some dynamic characteristic holds.
|
||||
// `$_attr`/`$_op`/`$_type` may be used to refer to an instance of the
|
||||
// entity being checked.
|
||||
code extraClassOf = "";
|
||||
}
|
||||
|
||||
// AttrInterface represents an interface registered to an attribute.
|
||||
|
|
|
@ -174,28 +174,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
|
|||
return success();
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Convenience version of `getNameAttr` that returns a StringRef.
|
||||
StringRef getName() {
|
||||
return getNameAttr().getValue();
|
||||
}
|
||||
|
||||
/// Convenience version of `setName` that take a StringRef.
|
||||
void setName(StringRef name) {
|
||||
setName(StringAttr::get(this->getContext(), name));
|
||||
}
|
||||
|
||||
/// Custom classof that handles the case where the symbol is optional.
|
||||
static bool classof(Operation *op) {
|
||||
auto *opConcept = getInterfaceFor(op);
|
||||
if (!opConcept)
|
||||
return false;
|
||||
return !opConcept->isOptionalSymbol(opConcept, op) ||
|
||||
op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
|
||||
}
|
||||
}];
|
||||
|
||||
let extraTraitClassDeclaration = [{
|
||||
let extraSharedClassDeclaration = [{
|
||||
using Visibility = mlir::SymbolTable::Visibility;
|
||||
|
||||
/// Convenience version of `getNameAttr` that returns a StringRef.
|
||||
|
@ -208,6 +187,11 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
|
|||
setName(StringAttr::get($_op->getContext(), name));
|
||||
}
|
||||
}];
|
||||
|
||||
// Add additional classof checks to properly handle "optional" symbols.
|
||||
let extraClassOf = [{
|
||||
return $_op->hasAttr(::mlir::SymbolTable::getSymbolAttrName());
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -110,6 +110,12 @@ public:
|
|||
"expected value to provide interface instance");
|
||||
}
|
||||
|
||||
/// Constructor for a known concept.
|
||||
Interface(ValueT t, Concept *conceptImpl)
|
||||
: BaseType(t), conceptImpl(conceptImpl) {
|
||||
assert(!t || ConcreteType::getInterfaceFor(t) == conceptImpl);
|
||||
}
|
||||
|
||||
/// Constructor for DenseMapInfo's empty key and tombstone key.
|
||||
Interface(ValueT t, std::nullptr_t) : BaseType(t), conceptImpl(nullptr) {}
|
||||
|
||||
|
|
|
@ -44,7 +44,6 @@ public:
|
|||
None,
|
||||
Custom, // For custom placeholders
|
||||
Builder, // For the $_builder placeholder
|
||||
Op, // For the $_op placeholder
|
||||
Self, // For the $_self placeholder
|
||||
};
|
||||
|
||||
|
@ -58,7 +57,6 @@ public:
|
|||
|
||||
// Setters for builtin placeholders
|
||||
FmtContext &withBuilder(Twine subst);
|
||||
FmtContext &withOp(Twine subst);
|
||||
FmtContext &withSelf(Twine subst);
|
||||
|
||||
std::optional<StringRef> getSubstFor(PHKind placeholder) const;
|
||||
|
|
|
@ -95,6 +95,9 @@ public:
|
|||
// trait classes.
|
||||
std::optional<StringRef> getExtraSharedClassDeclaration() const;
|
||||
|
||||
// Return the extra classof method code.
|
||||
std::optional<StringRef> getExtraClassOf() const;
|
||||
|
||||
// Return the verify method body if it has one.
|
||||
std::optional<StringRef> getVerify() const;
|
||||
|
||||
|
|
|
@ -190,7 +190,7 @@ void StaticVerifierFunctionEmitter::emitConstraints(
|
|||
const ConstraintMap &constraints, StringRef selfName,
|
||||
const char *const codeTemplate) {
|
||||
FmtContext ctx;
|
||||
ctx.withOp("*op").withSelf(selfName);
|
||||
ctx.addSubst("_op", "*op").withSelf(selfName);
|
||||
for (auto &it : constraints) {
|
||||
os << formatv(codeTemplate, it.second,
|
||||
tgfmt(it.first.getConditionTemplate(), &ctx),
|
||||
|
@ -216,7 +216,7 @@ void StaticVerifierFunctionEmitter::emitRegionConstraints() {
|
|||
|
||||
void StaticVerifierFunctionEmitter::emitPatternConstraints() {
|
||||
FmtContext ctx;
|
||||
ctx.withOp("*op").withBuilder("rewriter").withSelf("type");
|
||||
ctx.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type");
|
||||
for (auto &it : typeConstraints) {
|
||||
os << formatv(patternAttrOrTypeConstraintCode, it.second,
|
||||
tgfmt(it.first.getConditionTemplate(), &ctx),
|
||||
|
@ -240,9 +240,9 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() {
|
|||
/// because ops use cached identifiers.
|
||||
static bool canUniqueAttrConstraint(Attribute attr) {
|
||||
FmtContext ctx;
|
||||
auto test =
|
||||
tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op"))
|
||||
.str();
|
||||
auto test = tgfmt(attr.getConditionTemplate(),
|
||||
&ctx.withSelf("attr").addSubst("_op", "*op"))
|
||||
.str();
|
||||
return !StringRef(test).contains("<no-subst-found>");
|
||||
}
|
||||
|
||||
|
|
|
@ -38,11 +38,6 @@ FmtContext &FmtContext::withBuilder(Twine subst) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
FmtContext &FmtContext::withOp(Twine subst) {
|
||||
builtinSubstMap[PHKind::Op] = subst.str();
|
||||
return *this;
|
||||
}
|
||||
|
||||
FmtContext &FmtContext::withSelf(Twine subst) {
|
||||
builtinSubstMap[PHKind::Self] = subst.str();
|
||||
return *this;
|
||||
|
@ -69,7 +64,6 @@ std::optional<StringRef> FmtContext::getSubstFor(StringRef placeholder) const {
|
|||
FmtContext::PHKind FmtContext::getPlaceHolderKind(StringRef str) {
|
||||
return StringSwitch<FmtContext::PHKind>(str)
|
||||
.Case("_builder", FmtContext::PHKind::Builder)
|
||||
.Case("_op", FmtContext::PHKind::Op)
|
||||
.Case("_self", FmtContext::PHKind::Self)
|
||||
.Case("", FmtContext::PHKind::None)
|
||||
.Default(FmtContext::PHKind::Custom);
|
||||
|
|
|
@ -116,6 +116,11 @@ std::optional<StringRef> Interface::getExtraSharedClassDeclaration() const {
|
|||
return value.empty() ? std::optional<StringRef>() : value;
|
||||
}
|
||||
|
||||
std::optional<StringRef> Interface::getExtraClassOf() const {
|
||||
auto value = def->getValueAsString("extraClassOf");
|
||||
return value.empty() ? std::optional<StringRef>() : value;
|
||||
}
|
||||
|
||||
// Return the body for this method if it has one.
|
||||
std::optional<StringRef> Interface::getVerify() const {
|
||||
// Only OpInterface supports the verify method.
|
||||
|
|
|
@ -4,6 +4,17 @@
|
|||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def ExtraClassOfInterface : OpInterface<"ExtraClassOfInterface"> {
|
||||
let extraClassOf = "return $_op->someOtherMethod();";
|
||||
}
|
||||
|
||||
// DECL: class ExtraClassOfInterface
|
||||
// DECL: static bool classof(::mlir::Operation * base) {
|
||||
// DECL-NEXT: if (!getInterfaceFor(base))
|
||||
// DECL-NEXT: return false;
|
||||
// DECL-NEXT: return base->someOtherMethod();
|
||||
// DECL-NEXT: }
|
||||
|
||||
def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> {
|
||||
let extraSharedClassDeclaration = [{
|
||||
bool sharedMethodDeclaration() {
|
||||
|
|
|
@ -819,7 +819,7 @@ OpEmitter::OpEmitter(const Operator &op,
|
|||
formatExtraDefinitions(op)),
|
||||
staticVerifierEmitter(staticVerifierEmitter),
|
||||
emitHelper(op, /*emitForOp=*/true) {
|
||||
verifyCtx.withOp("(*this->getOperation())");
|
||||
verifyCtx.addSubst("_op", "(*this->getOperation())");
|
||||
verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
|
||||
|
||||
genTraits();
|
||||
|
|
|
@ -108,6 +108,8 @@ protected:
|
|||
StringRef interfaceBaseType;
|
||||
/// The name of the typename for the value template.
|
||||
StringRef valueTemplate;
|
||||
/// The name of the substituion variable for the value.
|
||||
StringRef substVar;
|
||||
/// The format context to use for methods.
|
||||
tblgen::FmtContext nonStaticMethodFmt;
|
||||
tblgen::FmtContext traitMethodFmt;
|
||||
|
@ -121,11 +123,12 @@ struct AttrInterfaceGenerator : public InterfaceGenerator {
|
|||
valueType = "::mlir::Attribute";
|
||||
interfaceBaseType = "AttributeInterface";
|
||||
valueTemplate = "ConcreteAttr";
|
||||
substVar = "_attr";
|
||||
StringRef castCode = "(tablegen_opaque_val.cast<ConcreteAttr>())";
|
||||
nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode);
|
||||
traitMethodFmt.addSubst("_attr",
|
||||
nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
|
||||
traitMethodFmt.addSubst(substVar,
|
||||
"(*static_cast<const ConcreteAttr *>(this))");
|
||||
extraDeclsFmt.addSubst("_attr", "(*this)");
|
||||
extraDeclsFmt.addSubst(substVar, "(*this)");
|
||||
}
|
||||
};
|
||||
/// A specialized generator for operation interfaces.
|
||||
|
@ -135,12 +138,13 @@ struct OpInterfaceGenerator : public InterfaceGenerator {
|
|||
valueType = "::mlir::Operation *";
|
||||
interfaceBaseType = "OpInterface";
|
||||
valueTemplate = "ConcreteOp";
|
||||
substVar = "_op";
|
||||
StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
|
||||
nonStaticMethodFmt.addSubst("_this", "impl")
|
||||
.withOp(castCode)
|
||||
.addSubst(substVar, castCode)
|
||||
.withSelf(castCode);
|
||||
traitMethodFmt.withOp("(*static_cast<ConcreteOp *>(this))");
|
||||
extraDeclsFmt.withOp("(*this)");
|
||||
traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))");
|
||||
extraDeclsFmt.addSubst(substVar, "(*this)");
|
||||
}
|
||||
};
|
||||
/// A specialized generator for type interfaces.
|
||||
|
@ -150,11 +154,12 @@ struct TypeInterfaceGenerator : public InterfaceGenerator {
|
|||
valueType = "::mlir::Type";
|
||||
interfaceBaseType = "TypeInterface";
|
||||
valueTemplate = "ConcreteType";
|
||||
substVar = "_type";
|
||||
StringRef castCode = "(tablegen_opaque_val.cast<ConcreteType>())";
|
||||
nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode);
|
||||
traitMethodFmt.addSubst("_type",
|
||||
nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
|
||||
traitMethodFmt.addSubst(substVar,
|
||||
"(*static_cast<const ConcreteType *>(this))");
|
||||
extraDeclsFmt.addSubst("_type", "(*this)");
|
||||
extraDeclsFmt.addSubst(substVar, "(*this)");
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -434,7 +439,7 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
|
|||
assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
|
||||
|
||||
tblgen::FmtContext verifyCtx;
|
||||
verifyCtx.withOp("op");
|
||||
verifyCtx.addSubst("_op", "op");
|
||||
os << llvm::formatv(
|
||||
" static ::mlir::LogicalResult {0}(::mlir::Operation *op) ",
|
||||
(interface.verifyWithRegions() ? "verifyRegionTrait"
|
||||
|
@ -506,6 +511,17 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
|
|||
interface.getExtraSharedClassDeclaration())
|
||||
os << tblgen::tgfmt(*extraDecls, &extraDeclsFmt);
|
||||
|
||||
// Emit classof code if necessary.
|
||||
if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
|
||||
auto extraClassOfFmt = tblgen::FmtContext();
|
||||
extraClassOfFmt.addSubst(substVar, "base");
|
||||
os << " static bool classof(" << valueType << " base) {\n"
|
||||
<< " if (!getInterfaceFor(base))\n"
|
||||
" return false;\n"
|
||||
<< " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
|
||||
<< "\n }\n";
|
||||
}
|
||||
|
||||
os << "};\n";
|
||||
|
||||
os << "namespace detail {\n";
|
||||
|
|
|
@ -105,12 +105,6 @@ TEST(FormatTest, PlaceHolderFmtStrWithBuilder) {
|
|||
EXPECT_THAT(result, StrEq("bbb"));
|
||||
}
|
||||
|
||||
TEST(FormatTest, PlaceHolderFmtStrWithOp) {
|
||||
FmtContext ctx;
|
||||
std::string result = std::string(tgfmt("$_op", &ctx.withOp("ooo")));
|
||||
EXPECT_THAT(result, StrEq("ooo"));
|
||||
}
|
||||
|
||||
TEST(FormatTest, PlaceHolderMissingCtx) {
|
||||
std::string result = std::string(tgfmt("$_op", nullptr));
|
||||
EXPECT_THAT(result, StrEq("$_op<no-subst-found>"));
|
||||
|
|
Loading…
Reference in New Issue
Block a user