Skip to content

Commit 6371233

Browse files
Nicoshevfacebook-github-bot
authored andcommitted
Add aarch64-specific EmbeddingSpMDM8Bit
Summary: Add SVE variant of EmbeddingSpMDM8Bit Differential Revision: D88195927
1 parent 4e03118 commit 6371233

File tree

6 files changed

+741
-6
lines changed

6 files changed

+741
-6
lines changed

BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ cc_library(
152152
],
153153
":linux-aarch64": [
154154
"-fopenmp",
155-
"-march=armv9-a+sve2+fp16",
155+
"-march=armv9-a+sve2+fp16+bf16",
156156
],
157157
"//conditions:default": [],
158158
}),

include/fbgemm/FbgemmEmbedding.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,31 @@ void compressed_indices_remap_avx512(
362362
float* out_weights);
363363
#endif
364364

365+
// Specialization for uint8_t* input on aarch64 called by GenerateEmbeddingSpMDM
366+
template <
367+
typename IndexType,
368+
typename OffsetType,
369+
typename OutType,
370+
bool NoBag,
371+
bool EnablePrefetching>
372+
FBGEMM_API bool EmbeddingSpMDM8Bit_Sve(
373+
const int64_t block_size,
374+
const int64_t output_size,
375+
const int64_t index_size,
376+
const int64_t data_size,
377+
const uint8_t* input,
378+
const IndexType* indices,
379+
const OffsetType* offsets_or_lengths,
380+
const float* weights, // optional, can be null for non-weighted sum
381+
const bool normalize_by_lengths,
382+
OutType* out,
383+
const bool is_weight_positional,
384+
const bool use_offsets,
385+
const int64_t output_stride,
386+
const int64_t input_stride,
387+
const bool scale_bias_last,
388+
const bool is_bf16_out);
389+
365390
} // namespace internal
366391

367392
template <typename IndexType>

include/fbgemm/Utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#if defined(__aarch64__) && __ARM_FEATURE_SVE && \
2525
__has_include(<arm_neon_sve_bridge.h>)
2626
#define HAVE_SVE 1
27+
#include <arm_neon_sve_bridge.h> // @manual
28+
#include <arm_sve.h>
2729
#else
2830
#define HAVE_SVE 0
2931
#endif

src/EmbeddingSpMDM.cc

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <tuple>
1919
#include "./CodeCache.h" // @manual
2020
#include "./EmbeddingSpMDMAutovec.h" // @manual
21+
#include "./EmbeddingSpMDMSve.h"
2122
#include "./MaskAvx2.h" // @manual
2223
#include "./RefImplementations.h" // @manual
2324
#include "fbgemm/FbgemmEmbedding.h"
@@ -1128,6 +1129,70 @@ typename EmbeddingSpMDMKernelSignature<inType, indxType, offsetType, outType>::
11281129
}
11291130
#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
11301131

1132+
#if HAVE_SVE
1133+
if constexpr (std::is_same<inType, uint8_t>::value) {
1134+
if (no_bag) {
1135+
return [=](int64_t output_size,
1136+
int64_t index_size,
1137+
int64_t data_size,
1138+
const uint8_t* input_u8,
1139+
const indxType* indices,
1140+
const offsetType* offsets_or_lengths,
1141+
const float*
1142+
weights, // optional, can be null for non-weighted sum
1143+
outType* out) {
1144+
return internal::
1145+
EmbeddingSpMDM8Bit_Sve<indxType, offsetType, outType, true, true>(
1146+
block_size,
1147+
output_size,
1148+
index_size,
1149+
data_size,
1150+
input_u8,
1151+
indices,
1152+
offsets_or_lengths,
1153+
weights,
1154+
normalize_by_lengths,
1155+
out,
1156+
is_weight_positional,
1157+
use_offsets,
1158+
output_stride,
1159+
input_stride,
1160+
scale_bias_last,
1161+
is_bf16_out);
1162+
};
1163+
} else {
1164+
return [=](int64_t output_size,
1165+
int64_t index_size,
1166+
int64_t data_size,
1167+
const uint8_t* input_u8,
1168+
const indxType* indices,
1169+
const offsetType* offsets_or_lengths,
1170+
const float* weights, // optional, can be null for
1171+
// non-weighted sum
1172+
outType* out) {
1173+
return internal::
1174+
EmbeddingSpMDM8Bit_Sve<indxType, offsetType, outType, false, true>(
1175+
block_size,
1176+
output_size,
1177+
index_size,
1178+
data_size,
1179+
input_u8,
1180+
indices,
1181+
offsets_or_lengths,
1182+
weights,
1183+
normalize_by_lengths,
1184+
out,
1185+
is_weight_positional,
1186+
use_offsets,
1187+
output_stride,
1188+
input_stride,
1189+
scale_bias_last,
1190+
is_bf16_out);
1191+
};
1192+
};
1193+
}
1194+
#endif
1195+
11311196
#ifdef FBGEMM_AUTOVEC_AVAILABLE
11321197
if (!cpuinfo_initialize()) {
11331198
throw std::runtime_error("Failed to initialize cpuinfo!");

0 commit comments

Comments
 (0)