From c3aa526ec7410d7ad2c8b243439ecbbc8c8c76c5 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Thu, 11 Jan 2024 13:53:27 -0800 Subject: [PATCH] Add SIMD NEON Optimization for QT_FP16 in Scalar Quantizer (#3166) Summary: ### Description We have SIMD support for x86 architecture using AVX2 optimization. But, we don't have a similar optimization for ARM architecture in Scalar Quantizer for the Quantization Type `QT_FP16`. This PR adds SIMD support for ARM using NEON optimization. ### Issues resolved https://github.com/facebookresearch/faiss/issues/3014 Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3166 Reviewed By: algoriddle, pemazare Differential Revision: D52510486 Pulled By: mdouze fbshipit-source-id: 2bb360475a0d9e0816c8e05b44d7da7f2e6b28c5 --- faiss/CMakeLists.txt | 1 + faiss/impl/ScalarQuantizer.cpp | 333 ++++++++++++++++++++++++++++++++- faiss/utils/fp16-arm.h | 29 +++ faiss/utils/fp16.h | 2 + 4 files changed, 362 insertions(+), 3 deletions(-) create mode 100644 faiss/utils/fp16-arm.h diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index db5133d68b..22dac7cb2c 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -189,6 +189,7 @@ set(FAISS_HEADERS utils/extra_distances.h utils/fp16-fp16c.h utils/fp16-inl.h + utils/fp16-arm.h utils/fp16.h utils/hamming-inl.h utils/hamming.h diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index fc7b28eff4..07d77d5622 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -91,6 +91,20 @@ struct Codec8bit { return _mm256_fmadd_ps(f8, one_255, half_one_255); } #endif + +#ifdef __aarch64__ + static FAISS_ALWAYS_INLINE float32x4x2_t + decode_8_components(const uint8_t* code, int i) { + float32_t result[8] = {}; + for (size_t j = 0; j < 8; j++) { + result[j] = decode_component(code, i + j); + } + float32x4_t res1 = vld1q_f32(result); + float32x4_t res2 = vld1q_f32(result + 4); + float32x4x2_t res = vzipq_f32(res1, res2); + return vuzpq_f32(res.val[0], res.val[1]); + } +#endif }; struct Codec4bit { @@ -129,6 +143,20 @@ struct Codec4bit { return _mm256_mul_ps(f8, one_255); } #endif + +#ifdef __aarch64__ + static FAISS_ALWAYS_INLINE float32x4x2_t + decode_8_components(const uint8_t* code, int i) { + float32_t result[8] = {}; + for (size_t j = 0; j < 8; j++) { + result[j] = decode_component(code, i + j); + } + float32x4_t res1 = vld1q_f32(result); + float32x4_t res2 = vld1q_f32(result + 4); + float32x4x2_t res = vzipq_f32(res1, res2); + return vuzpq_f32(res.val[0], res.val[1]); + } +#endif }; struct Codec6bit { @@ -228,6 +256,20 @@ struct Codec6bit { } #endif + +#ifdef __aarch64__ + static FAISS_ALWAYS_INLINE float32x4x2_t + decode_8_components(const uint8_t* code, int i) { + float32_t result[8] = {}; + for (size_t j = 0; j < 8; j++) { + result[j] = decode_component(code, i + j); + } + float32x4_t res1 = vld1q_f32(result); + float32x4_t res2 = vld1q_f32(result + 4); + float32x4x2_t res = vzipq_f32(res1, res2); + return vuzpq_f32(res.val[0], res.val[1]); + } +#endif }; /******************************************************************* @@ -293,6 +335,31 @@ struct QuantizerTemplate : QuantizerTemplate { #endif +#ifdef __aarch64__ + +template +struct QuantizerTemplate : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate(d, trained) {} + + FAISS_ALWAYS_INLINE float32x4x2_t + reconstruct_8_components(const uint8_t* code, int i) const { + float32x4x2_t xi = Codec::decode_8_components(code, i); + float32x4x2_t res = vzipq_f32( + vfmaq_f32( + vdupq_n_f32(this->vmin), + xi.val[0], + vdupq_n_f32(this->vdiff)), + vfmaq_f32( + vdupq_n_f32(this->vmin), + xi.val[1], + vdupq_n_f32(this->vdiff))); + return vuzpq_f32(res.val[0], res.val[1]); + } +}; + +#endif + template struct QuantizerTemplate : ScalarQuantizer::SQuantizer { const size_t d; @@ -350,6 +417,29 @@ struct QuantizerTemplate : QuantizerTemplate { #endif +#ifdef __aarch64__ + +template +struct QuantizerTemplate : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate(d, trained) {} + + FAISS_ALWAYS_INLINE float32x4x2_t + reconstruct_8_components(const uint8_t* code, int i) const { + float32x4x2_t xi = Codec::decode_8_components(code, i); + + float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i); + float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i); + + float32x4x2_t res = vzipq_f32( + vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]), + vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])); + return vuzpq_f32(res.val[0], res.val[1]); + } +}; + +#endif + /******************************************************************* * FP16 quantizer *******************************************************************/ @@ -397,6 +487,23 @@ struct QuantizerFP16<8> : QuantizerFP16<1> { #endif +#ifdef __aarch64__ + +template <> +struct QuantizerFP16<8> : QuantizerFP16<1> { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16<1>(d, trained) {} + + FAISS_ALWAYS_INLINE float32x4x2_t + reconstruct_8_components(const uint8_t* code, int i) const { + uint16x4x2_t codei = vld2_u16((const uint16_t*)(code + 2 * i)); + return vzipq_f32( + vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])), + vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))); + } +}; +#endif + /******************************************************************* * 8bit_direct quantizer *******************************************************************/ @@ -446,6 +553,28 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { #endif +#ifdef __aarch64__ + +template <> +struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect<1>(d, trained) {} + + FAISS_ALWAYS_INLINE float32x4x2_t + reconstruct_8_components(const uint8_t* code, int i) const { + float32_t result[8] = {}; + for (size_t j = 0; j < 8; j++) { + result[j] = code[i + j]; + } + float32x4_t res1 = vld1q_f32(result); + float32x4_t res2 = vld1q_f32(result + 4); + float32x4x2_t res = vzipq_f32(res1, res2); + return vuzpq_f32(res.val[0], res.val[1]); + } +}; + +#endif + template ScalarQuantizer::SQuantizer* select_quantizer_1( QuantizerType qtype, @@ -728,6 +857,59 @@ struct SimilarityL2<8> { #endif +#ifdef __aarch64__ +template <> +struct SimilarityL2<8> { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + explicit SimilarityL2(const float* y) : y(y) {} + float32x4x2_t accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) { + float32x4x2_t yiv = vld1q_f32_x2(yi); + yi += 8; + + float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]); + float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]); + + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); + + float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1); + accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]); + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + float32x4x2_t x, + float32x4x2_t y) { + float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]); + float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]); + + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); + + float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1); + accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]); + } + + FAISS_ALWAYS_INLINE float result_8() { + float32x4_t sum_0 = vpaddq_f32(accu8.val[0], accu8.val[0]); + float32x4_t sum_1 = vpaddq_f32(accu8.val[1], accu8.val[1]); + + float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0); + float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1); + return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0); + } +}; +#endif + template struct SimilarityIP {}; @@ -801,6 +983,56 @@ struct SimilarityIP<8> { }; #endif +#ifdef __aarch64__ + +template <> +struct SimilarityIP<8> { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + explicit SimilarityIP(const float* y) : y(y) {} + float32x4x2_t accu8; + + FAISS_ALWAYS_INLINE void begin_8() { + accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) { + float32x4x2_t yiv = vld1q_f32_x2(yi); + yi += 8; + + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]); + float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1); + accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]); + } + + FAISS_ALWAYS_INLINE void add_8_components_2( + float32x4x2_t x1, + float32x4x2_t x2) { + float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]); + float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]); + float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1); + accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]); + } + + FAISS_ALWAYS_INLINE float result_8() { + float32x4x2_t sum_tmp = vzipq_f32( + vpaddq_f32(accu8.val[0], accu8.val[0]), + vpaddq_f32(accu8.val[1], accu8.val[1])); + float32x4x2_t sum = vuzpq_f32(sum_tmp.val[0], sum_tmp.val[1]); + float32x4x2_t sum2_tmp = vzipq_f32( + vpaddq_f32(sum.val[0], sum.val[0]), + vpaddq_f32(sum.val[1], sum.val[1])); + float32x4x2_t sum2 = vuzpq_f32(sum2_tmp.val[0], sum2_tmp.val[1]); + return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0); + } +}; +#endif + /******************************************************************* * DistanceComputer: combines a similarity and a quantizer to do * code-to-vector or code-to-code comparisons @@ -903,6 +1135,53 @@ struct DCTemplate : SQDistanceComputer { #endif +#ifdef __aarch64__ + +template +struct DCTemplate : SQDistanceComputer { + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector& trained) + : quant(d, trained) {} + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + float32x4x2_t xi = quant.reconstruct_8_components(code, i); + sim.add_8_components(xi); + } + return sim.result_8(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + float32x4x2_t x1 = quant.reconstruct_8_components(code1, i); + float32x4x2_t x2 = quant.reconstruct_8_components(code2, i); + sim.add_8_components_2(x1, x2); + } + return sim.result_8(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; +#endif + /******************************************************************* * DistanceComputerByte: computes distances in the integer domain *******************************************************************/ @@ -1019,6 +1298,54 @@ struct DistanceComputerByte : SQDistanceComputer { #endif +#ifdef __aarch64__ + +template +struct DistanceComputerByte : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + int accu = 0; + for (int i = 0; i < d; i++) { + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + accu += int(code1[i]) * code2[i]; + } else { + int diff = int(code1[i]) - code2[i]; + accu += diff * diff; + } + } + return accu; + } + + void set_query(const float* x) final { + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +#endif + /******************************************************************* * select_distance_computer: runtime selection of template * specialization @@ -1155,7 +1482,7 @@ void ScalarQuantizer::train(size_t n, const float* x) { } ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const { -#ifdef USE_F16C +#if defined(USE_F16C) || defined(__aarch64__) if (d % 8 == 0) { return select_quantizer_1<8>(qtype, d, trained); } else @@ -1186,7 +1513,7 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const { SQDistanceComputer* ScalarQuantizer::get_distance_computer( MetricType metric) const { FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT); -#ifdef USE_F16C +#if defined(USE_F16C) || defined(__aarch64__) if (d % 8 == 0) { if (metric == METRIC_L2) { return select_distance_computer>(qtype, d, trained); @@ -1522,7 +1849,7 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner( bool store_pairs, const IDSelector* sel, bool by_residual) const { -#ifdef USE_F16C +#if defined(USE_F16C) || defined(__aarch64__) if (d % 8 == 0) { return sel0_InvertedListScanner<8>( mt, this, quantizer, store_pairs, sel, by_residual); diff --git a/faiss/utils/fp16-arm.h b/faiss/utils/fp16-arm.h new file mode 100644 index 0000000000..79c885b058 --- /dev/null +++ b/faiss/utils/fp16-arm.h @@ -0,0 +1,29 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace faiss { + +inline uint16_t encode_fp16(float x) { + float32x4_t fx4 = vdupq_n_f32(x); + float16x4_t f16x4 = vcvt_f16_f32(fx4); + uint16x4_t ui16x4 = vreinterpret_u16_f16(f16x4); + return vduph_lane_u16(ui16x4, 3); +} + +inline float decode_fp16(uint16_t x) { + uint16x4_t ui16x4 = vdup_n_u16(x); + float16x4_t f16x4 = vreinterpret_f16_u16(ui16x4); + float32x4_t fx4 = vcvt_f32_f16(f16x4); + return vdups_laneq_f32(fx4, 3); +} + +} // namespace faiss diff --git a/faiss/utils/fp16.h b/faiss/utils/fp16.h index 90691d8ffe..43e05dc3d3 100644 --- a/faiss/utils/fp16.h +++ b/faiss/utils/fp16.h @@ -13,6 +13,8 @@ #if defined(__F16C__) #include +#elif defined(__aarch64__) +#include #else #include #endif