Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ cc_library(
],
":linux-aarch64": [
"-fopenmp",
"-march=armv9-a+sve2+fp16",
"-march=armv9-a+sve2+fp16+bf16",
],
"//conditions:default": [],
}),
Expand Down
25 changes: 25 additions & 0 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,31 @@ void compressed_indices_remap_avx512(
float* out_weights);
#endif

// Specialization for uint8_t* input on aarch64 called by GenerateEmbeddingSpMDM
template <
typename IndexType,
typename OffsetType,
typename OutType,
bool NoBag,
bool EnablePrefetching>
FBGEMM_API bool EmbeddingSpMDM8Bit_Sve(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const uint8_t* input,
const IndexType* indices,
const OffsetType* offsets_or_lengths,
const float* weights, // optional, can be null for non-weighted sum
const bool normalize_by_lengths,
OutType* out,
const bool is_weight_positional,
const bool use_offsets,
const int64_t output_stride,
const int64_t input_stride,
const bool scale_bias_last,
const bool is_bf16_out);

} // namespace internal

template <typename IndexType>
Expand Down
2 changes: 2 additions & 0 deletions include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#if defined(__aarch64__) && __ARM_FEATURE_SVE && \
__has_include(<arm_neon_sve_bridge.h>)
#define HAVE_SVE 1
#include <arm_neon_sve_bridge.h> // @manual
#include <arm_sve.h>
#else
#define HAVE_SVE 0
#endif
Expand Down
65 changes: 65 additions & 0 deletions src/EmbeddingSpMDM.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <tuple>
#include "./CodeCache.h" // @manual
#include "./EmbeddingSpMDMAutovec.h" // @manual
#include "./EmbeddingSpMDMSve.h"
#include "./MaskAvx2.h" // @manual
#include "./RefImplementations.h" // @manual
#include "fbgemm/FbgemmEmbedding.h"
Expand Down Expand Up @@ -1128,6 +1129,70 @@ typename EmbeddingSpMDMKernelSignature<inType, indxType, offsetType, outType>::
}
#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64

#if HAVE_SVE
if constexpr (std::is_same<inType, uint8_t>::value) {
if (no_bag) {
return [=](int64_t output_size,
int64_t index_size,
int64_t data_size,
const uint8_t* input_u8,
const indxType* indices,
const offsetType* offsets_or_lengths,
const float*
weights, // optional, can be null for non-weighted sum
outType* out) {
return internal::
EmbeddingSpMDM8Bit_Sve<indxType, offsetType, outType, true, true>(
block_size,
output_size,
index_size,
data_size,
input_u8,
indices,
offsets_or_lengths,
weights,
normalize_by_lengths,
out,
is_weight_positional,
use_offsets,
output_stride,
input_stride,
scale_bias_last,
is_bf16_out);
};
} else {
return [=](int64_t output_size,
int64_t index_size,
int64_t data_size,
const uint8_t* input_u8,
const indxType* indices,
const offsetType* offsets_or_lengths,
const float* weights, // optional, can be null for
// non-weighted sum
outType* out) {
return internal::
EmbeddingSpMDM8Bit_Sve<indxType, offsetType, outType, false, true>(
block_size,
output_size,
index_size,
data_size,
input_u8,
indices,
offsets_or_lengths,
weights,
normalize_by_lengths,
out,
is_weight_positional,
use_offsets,
output_stride,
input_stride,
scale_bias_last,
is_bf16_out);
};
};
}
#endif

#ifdef FBGEMM_AUTOVEC_AVAILABLE
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
Expand Down
Loading
Loading