[mlir] Move SymbolOpInterfaces "classof" check to a proper "extraClassOf" interface field

SymbolOpInterface overrides the base classof to provide support
for optionally implementing the interface. This is currently placed
in the extraClassDeclarations, but that is kind of awkard given that
it requires underlying knowledge of how the base classof is implemented.
This commit adds a proper "extraClassOf" field to allow interfaces to
implement this, which abstracts away the default classof logic.

Differential Revision: https://reviews.llvm.org/D140197
This commit is contained in:
River Riddle 2022-12-16 01:16:15 -08:00
parent 3e731af912
commit 5cdc2bbc75
12 changed files with 70 additions and 52 deletions

View File

@ -2048,6 +2048,13 @@ class Interface<string name> {
// An optional code block containing extra declarations to place in both
// the interface and trait declaration.
code extraSharedClassDeclaration = "";
// An optional code block for adding additional "classof" logic. This can
// be used to better enable "optional" interfaces, where an entity only
// implements the interface if some dynamic characteristic holds.
// `$_attr`/`$_op`/`$_type` may be used to refer to an instance of the
// entity being checked.
code extraClassOf = "";
}
// AttrInterface represents an interface registered to an attribute.

View File

@ -174,28 +174,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
return success();
}];
let extraClassDeclaration = [{
/// Convenience version of `getNameAttr` that returns a StringRef.
StringRef getName() {
return getNameAttr().getValue();
}
/// Convenience version of `setName` that take a StringRef.
void setName(StringRef name) {
setName(StringAttr::get(this->getContext(), name));
}
/// Custom classof that handles the case where the symbol is optional.
static bool classof(Operation *op) {
auto *opConcept = getInterfaceFor(op);
if (!opConcept)
return false;
return !opConcept->isOptionalSymbol(opConcept, op) ||
op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
}
}];
let extraTraitClassDeclaration = [{
let extraSharedClassDeclaration = [{
using Visibility = mlir::SymbolTable::Visibility;
/// Convenience version of `getNameAttr` that returns a StringRef.
@ -208,6 +187,11 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
setName(StringAttr::get($_op->getContext(), name));
}
}];
// Add additional classof checks to properly handle "optional" symbols.
let extraClassOf = [{
return $_op->hasAttr(::mlir::SymbolTable::getSymbolAttrName());
}];
}
//===----------------------------------------------------------------------===//

View File

@ -110,6 +110,12 @@ public:
"expected value to provide interface instance");
}
/// Constructor for a known concept.
Interface(ValueT t, Concept *conceptImpl)
: BaseType(t), conceptImpl(conceptImpl) {
assert(!t || ConcreteType::getInterfaceFor(t) == conceptImpl);
}
/// Constructor for DenseMapInfo's empty key and tombstone key.
Interface(ValueT t, std::nullptr_t) : BaseType(t), conceptImpl(nullptr) {}

View File

@ -44,7 +44,6 @@ public:
None,
Custom, // For custom placeholders
Builder, // For the $_builder placeholder
Op, // For the $_op placeholder
Self, // For the $_self placeholder
};
@ -58,7 +57,6 @@ public:
// Setters for builtin placeholders
FmtContext &withBuilder(Twine subst);
FmtContext &withOp(Twine subst);
FmtContext &withSelf(Twine subst);
std::optional<StringRef> getSubstFor(PHKind placeholder) const;

View File

@ -95,6 +95,9 @@ public:
// trait classes.
std::optional<StringRef> getExtraSharedClassDeclaration() const;
// Return the extra classof method code.
std::optional<StringRef> getExtraClassOf() const;
// Return the verify method body if it has one.
std::optional<StringRef> getVerify() const;

View File

@ -190,7 +190,7 @@ void StaticVerifierFunctionEmitter::emitConstraints(
const ConstraintMap &constraints, StringRef selfName,
const char *const codeTemplate) {
FmtContext ctx;
ctx.withOp("*op").withSelf(selfName);
ctx.addSubst("_op", "*op").withSelf(selfName);
for (auto &it : constraints) {
os << formatv(codeTemplate, it.second,
tgfmt(it.first.getConditionTemplate(), &ctx),
@ -216,7 +216,7 @@ void StaticVerifierFunctionEmitter::emitRegionConstraints() {
void StaticVerifierFunctionEmitter::emitPatternConstraints() {
FmtContext ctx;
ctx.withOp("*op").withBuilder("rewriter").withSelf("type");
ctx.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type");
for (auto &it : typeConstraints) {
os << formatv(patternAttrOrTypeConstraintCode, it.second,
tgfmt(it.first.getConditionTemplate(), &ctx),
@ -240,9 +240,9 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() {
/// because ops use cached identifiers.
static bool canUniqueAttrConstraint(Attribute attr) {
FmtContext ctx;
auto test =
tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op"))
.str();
auto test = tgfmt(attr.getConditionTemplate(),
&ctx.withSelf("attr").addSubst("_op", "*op"))
.str();
return !StringRef(test).contains("<no-subst-found>");
}

View File

@ -38,11 +38,6 @@ FmtContext &FmtContext::withBuilder(Twine subst) {
return *this;
}
FmtContext &FmtContext::withOp(Twine subst) {
builtinSubstMap[PHKind::Op] = subst.str();
return *this;
}
FmtContext &FmtContext::withSelf(Twine subst) {
builtinSubstMap[PHKind::Self] = subst.str();
return *this;
@ -69,7 +64,6 @@ std::optional<StringRef> FmtContext::getSubstFor(StringRef placeholder) const {
FmtContext::PHKind FmtContext::getPlaceHolderKind(StringRef str) {
return StringSwitch<FmtContext::PHKind>(str)
.Case("_builder", FmtContext::PHKind::Builder)
.Case("_op", FmtContext::PHKind::Op)
.Case("_self", FmtContext::PHKind::Self)
.Case("", FmtContext::PHKind::None)
.Default(FmtContext::PHKind::Custom);

View File

@ -116,6 +116,11 @@ std::optional<StringRef> Interface::getExtraSharedClassDeclaration() const {
return value.empty() ? std::optional<StringRef>() : value;
}
std::optional<StringRef> Interface::getExtraClassOf() const {
auto value = def->getValueAsString("extraClassOf");
return value.empty() ? std::optional<StringRef>() : value;
}
// Return the body for this method if it has one.
std::optional<StringRef> Interface::getVerify() const {
// Only OpInterface supports the verify method.

View File

@ -4,6 +4,17 @@
include "mlir/IR/OpBase.td"
def ExtraClassOfInterface : OpInterface<"ExtraClassOfInterface"> {
let extraClassOf = "return $_op->someOtherMethod();";
}
// DECL: class ExtraClassOfInterface
// DECL: static bool classof(::mlir::Operation * base) {
// DECL-NEXT: if (!getInterfaceFor(base))
// DECL-NEXT: return false;
// DECL-NEXT: return base->someOtherMethod();
// DECL-NEXT: }
def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> {
let extraSharedClassDeclaration = [{
bool sharedMethodDeclaration() {

View File

@ -819,7 +819,7 @@ OpEmitter::OpEmitter(const Operator &op,
formatExtraDefinitions(op)),
staticVerifierEmitter(staticVerifierEmitter),
emitHelper(op, /*emitForOp=*/true) {
verifyCtx.withOp("(*this->getOperation())");
verifyCtx.addSubst("_op", "(*this->getOperation())");
verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
genTraits();

View File

@ -108,6 +108,8 @@ protected:
StringRef interfaceBaseType;
/// The name of the typename for the value template.
StringRef valueTemplate;
/// The name of the substituion variable for the value.
StringRef substVar;
/// The format context to use for methods.
tblgen::FmtContext nonStaticMethodFmt;
tblgen::FmtContext traitMethodFmt;
@ -121,11 +123,12 @@ struct AttrInterfaceGenerator : public InterfaceGenerator {
valueType = "::mlir::Attribute";
interfaceBaseType = "AttributeInterface";
valueTemplate = "ConcreteAttr";
substVar = "_attr";
StringRef castCode = "(tablegen_opaque_val.cast<ConcreteAttr>())";
nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode);
traitMethodFmt.addSubst("_attr",
nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
traitMethodFmt.addSubst(substVar,
"(*static_cast<const ConcreteAttr *>(this))");
extraDeclsFmt.addSubst("_attr", "(*this)");
extraDeclsFmt.addSubst(substVar, "(*this)");
}
};
/// A specialized generator for operation interfaces.
@ -135,12 +138,13 @@ struct OpInterfaceGenerator : public InterfaceGenerator {
valueType = "::mlir::Operation *";
interfaceBaseType = "OpInterface";
valueTemplate = "ConcreteOp";
substVar = "_op";
StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
nonStaticMethodFmt.addSubst("_this", "impl")
.withOp(castCode)
.addSubst(substVar, castCode)
.withSelf(castCode);
traitMethodFmt.withOp("(*static_cast<ConcreteOp *>(this))");
extraDeclsFmt.withOp("(*this)");
traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))");
extraDeclsFmt.addSubst(substVar, "(*this)");
}
};
/// A specialized generator for type interfaces.
@ -150,11 +154,12 @@ struct TypeInterfaceGenerator : public InterfaceGenerator {
valueType = "::mlir::Type";
interfaceBaseType = "TypeInterface";
valueTemplate = "ConcreteType";
substVar = "_type";
StringRef castCode = "(tablegen_opaque_val.cast<ConcreteType>())";
nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode);
traitMethodFmt.addSubst("_type",
nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
traitMethodFmt.addSubst(substVar,
"(*static_cast<const ConcreteType *>(this))");
extraDeclsFmt.addSubst("_type", "(*this)");
extraDeclsFmt.addSubst(substVar, "(*this)");
}
};
} // namespace
@ -434,7 +439,7 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
tblgen::FmtContext verifyCtx;
verifyCtx.withOp("op");
verifyCtx.addSubst("_op", "op");
os << llvm::formatv(
" static ::mlir::LogicalResult {0}(::mlir::Operation *op) ",
(interface.verifyWithRegions() ? "verifyRegionTrait"
@ -506,6 +511,17 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
interface.getExtraSharedClassDeclaration())
os << tblgen::tgfmt(*extraDecls, &extraDeclsFmt);
// Emit classof code if necessary.
if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
auto extraClassOfFmt = tblgen::FmtContext();
extraClassOfFmt.addSubst(substVar, "base");
os << " static bool classof(" << valueType << " base) {\n"
<< " if (!getInterfaceFor(base))\n"
" return false;\n"
<< " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
<< "\n }\n";
}
os << "};\n";
os << "namespace detail {\n";

View File

@ -105,12 +105,6 @@ TEST(FormatTest, PlaceHolderFmtStrWithBuilder) {
EXPECT_THAT(result, StrEq("bbb"));
}
TEST(FormatTest, PlaceHolderFmtStrWithOp) {
FmtContext ctx;
std::string result = std::string(tgfmt("$_op", &ctx.withOp("ooo")));
EXPECT_THAT(result, StrEq("ooo"));
}
TEST(FormatTest, PlaceHolderMissingCtx) {
std::string result = std::string(tgfmt("$_op", nullptr));
EXPECT_THAT(result, StrEq("$_op<no-subst-found>"));