Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,7 @@ cc_library(
":chlo_ops",
":chlo_rewriters_inc_gen",
":stablehlo_aggressive_simplification_inc_gen",
":stablehlo_broadcast_lowering",
":stablehlo_create_compatibility_expander_inc_gen",
":stablehlo_create_complex_math_expander_inc_gen",
":stablehlo_legalize_deprecated_ops_inc_gen",
Expand Down Expand Up @@ -1922,6 +1923,24 @@ cc_library(
],
)

cc_test(
name = "chlo_builder_test",
srcs = ["stablehlo/integrations/cpp/builder/ChloBuilderTest.cpp"],
deps = [
":attr_type_builder_util",
":chlo_builder",
":func_builder",
":mlir_builder",
":register",
":stablehlo_builder",
":stablehlo_ops",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//third-party/unittest:gmock",
"@llvm-project//third-party/unittest:gtest",
],
)

gentbl_cc_library(
name = "func_builder_inc",
tbl_outs = {
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
Expand Down Expand Up @@ -781,6 +782,14 @@ bool isValidQuantizedDimension(Type type) {
numScales == rankedType.getDimSize(quantDim));
}

bool isBoundedDynamic(Type type) {
RankedTensorType rankedType = dyn_cast<RankedTensorType>(type);
if (!rankedType) return false;
auto boundedAttr =
mlir::dyn_cast_if_present<BoundedAttrInterface>(rankedType.getEncoding());
return boundedAttr != nullptr;
}

bool hasSingleBoundedDimension(Type type) {
RankedTensorType rankedType = dyn_cast<RankedTensorType>(type);
auto boundedAttr =
Expand Down
51 changes: 27 additions & 24 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ bool isValidStablehloQuantizedElementType(Type elementType);
// mentioned in the StableHLO specification.
bool isValidQuantizedDimension(Type type);

// Returns true if the given type is a bounded dynamic tensor.
bool isBoundedDynamic(Type type);

// Returns true if the given type has a single bounded dimension.
bool hasSingleBoundedDimension(Type type);

Expand Down Expand Up @@ -135,19 +138,19 @@ FailureOr<Type> inferMostSpecificType(std::optional<Location> location,

LogicalResult inferMostSpecificTypeComponents(
std::optional<Location> location, TypeRange inputTypes,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes);
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes);

// Matches a constant with integer value into int64_t.
LogicalResult matchInt(Value value, int64_t &result);
LogicalResult matchInt(Value value, int64_t& result);

// Matches a constant tensor with integer values into a 1-dimensional vector.
// Doesn't preserve the bitness or the signedness of the underlying values,
// extracting them into int64_t.
LogicalResult matchInts(Value value, SmallVector<int64_t> &result);
LogicalResult matchInts(Value value, SmallVector<int64_t>& result);

// Matches a constant tensor with integer values into a 1-dimensional vector.
// Preserves the bitness and the signedness of the underlying values.
LogicalResult matchInts(Value value, SmallVector<APSInt> &result);
LogicalResult matchInts(Value value, SmallVector<APSInt>& result);

// Matches a constant tensor with integer values.
// Unlike the functions above, it doesn't return these values - it just checks
Expand All @@ -166,8 +169,8 @@ LogicalResult matchInts(Value value);
//
// and returns %4 as the shape value.
LogicalResult deriveShapeFromOperand(
OpBuilder *builder, Operation *op, Value operand,
SmallVectorImpl<Value> *reifiedReturnShapes);
OpBuilder* builder, Operation* op, Value operand,
SmallVectorImpl<Value>* reifiedReturnShapes);

// Type derivation function that returns a tensor type with a new element type.
ShapedType getSameShapeTensorType(ShapedType shapedType, Type elementType);
Expand Down Expand Up @@ -199,23 +202,23 @@ Attribute boundsToEncoding(Attribute prototype, ArrayRef<int64_t> bounds);
// If the attribute is valid but not all shape operands are constants,
// returns failure.
LogicalResult getShapeRefinements(
std::optional<Location> location, Operation *operation,
SmallVector<ShapedTypeComponents> &refinements);
std::optional<Location> location, Operation* operation,
SmallVector<ShapedTypeComponents>& refinements);

// For each type in `types`, recursively flatten tuple types into `result`.
// Result is populated via in-order traversal of tuple types in `types`, i.e.:
// * Flattenings of individual types from `types` follow one another in the
// same order as `types`.
// * Same for flattenings of element types of tuple types.
void flattenTupleTypes(TypeRange types, SmallVector<Type> &result);
void flattenTupleTypes(TypeRange types, SmallVector<Type>& result);

// Does the inverse of `flattenTupleTypes` - takes `types` and recursively
// unflattens it, creating tuple types as needed to exactly match the structure
// of `prototype`.
// Fails if the number of elements in flattened prototype is different from
// the number of elements in types.
LogicalResult unflattenTupleTypes(TypeRange prototype, TypeRange types,
SmallVector<Type> &result);
SmallVector<Type>& result);

ShapedType createShapedType(ShapedTypeComponents components);

Expand All @@ -224,7 +227,7 @@ ShapedType createShapedType(ShapedTypeComponents components);
// prettyprinting logic between them.
class HloDialectInterface : public DialectInterface::Base<HloDialectInterface> {
public:
HloDialectInterface(Dialect *dialect) : Base(dialect) {}
HloDialectInterface(Dialect* dialect) : Base(dialect) {}

// Creates a TokenType type, specific to this dialect.
// See docs for the particular type in the corresponding dialect.
Expand Down Expand Up @@ -283,8 +286,8 @@ namespace bytecode {
// Note this may cause issues if enums use an int64_t and have a large value.
// All enums in StableHLO and CHLO currently use uint32_t.
template <typename EnumTypeAttr, typename SymbolizeFn>
EnumTypeAttr readEnumAttribute(DialectBytecodeReader &reader,
MLIRContext *context, SymbolizeFn symbolizeFn) {
EnumTypeAttr readEnumAttribute(DialectBytecodeReader& reader,
MLIRContext* context, SymbolizeFn symbolizeFn) {
uint64_t code;
if (failed(reader.readVarInt(code))) return EnumTypeAttr();

Expand All @@ -295,7 +298,7 @@ EnumTypeAttr readEnumAttribute(DialectBytecodeReader &reader,
}

template <typename EnumType, typename EnumTypeAttr>
void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter &writer) {
void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter& writer) {
static_assert(
std::is_same<typename std::underlying_type<EnumType>::type,
uint32_t>::value,
Expand All @@ -311,7 +314,7 @@ void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter &writer) {
// shape operands. The last `count` operands are assumed to be shape operands.
// To be speculatable, such an op must have only static inputs and constant
// shape operands.
mlir::Speculation::Speculatability getShapedSpeculatability(Operation *op,
mlir::Speculation::Speculatability getShapedSpeculatability(Operation* op,
int64_t shapeCount);

// Applies `fn` to `type` if it is not a `tuple` type. Otherwise, applies `fn`
Expand All @@ -334,7 +337,7 @@ class PairwiseSameOperandAndResultType
: public mlir::OpTrait::TraitBase<ConcreteType,
PairwiseSameOperandAndResultType> {
public:
static LogicalResult verifyTrait(Operation *op) {
static LogicalResult verifyTrait(Operation* op) {
const int numOperands = op->getNumOperands();
const int numResults = op->getNumResults();
if (numOperands != numResults) {
Expand All @@ -358,7 +361,7 @@ class PairwiseSameOperandAndResultElementType
: public mlir::OpTrait::TraitBase<ConcreteType,
PairwiseSameOperandAndResultElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
static LogicalResult verifyTrait(Operation* op) {
const int numOperands = op->getNumOperands();
const int numResults = op->getNumResults();
if (numOperands != numResults) {
Expand All @@ -383,7 +386,7 @@ class CompatibleOperandsAndResultElementType
: public mlir::OpTrait::TraitBase<ConcreteType,
CompatibleOperandsAndResultElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
static LogicalResult verifyTrait(Operation* op) {
Type expected;
if (op->getNumResults() != 0) expected = op->getResult(0).getType();
if (op->getNumOperands() != 0) expected = op->getOperand(0).getType();
Expand All @@ -408,7 +411,7 @@ class CompatibleOperandsElementType
: public mlir::OpTrait::TraitBase<ConcreteType,
CompatibleOperandsElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
static LogicalResult verifyTrait(Operation* op) {
if (failed(mlir::OpTrait::impl::verifyAtLeastNOperands(op, 1)))
return failure();

Expand All @@ -431,7 +434,7 @@ class CompatibleOperandsAndResultType
: public mlir::OpTrait::TraitBase<ConcreteType,
CompatibleOperandsAndResultType> {
public:
static LogicalResult verifyTrait(Operation *op) {
static LogicalResult verifyTrait(Operation* op) {
Type expected;
if (op->getNumResults() != 0) expected = op->getResult(0).getType();
if (op->getNumOperands() != 0) expected = op->getOperand(0).getType();
Expand All @@ -451,10 +454,10 @@ class CompatibleOperandsAndResultType
}

static LogicalResult inferReturnTypes(
MLIRContext * /*context*/, std::optional<Location> location,
MLIRContext* /*context*/, std::optional<Location> location,
ValueRange operands, DictionaryAttr /*attributes*/,
OpaqueProperties /*properties*/, RegionRange /*regions*/,
SmallVectorImpl<Type> &inferredReturnTypes) {
SmallVectorImpl<Type>& inferredReturnTypes) {
// TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
// support quantization or sparsity.
if (operands.empty())
Expand All @@ -473,10 +476,10 @@ class CompatibleOperandsAndResultType
// It needs to be paired with INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS
// (see examples in StablehloOps.cpp).
static LogicalResult inferReturnTypeComponentsFromOperands(
MLIRContext *context, std::optional<Location> location,
MLIRContext* context, std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<Type> inferredReturnTypes;
if (failed(inferReturnTypes(context, location, operands.getValues(),
attributes, properties, regions,
Expand Down
9 changes: 6 additions & 3 deletions stablehlo/dialect/ChloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,14 @@ LogicalResult ConstantLikeOp::inferReturnTypeComponents(
Type elementType = op.getValue().getType();
Type operandType = op.getOperand().getType();
if (isa<UnrankedTensorType>(operandType)) {
// TODO(b/326463552): Remove unranked dynamism from CHLO.
inferredReturnShapes.emplace_back(elementType);
} else {
const auto& shape = cast<RankedTensorType>(operandType).getShape();
inferredReturnShapes.emplace_back(shape, elementType);
return success();
}
auto rankedType = cast<RankedTensorType>(operandType);
const auto& shape = rankedType.getShape();
Attribute encoding = rankedType.getEncoding();
inferredReturnShapes.emplace_back(shape, elementType, encoding);
return success();
}

Expand Down
1 change: 1 addition & 0 deletions stablehlo/integrations/cpp/builder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ if (TARGET llvm_gtest)
set_target_properties(check-stablehlo-ci PROPERTIES FOLDER "Tests")
add_unittest(check-stablehlo-ci "unittests"
MlirBuilderTest.cpp
ChloBuilderTest.cpp
StablehloBuilderTest.cpp
AttrTypeBuilderUtilTest.cpp
)
Expand Down
10 changes: 10 additions & 0 deletions stablehlo/integrations/cpp/builder/ChloBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,15 @@ namespace chlo {

#include "stablehlo/integrations/cpp/builder/ChloBuilder.cpp.inc"

/////////////////
// MANUAL APIs
/////////////////

MlirOp ConstantLike(MlirOp input, DenseElementsAttr val) {
MlirBuilder& builder = input.getBuilder();
auto splat_val = val.getSplatValue<TypedAttr>();
return builder.create<chlo::ConstantLikeOp>(splat_val, input.getValue());
}

} // namespace chlo
} // namespace mlir
7 changes: 7 additions & 0 deletions stablehlo/integrations/cpp/builder/ChloBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <cstdint>

#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/integrations/cpp/builder/MlirBuilder.h"

Expand All @@ -31,6 +32,12 @@ namespace chlo {

#include "stablehlo/integrations/cpp/builder/ChloBuilder.h.inc"

/////////////////
// MANUAL APIs
/////////////////

MlirOp ConstantLike(MlirOp input, DenseElementsAttr val);

} // namespace chlo
} // namespace mlir

Expand Down
Loading
Loading