Skip to content

Commit c428422

Browse files
committed
clang-format
1 parent fdf40a9 commit c428422

File tree

2 files changed

+27
-28
lines changed

2 files changed

+27
-28
lines changed

stablehlo/dialect/Base.h

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -138,19 +138,19 @@ FailureOr<Type> inferMostSpecificType(std::optional<Location> location,
138138

139139
LogicalResult inferMostSpecificTypeComponents(
140140
std::optional<Location> location, TypeRange inputTypes,
141-
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes);
141+
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes);
142142

143143
// Matches a constant with integer value into int64_t.
144-
LogicalResult matchInt(Value value, int64_t &result);
144+
LogicalResult matchInt(Value value, int64_t& result);
145145

146146
// Matches a constant tensor with integer values into a 1-dimensional vector.
147147
// Doesn't preserve the bitness or the signedness of the underlying values,
148148
// extracting them into int64_t.
149-
LogicalResult matchInts(Value value, SmallVector<int64_t> &result);
149+
LogicalResult matchInts(Value value, SmallVector<int64_t>& result);
150150

151151
// Matches a constant tensor with integer values into a 1-dimensional vector.
152152
// Preserves the bitness and the signedness of the underlying values.
153-
LogicalResult matchInts(Value value, SmallVector<APSInt> &result);
153+
LogicalResult matchInts(Value value, SmallVector<APSInt>& result);
154154

155155
// Matches a constant tensor with integer values.
156156
// Unlike the functions above, it doesn't return these values - it just checks
@@ -169,8 +169,8 @@ LogicalResult matchInts(Value value);
169169
//
170170
// and returns %4 as the shape value.
171171
LogicalResult deriveShapeFromOperand(
172-
OpBuilder *builder, Operation *op, Value operand,
173-
SmallVectorImpl<Value> *reifiedReturnShapes);
172+
OpBuilder* builder, Operation* op, Value operand,
173+
SmallVectorImpl<Value>* reifiedReturnShapes);
174174

175175
// Type derivation function that returns a tensor type with a new element type.
176176
ShapedType getSameShapeTensorType(ShapedType shapedType, Type elementType);
@@ -202,23 +202,23 @@ Attribute boundsToEncoding(Attribute prototype, ArrayRef<int64_t> bounds);
202202
// If the attribute is valid but not all shape operands are constants,
203203
// returns failure.
204204
LogicalResult getShapeRefinements(
205-
std::optional<Location> location, Operation *operation,
206-
SmallVector<ShapedTypeComponents> &refinements);
205+
std::optional<Location> location, Operation* operation,
206+
SmallVector<ShapedTypeComponents>& refinements);
207207

208208
// For each type in `types`, recursively flatten tuple types into `result`.
209209
// Result is populated via in-order traversal of tuple types in `types`, i.e.:
210210
// * Flattenings of individual types from `types` follow one another in the
211211
// same order as `types`.
212212
// * Same for flattenings of element types of tuple types.
213-
void flattenTupleTypes(TypeRange types, SmallVector<Type> &result);
213+
void flattenTupleTypes(TypeRange types, SmallVector<Type>& result);
214214

215215
// Does the inverse of `flattenTupleTypes` - takes `types` and recursively
216216
// unflattens it, creating tuple types as needed to exactly match the structure
217217
// of `prototype`.
218218
// Fails if the number of elements in flattened prototype is different from
219219
// the number of elements in types.
220220
LogicalResult unflattenTupleTypes(TypeRange prototype, TypeRange types,
221-
SmallVector<Type> &result);
221+
SmallVector<Type>& result);
222222

223223
ShapedType createShapedType(ShapedTypeComponents components);
224224

@@ -227,7 +227,7 @@ ShapedType createShapedType(ShapedTypeComponents components);
227227
// prettyprinting logic between them.
228228
class HloDialectInterface : public DialectInterface::Base<HloDialectInterface> {
229229
public:
230-
HloDialectInterface(Dialect *dialect) : Base(dialect) {}
230+
HloDialectInterface(Dialect* dialect) : Base(dialect) {}
231231

232232
// Creates a TokenType type, specific to this dialect.
233233
// See docs for the particular type in the corresponding dialect.
@@ -286,8 +286,8 @@ namespace bytecode {
286286
// Note this may cause issues if enums use an int64_t and have a large value.
287287
// All enums in StableHLO and CHLO currently use uint32_t.
288288
template <typename EnumTypeAttr, typename SymbolizeFn>
289-
EnumTypeAttr readEnumAttribute(DialectBytecodeReader &reader,
290-
MLIRContext *context, SymbolizeFn symbolizeFn) {
289+
EnumTypeAttr readEnumAttribute(DialectBytecodeReader& reader,
290+
MLIRContext* context, SymbolizeFn symbolizeFn) {
291291
uint64_t code;
292292
if (failed(reader.readVarInt(code))) return EnumTypeAttr();
293293

@@ -298,7 +298,7 @@ EnumTypeAttr readEnumAttribute(DialectBytecodeReader &reader,
298298
}
299299

300300
template <typename EnumType, typename EnumTypeAttr>
301-
void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter &writer) {
301+
void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter& writer) {
302302
static_assert(
303303
std::is_same<typename std::underlying_type<EnumType>::type,
304304
uint32_t>::value,
@@ -314,7 +314,7 @@ void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter &writer) {
314314
// shape operands. The last `count` operands are assumed to be shape operands.
315315
// To be speculatable, such an op must have only static inputs and constant
316316
// shape operands.
317-
mlir::Speculation::Speculatability getShapedSpeculatability(Operation *op,
317+
mlir::Speculation::Speculatability getShapedSpeculatability(Operation* op,
318318
int64_t shapeCount);
319319

320320
// Applies `fn` to `type` if it is not a `tuple` type. Otherwise, applies `fn`
@@ -337,7 +337,7 @@ class PairwiseSameOperandAndResultType
337337
: public mlir::OpTrait::TraitBase<ConcreteType,
338338
PairwiseSameOperandAndResultType> {
339339
public:
340-
static LogicalResult verifyTrait(Operation *op) {
340+
static LogicalResult verifyTrait(Operation* op) {
341341
const int numOperands = op->getNumOperands();
342342
const int numResults = op->getNumResults();
343343
if (numOperands != numResults) {
@@ -361,7 +361,7 @@ class PairwiseSameOperandAndResultElementType
361361
: public mlir::OpTrait::TraitBase<ConcreteType,
362362
PairwiseSameOperandAndResultElementType> {
363363
public:
364-
static LogicalResult verifyTrait(Operation *op) {
364+
static LogicalResult verifyTrait(Operation* op) {
365365
const int numOperands = op->getNumOperands();
366366
const int numResults = op->getNumResults();
367367
if (numOperands != numResults) {
@@ -386,7 +386,7 @@ class CompatibleOperandsAndResultElementType
386386
: public mlir::OpTrait::TraitBase<ConcreteType,
387387
CompatibleOperandsAndResultElementType> {
388388
public:
389-
static LogicalResult verifyTrait(Operation *op) {
389+
static LogicalResult verifyTrait(Operation* op) {
390390
Type expected;
391391
if (op->getNumResults() != 0) expected = op->getResult(0).getType();
392392
if (op->getNumOperands() != 0) expected = op->getOperand(0).getType();
@@ -411,7 +411,7 @@ class CompatibleOperandsElementType
411411
: public mlir::OpTrait::TraitBase<ConcreteType,
412412
CompatibleOperandsElementType> {
413413
public:
414-
static LogicalResult verifyTrait(Operation *op) {
414+
static LogicalResult verifyTrait(Operation* op) {
415415
if (failed(mlir::OpTrait::impl::verifyAtLeastNOperands(op, 1)))
416416
return failure();
417417

@@ -434,7 +434,7 @@ class CompatibleOperandsAndResultType
434434
: public mlir::OpTrait::TraitBase<ConcreteType,
435435
CompatibleOperandsAndResultType> {
436436
public:
437-
static LogicalResult verifyTrait(Operation *op) {
437+
static LogicalResult verifyTrait(Operation* op) {
438438
Type expected;
439439
if (op->getNumResults() != 0) expected = op->getResult(0).getType();
440440
if (op->getNumOperands() != 0) expected = op->getOperand(0).getType();
@@ -454,10 +454,10 @@ class CompatibleOperandsAndResultType
454454
}
455455

456456
static LogicalResult inferReturnTypes(
457-
MLIRContext * /*context*/, std::optional<Location> location,
457+
MLIRContext* /*context*/, std::optional<Location> location,
458458
ValueRange operands, DictionaryAttr /*attributes*/,
459459
OpaqueProperties /*properties*/, RegionRange /*regions*/,
460-
SmallVectorImpl<Type> &inferredReturnTypes) {
460+
SmallVectorImpl<Type>& inferredReturnTypes) {
461461
// TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
462462
// support quantization or sparsity.
463463
if (operands.empty())
@@ -476,10 +476,10 @@ class CompatibleOperandsAndResultType
476476
// It needs to be paired with INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS
477477
// (see examples in StablehloOps.cpp).
478478
static LogicalResult inferReturnTypeComponentsFromOperands(
479-
MLIRContext *context, std::optional<Location> location,
479+
MLIRContext* context, std::optional<Location> location,
480480
ValueShapeRange operands, DictionaryAttr attributes,
481481
OpaqueProperties properties, RegionRange regions,
482-
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
482+
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
483483
SmallVector<Type> inferredReturnTypes;
484484
if (failed(inferReturnTypes(context, location, operands.getValues(),
485485
attributes, properties, regions,

stablehlo/transforms/StablehloBroadcastLowering.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ mlir::RankedTensorType getRankedTensorType(const Dimensions& dims,
157157
return mlir::RankedTensorType::get(shape, element_type, encoding);
158158
}
159159

160-
161160
FailureOr<Dimensions> getNumpyBroadcastShape(OpBuilder& builder,
162161
ArrayRef<Value> ops) {
163162
if (ops.empty())
@@ -240,8 +239,8 @@ FailureOr<Value> numpyBroadcastIfNeeded(OpBuilder& builder, Value input,
240239
Dimensions inputShape = std::move(*inputShapeOrFail);
241240

242241
// Construct broadcast dimensions.
243-
auto broadcastDimensions = llvm::to_vector(
244-
llvm::seq<int64_t>(outputRank - inputRank, outputRank));
242+
auto broadcastDimensions =
243+
llvm::to_vector(llvm::seq<int64_t>(outputRank - inputRank, outputRank));
245244

246245
// Construct the result type of the broadcast
247246
// - If input is static and target shape is static, use static shape.
@@ -288,7 +287,7 @@ FailureOr<Value> numpyBroadcastIfNeeded(OpBuilder& builder, Value input,
288287
auto dimSize = stablehlo::GetDimensionSizeOp::create(
289288
builder, loc, boundOp, shape[i].boundOpDim);
290289
bcastOp = stablehlo::SetDimensionSizeOp::create(builder, loc, bcastOp,
291-
dimSize, i);
290+
dimSize, i);
292291
}
293292
}
294293
return bcastOp;

0 commit comments

Comments
 (0)