From 56fa6e4839f0df6cf01763e3e008c6d4970d8b0f Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 9 Feb 2026 06:53:40 -0800 Subject: [PATCH] Internal change plus add U8 type, check MatPtrT type at compile time PiperOrigin-RevId: 867582875 --- compression/types.h | 14 +++++++++----- util/mat.h | 2 ++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/compression/types.h b/compression/types.h index 6e6129d5..dc22f4ca 100644 --- a/compression/types.h +++ b/compression/types.h @@ -229,12 +229,14 @@ enum class Type { kU32, kU64, kI8, - kU16 + kU16, + kU8, }; // These are used in `ModelConfig.Specifier`, hence the strings will not // change, though new ones may be added. -static constexpr const char* kTypeStrings[] = { - "unknown", "f32", "bf16", "sfp", "nuq", "f64", "u32", "u64", "i8", "u16"}; +static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", + "nuq", "f64", "u32", "u64", + "i8", "u16", "u8"}; static constexpr size_t kNumTypes = sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); static constexpr size_t kTypeBits[] = { @@ -248,6 +250,7 @@ static constexpr size_t kTypeBits[] = { 8 * sizeof(uint64_t), 8 * sizeof(I8Stream), 8 * sizeof(uint16_t), + 8 * sizeof(uint8_t), }; static inline bool EnumValid(Type type) { @@ -256,7 +259,7 @@ static inline bool EnumValid(Type type) { // Returns a Type enum for the type of the template parameter. template -Type TypeEnum() { +constexpr Type TypeEnum() { using Packed = hwy::RemoveCvRef; if constexpr (hwy::IsSame()) { return Type::kF32; @@ -276,8 +279,9 @@ Type TypeEnum() { return Type::kI8; } else if constexpr (hwy::IsSame()) { return Type::kU16; + } else if constexpr (hwy::IsSame()) { + return Type::kU8; } else { - HWY_DASSERT(false); return Type::kUnknown; } } diff --git a/util/mat.h b/util/mat.h index 08300461..25f2cb2c 100644 --- a/util/mat.h +++ b/util/mat.h @@ -291,6 +291,8 @@ template class MatPtrT : public MatPtr { public: using T = MatT; + static_assert(TypeEnum() != Type::kUnknown, + "Must only use with supported MatT."); // Default constructor for use with uninitialized views. MatPtrT() = default;