Skip to content
12 changes: 1 addition & 11 deletions mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,11 @@ def BitVectorAttr : AttrDef<SMTDialect, "BitVector", [
present).
}];

let parameters = (ins "llvm::APInt":$value);
let parameters = (ins APIntParameter<"">:$value);

let hasCustomAssemblyFormat = true;
let genVerifyDecl = true;

// We need to manually define the storage class because the generated one is
// buggy (because the APInt asserts matching bitwidth in the `==` operator and
// the generated storage uses that directly.
// Alternatively: add a type parameter to redundantly store the bitwidth of
// of the attribute type, it it's in the order before the 'value' it will be
// checked before the APInt equality (this is the reason it works for the
// builtin integer attribute), but would be more fragile (and we'd store
// duplicate data).
let genStorageClass = false;

let builders = [
AttrBuilder<(ins "llvm::StringRef":$value)>,
AttrBuilder<(ins "uint64_t":$value, "unsigned":$width)>,
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/BuiltinAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer",
false // A bool, i.e. i1, value.
```
}];
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value);
let parameters = (ins AttributeSelfTypeParameter<"">:$type, APIntParameter<"">:$value);
let builders = [
AttrBuilderWithInferredContext<(ins "Type":$type,
"const APInt &":$value), [{
Expand Down
6 changes: 5 additions & 1 deletion mlir/include/mlir/TableGen/AttrOrTypeDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ class AttrOrTypeParameter {
/// If specified, get the custom allocator code for this parameter.
std::optional<StringRef> getAllocator() const;

/// If specified, get the custom comparator code for this parameter.
/// Return true if user defined comparator is specified.
bool hasCustomComparator() const;

/// Get the custom comparator code for this parameter or fallback to the
/// default.
StringRef getComparator() const;

/// Get the C++ type of this parameter.
Expand Down
36 changes: 0 additions & 36 deletions mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,6 @@ using namespace mlir::smt;
// BitVectorAttr
//===----------------------------------------------------------------------===//

namespace mlir {
namespace smt {
namespace detail {
struct BitVectorAttrStorage : public mlir::AttributeStorage {
using KeyTy = APInt;
BitVectorAttrStorage(APInt value) : value(std::move(value)) {}

KeyTy getAsKey() const { return value; }

// NOTE: the implementation of this operator is the reason we need to define
// the storage manually. The auto-generated version would just do the direct
// equality check of the APInt, but that asserts the bitwidth of both to be
// the same, leading to a crash. This implementation, therefore, checks for
// matching bit-width beforehand.
bool operator==(const KeyTy &key) const {
return (value.getBitWidth() == key.getBitWidth() && value == key);
}

static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}

static BitVectorAttrStorage *
construct(mlir::AttributeStorageAllocator &allocator, KeyTy &&key) {
return new (allocator.allocate<BitVectorAttrStorage>())
BitVectorAttrStorage(std::move(key));
}

APInt value;
};
} // namespace detail
} // namespace smt
} // namespace mlir

APInt BitVectorAttr::getValue() const { return getImpl()->value; }

LogicalResult BitVectorAttr::verify(
function_ref<InFlightDiagnostic()> emitError,
APInt value) { // NOLINT(performance-unnecessary-value-param)
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/TableGen/AttrOrTypeDef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ std::optional<StringRef> AttrOrTypeParameter::getAllocator() const {
return getDefValue<StringInit>("allocator");
}

bool AttrOrTypeParameter::hasCustomComparator() const {
return getDefValue<StringInit>("comparator").has_value();
}

StringRef AttrOrTypeParameter::getComparator() const {
return getDefValue<StringInit>("comparator").value_or("$_lhs == $_rhs");
}
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/mlir-tblgen/apint-param-warn.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: mlir-tblgen -gen-attrdef-decls -I %S/../../include %s 2>&1 | FileCheck %s

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

def Test_Dialect: Dialect {
let name = "TestDialect";
let cppNamespace = "::test";
}

def RawAPIntAttr : AttrDef<Test_Dialect, "RawAPInt"> {
let mnemonic = "raw_ap_int";
let parameters = (ins "APInt":$value);
let hasCustomAssemblyFormat = 1;
}

// CHECK: apint-param-warn.td:11:5: warning: Using a raw APInt parameter
12 changes: 11 additions & 1 deletion mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,18 @@ void DefGen::emitStorageClass() {
emitConstruct();
// Emit the storage class members as public, at the very end of the struct.
storageCls->finalize();
for (auto &param : params)
for (auto &param : params) {
if (param.getCppType().contains("APInt") && !param.hasCustomComparator()) {
PrintWarning(
def.getLoc(),
"Using a raw APInt parameter without a custom comparator is "
"not supported because an assert in the equality operator is "
"triggered when the two APInts have different bit widths. This can "
"lead to unexpected crashes. Use an `APIntParameter` or "
"provide a custom comparator.");
}
storageCls->declare<Field>(param.getCppType(), param.getName());
}
}

//===----------------------------------------------------------------------===//
Expand Down
Loading