mirror of
https://github.com/ZDoom/zdoom-macos-deps.git
synced 2024-11-28 22:52:17 +00:00
5050 lines
185 KiB
C++
5050 lines
185 KiB
C++
// Copyright 2021 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Arm SVE[2] vectors (length not known at compile time).
|
|
// External include guard in highway.h - see comment there.
|
|
|
|
#include <arm_sve.h>
|
|
|
|
#include "hwy/ops/shared-inl.h"
|
|
|
|
// Arm C215 declares that SVE vector lengths will always be a power of two.
|
|
// We default to relying on this, which makes some operations more efficient.
|
|
// You can still opt into fixups by setting this to 0 (unsupported).
|
|
#ifndef HWY_SVE_IS_POW2
|
|
#define HWY_SVE_IS_POW2 1
|
|
#endif
|
|
|
|
#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128
|
|
#define HWY_SVE_HAVE_2 1
|
|
#else
|
|
#define HWY_SVE_HAVE_2 0
|
|
#endif
|
|
|
|
HWY_BEFORE_NAMESPACE();
|
|
namespace hwy {
|
|
namespace HWY_NAMESPACE {
|
|
|
|
template <class V>
|
|
struct DFromV_t {}; // specialized in macros
|
|
template <class V>
|
|
using DFromV = typename DFromV_t<RemoveConst<V>>::type;
|
|
|
|
template <class V>
|
|
using TFromV = TFromD<DFromV<V>>;
|
|
|
|
// ================================================== MACROS
|
|
|
|
// Generate specializations and function definitions using X macros. Although
|
|
// harder to read and debug, writing everything manually is too bulky.
|
|
|
|
namespace detail { // for code folding
|
|
|
|
// Args: BASE, CHAR, BITS, HALF, NAME, OP
|
|
|
|
// Unsigned:
|
|
#define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, 8, NAME, OP)
|
|
#define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, 8, NAME, OP)
|
|
#define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \
|
|
X_MACRO(uint, u, 32, 16, NAME, OP)
|
|
#define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \
|
|
X_MACRO(uint, u, 64, 32, NAME, OP)
|
|
|
|
// Signed:
|
|
#define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(int, s, 8, 8, NAME, OP)
|
|
#define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(int, s, 16, 8, NAME, OP)
|
|
#define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(int, s, 32, 16, NAME, OP)
|
|
#define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(int, s, 64, 32, NAME, OP)
|
|
|
|
// Float:
|
|
#define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \
|
|
X_MACRO(float, f, 16, 16, NAME, OP)
|
|
#define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \
|
|
X_MACRO(float, f, 32, 16, NAME, OP)
|
|
#define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) \
|
|
X_MACRO(float, f, 64, 32, NAME, OP)
|
|
|
|
#if HWY_SVE_HAVE_BFLOAT16
|
|
#define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP) \
|
|
X_MACRO(bfloat, bf, 16, 16, NAME, OP)
|
|
#else
|
|
#define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP)
|
|
#endif
|
|
|
|
// For all element sizes:
|
|
#define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP)
|
|
|
|
#define HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP)
|
|
|
|
// HWY_SVE_FOREACH_F does not include HWY_SVE_FOREACH_BF16 because SVE lacks
|
|
// bf16 overloads for some intrinsics (especially less-common arithmetic).
|
|
#define HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP)
|
|
|
|
// Commonly used type categories for a given element size:
|
|
#define HWY_SVE_FOREACH_UI08(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP)
|
|
|
|
#define HWY_SVE_FOREACH_UI16(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP)
|
|
|
|
#define HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP)
|
|
|
|
#define HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP)
|
|
|
|
#define HWY_SVE_FOREACH_UIF3264(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP)
|
|
|
|
// Commonly used type categories:
|
|
#define HWY_SVE_FOREACH_UI(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_I(X_MACRO, NAME, OP)
|
|
|
|
#define HWY_SVE_FOREACH_IF(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_F(X_MACRO, NAME, OP)
|
|
|
|
#define HWY_SVE_FOREACH(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \
|
|
HWY_SVE_FOREACH_F(X_MACRO, NAME, OP)
|
|
|
|
// Assemble types for use in x-macros
|
|
#define HWY_SVE_T(BASE, BITS) BASE##BITS##_t
|
|
#define HWY_SVE_D(BASE, BITS, N, POW2) Simd<HWY_SVE_T(BASE, BITS), N, POW2>
|
|
#define HWY_SVE_V(BASE, BITS) sv##BASE##BITS##_t
|
|
#define HWY_SVE_TUPLE(BASE, BITS, MUL) sv##BASE##BITS##x##MUL##_t
|
|
|
|
} // namespace detail
|
|
|
|
#define HWY_SPECIALIZE(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <> \
|
|
struct DFromV_t<HWY_SVE_V(BASE, BITS)> { \
|
|
using type = ScalableTag<HWY_SVE_T(BASE, BITS)>; \
|
|
};
|
|
|
|
HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _)
|
|
HWY_SVE_FOREACH_BF16(HWY_SPECIALIZE, _, _)
|
|
#undef HWY_SPECIALIZE
|
|
|
|
// Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX
|
|
// instructions, and we anyway only use it when the predicate is ptrue.
|
|
|
|
// vector = f(vector), e.g. Not
|
|
#define HWY_SVE_RETV_ARGPV(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
|
|
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \
|
|
}
|
|
#define HWY_SVE_RETV_ARGV(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
|
|
return sv##OP##_##CHAR##BITS(v); \
|
|
}
|
|
|
|
// vector = f(vector, scalar), e.g. detail::AddN
|
|
#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
|
|
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \
|
|
}
|
|
#define HWY_SVE_RETV_ARGVN(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
|
|
return sv##OP##_##CHAR##BITS(a, b); \
|
|
}
|
|
|
|
// vector = f(vector, vector), e.g. Add
|
|
#define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
|
|
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \
|
|
}
|
|
#define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
|
|
return sv##OP##_##CHAR##BITS(a, b); \
|
|
}
|
|
|
|
#define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \
|
|
HWY_SVE_V(BASE, BITS) c) { \
|
|
return sv##OP##_##CHAR##BITS(a, b, c); \
|
|
}
|
|
|
|
// ------------------------------ Lanes
|
|
|
|
namespace detail {
|
|
|
|
// Returns actual lanes of a hardware vector without rounding to a power of two.
|
|
template <typename T, HWY_IF_T_SIZE(T, 1)>
|
|
HWY_INLINE size_t AllHardwareLanes() {
|
|
return svcntb_pat(SV_ALL);
|
|
}
|
|
template <typename T, HWY_IF_T_SIZE(T, 2)>
|
|
HWY_INLINE size_t AllHardwareLanes() {
|
|
return svcnth_pat(SV_ALL);
|
|
}
|
|
template <typename T, HWY_IF_T_SIZE(T, 4)>
|
|
HWY_INLINE size_t AllHardwareLanes() {
|
|
return svcntw_pat(SV_ALL);
|
|
}
|
|
template <typename T, HWY_IF_T_SIZE(T, 8)>
|
|
HWY_INLINE size_t AllHardwareLanes() {
|
|
return svcntd_pat(SV_ALL);
|
|
}
|
|
|
|
// All-true mask from a macro
|
|
|
|
#if HWY_SVE_IS_POW2
|
|
#define HWY_SVE_ALL_PTRUE(BITS) svptrue_b##BITS()
|
|
#define HWY_SVE_PTRUE(BITS) svptrue_b##BITS()
|
|
#else
|
|
#define HWY_SVE_ALL_PTRUE(BITS) svptrue_pat_b##BITS(SV_ALL)
|
|
#define HWY_SVE_PTRUE(BITS) svptrue_pat_b##BITS(SV_POW2)
|
|
#endif // HWY_SVE_IS_POW2
|
|
|
|
} // namespace detail
|
|
|
|
#if HWY_HAVE_SCALABLE
|
|
|
|
// Returns actual number of lanes after capping by N and shifting. May return 0
|
|
// (e.g. for "1/8th" of a u32x4 - would be 1 for 1/8th of u32x8).
|
|
template <typename T, size_t N, int kPow2>
|
|
HWY_API size_t Lanes(Simd<T, N, kPow2> d) {
|
|
const size_t actual = detail::AllHardwareLanes<T>();
|
|
constexpr size_t kMaxLanes = MaxLanes(d);
|
|
constexpr int kClampedPow2 = HWY_MIN(kPow2, 0);
|
|
// Common case of full vectors: avoid any extra instructions.
|
|
if (detail::IsFull(d)) return actual;
|
|
return HWY_MIN(detail::ScaleByPower(actual, kClampedPow2), kMaxLanes);
|
|
}
|
|
|
|
#endif // HWY_HAVE_SCALABLE
|
|
|
|
// ================================================== MASK INIT
|
|
|
|
// One mask bit per byte; only the one belonging to the lowest byte is valid.
|
|
|
|
// ------------------------------ FirstN
|
|
#define HWY_SVE_FIRSTN(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, size_t count) { \
|
|
const size_t limit = detail::IsFull(d) ? count : HWY_MIN(Lanes(d), count); \
|
|
return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast<uint32_t>(limit)); \
|
|
}
|
|
HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt)
|
|
HWY_SVE_FOREACH_BF16(HWY_SVE_FIRSTN, FirstN, whilelt)
|
|
|
|
#undef HWY_SVE_FIRSTN
|
|
|
|
template <class D>
|
|
using MFromD = decltype(FirstN(D(), 0));
|
|
|
|
#if !HWY_HAVE_FLOAT16
|
|
template <class D, HWY_IF_F16_D(D)>
|
|
MFromD<RebindToUnsigned<D>> FirstN(D /* tag */, size_t count) {
|
|
return FirstN(RebindToUnsigned<D>(), count);
|
|
}
|
|
#endif // !HWY_HAVE_FLOAT16
|
|
|
|
#if !HWY_SVE_HAVE_BFLOAT16
|
|
template <class D, HWY_IF_BF16_D(D)>
|
|
MFromD<RebindToUnsigned<D>> FirstN(D /* tag */, size_t count) {
|
|
return FirstN(RebindToUnsigned<D>(), count);
|
|
}
|
|
#endif // !HWY_SVE_HAVE_BFLOAT16
|
|
|
|
namespace detail {
|
|
|
|
#define HWY_SVE_WRAP_PTRUE(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \
|
|
return HWY_SVE_PTRUE(BITS); \
|
|
} \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API svbool_t All##NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \
|
|
return HWY_SVE_ALL_PTRUE(BITS); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) // return all-true
|
|
HWY_SVE_FOREACH_BF16(HWY_SVE_WRAP_PTRUE, PTrue, ptrue)
|
|
#undef HWY_SVE_WRAP_PTRUE
|
|
|
|
HWY_API svbool_t PFalse() { return svpfalse_b(); }
|
|
|
|
// Returns all-true if d is HWY_FULL or FirstN(N) after capping N.
|
|
//
|
|
// This is used in functions that load/store memory; other functions (e.g.
|
|
// arithmetic) can ignore d and use PTrue instead.
|
|
template <class D>
|
|
svbool_t MakeMask(D d) {
|
|
return IsFull(d) ? PTrue(d) : FirstN(d, Lanes(d));
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
// ================================================== INIT
|
|
|
|
// ------------------------------ Set
|
|
// vector = f(d, scalar), e.g. Set
|
|
#define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
|
|
HWY_SVE_T(BASE, BITS) arg) { \
|
|
return sv##OP##_##CHAR##BITS(arg); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n)
|
|
HWY_SVE_FOREACH_BF16(HWY_SVE_SET, Set, dup_n)
|
|
#if !HWY_SVE_HAVE_BFLOAT16
|
|
// Required for Zero and VFromD
|
|
template <size_t N, int kPow2>
|
|
svuint16_t Set(Simd<bfloat16_t, N, kPow2> d, bfloat16_t arg) {
|
|
return Set(RebindToUnsigned<decltype(d)>(), arg.bits);
|
|
}
|
|
#endif // HWY_SVE_HAVE_BFLOAT16
|
|
#undef HWY_SVE_SET
|
|
|
|
template <class D>
|
|
using VFromD = decltype(Set(D(), TFromD<D>()));
|
|
|
|
using VBF16 = VFromD<ScalableTag<bfloat16_t>>;
|
|
|
|
// ------------------------------ Zero
|
|
|
|
template <class D>
|
|
VFromD<D> Zero(D d) {
|
|
// Cast to support bfloat16_t.
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
return BitCast(d, Set(du, 0));
|
|
}
|
|
|
|
// ------------------------------ Undefined
|
|
|
|
#define HWY_SVE_UNDEFINED(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \
|
|
return sv##OP##_##CHAR##BITS(); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_UNDEFINED, Undefined, undef)
|
|
|
|
// ------------------------------ BitCast
|
|
|
|
namespace detail {
|
|
|
|
// u8: no change
|
|
#define HWY_SVE_CAST_NOP(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \
|
|
return v; \
|
|
} \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte( \
|
|
HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \
|
|
return v; \
|
|
}
|
|
|
|
// All other types
|
|
#define HWY_SVE_CAST(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \
|
|
return sv##OP##_u8_##CHAR##BITS(v); \
|
|
} \
|
|
template <size_t N, int kPow2> \
|
|
HWY_INLINE HWY_SVE_V(BASE, BITS) \
|
|
BitCastFromByte(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svuint8_t v) { \
|
|
return sv##OP##_##CHAR##BITS##_u8(v); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _)
|
|
HWY_SVE_FOREACH_I08(HWY_SVE_CAST, _, reinterpret)
|
|
HWY_SVE_FOREACH_UI16(HWY_SVE_CAST, _, reinterpret)
|
|
HWY_SVE_FOREACH_UI32(HWY_SVE_CAST, _, reinterpret)
|
|
HWY_SVE_FOREACH_UI64(HWY_SVE_CAST, _, reinterpret)
|
|
HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret)
|
|
HWY_SVE_FOREACH_BF16(HWY_SVE_CAST, _, reinterpret)
|
|
|
|
#undef HWY_SVE_CAST_NOP
|
|
#undef HWY_SVE_CAST
|
|
|
|
#if !HWY_SVE_HAVE_BFLOAT16
|
|
template <size_t N, int kPow2>
|
|
HWY_INLINE VBF16 BitCastFromByte(Simd<bfloat16_t, N, kPow2> /* d */,
|
|
svuint8_t v) {
|
|
return BitCastFromByte(Simd<uint16_t, N, kPow2>(), v);
|
|
}
|
|
#endif // !HWY_SVE_HAVE_BFLOAT16
|
|
|
|
} // namespace detail
|
|
|
|
template <class D, class FromV>
|
|
HWY_API VFromD<D> BitCast(D d, FromV v) {
|
|
return detail::BitCastFromByte(d, detail::BitCastToByte(v));
|
|
}
|
|
|
|
// ------------------------------ Tuple
|
|
|
|
// tuples = f(d, v..), e.g. Create2
|
|
#define HWY_SVE_CREATE(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_TUPLE(BASE, BITS, 2) \
|
|
NAME##2(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
|
|
HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1) { \
|
|
return sv##OP##2_##CHAR##BITS(v0, v1); \
|
|
} \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_TUPLE(BASE, BITS, 3) NAME##3( \
|
|
HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v0, \
|
|
HWY_SVE_V(BASE, BITS) v1, HWY_SVE_V(BASE, BITS) v2) { \
|
|
return sv##OP##3_##CHAR##BITS(v0, v1, v2); \
|
|
} \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_TUPLE(BASE, BITS, 4) \
|
|
NAME##4(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
|
|
HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \
|
|
HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3) { \
|
|
return sv##OP##4_##CHAR##BITS(v0, v1, v2, v3); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_CREATE, Create, create)
|
|
HWY_SVE_FOREACH_BF16(HWY_SVE_CREATE, Create, create)
|
|
#undef HWY_SVE_CREATE
|
|
|
|
template <class D>
|
|
using Vec2 = decltype(Create2(D(), Zero(D()), Zero(D())));
|
|
template <class D>
|
|
using Vec3 = decltype(Create3(D(), Zero(D()), Zero(D()), Zero(D())));
|
|
template <class D>
|
|
using Vec4 = decltype(Create4(D(), Zero(D()), Zero(D()), Zero(D()), Zero(D())));
|
|
|
|
#define HWY_SVE_GET(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t kIndex> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME##2(HWY_SVE_TUPLE(BASE, BITS, 2) tuple) { \
|
|
return sv##OP##2_##CHAR##BITS(tuple, kIndex); \
|
|
} \
|
|
template <size_t kIndex> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME##3(HWY_SVE_TUPLE(BASE, BITS, 3) tuple) { \
|
|
return sv##OP##3_##CHAR##BITS(tuple, kIndex); \
|
|
} \
|
|
template <size_t kIndex> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME##4(HWY_SVE_TUPLE(BASE, BITS, 4) tuple) { \
|
|
return sv##OP##4_##CHAR##BITS(tuple, kIndex); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_GET, Get, get)
|
|
HWY_SVE_FOREACH_BF16(HWY_SVE_GET, Get, get)
|
|
#undef HWY_SVE_GET
|
|
|
|
#define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t kIndex> \
|
|
HWY_API HWY_SVE_TUPLE(BASE, BITS, 2) \
|
|
NAME##2(HWY_SVE_TUPLE(BASE, BITS, 2) tuple, HWY_SVE_V(BASE, BITS) vec) { \
|
|
return sv##OP##2_##CHAR##BITS(tuple, kIndex, vec); \
|
|
} \
|
|
template <size_t kIndex> \
|
|
HWY_API HWY_SVE_TUPLE(BASE, BITS, 3) \
|
|
NAME##3(HWY_SVE_TUPLE(BASE, BITS, 3) tuple, HWY_SVE_V(BASE, BITS) vec) { \
|
|
return sv##OP##3_##CHAR##BITS(tuple, kIndex, vec); \
|
|
} \
|
|
template <size_t kIndex> \
|
|
HWY_API HWY_SVE_TUPLE(BASE, BITS, 4) \
|
|
NAME##4(HWY_SVE_TUPLE(BASE, BITS, 4) tuple, HWY_SVE_V(BASE, BITS) vec) { \
|
|
return sv##OP##4_##CHAR##BITS(tuple, kIndex, vec); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_SET, Set, set)
|
|
HWY_SVE_FOREACH_BF16(HWY_SVE_SET, Set, set)
|
|
#undef HWY_SVE_SET
|
|
|
|
// ------------------------------ ResizeBitCast
|
|
|
|
// Same as BitCast on SVE
|
|
template <class D, class FromV>
|
|
HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
|
|
return BitCast(d, v);
|
|
}
|
|
|
|
// ================================================== LOGICAL
|
|
|
|
// detail::*N() functions accept a scalar argument to avoid extra Set().
|
|
|
|
// ------------------------------ Not
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV, Not, not ) // NOLINT
|
|
|
|
// ------------------------------ And
|
|
|
|
namespace detail {
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, AndN, and_n)
|
|
} // namespace detail
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, And, and)
|
|
|
|
template <class V, HWY_IF_FLOAT_V(V)>
|
|
HWY_API V And(const V a, const V b) {
|
|
const DFromV<V> df;
|
|
const RebindToUnsigned<decltype(df)> du;
|
|
return BitCast(df, And(BitCast(du, a), BitCast(du, b)));
|
|
}
|
|
|
|
// ------------------------------ Or
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Or, orr)
|
|
|
|
template <class V, HWY_IF_FLOAT_V(V)>
|
|
HWY_API V Or(const V a, const V b) {
|
|
const DFromV<V> df;
|
|
const RebindToUnsigned<decltype(df)> du;
|
|
return BitCast(df, Or(BitCast(du, a), BitCast(du, b)));
|
|
}
|
|
|
|
// ------------------------------ Xor
|
|
|
|
namespace detail {
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, XorN, eor_n)
|
|
} // namespace detail
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Xor, eor)
|
|
|
|
template <class V, HWY_IF_FLOAT_V(V)>
|
|
HWY_API V Xor(const V a, const V b) {
|
|
const DFromV<V> df;
|
|
const RebindToUnsigned<decltype(df)> du;
|
|
return BitCast(df, Xor(BitCast(du, a), BitCast(du, b)));
|
|
}
|
|
|
|
// ------------------------------ AndNot
|
|
|
|
namespace detail {
|
|
#define HWY_SVE_RETV_ARGPVN_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
|
|
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN_SWAP, AndNotN, bic_n)
|
|
#undef HWY_SVE_RETV_ARGPVN_SWAP
|
|
} // namespace detail
|
|
|
|
#define HWY_SVE_RETV_ARGPVV_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
|
|
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \
|
|
}
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP, AndNot, bic)
|
|
#undef HWY_SVE_RETV_ARGPVV_SWAP
|
|
|
|
template <class V, HWY_IF_FLOAT_V(V)>
|
|
HWY_API V AndNot(const V a, const V b) {
|
|
const DFromV<V> df;
|
|
const RebindToUnsigned<decltype(df)> du;
|
|
return BitCast(df, AndNot(BitCast(du, a), BitCast(du, b)));
|
|
}
|
|
|
|
// ------------------------------ Xor3
|
|
|
|
#if HWY_SVE_HAVE_2
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVVV, Xor3, eor3)
|
|
|
|
template <class V, HWY_IF_FLOAT_V(V)>
|
|
HWY_API V Xor3(const V x1, const V x2, const V x3) {
|
|
const DFromV<V> df;
|
|
const RebindToUnsigned<decltype(df)> du;
|
|
return BitCast(df, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3)));
|
|
}
|
|
|
|
#else
|
|
template <class V>
|
|
HWY_API V Xor3(V x1, V x2, V x3) {
|
|
return Xor(x1, Xor(x2, x3));
|
|
}
|
|
#endif
|
|
|
|
// ------------------------------ Or3
|
|
template <class V>
|
|
HWY_API V Or3(V o1, V o2, V o3) {
|
|
return Or(o1, Or(o2, o3));
|
|
}
|
|
|
|
// ------------------------------ OrAnd
|
|
template <class V>
|
|
HWY_API V OrAnd(const V o, const V a1, const V a2) {
|
|
return Or(o, And(a1, a2));
|
|
}
|
|
|
|
// ------------------------------ PopulationCount
|
|
|
|
#ifdef HWY_NATIVE_POPCNT
|
|
#undef HWY_NATIVE_POPCNT
|
|
#else
|
|
#define HWY_NATIVE_POPCNT
|
|
#endif
|
|
|
|
// Need to return original type instead of unsigned.
|
|
#define HWY_SVE_POPCNT(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
|
|
return BitCast(DFromV<decltype(v)>(), \
|
|
sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v)); \
|
|
}
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT, PopulationCount, cnt)
|
|
#undef HWY_SVE_POPCNT
|
|
|
|
// ================================================== SIGN
|
|
|
|
// ------------------------------ Neg
|
|
HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Neg, neg)
|
|
|
|
HWY_API VBF16 Neg(VBF16 v) {
|
|
const DFromV<decltype(v)> d;
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
using TU = TFromD<decltype(du)>;
|
|
return BitCast(d, Xor(BitCast(du, v), Set(du, SignMask<TU>())));
|
|
}
|
|
|
|
// ------------------------------ Abs
|
|
HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs)
|
|
|
|
// ================================================== ARITHMETIC
|
|
|
|
// Per-target flags to prevent generic_ops-inl.h defining Add etc.
|
|
#ifdef HWY_NATIVE_OPERATOR_REPLACEMENTS
|
|
#undef HWY_NATIVE_OPERATOR_REPLACEMENTS
|
|
#else
|
|
#define HWY_NATIVE_OPERATOR_REPLACEMENTS
|
|
#endif
|
|
|
|
// ------------------------------ Add
|
|
|
|
namespace detail {
|
|
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN, AddN, add_n)
|
|
} // namespace detail
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Add, add)
|
|
|
|
// ------------------------------ Sub
|
|
|
|
namespace detail {
|
|
// Can't use HWY_SVE_RETV_ARGPVN because caller wants to specify pg.
|
|
#define HWY_SVE_RETV_ARGPVN_MASK(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
|
|
return sv##OP##_##CHAR##BITS##_z(pg, a, b); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN_MASK, SubN, sub_n)
|
|
#undef HWY_SVE_RETV_ARGPVN_MASK
|
|
} // namespace detail
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Sub, sub)
|
|
|
|
// ------------------------------ SumsOf8
|
|
HWY_API svuint64_t SumsOf8(const svuint8_t v) {
|
|
const ScalableTag<uint32_t> du32;
|
|
const ScalableTag<uint64_t> du64;
|
|
const svbool_t pg = detail::PTrue(du64);
|
|
|
|
const svuint32_t sums_of_4 = svdot_n_u32(Zero(du32), v, 1);
|
|
// Compute pairwise sum of u32 and extend to u64.
|
|
// TODO(janwas): on SVE2, we can instead use svaddp.
|
|
const svuint64_t hi = svlsr_n_u64_x(pg, BitCast(du64, sums_of_4), 32);
|
|
// Isolate the lower 32 bits (to be added to the upper 32 and zero-extended)
|
|
const svuint64_t lo = svextw_u64_x(pg, BitCast(du64, sums_of_4));
|
|
return Add(hi, lo);
|
|
}
|
|
|
|
// ------------------------------ SaturatedAdd
|
|
|
|
#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB
|
|
#undef HWY_NATIVE_I32_SATURATED_ADDSUB
|
|
#else
|
|
#define HWY_NATIVE_I32_SATURATED_ADDSUB
|
|
#endif
|
|
|
|
#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB
|
|
#undef HWY_NATIVE_U32_SATURATED_ADDSUB
|
|
#else
|
|
#define HWY_NATIVE_U32_SATURATED_ADDSUB
|
|
#endif
|
|
|
|
#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB
|
|
#undef HWY_NATIVE_I64_SATURATED_ADDSUB
|
|
#else
|
|
#define HWY_NATIVE_I64_SATURATED_ADDSUB
|
|
#endif
|
|
|
|
#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB
|
|
#undef HWY_NATIVE_U64_SATURATED_ADDSUB
|
|
#else
|
|
#define HWY_NATIVE_U64_SATURATED_ADDSUB
|
|
#endif
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd)
|
|
|
|
// ------------------------------ SaturatedSub
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub)
|
|
|
|
// ------------------------------ AbsDiff
|
|
#ifdef HWY_NATIVE_INTEGER_ABS_DIFF
|
|
#undef HWY_NATIVE_INTEGER_ABS_DIFF
|
|
#else
|
|
#define HWY_NATIVE_INTEGER_ABS_DIFF
|
|
#endif
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, AbsDiff, abd)
|
|
|
|
// ------------------------------ ShiftLeft[Same]
|
|
|
|
#define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <int kBits> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
|
|
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits); \
|
|
} \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME##Same(HWY_SVE_V(BASE, BITS) v, HWY_SVE_T(uint, BITS) bits) { \
|
|
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, bits); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_N, ShiftLeft, lsl_n)
|
|
|
|
// ------------------------------ ShiftRight[Same]
|
|
|
|
HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_N, ShiftRight, lsr_n)
|
|
HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n)
|
|
|
|
#undef HWY_SVE_SHIFT_N
|
|
|
|
// ------------------------------ RotateRight
|
|
|
|
// TODO(janwas): svxar on SVE2
|
|
template <int kBits, class V>
|
|
HWY_API V RotateRight(const V v) {
|
|
constexpr size_t kSizeInBits = sizeof(TFromV<V>) * 8;
|
|
static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count");
|
|
if (kBits == 0) return v;
|
|
return Or(ShiftRight<kBits>(v),
|
|
ShiftLeft<HWY_MIN(kSizeInBits - 1, kSizeInBits - kBits)>(v));
|
|
}
|
|
|
|
// ------------------------------ Shl/r
|
|
|
|
#define HWY_SVE_SHIFT(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \
|
|
const RebindToUnsigned<DFromV<decltype(v)>> du; \
|
|
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \
|
|
BitCast(du, bits)); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT, Shl, lsl)
|
|
|
|
HWY_SVE_FOREACH_U(HWY_SVE_SHIFT, Shr, lsr)
|
|
HWY_SVE_FOREACH_I(HWY_SVE_SHIFT, Shr, asr)
|
|
|
|
#undef HWY_SVE_SHIFT
|
|
|
|
// ------------------------------ Min/Max
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Min, min)
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Max, max)
|
|
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Min, minnm)
|
|
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Max, maxnm)
|
|
|
|
namespace detail {
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MinN, min_n)
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MaxN, max_n)
|
|
} // namespace detail
|
|
|
|
// ------------------------------ Mul
|
|
|
|
// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*.
|
|
#ifdef HWY_NATIVE_MUL_8
|
|
#undef HWY_NATIVE_MUL_8
|
|
#else
|
|
#define HWY_NATIVE_MUL_8
|
|
#endif
|
|
#ifdef HWY_NATIVE_MUL_64
|
|
#undef HWY_NATIVE_MUL_64
|
|
#else
|
|
#define HWY_NATIVE_MUL_64
|
|
#endif
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Mul, mul)
|
|
|
|
// ------------------------------ MulHigh
|
|
HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
|
|
// Not part of API, used internally:
|
|
HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
|
|
HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
|
|
HWY_SVE_FOREACH_U64(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
|
|
|
|
// ------------------------------ MulFixedPoint15
|
|
HWY_API svint16_t MulFixedPoint15(svint16_t a, svint16_t b) {
|
|
#if HWY_SVE_HAVE_2
|
|
return svqrdmulh_s16(a, b);
|
|
#else
|
|
const DFromV<decltype(a)> d;
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
|
|
const svuint16_t lo = BitCast(du, Mul(a, b));
|
|
const svint16_t hi = MulHigh(a, b);
|
|
// We want (lo + 0x4000) >> 15, but that can overflow, and if it does we must
|
|
// carry that into the result. Instead isolate the top two bits because only
|
|
// they can influence the result.
|
|
const svuint16_t lo_top2 = ShiftRight<14>(lo);
|
|
// Bits 11: add 2, 10: add 1, 01: add 1, 00: add 0.
|
|
const svuint16_t rounding = ShiftRight<1>(detail::AddN(lo_top2, 1));
|
|
return Add(Add(hi, hi), BitCast(d, rounding));
|
|
#endif
|
|
}
|
|
|
|
// ------------------------------ Div
|
|
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Div, div)
|
|
|
|
// ------------------------------ ApproximateReciprocal
|
|
#ifdef HWY_NATIVE_F64_APPROX_RECIP
|
|
#undef HWY_NATIVE_F64_APPROX_RECIP
|
|
#else
|
|
#define HWY_NATIVE_F64_APPROX_RECIP
|
|
#endif
|
|
|
|
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe)
|
|
|
|
// ------------------------------ Sqrt
|
|
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt)
|
|
|
|
// ------------------------------ ApproximateReciprocalSqrt
|
|
#ifdef HWY_NATIVE_F64_APPROX_RSQRT
|
|
#undef HWY_NATIVE_F64_APPROX_RSQRT
|
|
#else
|
|
#define HWY_NATIVE_F64_APPROX_RSQRT
|
|
#endif
|
|
|
|
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocalSqrt, rsqrte)
|
|
|
|
// ------------------------------ MulAdd
|
|
|
|
// Per-target flag to prevent generic_ops-inl.h from defining int MulAdd.
|
|
#ifdef HWY_NATIVE_INT_FMA
|
|
#undef HWY_NATIVE_INT_FMA
|
|
#else
|
|
#define HWY_NATIVE_INT_FMA
|
|
#endif
|
|
|
|
#define HWY_SVE_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \
|
|
HWY_SVE_V(BASE, BITS) add) { \
|
|
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), x, mul, add); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_FMA, MulAdd, mad)
|
|
|
|
// ------------------------------ NegMulAdd
|
|
HWY_SVE_FOREACH(HWY_SVE_FMA, NegMulAdd, msb)
|
|
|
|
// ------------------------------ MulSub
|
|
HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulSub, nmsb)
|
|
|
|
// ------------------------------ NegMulSub
|
|
HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad)
|
|
|
|
#undef HWY_SVE_FMA
|
|
|
|
// ------------------------------ Round etc.
|
|
|
|
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn)
|
|
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Floor, rintm)
|
|
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Ceil, rintp)
|
|
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Trunc, rintz)
|
|
|
|
// ================================================== MASK
|
|
|
|
// ------------------------------ RebindMask
|
|
template <class D, typename MFrom>
|
|
HWY_API svbool_t RebindMask(const D /*d*/, const MFrom mask) {
|
|
return mask;
|
|
}
|
|
|
|
// ------------------------------ Mask logical
|
|
|
|
HWY_API svbool_t Not(svbool_t m) {
|
|
// We don't know the lane type, so assume 8-bit. For larger types, this will
|
|
// de-canonicalize the predicate, i.e. set bits to 1 even though they do not
|
|
// correspond to the lowest byte in the lane. Arm says such bits are ignored.
|
|
return svnot_b_z(HWY_SVE_PTRUE(8), m);
|
|
}
|
|
HWY_API svbool_t And(svbool_t a, svbool_t b) {
|
|
return svand_b_z(b, b, a); // same order as AndNot for consistency
|
|
}
|
|
HWY_API svbool_t AndNot(svbool_t a, svbool_t b) {
|
|
return svbic_b_z(b, b, a); // reversed order like NEON
|
|
}
|
|
HWY_API svbool_t Or(svbool_t a, svbool_t b) {
|
|
return svsel_b(a, a, b); // a ? true : b
|
|
}
|
|
HWY_API svbool_t Xor(svbool_t a, svbool_t b) {
|
|
return svsel_b(a, svnand_b_z(a, a, b), b); // a ? !(a & b) : b.
|
|
}
|
|
|
|
HWY_API svbool_t ExclusiveNeither(svbool_t a, svbool_t b) {
|
|
return svnor_b_z(HWY_SVE_PTRUE(8), a, b); // !a && !b, undefined if a && b.
|
|
}
|
|
|
|
// ------------------------------ CountTrue
|
|
|
|
#define HWY_SVE_COUNT_TRUE(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, svbool_t m) { \
|
|
return sv##OP##_b##BITS(detail::MakeMask(d), m); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp)
|
|
#undef HWY_SVE_COUNT_TRUE
|
|
|
|
// For 16-bit Compress: full vector, not limited to SV_POW2.
|
|
namespace detail {
|
|
|
|
#define HWY_SVE_COUNT_TRUE_FULL(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m) { \
|
|
return sv##OP##_b##BITS(svptrue_b##BITS(), m); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp)
|
|
#undef HWY_SVE_COUNT_TRUE_FULL
|
|
|
|
} // namespace detail
|
|
|
|
// ------------------------------ AllFalse
|
|
template <class D>
|
|
HWY_API bool AllFalse(D d, svbool_t m) {
|
|
return !svptest_any(detail::MakeMask(d), m);
|
|
}
|
|
|
|
// ------------------------------ AllTrue
|
|
template <class D>
|
|
HWY_API bool AllTrue(D d, svbool_t m) {
|
|
return CountTrue(d, m) == Lanes(d);
|
|
}
|
|
|
|
// ------------------------------ FindFirstTrue
|
|
template <class D>
|
|
HWY_API intptr_t FindFirstTrue(D d, svbool_t m) {
|
|
return AllFalse(d, m) ? intptr_t{-1}
|
|
: static_cast<intptr_t>(
|
|
CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m)));
|
|
}
|
|
|
|
// ------------------------------ FindKnownFirstTrue
|
|
template <class D>
|
|
HWY_API size_t FindKnownFirstTrue(D d, svbool_t m) {
|
|
return CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m));
|
|
}
|
|
|
|
// ------------------------------ IfThenElse
|
|
#define HWY_SVE_IF_THEN_ELSE(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) yes, HWY_SVE_V(BASE, BITS) no) { \
|
|
return sv##OP##_##CHAR##BITS(m, yes, no); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel)
|
|
#undef HWY_SVE_IF_THEN_ELSE
|
|
|
|
// ------------------------------ IfThenElseZero
|
|
template <class V>
|
|
HWY_API V IfThenElseZero(const svbool_t mask, const V yes) {
|
|
return IfThenElse(mask, yes, Zero(DFromV<V>()));
|
|
}
|
|
|
|
// ------------------------------ IfThenZeroElse
|
|
template <class V>
|
|
HWY_API V IfThenZeroElse(const svbool_t mask, const V no) {
|
|
return IfThenElse(mask, Zero(DFromV<V>()), no);
|
|
}
|
|
|
|
// ------------------------------ Additional mask logical operations
|
|
HWY_API svbool_t SetBeforeFirst(svbool_t m) {
|
|
// We don't know the lane type, so assume 8-bit. For larger types, this will
|
|
// de-canonicalize the predicate, i.e. set bits to 1 even though they do not
|
|
// correspond to the lowest byte in the lane. Arm says such bits are ignored.
|
|
return svbrkb_b_z(HWY_SVE_PTRUE(8), m);
|
|
}
|
|
|
|
HWY_API svbool_t SetAtOrBeforeFirst(svbool_t m) {
|
|
// We don't know the lane type, so assume 8-bit. For larger types, this will
|
|
// de-canonicalize the predicate, i.e. set bits to 1 even though they do not
|
|
// correspond to the lowest byte in the lane. Arm says such bits are ignored.
|
|
return svbrka_b_z(HWY_SVE_PTRUE(8), m);
|
|
}
|
|
|
|
HWY_API svbool_t SetOnlyFirst(svbool_t m) { return svbrka_b_z(m, m); }
|
|
|
|
HWY_API svbool_t SetAtOrAfterFirst(svbool_t m) {
|
|
return Not(SetBeforeFirst(m));
|
|
}
|
|
|
|
// ================================================== COMPARE
|
|
|
|
// mask = f(vector, vector)
|
|
#define HWY_SVE_COMPARE(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
|
|
return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \
|
|
}
|
|
#define HWY_SVE_COMPARE_N(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
|
|
return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \
|
|
}
|
|
|
|
// ------------------------------ Eq
|
|
HWY_SVE_FOREACH(HWY_SVE_COMPARE, Eq, cmpeq)
|
|
namespace detail {
|
|
HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, EqN, cmpeq_n)
|
|
} // namespace detail
|
|
|
|
// ------------------------------ Ne
|
|
HWY_SVE_FOREACH(HWY_SVE_COMPARE, Ne, cmpne)
|
|
namespace detail {
|
|
HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, NeN, cmpne_n)
|
|
} // namespace detail
|
|
|
|
// ------------------------------ Lt
|
|
HWY_SVE_FOREACH(HWY_SVE_COMPARE, Lt, cmplt)
|
|
namespace detail {
|
|
HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LtN, cmplt_n)
|
|
} // namespace detail
|
|
|
|
// ------------------------------ Le
|
|
HWY_SVE_FOREACH(HWY_SVE_COMPARE, Le, cmple)
|
|
namespace detail {
|
|
HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LeN, cmple_n)
|
|
} // namespace detail
|
|
|
|
// ------------------------------ Gt/Ge (swapped order)
|
|
template <class V>
|
|
HWY_API svbool_t Gt(const V a, const V b) {
|
|
return Lt(b, a);
|
|
}
|
|
template <class V>
|
|
HWY_API svbool_t Ge(const V a, const V b) {
|
|
return Le(b, a);
|
|
}
|
|
namespace detail {
|
|
HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, GeN, cmpge_n)
|
|
HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, GtN, cmpgt_n)
|
|
} // namespace detail
|
|
|
|
#undef HWY_SVE_COMPARE
|
|
#undef HWY_SVE_COMPARE_N
|
|
|
|
// ------------------------------ TestBit
|
|
template <class V>
|
|
HWY_API svbool_t TestBit(const V a, const V bit) {
|
|
return detail::NeN(And(a, bit), 0);
|
|
}
|
|
|
|
// ------------------------------ MaskFromVec (Ne)
|
|
template <class V>
|
|
HWY_API svbool_t MaskFromVec(const V v) {
|
|
return detail::NeN(v, static_cast<TFromV<V>>(0));
|
|
}
|
|
|
|
// ------------------------------ VecFromMask
|
|
template <class D>
|
|
HWY_API VFromD<D> VecFromMask(const D d, svbool_t mask) {
|
|
const RebindToSigned<D> di;
|
|
// This generates MOV imm, whereas svdup_n_s8_z generates MOV scalar, which
|
|
// requires an extra instruction plus M0 pipeline.
|
|
return BitCast(d, IfThenElseZero(mask, Set(di, -1)));
|
|
}
|
|
|
|
// ------------------------------ IfVecThenElse (MaskFromVec, IfThenElse)
|
|
|
|
#if HWY_SVE_HAVE_2
|
|
|
|
#define HWY_SVE_IF_VEC(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) yes, \
|
|
HWY_SVE_V(BASE, BITS) no) { \
|
|
return sv##OP##_##CHAR##BITS(yes, no, mask); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_IF_VEC, IfVecThenElse, bsl)
|
|
#undef HWY_SVE_IF_VEC
|
|
|
|
template <class V, HWY_IF_FLOAT_V(V)>
|
|
HWY_API V IfVecThenElse(const V mask, const V yes, const V no) {
|
|
const DFromV<V> d;
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
return BitCast(
|
|
d, IfVecThenElse(BitCast(du, mask), BitCast(du, yes), BitCast(du, no)));
|
|
}
|
|
|
|
#else
|
|
|
|
template <class V>
|
|
HWY_API V IfVecThenElse(const V mask, const V yes, const V no) {
|
|
return Or(And(mask, yes), AndNot(mask, no));
|
|
}
|
|
|
|
#endif // HWY_SVE_HAVE_2
|
|
|
|
// ------------------------------ BitwiseIfThenElse
|
|
|
|
#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE
|
|
#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE
|
|
#else
|
|
#define HWY_NATIVE_BITWISE_IF_THEN_ELSE
|
|
#endif
|
|
|
|
template <class V>
|
|
HWY_API V BitwiseIfThenElse(V mask, V yes, V no) {
|
|
return IfVecThenElse(mask, yes, no);
|
|
}
|
|
|
|
// ------------------------------ CopySign (BitwiseIfThenElse)
|
|
template <class V>
|
|
HWY_API V CopySign(const V magn, const V sign) {
|
|
const DFromV<decltype(magn)> d;
|
|
return BitwiseIfThenElse(SignBit(d), sign, magn);
|
|
}
|
|
|
|
// ------------------------------ CopySignToAbs
|
|
template <class V>
|
|
HWY_API V CopySignToAbs(const V abs, const V sign) {
|
|
#if HWY_SVE_HAVE_2 // CopySign is more efficient than OrAnd
|
|
return CopySign(abs, sign);
|
|
#else
|
|
const DFromV<V> d;
|
|
return OrAnd(abs, SignBit(d), sign);
|
|
#endif
|
|
}
|
|
|
|
// ------------------------------ Floating-point classification (Ne)
|
|
|
|
template <class V>
|
|
HWY_API svbool_t IsNaN(const V v) {
|
|
return Ne(v, v); // could also use cmpuo
|
|
}
|
|
|
|
template <class V>
|
|
HWY_API svbool_t IsInf(const V v) {
|
|
using T = TFromV<V>;
|
|
const DFromV<decltype(v)> d;
|
|
const RebindToSigned<decltype(d)> di;
|
|
const VFromD<decltype(di)> vi = BitCast(di, v);
|
|
// 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0.
|
|
return RebindMask(d, detail::EqN(Add(vi, vi), hwy::MaxExponentTimes2<T>()));
|
|
}
|
|
|
|
// Returns whether normal/subnormal/zero.
|
|
template <class V>
|
|
HWY_API svbool_t IsFinite(const V v) {
|
|
using T = TFromV<V>;
|
|
const DFromV<decltype(v)> d;
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
const RebindToSigned<decltype(d)> di; // cheaper than unsigned comparison
|
|
const VFromD<decltype(du)> vu = BitCast(du, v);
|
|
// 'Shift left' to clear the sign bit, then right so we can compare with the
|
|
// max exponent (cannot compare with MaxExponentTimes2 directly because it is
|
|
// negative and non-negative floats would be greater).
|
|
const VFromD<decltype(di)> exp =
|
|
BitCast(di, ShiftRight<hwy::MantissaBits<T>() + 1>(Add(vu, vu)));
|
|
return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField<T>()));
|
|
}
|
|
|
|
// ================================================== MEMORY
|
|
|
|
// ------------------------------ Load/MaskedLoad/LoadDup128/Store/Stream
|
|
|
|
#define HWY_SVE_LOAD(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \
|
|
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
|
|
return sv##OP##_##CHAR##BITS(detail::MakeMask(d), p); \
|
|
}
|
|
|
|
#define HWY_SVE_MASKED_LOAD(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
|
|
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
|
|
return sv##OP##_##CHAR##BITS(m, p); \
|
|
}
|
|
|
|
#define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
|
|
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
|
|
/* All-true predicate to load all 128 bits. */ \
|
|
return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), p); \
|
|
}
|
|
|
|
#define HWY_SVE_STORE(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \
|
|
HWY_SVE_D(BASE, BITS, N, kPow2) d, \
|
|
HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
|
|
sv##OP##_##CHAR##BITS(detail::MakeMask(d), p, v); \
|
|
}
|
|
|
|
#define HWY_SVE_BLENDED_STORE(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, svbool_t m, \
|
|
HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
|
|
HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
|
|
sv##OP##_##CHAR##BITS(m, p, v); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_LOAD, Load, ld1)
|
|
HWY_SVE_FOREACH(HWY_SVE_MASKED_LOAD, MaskedLoad, ld1)
|
|
HWY_SVE_FOREACH(HWY_SVE_STORE, Store, st1)
|
|
HWY_SVE_FOREACH(HWY_SVE_STORE, Stream, stnt1)
|
|
HWY_SVE_FOREACH(HWY_SVE_BLENDED_STORE, BlendedStore, st1)
|
|
|
|
HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD, Load, ld1)
|
|
HWY_SVE_FOREACH_BF16(HWY_SVE_MASKED_LOAD, MaskedLoad, ld1)
|
|
HWY_SVE_FOREACH_BF16(HWY_SVE_STORE, Store, st1)
|
|
HWY_SVE_FOREACH_BF16(HWY_SVE_STORE, Stream, stnt1)
|
|
HWY_SVE_FOREACH_BF16(HWY_SVE_BLENDED_STORE, BlendedStore, st1)
|
|
|
|
#if HWY_TARGET != HWY_SVE2_128
|
|
namespace detail {
|
|
HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq)
|
|
} // namespace detail
|
|
#endif // HWY_TARGET != HWY_SVE2_128
|
|
|
|
#undef HWY_SVE_LOAD
|
|
#undef HWY_SVE_MASKED_LOAD
|
|
#undef HWY_SVE_LOAD_DUP128
|
|
#undef HWY_SVE_STORE
|
|
#undef HWY_SVE_BLENDED_STORE
|
|
|
|
#if !HWY_SVE_HAVE_BFLOAT16
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API VBF16 Load(Simd<bfloat16_t, N, kPow2> d,
|
|
const bfloat16_t* HWY_RESTRICT p) {
|
|
return Load(RebindToUnsigned<decltype(d)>(),
|
|
reinterpret_cast<const uint16_t * HWY_RESTRICT>(p));
|
|
}
|
|
|
|
#endif // !HWY_SVE_HAVE_BFLOAT16
|
|
|
|
#if HWY_TARGET == HWY_SVE2_128
|
|
// On the HWY_SVE2_128 target, LoadDup128 is the same as Load since vectors
|
|
// cannot exceed 16 bytes on the HWY_SVE2_128 target.
|
|
template <class D>
|
|
HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) {
|
|
return Load(d, p);
|
|
}
|
|
#else
|
|
// If D().MaxBytes() <= 16 is true, simply do a Load operation.
|
|
template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
|
|
HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) {
|
|
return Load(d, p);
|
|
}
|
|
|
|
// If D().MaxBytes() > 16 is true, need to load the vector using ld1rq
|
|
template <class D, HWY_IF_V_SIZE_GT_D(D, 16),
|
|
hwy::EnableIf<!IsSame<TFromD<D>, bfloat16_t>()>* = nullptr>
|
|
HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) {
|
|
return detail::LoadDupFull128(d, p);
|
|
}
|
|
|
|
#if !HWY_SVE_HAVE_BFLOAT16
|
|
|
|
template <class D, HWY_IF_V_SIZE_GT_D(D, 16), HWY_IF_BF16_D(D)>
|
|
HWY_API VBF16 LoadDup128(D d, const bfloat16_t* HWY_RESTRICT p) {
|
|
return detail::LoadDupFull128(
|
|
RebindToUnsigned<decltype(d)>(),
|
|
reinterpret_cast<const uint16_t * HWY_RESTRICT>(p));
|
|
}
|
|
#endif // !HWY_SVE_HAVE_BFLOAT16
|
|
|
|
#endif // HWY_TARGET != HWY_SVE2_128
|
|
|
|
#if !HWY_SVE_HAVE_BFLOAT16
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API void Store(VBF16 v, Simd<bfloat16_t, N, kPow2> d,
|
|
bfloat16_t* HWY_RESTRICT p) {
|
|
Store(v, RebindToUnsigned<decltype(d)>(),
|
|
reinterpret_cast<uint16_t * HWY_RESTRICT>(p));
|
|
}
|
|
|
|
#endif
|
|
|
|
// ------------------------------ Load/StoreU
|
|
|
|
// SVE only requires lane alignment, not natural alignment of the entire
|
|
// vector.
|
|
template <class D>
|
|
HWY_API VFromD<D> LoadU(D d, const TFromD<D>* HWY_RESTRICT p) {
|
|
return Load(d, p);
|
|
}
|
|
|
|
template <class V, class D>
|
|
HWY_API void StoreU(const V v, D d, TFromD<D>* HWY_RESTRICT p) {
|
|
Store(v, d, p);
|
|
}
|
|
|
|
// ------------------------------ MaskedLoadOr
|
|
|
|
// SVE MaskedLoad hard-codes zero, so this requires an extra blend.
|
|
template <class D>
|
|
HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D d,
|
|
const TFromD<D>* HWY_RESTRICT p) {
|
|
return IfThenElse(m, MaskedLoad(m, d, p), v);
|
|
}
|
|
|
|
// ------------------------------ ScatterOffset/Index
|
|
|
|
#ifdef HWY_NATIVE_SCATTER
|
|
#undef HWY_NATIVE_SCATTER
|
|
#else
|
|
#define HWY_NATIVE_SCATTER
|
|
#endif
|
|
|
|
#define HWY_SVE_SCATTER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \
|
|
HWY_SVE_D(BASE, BITS, N, kPow2) d, \
|
|
HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \
|
|
HWY_SVE_V(int, BITS) offset) { \
|
|
sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, offset, \
|
|
v); \
|
|
}
|
|
|
|
#define HWY_SVE_MASKED_SCATTER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, svbool_t m, \
|
|
HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, \
|
|
HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \
|
|
HWY_SVE_V(int, BITS) index) { \
|
|
sv##OP##_s##BITS##index_##CHAR##BITS(m, base, index, v); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET, ScatterOffset, st1_scatter)
|
|
HWY_SVE_FOREACH_UIF3264(HWY_SVE_MASKED_SCATTER_INDEX, MaskedScatterIndex,
|
|
st1_scatter)
|
|
#undef HWY_SVE_SCATTER_OFFSET
|
|
#undef HWY_SVE_MASKED_SCATTER_INDEX
|
|
|
|
template <class D>
|
|
HWY_API void ScatterIndex(VFromD<D> v, D d, TFromD<D>* HWY_RESTRICT p,
|
|
VFromD<RebindToSigned<D>> indices) {
|
|
MaskedScatterIndex(v, detail::MakeMask(d), d, p, indices);
|
|
}
|
|
|
|
// ------------------------------ GatherOffset/Index
|
|
|
|
#ifdef HWY_NATIVE_GATHER
|
|
#undef HWY_NATIVE_GATHER
|
|
#else
|
|
#define HWY_NATIVE_GATHER
|
|
#endif
|
|
|
|
#define HWY_SVE_GATHER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \
|
|
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \
|
|
HWY_SVE_V(int, BITS) offset) { \
|
|
return sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, \
|
|
offset); \
|
|
}
|
|
#define HWY_SVE_MASKED_GATHER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, \
|
|
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \
|
|
HWY_SVE_V(int, BITS) index) { \
|
|
return sv##OP##_s##BITS##index_##CHAR##BITS(m, base, index); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_OFFSET, GatherOffset, ld1_gather)
|
|
HWY_SVE_FOREACH_UIF3264(HWY_SVE_MASKED_GATHER_INDEX, MaskedGatherIndex,
|
|
ld1_gather)
|
|
#undef HWY_SVE_GATHER_OFFSET
|
|
#undef HWY_SVE_MASKED_GATHER_INDEX
|
|
|
|
template <class D>
|
|
HWY_API VFromD<D> GatherIndex(D d, const TFromD<D>* HWY_RESTRICT p,
|
|
VFromD<RebindToSigned<D>> indices) {
|
|
return MaskedGatherIndex(detail::MakeMask(d), d, p, indices);
|
|
}
|
|
|
|
// ------------------------------ LoadInterleaved2
|
|
|
|
// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2.
|
|
#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED
|
|
#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED
|
|
#else
|
|
#define HWY_NATIVE_LOAD_STORE_INTERLEAVED
|
|
#endif
|
|
|
|
#define HWY_SVE_LOAD2(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \
|
|
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \
|
|
HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1) { \
|
|
const HWY_SVE_TUPLE(BASE, BITS, 2) tuple = \
|
|
sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \
|
|
v0 = svget2(tuple, 0); \
|
|
v1 = svget2(tuple, 1); \
|
|
}
|
|
HWY_SVE_FOREACH(HWY_SVE_LOAD2, LoadInterleaved2, ld2)
|
|
|
|
#undef HWY_SVE_LOAD2
|
|
|
|
// ------------------------------ LoadInterleaved3
|
|
|
|
#define HWY_SVE_LOAD3(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \
|
|
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \
|
|
HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \
|
|
HWY_SVE_V(BASE, BITS) & v2) { \
|
|
const HWY_SVE_TUPLE(BASE, BITS, 3) tuple = \
|
|
sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \
|
|
v0 = svget3(tuple, 0); \
|
|
v1 = svget3(tuple, 1); \
|
|
v2 = svget3(tuple, 2); \
|
|
}
|
|
HWY_SVE_FOREACH(HWY_SVE_LOAD3, LoadInterleaved3, ld3)
|
|
|
|
#undef HWY_SVE_LOAD3
|
|
|
|
// ------------------------------ LoadInterleaved4
|
|
|
|
#define HWY_SVE_LOAD4(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \
|
|
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \
|
|
HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \
|
|
HWY_SVE_V(BASE, BITS) & v2, HWY_SVE_V(BASE, BITS) & v3) { \
|
|
const HWY_SVE_TUPLE(BASE, BITS, 4) tuple = \
|
|
sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \
|
|
v0 = svget4(tuple, 0); \
|
|
v1 = svget4(tuple, 1); \
|
|
v2 = svget4(tuple, 2); \
|
|
v3 = svget4(tuple, 3); \
|
|
}
|
|
HWY_SVE_FOREACH(HWY_SVE_LOAD4, LoadInterleaved4, ld4)
|
|
|
|
#undef HWY_SVE_LOAD4
|
|
|
|
// ------------------------------ StoreInterleaved2
|
|
|
|
#define HWY_SVE_STORE2(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \
|
|
HWY_SVE_D(BASE, BITS, N, kPow2) d, \
|
|
HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \
|
|
sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, Create2(d, v0, v1)); \
|
|
}
|
|
HWY_SVE_FOREACH(HWY_SVE_STORE2, StoreInterleaved2, st2)
|
|
|
|
#undef HWY_SVE_STORE2
|
|
|
|
// ------------------------------ StoreInterleaved3
|
|
|
|
#define HWY_SVE_STORE3(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \
|
|
HWY_SVE_V(BASE, BITS) v2, \
|
|
HWY_SVE_D(BASE, BITS, N, kPow2) d, \
|
|
HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \
|
|
sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, \
|
|
Create3(d, v0, v1, v2)); \
|
|
}
|
|
HWY_SVE_FOREACH(HWY_SVE_STORE3, StoreInterleaved3, st3)
|
|
|
|
#undef HWY_SVE_STORE3
|
|
|
|
// ------------------------------ StoreInterleaved4
|
|
|
|
#define HWY_SVE_STORE4(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \
|
|
HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3, \
|
|
HWY_SVE_D(BASE, BITS, N, kPow2) d, \
|
|
HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \
|
|
sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, \
|
|
Create4(d, v0, v1, v2, v3)); \
|
|
}
|
|
HWY_SVE_FOREACH(HWY_SVE_STORE4, StoreInterleaved4, st4)
|
|
|
|
#undef HWY_SVE_STORE4
|
|
|
|
// ================================================== CONVERT
|
|
|
|
// ------------------------------ PromoteTo
|
|
|
|
// Same sign
|
|
#define HWY_SVE_PROMOTE_TO(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME( \
|
|
HWY_SVE_D(BASE, BITS, N, kPow2) /* tag */, HWY_SVE_V(BASE, HALF) v) { \
|
|
return sv##OP##_##CHAR##BITS(v); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo)
|
|
HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo)
|
|
HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo)
|
|
|
|
// 2x
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint32_t PromoteTo(Simd<uint32_t, N, kPow2> dto, svuint8_t vfrom) {
|
|
const RepartitionToWide<DFromV<decltype(vfrom)>> d2;
|
|
return PromoteTo(dto, PromoteTo(d2, vfrom));
|
|
}
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint32_t PromoteTo(Simd<int32_t, N, kPow2> dto, svint8_t vfrom) {
|
|
const RepartitionToWide<DFromV<decltype(vfrom)>> d2;
|
|
return PromoteTo(dto, PromoteTo(d2, vfrom));
|
|
}
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint64_t PromoteTo(Simd<uint64_t, N, kPow2> dto, svuint16_t vfrom) {
|
|
const RepartitionToWide<DFromV<decltype(vfrom)>> d2;
|
|
return PromoteTo(dto, PromoteTo(d2, vfrom));
|
|
}
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint64_t PromoteTo(Simd<int64_t, N, kPow2> dto, svint16_t vfrom) {
|
|
const RepartitionToWide<DFromV<decltype(vfrom)>> d2;
|
|
return PromoteTo(dto, PromoteTo(d2, vfrom));
|
|
}
|
|
|
|
// 3x
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint64_t PromoteTo(Simd<uint64_t, N, kPow2> dto, svuint8_t vfrom) {
|
|
const RepartitionToNarrow<decltype(dto)> d4;
|
|
const RepartitionToNarrow<decltype(d4)> d2;
|
|
return PromoteTo(dto, PromoteTo(d4, PromoteTo(d2, vfrom)));
|
|
}
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint64_t PromoteTo(Simd<int64_t, N, kPow2> dto, svint8_t vfrom) {
|
|
const RepartitionToNarrow<decltype(dto)> d4;
|
|
const RepartitionToNarrow<decltype(d4)> d2;
|
|
return PromoteTo(dto, PromoteTo(d4, PromoteTo(d2, vfrom)));
|
|
}
|
|
|
|
// Sign change
|
|
template <class D, class V, HWY_IF_SIGNED_D(D), HWY_IF_UNSIGNED_V(V),
|
|
HWY_IF_LANES_GT(sizeof(TFromD<D>), sizeof(TFromV<V>))>
|
|
HWY_API VFromD<D> PromoteTo(D di, V v) {
|
|
const RebindToUnsigned<decltype(di)> du;
|
|
return BitCast(di, PromoteTo(du, v));
|
|
}
|
|
|
|
// ------------------------------ PromoteTo F
|
|
|
|
// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions.
|
|
#ifdef HWY_NATIVE_F16C
|
|
#undef HWY_NATIVE_F16C
|
|
#else
|
|
#define HWY_NATIVE_F16C
|
|
#endif
|
|
|
|
// Unlike Highway's ZipLower, this returns the same type.
|
|
namespace detail {
|
|
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipLowerSame, zip1)
|
|
} // namespace detail
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svfloat32_t PromoteTo(Simd<float32_t, N, kPow2> /* d */,
|
|
const svfloat16_t v) {
|
|
// svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so
|
|
// first replicate each lane once.
|
|
const svfloat16_t vv = detail::ZipLowerSame(v, v);
|
|
return svcvt_f32_f16_x(detail::PTrue(Simd<float16_t, N, kPow2>()), vv);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */,
|
|
const svfloat32_t v) {
|
|
const svfloat32_t vv = detail::ZipLowerSame(v, v);
|
|
return svcvt_f64_f32_x(detail::PTrue(Simd<float32_t, N, kPow2>()), vv);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */,
|
|
const svint32_t v) {
|
|
const svint32_t vv = detail::ZipLowerSame(v, v);
|
|
return svcvt_f64_s32_x(detail::PTrue(Simd<int32_t, N, kPow2>()), vv);
|
|
}
|
|
|
|
// For 16-bit Compress
|
|
namespace detail {
|
|
HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi)
|
|
#undef HWY_SVE_PROMOTE_TO
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svfloat32_t PromoteUpperTo(Simd<float, N, kPow2> df, svfloat16_t v) {
|
|
const RebindToUnsigned<decltype(df)> du;
|
|
const RepartitionToNarrow<decltype(du)> dn;
|
|
return BitCast(df, PromoteUpperTo(du, BitCast(dn, v)));
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
// ------------------------------ DemoteTo U
|
|
|
|
namespace detail {
|
|
|
|
// Saturates unsigned vectors to half/quarter-width TN.
|
|
template <typename TN, class VU>
|
|
VU SaturateU(VU v) {
|
|
return detail::MinN(v, static_cast<TFromV<VU>>(LimitsMax<TN>()));
|
|
}
|
|
|
|
// Saturates unsigned vectors to half/quarter-width TN.
|
|
template <typename TN, class VI>
|
|
VI SaturateI(VI v) {
|
|
return detail::MinN(detail::MaxN(v, LimitsMin<TN>()), LimitsMax<TN>());
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svint16_t v) {
|
|
#if HWY_SVE_HAVE_2
|
|
const svuint8_t vn = BitCast(dn, svqxtunb_s16(v));
|
|
#else
|
|
const DFromV<decltype(v)> di;
|
|
const RebindToUnsigned<decltype(di)> du;
|
|
using TN = TFromD<decltype(dn)>;
|
|
// First clamp negative numbers to zero and cast to unsigned.
|
|
const svuint16_t clamped = BitCast(du, detail::MaxN(v, 0));
|
|
// Saturate to unsigned-max and halve the width.
|
|
const svuint8_t vn = BitCast(dn, detail::SaturateU<TN>(clamped));
|
|
#endif
|
|
return svuzp1_u8(vn, vn);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svint32_t v) {
|
|
#if HWY_SVE_HAVE_2
|
|
const svuint16_t vn = BitCast(dn, svqxtunb_s32(v));
|
|
#else
|
|
const DFromV<decltype(v)> di;
|
|
const RebindToUnsigned<decltype(di)> du;
|
|
using TN = TFromD<decltype(dn)>;
|
|
// First clamp negative numbers to zero and cast to unsigned.
|
|
const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0));
|
|
// Saturate to unsigned-max and halve the width.
|
|
const svuint16_t vn = BitCast(dn, detail::SaturateU<TN>(clamped));
|
|
#endif
|
|
return svuzp1_u16(vn, vn);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svint32_t v) {
|
|
const DFromV<decltype(v)> di;
|
|
const RebindToUnsigned<decltype(di)> du;
|
|
const RepartitionToNarrow<decltype(du)> d2;
|
|
#if HWY_SVE_HAVE_2
|
|
const svuint16_t cast16 = BitCast(d2, svqxtnb_u16(svqxtunb_s32(v)));
|
|
#else
|
|
using TN = TFromD<decltype(dn)>;
|
|
// First clamp negative numbers to zero and cast to unsigned.
|
|
const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0));
|
|
// Saturate to unsigned-max and quarter the width.
|
|
const svuint16_t cast16 = BitCast(d2, detail::SaturateU<TN>(clamped));
|
|
#endif
|
|
const svuint8_t x2 = BitCast(dn, svuzp1_u16(cast16, cast16));
|
|
return svuzp1_u8(x2, x2);
|
|
}
|
|
|
|
HWY_API svuint8_t U8FromU32(const svuint32_t v) {
|
|
const DFromV<svuint32_t> du32;
|
|
const RepartitionToNarrow<decltype(du32)> du16;
|
|
const RepartitionToNarrow<decltype(du16)> du8;
|
|
|
|
const svuint16_t cast16 = BitCast(du16, v);
|
|
const svuint16_t x2 = svuzp1_u16(cast16, cast16);
|
|
const svuint8_t cast8 = BitCast(du8, x2);
|
|
return svuzp1_u8(cast8, cast8);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svuint16_t v) {
|
|
#if HWY_SVE_HAVE_2
|
|
const svuint8_t vn = BitCast(dn, svqxtnb_u16(v));
|
|
#else
|
|
using TN = TFromD<decltype(dn)>;
|
|
const svuint8_t vn = BitCast(dn, detail::SaturateU<TN>(v));
|
|
#endif
|
|
return svuzp1_u8(vn, vn);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svuint32_t v) {
|
|
#if HWY_SVE_HAVE_2
|
|
const svuint16_t vn = BitCast(dn, svqxtnb_u32(v));
|
|
#else
|
|
using TN = TFromD<decltype(dn)>;
|
|
const svuint16_t vn = BitCast(dn, detail::SaturateU<TN>(v));
|
|
#endif
|
|
return svuzp1_u16(vn, vn);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svuint32_t v) {
|
|
using TN = TFromD<decltype(dn)>;
|
|
return U8FromU32(detail::SaturateU<TN>(v));
|
|
}
|
|
|
|
// ------------------------------ Truncations
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint8_t TruncateTo(Simd<uint8_t, N, kPow2> /* tag */,
|
|
const svuint64_t v) {
|
|
const DFromV<svuint8_t> d;
|
|
const svuint8_t v1 = BitCast(d, v);
|
|
const svuint8_t v2 = svuzp1_u8(v1, v1);
|
|
const svuint8_t v3 = svuzp1_u8(v2, v2);
|
|
return svuzp1_u8(v3, v3);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint16_t TruncateTo(Simd<uint16_t, N, kPow2> /* tag */,
|
|
const svuint64_t v) {
|
|
const DFromV<svuint16_t> d;
|
|
const svuint16_t v1 = BitCast(d, v);
|
|
const svuint16_t v2 = svuzp1_u16(v1, v1);
|
|
return svuzp1_u16(v2, v2);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint32_t TruncateTo(Simd<uint32_t, N, kPow2> /* tag */,
|
|
const svuint64_t v) {
|
|
const DFromV<svuint32_t> d;
|
|
const svuint32_t v1 = BitCast(d, v);
|
|
return svuzp1_u32(v1, v1);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint8_t TruncateTo(Simd<uint8_t, N, kPow2> /* tag */,
|
|
const svuint32_t v) {
|
|
const DFromV<svuint8_t> d;
|
|
const svuint8_t v1 = BitCast(d, v);
|
|
const svuint8_t v2 = svuzp1_u8(v1, v1);
|
|
return svuzp1_u8(v2, v2);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint16_t TruncateTo(Simd<uint16_t, N, kPow2> /* tag */,
|
|
const svuint32_t v) {
|
|
const DFromV<svuint16_t> d;
|
|
const svuint16_t v1 = BitCast(d, v);
|
|
return svuzp1_u16(v1, v1);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint8_t TruncateTo(Simd<uint8_t, N, kPow2> /* tag */,
|
|
const svuint16_t v) {
|
|
const DFromV<svuint8_t> d;
|
|
const svuint8_t v1 = BitCast(d, v);
|
|
return svuzp1_u8(v1, v1);
|
|
}
|
|
|
|
// ------------------------------ DemoteTo I
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint8_t DemoteTo(Simd<int8_t, N, kPow2> dn, const svint16_t v) {
|
|
#if HWY_SVE_HAVE_2
|
|
const svint8_t vn = BitCast(dn, svqxtnb_s16(v));
|
|
#else
|
|
using TN = TFromD<decltype(dn)>;
|
|
const svint8_t vn = BitCast(dn, detail::SaturateI<TN>(v));
|
|
#endif
|
|
return svuzp1_s8(vn, vn);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint16_t DemoteTo(Simd<int16_t, N, kPow2> dn, const svint32_t v) {
|
|
#if HWY_SVE_HAVE_2
|
|
const svint16_t vn = BitCast(dn, svqxtnb_s32(v));
|
|
#else
|
|
using TN = TFromD<decltype(dn)>;
|
|
const svint16_t vn = BitCast(dn, detail::SaturateI<TN>(v));
|
|
#endif
|
|
return svuzp1_s16(vn, vn);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint8_t DemoteTo(Simd<int8_t, N, kPow2> dn, const svint32_t v) {
|
|
const RepartitionToWide<decltype(dn)> d2;
|
|
#if HWY_SVE_HAVE_2
|
|
const svint16_t cast16 = BitCast(d2, svqxtnb_s16(svqxtnb_s32(v)));
|
|
#else
|
|
using TN = TFromD<decltype(dn)>;
|
|
const svint16_t cast16 = BitCast(d2, detail::SaturateI<TN>(v));
|
|
#endif
|
|
const svint8_t v2 = BitCast(dn, svuzp1_s16(cast16, cast16));
|
|
return BitCast(dn, svuzp1_s8(v2, v2));
|
|
}
|
|
|
|
// ------------------------------ I64/U64 DemoteTo
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint32_t DemoteTo(Simd<int32_t, N, kPow2> dn, const svint64_t v) {
|
|
const Rebind<uint64_t, decltype(dn)> du64;
|
|
const RebindToUnsigned<decltype(dn)> dn_u;
|
|
#if HWY_SVE_HAVE_2
|
|
const svuint64_t vn = BitCast(du64, svqxtnb_s64(v));
|
|
#else
|
|
using TN = TFromD<decltype(dn)>;
|
|
const svuint64_t vn = BitCast(du64, detail::SaturateI<TN>(v));
|
|
#endif
|
|
return BitCast(dn, TruncateTo(dn_u, vn));
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint16_t DemoteTo(Simd<int16_t, N, kPow2> dn, const svint64_t v) {
|
|
const Rebind<uint64_t, decltype(dn)> du64;
|
|
const RebindToUnsigned<decltype(dn)> dn_u;
|
|
#if HWY_SVE_HAVE_2
|
|
const svuint64_t vn = BitCast(du64, svqxtnb_s32(svqxtnb_s64(v)));
|
|
#else
|
|
using TN = TFromD<decltype(dn)>;
|
|
const svuint64_t vn = BitCast(du64, detail::SaturateI<TN>(v));
|
|
#endif
|
|
return BitCast(dn, TruncateTo(dn_u, vn));
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint8_t DemoteTo(Simd<int8_t, N, kPow2> dn, const svint64_t v) {
|
|
const Rebind<uint64_t, decltype(dn)> du64;
|
|
const RebindToUnsigned<decltype(dn)> dn_u;
|
|
using TN = TFromD<decltype(dn)>;
|
|
const svuint64_t vn = BitCast(du64, detail::SaturateI<TN>(v));
|
|
return BitCast(dn, TruncateTo(dn_u, vn));
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint32_t DemoteTo(Simd<uint32_t, N, kPow2> dn, const svint64_t v) {
|
|
const Rebind<uint64_t, decltype(dn)> du64;
|
|
#if HWY_SVE_HAVE_2
|
|
const svuint64_t vn = BitCast(du64, svqxtunb_s64(v));
|
|
#else
|
|
using TN = TFromD<decltype(dn)>;
|
|
// First clamp negative numbers to zero and cast to unsigned.
|
|
const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0));
|
|
// Saturate to unsigned-max
|
|
const svuint64_t vn = detail::SaturateU<TN>(clamped);
|
|
#endif
|
|
return TruncateTo(dn, vn);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svint64_t v) {
|
|
const Rebind<uint64_t, decltype(dn)> du64;
|
|
#if HWY_SVE_HAVE_2
|
|
const svuint64_t vn = BitCast(du64, svqxtnb_u32(svqxtunb_s64(v)));
|
|
#else
|
|
using TN = TFromD<decltype(dn)>;
|
|
// First clamp negative numbers to zero and cast to unsigned.
|
|
const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0));
|
|
// Saturate to unsigned-max
|
|
const svuint64_t vn = detail::SaturateU<TN>(clamped);
|
|
#endif
|
|
return TruncateTo(dn, vn);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svint64_t v) {
|
|
const Rebind<uint64_t, decltype(dn)> du64;
|
|
using TN = TFromD<decltype(dn)>;
|
|
// First clamp negative numbers to zero and cast to unsigned.
|
|
const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0));
|
|
// Saturate to unsigned-max
|
|
const svuint64_t vn = detail::SaturateU<TN>(clamped);
|
|
return TruncateTo(dn, vn);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint32_t DemoteTo(Simd<uint32_t, N, kPow2> dn, const svuint64_t v) {
|
|
const Rebind<uint64_t, decltype(dn)> du64;
|
|
#if HWY_SVE_HAVE_2
|
|
const svuint64_t vn = BitCast(du64, svqxtnb_u64(v));
|
|
#else
|
|
using TN = TFromD<decltype(dn)>;
|
|
const svuint64_t vn = BitCast(du64, detail::SaturateU<TN>(v));
|
|
#endif
|
|
return TruncateTo(dn, vn);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svuint64_t v) {
|
|
const Rebind<uint64_t, decltype(dn)> du64;
|
|
#if HWY_SVE_HAVE_2
|
|
const svuint64_t vn = BitCast(du64, svqxtnb_u32(svqxtnb_u64(v)));
|
|
#else
|
|
using TN = TFromD<decltype(dn)>;
|
|
const svuint64_t vn = BitCast(du64, detail::SaturateU<TN>(v));
|
|
#endif
|
|
return TruncateTo(dn, vn);
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svuint64_t v) {
|
|
const Rebind<uint64_t, decltype(dn)> du64;
|
|
using TN = TFromD<decltype(dn)>;
|
|
const svuint64_t vn = BitCast(du64, detail::SaturateU<TN>(v));
|
|
return TruncateTo(dn, vn);
|
|
}
|
|
|
|
// ------------------------------ ConcatEven/ConcatOdd
|
|
|
|
// WARNING: the upper half of these needs fixing up (uzp1/uzp2 use the
|
|
// full vector length, not rounded down to a power of two as we require).
|
|
namespace detail {
|
|
|
|
#define HWY_SVE_CONCAT_EVERY_SECOND(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_INLINE HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \
|
|
return sv##OP##_##CHAR##BITS(lo, hi); \
|
|
}
|
|
HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, uzp1)
|
|
HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, uzp2)
|
|
#if defined(__ARM_FEATURE_SVE_MATMUL_FP64)
|
|
HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenBlocks, uzp1q)
|
|
HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, uzp2q)
|
|
#endif
|
|
#undef HWY_SVE_CONCAT_EVERY_SECOND
|
|
|
|
// Used to slide up / shift whole register left; mask indicates which range
|
|
// to take from lo, and the rest is filled from hi starting at its lowest.
|
|
#define HWY_SVE_SPLICE(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME( \
|
|
HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo, svbool_t mask) { \
|
|
return sv##OP##_##CHAR##BITS(mask, lo, hi); \
|
|
}
|
|
HWY_SVE_FOREACH(HWY_SVE_SPLICE, Splice, splice)
|
|
#undef HWY_SVE_SPLICE
|
|
|
|
} // namespace detail
|
|
|
|
template <class D>
|
|
HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) {
|
|
#if HWY_SVE_IS_POW2
|
|
if (detail::IsFull(d)) return detail::ConcatOddFull(hi, lo);
|
|
#endif
|
|
const VFromD<D> hi_odd = detail::ConcatOddFull(hi, hi);
|
|
const VFromD<D> lo_odd = detail::ConcatOddFull(lo, lo);
|
|
return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2));
|
|
}
|
|
|
|
template <class D>
|
|
HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) {
|
|
#if HWY_SVE_IS_POW2
|
|
if (detail::IsFull(d)) return detail::ConcatEvenFull(hi, lo);
|
|
#endif
|
|
const VFromD<D> hi_odd = detail::ConcatEvenFull(hi, hi);
|
|
const VFromD<D> lo_odd = detail::ConcatEvenFull(lo, lo);
|
|
return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2));
|
|
}
|
|
|
|
// ------------------------------ DemoteTo F
|
|
|
|
// We already toggled HWY_NATIVE_F16C above.
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svfloat16_t DemoteTo(Simd<float16_t, N, kPow2> d, const svfloat32_t v) {
|
|
const svfloat16_t in_even = svcvt_f16_f32_x(detail::PTrue(d), v);
|
|
return detail::ConcatEvenFull(in_even,
|
|
in_even); // lower half
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API VBF16 DemoteTo(Simd<bfloat16_t, N, kPow2> dbf16, svfloat32_t v) {
|
|
const svuint16_t in_even = BitCast(ScalableTag<uint16_t>(), v);
|
|
return BitCast(dbf16, detail::ConcatOddFull(in_even, in_even)); // lower half
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svfloat32_t DemoteTo(Simd<float32_t, N, kPow2> d, const svfloat64_t v) {
|
|
const svfloat32_t in_even = svcvt_f32_f64_x(detail::PTrue(d), v);
|
|
return detail::ConcatEvenFull(in_even,
|
|
in_even); // lower half
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint32_t DemoteTo(Simd<int32_t, N, kPow2> d, const svfloat64_t v) {
|
|
const svint32_t in_even = svcvt_s32_f64_x(detail::PTrue(d), v);
|
|
return detail::ConcatEvenFull(in_even,
|
|
in_even); // lower half
|
|
}
|
|
|
|
// ------------------------------ ConvertTo F
|
|
|
|
#define HWY_SVE_CONVERT(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
/* signed integers */ \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(int, BITS) v) { \
|
|
return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v); \
|
|
} \
|
|
/* unsigned integers */ \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(uint, BITS) v) { \
|
|
return sv##OP##_##CHAR##BITS##_u##BITS##_x(HWY_SVE_PTRUE(BITS), v); \
|
|
} \
|
|
/* Truncates (rounds toward zero). */ \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_V(int, BITS) \
|
|
NAME(HWY_SVE_D(int, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \
|
|
return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \
|
|
}
|
|
|
|
// API only requires f32 but we provide f64 for use by Iota.
|
|
HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt)
|
|
#undef HWY_SVE_CONVERT
|
|
|
|
// ------------------------------ NearestInt (Round, ConvertTo)
|
|
template <class VF, class DI = RebindToSigned<DFromV<VF>>>
|
|
HWY_API VFromD<DI> NearestInt(VF v) {
|
|
// No single instruction, round then truncate.
|
|
return ConvertTo(DI(), Round(v));
|
|
}
|
|
|
|
// ------------------------------ Iota (Add, ConvertTo)
|
|
|
|
#define HWY_SVE_IOTA(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
|
|
HWY_SVE_T(BASE, BITS) first) { \
|
|
return sv##OP##_##CHAR##BITS(first, 1); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_IOTA, Iota, index)
|
|
#undef HWY_SVE_IOTA
|
|
|
|
template <class D, HWY_IF_FLOAT_D(D)>
|
|
HWY_API VFromD<D> Iota(const D d, TFromD<D> first) {
|
|
const RebindToSigned<D> di;
|
|
return detail::AddN(ConvertTo(d, Iota(di, 0)), first);
|
|
}
|
|
|
|
// ------------------------------ InterleaveLower
|
|
|
|
template <class D, class V>
|
|
HWY_API V InterleaveLower(D d, const V a, const V b) {
|
|
static_assert(IsSame<TFromD<D>, TFromV<V>>(), "D/V mismatch");
|
|
#if HWY_TARGET == HWY_SVE2_128
|
|
(void)d;
|
|
return detail::ZipLowerSame(a, b);
|
|
#else
|
|
// Move lower halves of blocks to lower half of vector.
|
|
const Repartition<uint64_t, decltype(d)> d64;
|
|
const auto a64 = BitCast(d64, a);
|
|
const auto b64 = BitCast(d64, b);
|
|
const auto a_blocks = detail::ConcatEvenFull(a64, a64); // lower half
|
|
const auto b_blocks = detail::ConcatEvenFull(b64, b64);
|
|
return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks));
|
|
#endif
|
|
}
|
|
|
|
template <class V>
|
|
HWY_API V InterleaveLower(const V a, const V b) {
|
|
return InterleaveLower(DFromV<V>(), a, b);
|
|
}
|
|
|
|
// ------------------------------ InterleaveUpper
|
|
|
|
// Only use zip2 if vector are a powers of two, otherwise getting the actual
|
|
// "upper half" requires MaskUpperHalf.
|
|
#if HWY_TARGET == HWY_SVE2_128
|
|
namespace detail {
|
|
// Unlike Highway's ZipUpper, this returns the same type.
|
|
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipUpperSame, zip2)
|
|
} // namespace detail
|
|
#endif
|
|
|
|
// Full vector: guaranteed to have at least one block
|
|
template <class D, class V = VFromD<D>,
|
|
hwy::EnableIf<detail::IsFull(D())>* = nullptr>
|
|
HWY_API V InterleaveUpper(D d, const V a, const V b) {
|
|
#if HWY_TARGET == HWY_SVE2_128
|
|
(void)d;
|
|
return detail::ZipUpperSame(a, b);
|
|
#else
|
|
// Move upper halves of blocks to lower half of vector.
|
|
const Repartition<uint64_t, decltype(d)> d64;
|
|
const auto a64 = BitCast(d64, a);
|
|
const auto b64 = BitCast(d64, b);
|
|
const auto a_blocks = detail::ConcatOddFull(a64, a64); // lower half
|
|
const auto b_blocks = detail::ConcatOddFull(b64, b64);
|
|
return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks));
|
|
#endif
|
|
}
|
|
|
|
// Capped/fraction: need runtime check
|
|
template <class D, class V = VFromD<D>,
|
|
hwy::EnableIf<!detail::IsFull(D())>* = nullptr>
|
|
HWY_API V InterleaveUpper(D d, const V a, const V b) {
|
|
// Less than one block: treat as capped
|
|
if (Lanes(d) * sizeof(TFromD<D>) < 16) {
|
|
const Half<decltype(d)> d2;
|
|
return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b));
|
|
}
|
|
return InterleaveUpper(DFromV<V>(), a, b);
|
|
}
|
|
|
|
// ------------------------------ Per4LaneBlockShuffle
|
|
|
|
namespace detail {
|
|
|
|
template <size_t kLaneSize, size_t kVectSize, class V,
|
|
HWY_IF_NOT_T_SIZE_V(V, 8)>
|
|
HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x88> /*idx_3210_tag*/,
|
|
hwy::SizeTag<kLaneSize> /*lane_size_tag*/,
|
|
hwy::SizeTag<kVectSize> /*vect_size_tag*/,
|
|
V v) {
|
|
const DFromV<decltype(v)> d;
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
const RepartitionToWide<decltype(du)> dw;
|
|
|
|
const auto evens = BitCast(dw, ConcatEvenFull(v, v));
|
|
return BitCast(d, ZipLowerSame(evens, evens));
|
|
}
|
|
|
|
template <size_t kLaneSize, size_t kVectSize, class V,
|
|
HWY_IF_NOT_T_SIZE_V(V, 8)>
|
|
HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xDD> /*idx_3210_tag*/,
|
|
hwy::SizeTag<kLaneSize> /*lane_size_tag*/,
|
|
hwy::SizeTag<kVectSize> /*vect_size_tag*/,
|
|
V v) {
|
|
const DFromV<decltype(v)> d;
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
const RepartitionToWide<decltype(du)> dw;
|
|
|
|
const auto odds = BitCast(dw, ConcatOddFull(v, v));
|
|
return BitCast(d, ZipLowerSame(odds, odds));
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
// ================================================== COMBINE
|
|
|
|
namespace detail {
|
|
|
|
#if HWY_TARGET == HWY_SVE_256 || HWY_IDE
|
|
template <class D, HWY_IF_T_SIZE_D(D, 1)>
|
|
svbool_t MaskLowerHalf(D d) {
|
|
switch (Lanes(d)) {
|
|
case 32:
|
|
return svptrue_pat_b8(SV_VL16);
|
|
case 16:
|
|
return svptrue_pat_b8(SV_VL8);
|
|
case 8:
|
|
return svptrue_pat_b8(SV_VL4);
|
|
case 4:
|
|
return svptrue_pat_b8(SV_VL2);
|
|
default:
|
|
return svptrue_pat_b8(SV_VL1);
|
|
}
|
|
}
|
|
template <class D, HWY_IF_T_SIZE_D(D, 2)>
|
|
svbool_t MaskLowerHalf(D d) {
|
|
switch (Lanes(d)) {
|
|
case 16:
|
|
return svptrue_pat_b16(SV_VL8);
|
|
case 8:
|
|
return svptrue_pat_b16(SV_VL4);
|
|
case 4:
|
|
return svptrue_pat_b16(SV_VL2);
|
|
default:
|
|
return svptrue_pat_b16(SV_VL1);
|
|
}
|
|
}
|
|
template <class D, HWY_IF_T_SIZE_D(D, 4)>
|
|
svbool_t MaskLowerHalf(D d) {
|
|
switch (Lanes(d)) {
|
|
case 8:
|
|
return svptrue_pat_b32(SV_VL4);
|
|
case 4:
|
|
return svptrue_pat_b32(SV_VL2);
|
|
default:
|
|
return svptrue_pat_b32(SV_VL1);
|
|
}
|
|
}
|
|
template <class D, HWY_IF_T_SIZE_D(D, 8)>
|
|
svbool_t MaskLowerHalf(D d) {
|
|
switch (Lanes(d)) {
|
|
case 4:
|
|
return svptrue_pat_b64(SV_VL2);
|
|
default:
|
|
return svptrue_pat_b64(SV_VL1);
|
|
}
|
|
}
|
|
#endif
|
|
#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE
|
|
template <class D, HWY_IF_T_SIZE_D(D, 1)>
|
|
svbool_t MaskLowerHalf(D d) {
|
|
switch (Lanes(d)) {
|
|
case 16:
|
|
return svptrue_pat_b8(SV_VL8);
|
|
case 8:
|
|
return svptrue_pat_b8(SV_VL4);
|
|
case 4:
|
|
return svptrue_pat_b8(SV_VL2);
|
|
case 2:
|
|
case 1:
|
|
default:
|
|
return svptrue_pat_b8(SV_VL1);
|
|
}
|
|
}
|
|
template <class D, HWY_IF_T_SIZE_D(D, 2)>
|
|
svbool_t MaskLowerHalf(D d) {
|
|
switch (Lanes(d)) {
|
|
case 8:
|
|
return svptrue_pat_b16(SV_VL4);
|
|
case 4:
|
|
return svptrue_pat_b16(SV_VL2);
|
|
case 2:
|
|
case 1:
|
|
default:
|
|
return svptrue_pat_b16(SV_VL1);
|
|
}
|
|
}
|
|
template <class D, HWY_IF_T_SIZE_D(D, 4)>
|
|
svbool_t MaskLowerHalf(D d) {
|
|
return svptrue_pat_b32(Lanes(d) == 4 ? SV_VL2 : SV_VL1);
|
|
}
|
|
template <class D, HWY_IF_T_SIZE_D(D, 8)>
|
|
svbool_t MaskLowerHalf(D /*d*/) {
|
|
return svptrue_pat_b64(SV_VL1);
|
|
}
|
|
#endif // HWY_TARGET == HWY_SVE2_128
|
|
#if HWY_TARGET != HWY_SVE_256 && HWY_TARGET != HWY_SVE2_128
|
|
template <class D>
|
|
svbool_t MaskLowerHalf(D d) {
|
|
return FirstN(d, Lanes(d) / 2);
|
|
}
|
|
#endif
|
|
|
|
template <class D>
|
|
svbool_t MaskUpperHalf(D d) {
|
|
// TODO(janwas): WHILEGE on SVE2
|
|
if (HWY_SVE_IS_POW2 && IsFull(d)) {
|
|
return Not(MaskLowerHalf(d));
|
|
}
|
|
|
|
// For Splice to work as intended, make sure bits above Lanes(d) are zero.
|
|
return AndNot(MaskLowerHalf(d), detail::MakeMask(d));
|
|
}
|
|
|
|
// Right-shift vector pair by constexpr; can be used to slide down (=N) or up
|
|
// (=Lanes()-N).
|
|
#define HWY_SVE_EXT(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t kIndex> \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \
|
|
return sv##OP##_##CHAR##BITS(lo, hi, kIndex); \
|
|
}
|
|
HWY_SVE_FOREACH(HWY_SVE_EXT, Ext, ext)
|
|
#undef HWY_SVE_EXT
|
|
|
|
} // namespace detail
|
|
|
|
// ------------------------------ ConcatUpperLower
|
|
template <class D, class V>
|
|
HWY_API V ConcatUpperLower(const D d, const V hi, const V lo) {
|
|
return IfThenElse(detail::MaskLowerHalf(d), lo, hi);
|
|
}
|
|
|
|
// ------------------------------ ConcatLowerLower
|
|
template <class D, class V>
|
|
HWY_API V ConcatLowerLower(const D d, const V hi, const V lo) {
|
|
if (detail::IsFull(d)) {
|
|
#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256
|
|
return detail::ConcatEvenBlocks(hi, lo);
|
|
#endif
|
|
#if HWY_TARGET == HWY_SVE2_128
|
|
const Repartition<uint64_t, D> du64;
|
|
const auto lo64 = BitCast(du64, lo);
|
|
return BitCast(d, InterleaveLower(du64, lo64, BitCast(du64, hi)));
|
|
#endif
|
|
}
|
|
return detail::Splice(hi, lo, detail::MaskLowerHalf(d));
|
|
}
|
|
|
|
// ------------------------------ ConcatLowerUpper
|
|
template <class D, class V>
|
|
HWY_API V ConcatLowerUpper(const D d, const V hi, const V lo) {
|
|
#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 // constexpr Lanes
|
|
if (detail::IsFull(d)) {
|
|
return detail::Ext<Lanes(d) / 2>(hi, lo);
|
|
}
|
|
#endif
|
|
return detail::Splice(hi, lo, detail::MaskUpperHalf(d));
|
|
}
|
|
|
|
// ------------------------------ ConcatUpperUpper
|
|
template <class D, class V>
|
|
HWY_API V ConcatUpperUpper(const D d, const V hi, const V lo) {
|
|
if (detail::IsFull(d)) {
|
|
#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256
|
|
return detail::ConcatOddBlocks(hi, lo);
|
|
#endif
|
|
#if HWY_TARGET == HWY_SVE2_128
|
|
const Repartition<uint64_t, D> du64;
|
|
const auto lo64 = BitCast(du64, lo);
|
|
return BitCast(d, InterleaveUpper(du64, lo64, BitCast(du64, hi)));
|
|
#endif
|
|
}
|
|
const svbool_t mask_upper = detail::MaskUpperHalf(d);
|
|
const V lo_upper = detail::Splice(lo, lo, mask_upper);
|
|
return IfThenElse(mask_upper, hi, lo_upper);
|
|
}
|
|
|
|
// ------------------------------ Combine
|
|
template <class D, class V2>
|
|
HWY_API VFromD<D> Combine(const D d, const V2 hi, const V2 lo) {
|
|
return ConcatLowerLower(d, hi, lo);
|
|
}
|
|
|
|
// ------------------------------ ZeroExtendVector
|
|
template <class D, class V>
|
|
HWY_API V ZeroExtendVector(const D d, const V lo) {
|
|
return Combine(d, Zero(Half<D>()), lo);
|
|
}
|
|
|
|
// ------------------------------ Lower/UpperHalf
|
|
|
|
template <class D2, class V>
|
|
HWY_API V LowerHalf(D2 /* tag */, const V v) {
|
|
return v;
|
|
}
|
|
|
|
template <class V>
|
|
HWY_API V LowerHalf(const V v) {
|
|
return v;
|
|
}
|
|
|
|
template <class DH, class V>
|
|
HWY_API V UpperHalf(const DH dh, const V v) {
|
|
const Twice<decltype(dh)> d;
|
|
// Cast so that we support bfloat16_t.
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
const VFromD<decltype(du)> vu = BitCast(du, v);
|
|
#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 // constexpr Lanes
|
|
return BitCast(d, detail::Ext<Lanes(dh)>(vu, vu));
|
|
#else
|
|
const MFromD<decltype(du)> mask = detail::MaskUpperHalf(du);
|
|
return BitCast(d, detail::Splice(vu, vu, mask));
|
|
#endif
|
|
}
|
|
|
|
// ================================================== REDUCE
|
|
|
|
// These return T, whereas the Highway op returns a broadcasted vector.
|
|
namespace detail {
|
|
#define HWY_SVE_REDUCE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \
|
|
/* The intrinsic returns [u]int64_t; truncate to T so we can broadcast. */ \
|
|
using T = HWY_SVE_T(BASE, BITS); \
|
|
using TU = MakeUnsigned<T>; \
|
|
constexpr uint64_t kMask = LimitsMax<TU>(); \
|
|
return static_cast<T>(static_cast<TU>( \
|
|
static_cast<uint64_t>(sv##OP##_##CHAR##BITS(pg, v)) & kMask)); \
|
|
}
|
|
|
|
#define HWY_SVE_REDUCE(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \
|
|
return sv##OP##_##CHAR##BITS(pg, v); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE_ADD, SumOfLanesM, addv)
|
|
HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, SumOfLanesM, addv)
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MinOfLanesM, minv)
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MaxOfLanesM, maxv)
|
|
// NaN if all are
|
|
HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MinOfLanesM, minnmv)
|
|
HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MaxOfLanesM, maxnmv)
|
|
|
|
#undef HWY_SVE_REDUCE
|
|
#undef HWY_SVE_REDUCE_ADD
|
|
} // namespace detail
|
|
|
|
template <class D, class V>
|
|
V SumOfLanes(D d, V v) {
|
|
return Set(d, detail::SumOfLanesM(detail::MakeMask(d), v));
|
|
}
|
|
|
|
template <class D, class V>
|
|
TFromV<V> ReduceSum(D d, V v) {
|
|
return detail::SumOfLanesM(detail::MakeMask(d), v);
|
|
}
|
|
|
|
template <class D, class V>
|
|
V MinOfLanes(D d, V v) {
|
|
return Set(d, detail::MinOfLanesM(detail::MakeMask(d), v));
|
|
}
|
|
|
|
template <class D, class V>
|
|
V MaxOfLanes(D d, V v) {
|
|
return Set(d, detail::MaxOfLanesM(detail::MakeMask(d), v));
|
|
}
|
|
|
|
// ================================================== SWIZZLE
|
|
|
|
// ------------------------------ GetLane
|
|
|
|
namespace detail {
|
|
#define HWY_SVE_GET_LANE(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_INLINE HWY_SVE_T(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \
|
|
return sv##OP##_##CHAR##BITS(mask, v); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_GET_LANE, GetLaneM, lasta)
|
|
HWY_SVE_FOREACH(HWY_SVE_GET_LANE, ExtractLastMatchingLaneM, lastb)
|
|
#undef HWY_SVE_GET_LANE
|
|
} // namespace detail
|
|
|
|
template <class V>
|
|
HWY_API TFromV<V> GetLane(V v) {
|
|
return detail::GetLaneM(v, detail::PFalse());
|
|
}
|
|
|
|
// ------------------------------ ExtractLane
|
|
template <class V>
|
|
HWY_API TFromV<V> ExtractLane(V v, size_t i) {
|
|
return detail::GetLaneM(v, FirstN(DFromV<V>(), i));
|
|
}
|
|
|
|
// ------------------------------ InsertLane (IfThenElse)
|
|
template <class V>
|
|
HWY_API V InsertLane(const V v, size_t i, TFromV<V> t) {
|
|
const DFromV<V> d;
|
|
const auto is_i = detail::EqN(Iota(d, 0), static_cast<TFromV<V>>(i));
|
|
return IfThenElse(RebindMask(d, is_i), Set(d, t), v);
|
|
}
|
|
|
|
// ------------------------------ DupEven
|
|
|
|
namespace detail {
|
|
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveEven, trn1)
|
|
} // namespace detail
|
|
|
|
template <class V>
|
|
HWY_API V DupEven(const V v) {
|
|
return detail::InterleaveEven(v, v);
|
|
}
|
|
|
|
// ------------------------------ DupOdd
|
|
|
|
namespace detail {
|
|
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveOdd, trn2)
|
|
} // namespace detail
|
|
|
|
template <class V>
|
|
HWY_API V DupOdd(const V v) {
|
|
return detail::InterleaveOdd(v, v);
|
|
}
|
|
|
|
// ------------------------------ OddEven
|
|
|
|
#if HWY_SVE_HAVE_2
|
|
|
|
#define HWY_SVE_ODD_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) odd, HWY_SVE_V(BASE, BITS) even) { \
|
|
return sv##OP##_##CHAR##BITS(even, odd, /*xor=*/0); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_ODD_EVEN, OddEven, eortb_n)
|
|
#undef HWY_SVE_ODD_EVEN
|
|
|
|
template <class V, HWY_IF_FLOAT_V(V)>
|
|
HWY_API V OddEven(const V odd, const V even) {
|
|
const DFromV<V> d;
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
return BitCast(d, OddEven(BitCast(du, odd), BitCast(du, even)));
|
|
}
|
|
|
|
#else
|
|
|
|
template <class V>
|
|
HWY_API V OddEven(const V odd, const V even) {
|
|
const auto odd_in_even = detail::Ext<1>(odd, odd);
|
|
return detail::InterleaveEven(even, odd_in_even);
|
|
}
|
|
|
|
#endif // HWY_TARGET
|
|
|
|
// ------------------------------ OddEvenBlocks
|
|
template <class V>
|
|
HWY_API V OddEvenBlocks(const V odd, const V even) {
|
|
const DFromV<V> d;
|
|
#if HWY_TARGET == HWY_SVE_256
|
|
return ConcatUpperLower(d, odd, even);
|
|
#elif HWY_TARGET == HWY_SVE2_128
|
|
(void)odd;
|
|
(void)d;
|
|
return even;
|
|
#else
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
using TU = TFromD<decltype(du)>;
|
|
constexpr size_t kShift = CeilLog2(16 / sizeof(TU));
|
|
const auto idx_block = ShiftRight<kShift>(Iota(du, 0));
|
|
const auto lsb = detail::AndN(idx_block, static_cast<TU>(1));
|
|
const svbool_t is_even = detail::EqN(lsb, static_cast<TU>(0));
|
|
return IfThenElse(is_even, even, odd);
|
|
#endif
|
|
}
|
|
|
|
// ------------------------------ TableLookupLanes
|
|
|
|
template <class D, class VI>
|
|
HWY_API VFromD<RebindToUnsigned<D>> IndicesFromVec(D d, VI vec) {
|
|
using TI = TFromV<VI>;
|
|
static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index/lane size mismatch");
|
|
const RebindToUnsigned<D> du;
|
|
const auto indices = BitCast(du, vec);
|
|
#if HWY_IS_DEBUG_BUILD
|
|
using TU = MakeUnsigned<TI>;
|
|
const size_t twice_max_lanes = Lanes(d) * 2;
|
|
HWY_DASSERT(AllTrue(
|
|
du, Eq(indices,
|
|
detail::AndN(indices, static_cast<TU>(twice_max_lanes - 1)))));
|
|
#else
|
|
(void)d;
|
|
#endif
|
|
return indices;
|
|
}
|
|
|
|
template <class D, typename TI>
|
|
HWY_API VFromD<RebindToUnsigned<D>> SetTableIndices(D d, const TI* idx) {
|
|
static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index size must match lane");
|
|
return IndicesFromVec(d, LoadU(Rebind<TI, D>(), idx));
|
|
}
|
|
|
|
#define HWY_SVE_TABLE(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(uint, BITS) idx) { \
|
|
return sv##OP##_##CHAR##BITS(v, idx); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_TABLE, TableLookupLanes, tbl)
|
|
#undef HWY_SVE_TABLE
|
|
|
|
#if HWY_SVE_HAVE_2
|
|
namespace detail {
|
|
#define HWY_SVE_TABLE2(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_TUPLE(BASE, BITS, 2) tuple, HWY_SVE_V(uint, BITS) idx) { \
|
|
return sv##OP##_##CHAR##BITS(tuple, idx); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_TABLE2, NativeTwoTableLookupLanes, tbl2)
|
|
#undef HWY_SVE_TABLE
|
|
} // namespace detail
|
|
#endif // HWY_SVE_HAVE_2
|
|
|
|
template <class D>
|
|
HWY_API VFromD<D> TwoTablesLookupLanes(D d, VFromD<D> a, VFromD<D> b,
|
|
VFromD<RebindToUnsigned<D>> idx) {
|
|
// SVE2 has an instruction for this, but it only works for full 2^n vectors.
|
|
#if HWY_SVE_HAVE_2 && HWY_SVE_IS_POW2
|
|
if (detail::IsFull(d)) {
|
|
return detail::NativeTwoTableLookupLanes(Create2(d, a, b), idx);
|
|
}
|
|
#endif
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
using TU = TFromD<decltype(du)>;
|
|
|
|
const size_t num_of_lanes = Lanes(d);
|
|
const auto idx_mod = detail::AndN(idx, static_cast<TU>(num_of_lanes - 1));
|
|
const auto sel_a_mask = Eq(idx, idx_mod);
|
|
|
|
const auto a_lookup_result = TableLookupLanes(a, idx_mod);
|
|
const auto b_lookup_result = TableLookupLanes(b, idx_mod);
|
|
return IfThenElse(sel_a_mask, a_lookup_result, b_lookup_result);
|
|
}
|
|
|
|
template <class V>
|
|
HWY_API V TwoTablesLookupLanes(V a, V b,
|
|
VFromD<RebindToUnsigned<DFromV<V>>> idx) {
|
|
const DFromV<decltype(a)> d;
|
|
return TwoTablesLookupLanes(d, a, b, idx);
|
|
}
|
|
|
|
// ------------------------------ SwapAdjacentBlocks (TableLookupLanes)
|
|
|
|
namespace detail {
|
|
|
|
template <typename T, size_t N, int kPow2>
|
|
constexpr size_t LanesPerBlock(Simd<T, N, kPow2> d) {
|
|
// We might have a capped vector smaller than a block, so honor that.
|
|
return HWY_MIN(16 / sizeof(T), MaxLanes(d));
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
template <class V>
|
|
HWY_API V SwapAdjacentBlocks(const V v) {
|
|
const DFromV<V> d;
|
|
#if HWY_TARGET == HWY_SVE_256
|
|
return ConcatLowerUpper(d, v, v);
|
|
#elif HWY_TARGET == HWY_SVE2_128
|
|
(void)d;
|
|
return v;
|
|
#else
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
constexpr auto kLanesPerBlock =
|
|
static_cast<TFromD<decltype(du)>>(detail::LanesPerBlock(d));
|
|
const VFromD<decltype(du)> idx = detail::XorN(Iota(du, 0), kLanesPerBlock);
|
|
return TableLookupLanes(v, idx);
|
|
#endif
|
|
}
|
|
|
|
// ------------------------------ Reverse
|
|
|
|
namespace detail {
|
|
|
|
#define HWY_SVE_REVERSE(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
|
|
return sv##OP##_##CHAR##BITS(v); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_REVERSE, ReverseFull, rev)
|
|
#undef HWY_SVE_REVERSE
|
|
|
|
} // namespace detail
|
|
|
|
template <class D, class V>
|
|
HWY_API V Reverse(D d, V v) {
|
|
using T = TFromD<D>;
|
|
const auto reversed = detail::ReverseFull(v);
|
|
if (HWY_SVE_IS_POW2 && detail::IsFull(d)) return reversed;
|
|
// Shift right to remove extra (non-pow2 and remainder) lanes.
|
|
// TODO(janwas): on SVE2, use WHILEGE.
|
|
// Avoids FirstN truncating to the return vector size. Must also avoid Not
|
|
// because that is limited to SV_POW2.
|
|
const ScalableTag<T> dfull;
|
|
const svbool_t all_true = detail::AllPTrue(dfull);
|
|
const size_t all_lanes = detail::AllHardwareLanes<T>();
|
|
const size_t want_lanes = Lanes(d);
|
|
HWY_DASSERT(want_lanes <= all_lanes);
|
|
const svbool_t mask =
|
|
svnot_b_z(all_true, FirstN(dfull, all_lanes - want_lanes));
|
|
return detail::Splice(reversed, reversed, mask);
|
|
}
|
|
|
|
// ------------------------------ Reverse2
|
|
|
|
// Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8.
|
|
#ifdef HWY_NATIVE_REVERSE2_8
|
|
#undef HWY_NATIVE_REVERSE2_8
|
|
#else
|
|
#define HWY_NATIVE_REVERSE2_8
|
|
#endif
|
|
|
|
template <class D, HWY_IF_T_SIZE_D(D, 1)>
|
|
HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) {
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
const RepartitionToWide<decltype(du)> dw;
|
|
return BitCast(d, svrevb_u16_x(detail::PTrue(d), BitCast(dw, v)));
|
|
}
|
|
|
|
template <class D, HWY_IF_T_SIZE_D(D, 2)>
|
|
HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) {
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
const RepartitionToWide<decltype(du)> dw;
|
|
return BitCast(d, svrevh_u32_x(detail::PTrue(d), BitCast(dw, v)));
|
|
}
|
|
|
|
template <class D, HWY_IF_T_SIZE_D(D, 4)>
|
|
HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) {
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
const RepartitionToWide<decltype(du)> dw;
|
|
return BitCast(d, svrevw_u64_x(detail::PTrue(d), BitCast(dw, v)));
|
|
}
|
|
|
|
template <class D, HWY_IF_T_SIZE_D(D, 8)>
|
|
HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { // 3210
|
|
#if HWY_TARGET == HWY_SVE2_128
|
|
if (detail::IsFull(d)) {
|
|
return detail::Ext<1>(v, v);
|
|
}
|
|
#endif
|
|
(void)d;
|
|
const auto odd_in_even = detail::Ext<1>(v, v); // x321
|
|
return detail::InterleaveEven(odd_in_even, v); // 2301
|
|
}
|
|
|
|
// ------------------------------ Reverse4 (TableLookupLanes)
|
|
|
|
template <class D, HWY_IF_T_SIZE_D(D, 1)>
|
|
HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) {
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
const RepartitionToWide<RepartitionToWide<decltype(du)>> du32;
|
|
return BitCast(d, svrevb_u32_x(detail::PTrue(d), BitCast(du32, v)));
|
|
}
|
|
|
|
template <class D, HWY_IF_T_SIZE_D(D, 2)>
|
|
HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) {
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
const RepartitionToWide<RepartitionToWide<decltype(du)>> du64;
|
|
return BitCast(d, svrevh_u64_x(detail::PTrue(d), BitCast(du64, v)));
|
|
}
|
|
|
|
template <class D, HWY_IF_T_SIZE_D(D, 4)>
|
|
HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) {
|
|
if (HWY_TARGET == HWY_SVE2_128 && detail::IsFull(d)) {
|
|
return detail::ReverseFull(v);
|
|
}
|
|
// TODO(janwas): is this approach faster than Shuffle0123?
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
const auto idx = detail::XorN(Iota(du, 0), 3);
|
|
return TableLookupLanes(v, idx);
|
|
}
|
|
|
|
template <class D, HWY_IF_T_SIZE_D(D, 8)>
|
|
HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) {
|
|
if (HWY_TARGET == HWY_SVE_256 && detail::IsFull(d)) {
|
|
return detail::ReverseFull(v);
|
|
}
|
|
// TODO(janwas): is this approach faster than Shuffle0123?
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
const auto idx = detail::XorN(Iota(du, 0), 3);
|
|
return TableLookupLanes(v, idx);
|
|
}
|
|
|
|
// ------------------------------ Reverse8 (TableLookupLanes)
|
|
|
|
template <class D, HWY_IF_T_SIZE_D(D, 1)>
|
|
HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) {
|
|
const Repartition<uint64_t, decltype(d)> du64;
|
|
return BitCast(d, svrevb_u64_x(detail::PTrue(d), BitCast(du64, v)));
|
|
}
|
|
|
|
template <class D, HWY_IF_NOT_T_SIZE_D(D, 1)>
|
|
HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) {
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
const auto idx = detail::XorN(Iota(du, 0), 7);
|
|
return TableLookupLanes(v, idx);
|
|
}
|
|
|
|
// ------------------------------- ReverseBits
|
|
|
|
#ifdef HWY_NATIVE_REVERSE_BITS_UI8
|
|
#undef HWY_NATIVE_REVERSE_BITS_UI8
|
|
#else
|
|
#define HWY_NATIVE_REVERSE_BITS_UI8
|
|
#endif
|
|
|
|
#ifdef HWY_NATIVE_REVERSE_BITS_UI16_32_64
|
|
#undef HWY_NATIVE_REVERSE_BITS_UI16_32_64
|
|
#else
|
|
#define HWY_NATIVE_REVERSE_BITS_UI16_32_64
|
|
#endif
|
|
|
|
#define HWY_SVE_REVERSE_BITS(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
|
|
const DFromV<decltype(v)> d; \
|
|
return sv##OP##_##CHAR##BITS##_x(detail::PTrue(d), v); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_REVERSE_BITS, ReverseBits, rbit)
|
|
#undef HWY_SVE_REVERSE_BITS
|
|
|
|
// ------------------------------ SlideUpLanes
|
|
|
|
template <class D>
|
|
HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) {
|
|
return detail::Splice(v, Zero(d), FirstN(d, amt));
|
|
}
|
|
|
|
// ------------------------------ Slide1Up
|
|
|
|
#ifdef HWY_NATIVE_SLIDE1_UP_DOWN
|
|
#undef HWY_NATIVE_SLIDE1_UP_DOWN
|
|
#else
|
|
#define HWY_NATIVE_SLIDE1_UP_DOWN
|
|
#endif
|
|
|
|
template <class D>
|
|
HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) {
|
|
return SlideUpLanes(d, v, 1);
|
|
}
|
|
|
|
// ------------------------------ SlideDownLanes (TableLookupLanes)
|
|
|
|
template <class D>
|
|
HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) {
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
using TU = TFromD<decltype(du)>;
|
|
const auto idx = Iota(du, static_cast<TU>(amt));
|
|
return IfThenElseZero(FirstN(d, Lanes(d) - amt), TableLookupLanes(v, idx));
|
|
}
|
|
|
|
// ------------------------------ Slide1Down
|
|
|
|
template <class D>
|
|
HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) {
|
|
return SlideDownLanes(d, v, 1);
|
|
}
|
|
|
|
// ------------------------------ Block insert/extract/broadcast ops
|
|
#if HWY_TARGET != HWY_SVE2_128
|
|
|
|
#ifdef HWY_NATIVE_BLK_INSERT_EXTRACT
|
|
#undef HWY_NATIVE_BLK_INSERT_EXTRACT
|
|
#else
|
|
#define HWY_NATIVE_BLK_INSERT_EXTRACT
|
|
#endif
|
|
|
|
template <int kBlockIdx, class V>
|
|
HWY_API V InsertBlock(V v, V blk_to_insert) {
|
|
const DFromV<decltype(v)> d;
|
|
static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(),
|
|
"Invalid block index");
|
|
|
|
#if HWY_TARGET == HWY_SVE_256
|
|
return (kBlockIdx == 0) ? ConcatUpperLower(d, v, blk_to_insert)
|
|
: ConcatLowerLower(d, blk_to_insert, v);
|
|
#else
|
|
constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d);
|
|
|
|
constexpr size_t kBlockOffset =
|
|
static_cast<size_t>(kBlockIdx) * kLanesPerBlock;
|
|
const auto splice_mask = FirstN(d, kBlockOffset);
|
|
const auto sel_lo_mask = FirstN(d, kBlockOffset + kLanesPerBlock);
|
|
|
|
const auto splice_result = detail::Splice(blk_to_insert, v, splice_mask);
|
|
return IfThenElse(sel_lo_mask, splice_result, v);
|
|
#endif
|
|
}
|
|
|
|
template <int kBlockIdx, class V>
|
|
HWY_API V ExtractBlock(V v) {
|
|
const DFromV<decltype(v)> d;
|
|
static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(),
|
|
"Invalid block index");
|
|
|
|
if (kBlockIdx == 0) return v;
|
|
|
|
#if HWY_TARGET == HWY_SVE_256
|
|
return UpperHalf(Half<decltype(d)>(), v);
|
|
#else
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
using TU = TFromD<decltype(du)>;
|
|
constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d);
|
|
constexpr size_t kBlockOffset =
|
|
static_cast<size_t>(kBlockIdx) * kLanesPerBlock;
|
|
const auto splice_mask =
|
|
RebindMask(d, detail::LtN(Iota(du, static_cast<TU>(0u - kBlockOffset)),
|
|
static_cast<TU>(kLanesPerBlock)));
|
|
return detail::Splice(v, v, splice_mask);
|
|
#endif
|
|
}
|
|
|
|
template <int kBlockIdx, class V>
|
|
HWY_API V BroadcastBlock(V v) {
|
|
const DFromV<decltype(v)> d;
|
|
static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(),
|
|
"Invalid block index");
|
|
|
|
#if HWY_TARGET == HWY_SVE_256
|
|
return (kBlockIdx == 0) ? ConcatLowerLower(d, v, v)
|
|
: ConcatUpperUpper(d, v, v);
|
|
#else
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
using TU = TFromD<decltype(du)>;
|
|
constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d);
|
|
constexpr size_t kBlockOffset =
|
|
static_cast<size_t>(kBlockIdx) * kLanesPerBlock;
|
|
|
|
const auto idx = detail::AddN(
|
|
detail::AndN(Iota(du, TU{0}), static_cast<TU>(kLanesPerBlock - 1)),
|
|
static_cast<TU>(kBlockOffset));
|
|
return TableLookupLanes(v, idx);
|
|
#endif
|
|
}
|
|
|
|
#endif // HWY_TARGET != HWY_SVE2_128
|
|
|
|
// ------------------------------ Compress (PromoteTo)
|
|
|
|
template <typename T>
|
|
struct CompressIsPartition {
|
|
#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128
|
|
// Optimization for 64-bit lanes (could also be applied to 32-bit, but that
|
|
// requires a larger table).
|
|
enum { value = (sizeof(T) == 8) };
|
|
#else
|
|
enum { value = 0 };
|
|
#endif // HWY_TARGET == HWY_SVE_256
|
|
};
|
|
|
|
#define HWY_SVE_COMPRESS(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \
|
|
return sv##OP##_##CHAR##BITS(mask, v); \
|
|
}
|
|
|
|
#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128
|
|
HWY_SVE_FOREACH_UI32(HWY_SVE_COMPRESS, Compress, compact)
|
|
HWY_SVE_FOREACH_F32(HWY_SVE_COMPRESS, Compress, compact)
|
|
#else
|
|
HWY_SVE_FOREACH_UIF3264(HWY_SVE_COMPRESS, Compress, compact)
|
|
#endif
|
|
#undef HWY_SVE_COMPRESS
|
|
|
|
#if HWY_TARGET == HWY_SVE_256 || HWY_IDE
|
|
template <class V, HWY_IF_T_SIZE_V(V, 8)>
|
|
HWY_API V Compress(V v, svbool_t mask) {
|
|
const DFromV<V> d;
|
|
const RebindToUnsigned<decltype(d)> du64;
|
|
|
|
// Convert mask into bitfield via horizontal sum (faster than ORV) of masked
|
|
// bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for
|
|
// SetTableIndices.
|
|
const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2));
|
|
const size_t offset = detail::SumOfLanesM(mask, bits);
|
|
|
|
// See CompressIsPartition.
|
|
alignas(16) static constexpr uint64_t table[4 * 16] = {
|
|
// PrintCompress64x4Tables
|
|
0, 1, 2, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 2, 0, 1, 3, 0, 2,
|
|
1, 3, 1, 2, 0, 3, 0, 1, 2, 3, 3, 0, 1, 2, 0, 3, 1, 2, 1, 3, 0, 2,
|
|
0, 1, 3, 2, 2, 3, 0, 1, 0, 2, 3, 1, 1, 2, 3, 0, 0, 1, 2, 3};
|
|
return TableLookupLanes(v, SetTableIndices(d, table + offset));
|
|
}
|
|
|
|
#endif // HWY_TARGET == HWY_SVE_256
|
|
#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE
|
|
template <class V, HWY_IF_T_SIZE_V(V, 8)>
|
|
HWY_API V Compress(V v, svbool_t mask) {
|
|
// If mask == 10: swap via splice. A mask of 00 or 11 leaves v unchanged, 10
|
|
// swaps upper/lower (the lower half is set to the upper half, and the
|
|
// remaining upper half is filled from the lower half of the second v), and
|
|
// 01 is invalid because it would ConcatLowerLower. zip1 and AndNot keep 10
|
|
// unchanged and map everything else to 00.
|
|
const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane
|
|
return detail::Splice(v, v, AndNot(maskLL, mask));
|
|
}
|
|
|
|
#endif // HWY_TARGET == HWY_SVE2_128
|
|
|
|
template <class V, HWY_IF_T_SIZE_V(V, 2)>
|
|
HWY_API V Compress(V v, svbool_t mask16) {
|
|
static_assert(!IsSame<V, svfloat16_t>(), "Must use overload");
|
|
const DFromV<V> d16;
|
|
|
|
// Promote vector and mask to 32-bit
|
|
const RepartitionToWide<decltype(d16)> dw;
|
|
const auto v32L = PromoteTo(dw, v);
|
|
const auto v32H = detail::PromoteUpperTo(dw, v);
|
|
const svbool_t mask32L = svunpklo_b(mask16);
|
|
const svbool_t mask32H = svunpkhi_b(mask16);
|
|
|
|
const auto compressedL = Compress(v32L, mask32L);
|
|
const auto compressedH = Compress(v32H, mask32H);
|
|
|
|
// Demote to 16-bit (already in range) - separately so we can splice
|
|
const V evenL = BitCast(d16, compressedL);
|
|
const V evenH = BitCast(d16, compressedH);
|
|
const V v16L = detail::ConcatEvenFull(evenL, evenL); // lower half
|
|
const V v16H = detail::ConcatEvenFull(evenH, evenH);
|
|
|
|
// We need to combine two vectors of non-constexpr length, so the only option
|
|
// is Splice, which requires us to synthesize a mask. NOTE: this function uses
|
|
// full vectors (SV_ALL instead of SV_POW2), hence we need unmasked svcnt.
|
|
const size_t countL = detail::CountTrueFull(dw, mask32L);
|
|
const auto compressed_maskL = FirstN(d16, countL);
|
|
return detail::Splice(v16H, v16L, compressed_maskL);
|
|
}
|
|
|
|
// Must treat float16_t as integers so we can ConcatEven.
|
|
HWY_API svfloat16_t Compress(svfloat16_t v, svbool_t mask16) {
|
|
const DFromV<decltype(v)> df;
|
|
const RebindToSigned<decltype(df)> di;
|
|
return BitCast(df, Compress(BitCast(di, v), mask16));
|
|
}
|
|
|
|
// ------------------------------ CompressNot
|
|
|
|
// 2 or 4 bytes
|
|
template <class V, HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 2) | (1 << 4))>
|
|
HWY_API V CompressNot(V v, const svbool_t mask) {
|
|
return Compress(v, Not(mask));
|
|
}
|
|
|
|
template <class V, HWY_IF_T_SIZE_V(V, 8)>
|
|
HWY_API V CompressNot(V v, svbool_t mask) {
|
|
#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE
|
|
// If mask == 01: swap via splice. A mask of 00 or 11 leaves v unchanged, 10
|
|
// swaps upper/lower (the lower half is set to the upper half, and the
|
|
// remaining upper half is filled from the lower half of the second v), and
|
|
// 01 is invalid because it would ConcatLowerLower. zip1 and AndNot map
|
|
// 01 to 10, and everything else to 00.
|
|
const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane
|
|
return detail::Splice(v, v, AndNot(mask, maskLL));
|
|
#endif
|
|
#if HWY_TARGET == HWY_SVE_256 || HWY_IDE
|
|
const DFromV<V> d;
|
|
const RebindToUnsigned<decltype(d)> du64;
|
|
|
|
// Convert mask into bitfield via horizontal sum (faster than ORV) of masked
|
|
// bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for
|
|
// SetTableIndices.
|
|
const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2));
|
|
const size_t offset = detail::SumOfLanesM(mask, bits);
|
|
|
|
// See CompressIsPartition.
|
|
alignas(16) static constexpr uint64_t table[4 * 16] = {
|
|
// PrintCompressNot64x4Tables
|
|
0, 1, 2, 3, 1, 2, 3, 0, 0, 2, 3, 1, 2, 3, 0, 1, 0, 1, 3, 2, 1, 3,
|
|
0, 2, 0, 3, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 0, 3, 0, 2, 1, 3,
|
|
2, 0, 1, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3};
|
|
return TableLookupLanes(v, SetTableIndices(d, table + offset));
|
|
#endif // HWY_TARGET == HWY_SVE_256
|
|
|
|
return Compress(v, Not(mask));
|
|
}
|
|
|
|
// ------------------------------ CompressBlocksNot
|
|
HWY_API svuint64_t CompressBlocksNot(svuint64_t v, svbool_t mask) {
|
|
#if HWY_TARGET == HWY_SVE2_128
|
|
(void)mask;
|
|
return v;
|
|
#endif
|
|
#if HWY_TARGET == HWY_SVE_256 || HWY_IDE
|
|
uint64_t bits = 0; // predicate reg is 32-bit
|
|
CopyBytes<4>(&mask, &bits); // not same size - 64-bit more efficient
|
|
// Concatenate LSB for upper and lower blocks, pre-scale by 4 for table idx.
|
|
const size_t offset = ((bits & 1) ? 4u : 0u) + ((bits & 0x10000) ? 8u : 0u);
|
|
// See CompressIsPartition. Manually generated; flip halves if mask = [0, 1].
|
|
alignas(16) static constexpr uint64_t table[4 * 4] = {0, 1, 2, 3, 2, 3, 0, 1,
|
|
0, 1, 2, 3, 0, 1, 2, 3};
|
|
const ScalableTag<uint64_t> d;
|
|
return TableLookupLanes(v, SetTableIndices(d, table + offset));
|
|
#endif
|
|
|
|
return CompressNot(v, mask);
|
|
}
|
|
|
|
// ------------------------------ CompressStore
|
|
template <class V, class D, HWY_IF_NOT_T_SIZE_D(D, 1)>
|
|
HWY_API size_t CompressStore(const V v, const svbool_t mask, const D d,
|
|
TFromD<D>* HWY_RESTRICT unaligned) {
|
|
StoreU(Compress(v, mask), d, unaligned);
|
|
return CountTrue(d, mask);
|
|
}
|
|
|
|
// ------------------------------ CompressBlendedStore
|
|
template <class V, class D, HWY_IF_NOT_T_SIZE_D(D, 1)>
|
|
HWY_API size_t CompressBlendedStore(const V v, const svbool_t mask, const D d,
|
|
TFromD<D>* HWY_RESTRICT unaligned) {
|
|
const size_t count = CountTrue(d, mask);
|
|
const svbool_t store_mask = FirstN(d, count);
|
|
BlendedStore(Compress(v, mask), store_mask, d, unaligned);
|
|
return count;
|
|
}
|
|
|
|
// ================================================== MASK (2)
|
|
|
|
// ------------------------------ FindKnownLastTrue
|
|
template <class D>
|
|
HWY_API size_t FindKnownLastTrue(D d, svbool_t m) {
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
return static_cast<size_t>(detail::ExtractLastMatchingLaneM(
|
|
Iota(du, 0), And(m, detail::MakeMask(d))));
|
|
}
|
|
|
|
// ------------------------------ FindLastTrue
|
|
template <class D>
|
|
HWY_API intptr_t FindLastTrue(D d, svbool_t m) {
|
|
return AllFalse(d, m) ? intptr_t{-1}
|
|
: static_cast<intptr_t>(FindKnownLastTrue(d, m));
|
|
}
|
|
|
|
// ================================================== BLOCKWISE
|
|
|
|
// ------------------------------ CombineShiftRightBytes
|
|
|
|
// Prevent accidentally using these for 128-bit vectors - should not be
|
|
// necessary.
|
|
#if HWY_TARGET != HWY_SVE2_128
|
|
namespace detail {
|
|
|
|
// For x86-compatible behaviour mandated by Highway API: TableLookupBytes
|
|
// offsets are implicitly relative to the start of their 128-bit block.
|
|
template <class D, class V>
|
|
HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) {
|
|
using T = MakeUnsigned<TFromD<D>>;
|
|
return detail::AndNotN(static_cast<T>(LanesPerBlock(d) - 1), iota0);
|
|
}
|
|
|
|
template <size_t kLanes, class D, HWY_IF_T_SIZE_D(D, 1)>
|
|
svbool_t FirstNPerBlock(D d) {
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du);
|
|
const svuint8_t idx_mod =
|
|
svdupq_n_u8(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock,
|
|
3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock,
|
|
6 % kLanesPerBlock, 7 % kLanesPerBlock, 8 % kLanesPerBlock,
|
|
9 % kLanesPerBlock, 10 % kLanesPerBlock, 11 % kLanesPerBlock,
|
|
12 % kLanesPerBlock, 13 % kLanesPerBlock, 14 % kLanesPerBlock,
|
|
15 % kLanesPerBlock);
|
|
return detail::LtN(BitCast(du, idx_mod), kLanes);
|
|
}
|
|
template <size_t kLanes, class D, HWY_IF_T_SIZE_D(D, 2)>
|
|
svbool_t FirstNPerBlock(D d) {
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du);
|
|
const svuint16_t idx_mod =
|
|
svdupq_n_u16(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock,
|
|
3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock,
|
|
6 % kLanesPerBlock, 7 % kLanesPerBlock);
|
|
return detail::LtN(BitCast(du, idx_mod), kLanes);
|
|
}
|
|
template <size_t kLanes, class D, HWY_IF_T_SIZE_D(D, 4)>
|
|
svbool_t FirstNPerBlock(D d) {
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du);
|
|
const svuint32_t idx_mod =
|
|
svdupq_n_u32(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock,
|
|
3 % kLanesPerBlock);
|
|
return detail::LtN(BitCast(du, idx_mod), kLanes);
|
|
}
|
|
template <size_t kLanes, class D, HWY_IF_T_SIZE_D(D, 8)>
|
|
svbool_t FirstNPerBlock(D d) {
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du);
|
|
const svuint64_t idx_mod =
|
|
svdupq_n_u64(0 % kLanesPerBlock, 1 % kLanesPerBlock);
|
|
return detail::LtN(BitCast(du, idx_mod), kLanes);
|
|
}
|
|
|
|
} // namespace detail
|
|
#endif // HWY_TARGET != HWY_SVE2_128
|
|
|
|
template <size_t kBytes, class D, class V = VFromD<D>>
|
|
HWY_API V CombineShiftRightBytes(const D d, const V hi, const V lo) {
|
|
const Repartition<uint8_t, decltype(d)> d8;
|
|
const auto hi8 = BitCast(d8, hi);
|
|
const auto lo8 = BitCast(d8, lo);
|
|
#if HWY_TARGET == HWY_SVE2_128
|
|
return BitCast(d, detail::Ext<kBytes>(hi8, lo8));
|
|
#else
|
|
const auto hi_up = detail::Splice(hi8, hi8, FirstN(d8, 16 - kBytes));
|
|
const auto lo_down = detail::Ext<kBytes>(lo8, lo8);
|
|
const svbool_t is_lo = detail::FirstNPerBlock<16 - kBytes>(d8);
|
|
return BitCast(d, IfThenElse(is_lo, lo_down, hi_up));
|
|
#endif
|
|
}
|
|
|
|
// ------------------------------ Shuffle2301
|
|
template <class V>
|
|
HWY_API V Shuffle2301(const V v) {
|
|
const DFromV<V> d;
|
|
static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
|
|
return Reverse2(d, v);
|
|
}
|
|
|
|
// ------------------------------ Shuffle2103
|
|
template <class V>
|
|
HWY_API V Shuffle2103(const V v) {
|
|
const DFromV<V> d;
|
|
const Repartition<uint8_t, decltype(d)> d8;
|
|
static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
|
|
const svuint8_t v8 = BitCast(d8, v);
|
|
return BitCast(d, CombineShiftRightBytes<12>(d8, v8, v8));
|
|
}
|
|
|
|
// ------------------------------ Shuffle0321
|
|
template <class V>
|
|
HWY_API V Shuffle0321(const V v) {
|
|
const DFromV<V> d;
|
|
const Repartition<uint8_t, decltype(d)> d8;
|
|
static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
|
|
const svuint8_t v8 = BitCast(d8, v);
|
|
return BitCast(d, CombineShiftRightBytes<4>(d8, v8, v8));
|
|
}
|
|
|
|
// ------------------------------ Shuffle1032
|
|
template <class V>
|
|
HWY_API V Shuffle1032(const V v) {
|
|
const DFromV<V> d;
|
|
const Repartition<uint8_t, decltype(d)> d8;
|
|
static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types");
|
|
const svuint8_t v8 = BitCast(d8, v);
|
|
return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8));
|
|
}
|
|
|
|
// ------------------------------ Shuffle01
|
|
template <class V>
|
|
HWY_API V Shuffle01(const V v) {
|
|
const DFromV<V> d;
|
|
const Repartition<uint8_t, decltype(d)> d8;
|
|
static_assert(sizeof(TFromD<decltype(d)>) == 8, "Defined for 64-bit types");
|
|
const svuint8_t v8 = BitCast(d8, v);
|
|
return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8));
|
|
}
|
|
|
|
// ------------------------------ Shuffle0123
|
|
template <class V>
|
|
HWY_API V Shuffle0123(const V v) {
|
|
return Shuffle2301(Shuffle1032(v));
|
|
}
|
|
|
|
// ------------------------------ ReverseBlocks (Reverse, Shuffle01)
|
|
template <class D, class V = VFromD<D>>
|
|
HWY_API V ReverseBlocks(D d, V v) {
|
|
#if HWY_TARGET == HWY_SVE_256
|
|
if (detail::IsFull(d)) {
|
|
return SwapAdjacentBlocks(v);
|
|
} else if (detail::IsFull(Twice<D>())) {
|
|
return v;
|
|
}
|
|
#elif HWY_TARGET == HWY_SVE2_128
|
|
(void)d;
|
|
return v;
|
|
#endif
|
|
const Repartition<uint64_t, D> du64;
|
|
return BitCast(d, Shuffle01(Reverse(du64, BitCast(du64, v))));
|
|
}
|
|
|
|
// ------------------------------ TableLookupBytes
|
|
|
|
template <class V, class VI>
|
|
HWY_API VI TableLookupBytes(const V v, const VI idx) {
|
|
const DFromV<VI> d;
|
|
const Repartition<uint8_t, decltype(d)> du8;
|
|
#if HWY_TARGET == HWY_SVE2_128
|
|
return BitCast(d, TableLookupLanes(BitCast(du8, v), BitCast(du8, idx)));
|
|
#else
|
|
const auto offsets128 = detail::OffsetsOf128BitBlocks(du8, Iota(du8, 0));
|
|
const auto idx8 = Add(BitCast(du8, idx), offsets128);
|
|
return BitCast(d, TableLookupLanes(BitCast(du8, v), idx8));
|
|
#endif
|
|
}
|
|
|
|
template <class V, class VI>
|
|
HWY_API VI TableLookupBytesOr0(const V v, const VI idx) {
|
|
const DFromV<VI> d;
|
|
// Mask size must match vector type, so cast everything to this type.
|
|
const Repartition<int8_t, decltype(d)> di8;
|
|
|
|
auto idx8 = BitCast(di8, idx);
|
|
const auto msb = detail::LtN(idx8, 0);
|
|
|
|
const auto lookup = TableLookupBytes(BitCast(di8, v), idx8);
|
|
return BitCast(d, IfThenZeroElse(msb, lookup));
|
|
}
|
|
|
|
// ------------------------------ Broadcast
|
|
|
|
#ifdef HWY_NATIVE_BROADCASTLANE
|
|
#undef HWY_NATIVE_BROADCASTLANE
|
|
#else
|
|
#define HWY_NATIVE_BROADCASTLANE
|
|
#endif
|
|
|
|
namespace detail {
|
|
#define HWY_SVE_BROADCAST(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <int kLane> \
|
|
HWY_INLINE HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
|
|
return sv##OP##_##CHAR##BITS(v, kLane); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH(HWY_SVE_BROADCAST, BroadcastLane, dup_lane)
|
|
#undef HWY_SVE_BROADCAST
|
|
} // namespace detail
|
|
|
|
template <int kLane, class V>
|
|
HWY_API V Broadcast(const V v) {
|
|
const DFromV<V> d;
|
|
const RebindToUnsigned<decltype(d)> du;
|
|
constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du);
|
|
static_assert(0 <= kLane && kLane < kLanesPerBlock, "Invalid lane");
|
|
#if HWY_TARGET == HWY_SVE2_128
|
|
return detail::BroadcastLane<kLane>(v);
|
|
#else
|
|
auto idx = detail::OffsetsOf128BitBlocks(du, Iota(du, 0));
|
|
if (kLane != 0) {
|
|
idx = detail::AddN(idx, kLane);
|
|
}
|
|
return TableLookupLanes(v, idx);
|
|
#endif
|
|
}
|
|
|
|
template <int kLane, class V>
|
|
HWY_API V BroadcastLane(const V v) {
|
|
static_assert(0 <= kLane && kLane < HWY_MAX_LANES_V(V), "Invalid lane");
|
|
return detail::BroadcastLane<kLane>(v);
|
|
}
|
|
|
|
// ------------------------------ ShiftLeftLanes
|
|
|
|
template <size_t kLanes, class D, class V = VFromD<D>>
|
|
HWY_API V ShiftLeftLanes(D d, const V v) {
|
|
const auto zero = Zero(d);
|
|
const auto shifted = detail::Splice(v, zero, FirstN(d, kLanes));
|
|
#if HWY_TARGET == HWY_SVE2_128
|
|
return shifted;
|
|
#else
|
|
// Match x86 semantics by zeroing lower lanes in 128-bit blocks
|
|
return IfThenElse(detail::FirstNPerBlock<kLanes>(d), zero, shifted);
|
|
#endif
|
|
}
|
|
|
|
template <size_t kLanes, class V>
|
|
HWY_API V ShiftLeftLanes(const V v) {
|
|
return ShiftLeftLanes<kLanes>(DFromV<V>(), v);
|
|
}
|
|
|
|
// ------------------------------ ShiftRightLanes
|
|
template <size_t kLanes, class D, class V = VFromD<D>>
|
|
HWY_API V ShiftRightLanes(D d, V v) {
|
|
// For capped/fractional vectors, clear upper lanes so we shift in zeros.
|
|
if (!detail::IsFull(d)) {
|
|
v = IfThenElseZero(detail::MakeMask(d), v);
|
|
}
|
|
|
|
#if HWY_TARGET == HWY_SVE2_128
|
|
return detail::Ext<kLanes>(Zero(d), v);
|
|
#else
|
|
const auto shifted = detail::Ext<kLanes>(v, v);
|
|
// Match x86 semantics by zeroing upper lanes in 128-bit blocks
|
|
constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d);
|
|
const svbool_t mask = detail::FirstNPerBlock<kLanesPerBlock - kLanes>(d);
|
|
return IfThenElseZero(mask, shifted);
|
|
#endif
|
|
}
|
|
|
|
// ------------------------------ ShiftLeftBytes
|
|
|
|
template <int kBytes, class D, class V = VFromD<D>>
|
|
HWY_API V ShiftLeftBytes(const D d, const V v) {
|
|
const Repartition<uint8_t, decltype(d)> d8;
|
|
return BitCast(d, ShiftLeftLanes<kBytes>(BitCast(d8, v)));
|
|
}
|
|
|
|
template <int kBytes, class V>
|
|
HWY_API V ShiftLeftBytes(const V v) {
|
|
return ShiftLeftBytes<kBytes>(DFromV<V>(), v);
|
|
}
|
|
|
|
// ------------------------------ ShiftRightBytes
|
|
template <int kBytes, class D, class V = VFromD<D>>
|
|
HWY_API V ShiftRightBytes(const D d, const V v) {
|
|
const Repartition<uint8_t, decltype(d)> d8;
|
|
return BitCast(d, ShiftRightLanes<kBytes>(d8, BitCast(d8, v)));
|
|
}
|
|
|
|
// ------------------------------ ZipLower
|
|
|
|
template <class V, class DW = RepartitionToWide<DFromV<V>>>
|
|
HWY_API VFromD<DW> ZipLower(DW dw, V a, V b) {
|
|
const RepartitionToNarrow<DW> dn;
|
|
static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch");
|
|
return BitCast(dw, InterleaveLower(dn, a, b));
|
|
}
|
|
template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>>
|
|
HWY_API VFromD<DW> ZipLower(const V a, const V b) {
|
|
return BitCast(DW(), InterleaveLower(D(), a, b));
|
|
}
|
|
|
|
// ------------------------------ ZipUpper
|
|
template <class V, class DW = RepartitionToWide<DFromV<V>>>
|
|
HWY_API VFromD<DW> ZipUpper(DW dw, V a, V b) {
|
|
const RepartitionToNarrow<DW> dn;
|
|
static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch");
|
|
return BitCast(dw, InterleaveUpper(dn, a, b));
|
|
}
|
|
|
|
// ================================================== Ops with dependencies
|
|
|
|
// ------------------------------ PromoteTo bfloat16 (ZipLower)
|
|
template <size_t N, int kPow2>
|
|
HWY_API svfloat32_t PromoteTo(Simd<float32_t, N, kPow2> df32, VBF16 v) {
|
|
const ScalableTag<uint16_t> du16;
|
|
return BitCast(df32, detail::ZipLowerSame(svdup_n_u16(0), BitCast(du16, v)));
|
|
}
|
|
|
|
// ------------------------------ ReorderDemote2To (OddEven)
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API VBF16 ReorderDemote2To(Simd<bfloat16_t, N, kPow2> dbf16, svfloat32_t a,
|
|
svfloat32_t b) {
|
|
const RebindToUnsigned<decltype(dbf16)> du16;
|
|
const Repartition<uint32_t, decltype(dbf16)> du32;
|
|
const svuint32_t b_in_even = ShiftRight<16>(BitCast(du32, b));
|
|
return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even)));
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint16_t ReorderDemote2To(Simd<int16_t, N, kPow2> d16, svint32_t a,
|
|
svint32_t b) {
|
|
#if HWY_SVE_HAVE_2
|
|
(void)d16;
|
|
const svint16_t a_in_even = svqxtnb_s32(a);
|
|
return svqxtnt_s32(a_in_even, b);
|
|
#else
|
|
const svint16_t a16 = BitCast(d16, detail::SaturateI<int16_t>(a));
|
|
const svint16_t b16 = BitCast(d16, detail::SaturateI<int16_t>(b));
|
|
return detail::InterleaveEven(a16, b16);
|
|
#endif
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint16_t ReorderDemote2To(Simd<uint16_t, N, kPow2> d16, svint32_t a,
|
|
svint32_t b) {
|
|
#if HWY_SVE_HAVE_2
|
|
(void)d16;
|
|
const svuint16_t a_in_even = svqxtunb_s32(a);
|
|
return svqxtunt_s32(a_in_even, b);
|
|
#else
|
|
const Repartition<uint32_t, decltype(d16)> du32;
|
|
const svuint32_t clamped_a = BitCast(du32, detail::MaxN(a, 0));
|
|
const svuint32_t clamped_b = BitCast(du32, detail::MaxN(b, 0));
|
|
const svuint16_t a16 = BitCast(d16, detail::SaturateU<uint16_t>(clamped_a));
|
|
const svuint16_t b16 = BitCast(d16, detail::SaturateU<uint16_t>(clamped_b));
|
|
return detail::InterleaveEven(a16, b16);
|
|
#endif
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint16_t ReorderDemote2To(Simd<uint16_t, N, kPow2> d16, svuint32_t a,
|
|
svuint32_t b) {
|
|
#if HWY_SVE_HAVE_2
|
|
(void)d16;
|
|
const svuint16_t a_in_even = svqxtnb_u32(a);
|
|
return svqxtnt_u32(a_in_even, b);
|
|
#else
|
|
const svuint16_t a16 = BitCast(d16, detail::SaturateU<uint16_t>(a));
|
|
const svuint16_t b16 = BitCast(d16, detail::SaturateU<uint16_t>(b));
|
|
return detail::InterleaveEven(a16, b16);
|
|
#endif
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint8_t ReorderDemote2To(Simd<int8_t, N, kPow2> d8, svint16_t a,
|
|
svint16_t b) {
|
|
#if HWY_SVE_HAVE_2
|
|
(void)d8;
|
|
const svint8_t a_in_even = svqxtnb_s16(a);
|
|
return svqxtnt_s16(a_in_even, b);
|
|
#else
|
|
const svint8_t a8 = BitCast(d8, detail::SaturateI<int8_t>(a));
|
|
const svint8_t b8 = BitCast(d8, detail::SaturateI<int8_t>(b));
|
|
return detail::InterleaveEven(a8, b8);
|
|
#endif
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint8_t ReorderDemote2To(Simd<uint8_t, N, kPow2> d8, svint16_t a,
|
|
svint16_t b) {
|
|
#if HWY_SVE_HAVE_2
|
|
(void)d8;
|
|
const svuint8_t a_in_even = svqxtunb_s16(a);
|
|
return svqxtunt_s16(a_in_even, b);
|
|
#else
|
|
const Repartition<uint16_t, decltype(d8)> du16;
|
|
const svuint16_t clamped_a = BitCast(du16, detail::MaxN(a, 0));
|
|
const svuint16_t clamped_b = BitCast(du16, detail::MaxN(b, 0));
|
|
const svuint8_t a8 = BitCast(d8, detail::SaturateU<uint8_t>(clamped_a));
|
|
const svuint8_t b8 = BitCast(d8, detail::SaturateU<uint8_t>(clamped_b));
|
|
return detail::InterleaveEven(a8, b8);
|
|
#endif
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint8_t ReorderDemote2To(Simd<uint8_t, N, kPow2> d8, svuint16_t a,
|
|
svuint16_t b) {
|
|
#if HWY_SVE_HAVE_2
|
|
(void)d8;
|
|
const svuint8_t a_in_even = svqxtnb_u16(a);
|
|
return svqxtnt_u16(a_in_even, b);
|
|
#else
|
|
const svuint8_t a8 = BitCast(d8, detail::SaturateU<uint8_t>(a));
|
|
const svuint8_t b8 = BitCast(d8, detail::SaturateU<uint8_t>(b));
|
|
return detail::InterleaveEven(a8, b8);
|
|
#endif
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint32_t ReorderDemote2To(Simd<int32_t, N, kPow2> d32, svint64_t a,
|
|
svint64_t b) {
|
|
#if HWY_SVE_HAVE_2
|
|
(void)d32;
|
|
const svint32_t a_in_even = svqxtnb_s64(a);
|
|
return svqxtnt_s64(a_in_even, b);
|
|
#else
|
|
const svint32_t a32 = BitCast(d32, detail::SaturateI<int32_t>(a));
|
|
const svint32_t b32 = BitCast(d32, detail::SaturateI<int32_t>(b));
|
|
return detail::InterleaveEven(a32, b32);
|
|
#endif
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint32_t ReorderDemote2To(Simd<uint32_t, N, kPow2> d32, svint64_t a,
|
|
svint64_t b) {
|
|
#if HWY_SVE_HAVE_2
|
|
(void)d32;
|
|
const svuint32_t a_in_even = svqxtunb_s64(a);
|
|
return svqxtunt_s64(a_in_even, b);
|
|
#else
|
|
const Repartition<uint64_t, decltype(d32)> du64;
|
|
const svuint64_t clamped_a = BitCast(du64, detail::MaxN(a, 0));
|
|
const svuint64_t clamped_b = BitCast(du64, detail::MaxN(b, 0));
|
|
const svuint32_t a32 = BitCast(d32, detail::SaturateU<uint32_t>(clamped_a));
|
|
const svuint32_t b32 = BitCast(d32, detail::SaturateU<uint32_t>(clamped_b));
|
|
return detail::InterleaveEven(a32, b32);
|
|
#endif
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint32_t ReorderDemote2To(Simd<uint32_t, N, kPow2> d32, svuint64_t a,
|
|
svuint64_t b) {
|
|
#if HWY_SVE_HAVE_2
|
|
(void)d32;
|
|
const svuint32_t a_in_even = svqxtnb_u64(a);
|
|
return svqxtnt_u64(a_in_even, b);
|
|
#else
|
|
const svuint32_t a32 = BitCast(d32, detail::SaturateU<uint32_t>(a));
|
|
const svuint32_t b32 = BitCast(d32, detail::SaturateU<uint32_t>(b));
|
|
return detail::InterleaveEven(a32, b32);
|
|
#endif
|
|
}
|
|
|
|
template <class D, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL(TFromD<D>),
|
|
HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V),
|
|
HWY_IF_T_SIZE_V(V, sizeof(TFromD<D>) * 2)>
|
|
HWY_API VFromD<D> OrderedDemote2To(D dn, V a, V b) {
|
|
const Half<decltype(dn)> dnh;
|
|
const auto demoted_a = DemoteTo(dnh, a);
|
|
const auto demoted_b = DemoteTo(dnh, b);
|
|
return Combine(dn, demoted_b, demoted_a);
|
|
}
|
|
|
|
template <class D, HWY_IF_BF16_D(D)>
|
|
HWY_API VBF16 OrderedDemote2To(D dn, svfloat32_t a, svfloat32_t b) {
|
|
const Half<decltype(dn)> dnh;
|
|
const RebindToUnsigned<decltype(dn)> dn_u;
|
|
const RebindToUnsigned<decltype(dnh)> dnh_u;
|
|
const auto demoted_a = DemoteTo(dnh, a);
|
|
const auto demoted_b = DemoteTo(dnh, b);
|
|
return BitCast(
|
|
dn, Combine(dn_u, BitCast(dnh_u, demoted_b), BitCast(dnh_u, demoted_a)));
|
|
}
|
|
|
|
// ------------------------------ ZeroIfNegative (Lt, IfThenElse)
|
|
template <class V>
|
|
HWY_API V ZeroIfNegative(const V v) {
|
|
return IfThenZeroElse(detail::LtN(v, 0), v);
|
|
}
|
|
|
|
// ------------------------------ BroadcastSignBit (ShiftRight)
|
|
template <class V>
|
|
HWY_API V BroadcastSignBit(const V v) {
|
|
return ShiftRight<sizeof(TFromV<V>) * 8 - 1>(v);
|
|
}
|
|
|
|
// ------------------------------ IfNegativeThenElse (BroadcastSignBit)
|
|
template <class V>
|
|
HWY_API V IfNegativeThenElse(V v, V yes, V no) {
|
|
static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float");
|
|
const DFromV<V> d;
|
|
const RebindToSigned<decltype(d)> di;
|
|
|
|
const svbool_t m = detail::LtN(BitCast(di, v), 0);
|
|
return IfThenElse(m, yes, no);
|
|
}
|
|
|
|
// ------------------------------ AverageRound (ShiftRight)
|
|
|
|
#if HWY_SVE_HAVE_2
|
|
HWY_SVE_FOREACH_U08(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd)
|
|
HWY_SVE_FOREACH_U16(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd)
|
|
#else
|
|
template <class V>
|
|
V AverageRound(const V a, const V b) {
|
|
return ShiftRight<1>(detail::AddN(Add(a, b), 1));
|
|
}
|
|
#endif // HWY_SVE_HAVE_2
|
|
|
|
// ------------------------------ LoadMaskBits (TestBit)
|
|
|
|
// `p` points to at least 8 readable bytes, not all of which need be valid.
|
|
template <class D, HWY_IF_T_SIZE_D(D, 1)>
|
|
HWY_INLINE svbool_t LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) {
|
|
// TODO(janwas): with SVE2.1, load to vector, then PMOV
|
|
const RebindToUnsigned<D> du;
|
|
const svuint8_t iota = Iota(du, 0);
|
|
|
|
// Load correct number of bytes (bits/8) with 7 zeros after each.
|
|
const svuint8_t bytes = BitCast(du, svld1ub_u64(detail::PTrue(d), bits));
|
|
// Replicate bytes 8x such that each byte contains the bit that governs it.
|
|
const svuint8_t rep8 = svtbl_u8(bytes, detail::AndNotN(7, iota));
|
|
|
|
const svuint8_t bit =
|
|
svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128);
|
|
return TestBit(rep8, bit);
|
|
}
|
|
|
|
template <class D, HWY_IF_T_SIZE_D(D, 2)>
|
|
HWY_INLINE svbool_t LoadMaskBits(D /* tag */,
|
|
const uint8_t* HWY_RESTRICT bits) {
|
|
const RebindToUnsigned<D> du;
|
|
const Repartition<uint8_t, D> du8;
|
|
|
|
// There may be up to 128 bits; avoid reading past the end.
|
|
const svuint8_t bytes = svld1(FirstN(du8, (Lanes(du) + 7) / 8), bits);
|
|
|
|
// Replicate bytes 16x such that each lane contains the bit that governs it.
|
|
const svuint8_t rep16 = svtbl_u8(bytes, ShiftRight<4>(Iota(du8, 0)));
|
|
|
|
const svuint16_t bit = svdupq_n_u16(1, 2, 4, 8, 16, 32, 64, 128);
|
|
return TestBit(BitCast(du, rep16), bit);
|
|
}
|
|
|
|
template <class D, HWY_IF_T_SIZE_D(D, 4)>
|
|
HWY_INLINE svbool_t LoadMaskBits(D /* tag */,
|
|
const uint8_t* HWY_RESTRICT bits) {
|
|
const RebindToUnsigned<D> du;
|
|
const Repartition<uint8_t, D> du8;
|
|
|
|
// Upper bound = 2048 bits / 32 bit = 64 bits; at least 8 bytes are readable,
|
|
// so we can skip computing the actual length (Lanes(du)+7)/8.
|
|
const svuint8_t bytes = svld1(FirstN(du8, 8), bits);
|
|
|
|
// Replicate bytes 32x such that each lane contains the bit that governs it.
|
|
const svuint8_t rep32 = svtbl_u8(bytes, ShiftRight<5>(Iota(du8, 0)));
|
|
|
|
// 1, 2, 4, 8, 16, 32, 64, 128, 1, 2 ..
|
|
const svuint32_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7));
|
|
|
|
return TestBit(BitCast(du, rep32), bit);
|
|
}
|
|
|
|
template <class D, HWY_IF_T_SIZE_D(D, 8)>
|
|
HWY_INLINE svbool_t LoadMaskBits(D /* tag */,
|
|
const uint8_t* HWY_RESTRICT bits) {
|
|
const RebindToUnsigned<D> du;
|
|
|
|
// Max 2048 bits = 32 lanes = 32 input bits; replicate those into each lane.
|
|
// The "at least 8 byte" guarantee in quick_reference ensures this is safe.
|
|
uint32_t mask_bits;
|
|
CopyBytes<4>(bits, &mask_bits); // copy from bytes
|
|
const auto vbits = Set(du, mask_bits);
|
|
|
|
// 2 ^ {0,1, .., 31}, will not have more lanes than that.
|
|
const svuint64_t bit = Shl(Set(du, 1), Iota(du, 0));
|
|
|
|
return TestBit(vbits, bit);
|
|
}
|
|
|
|
// ------------------------------ StoreMaskBits
|
|
|
|
namespace detail {
|
|
|
|
// For each mask lane (governing lane type T), store 1 or 0 in BYTE lanes.
|
|
template <class T, HWY_IF_T_SIZE(T, 1)>
|
|
HWY_INLINE svuint8_t BoolFromMask(svbool_t m) {
|
|
return svdup_n_u8_z(m, 1);
|
|
}
|
|
template <class T, HWY_IF_T_SIZE(T, 2)>
|
|
HWY_INLINE svuint8_t BoolFromMask(svbool_t m) {
|
|
const ScalableTag<uint8_t> d8;
|
|
const svuint8_t b16 = BitCast(d8, svdup_n_u16_z(m, 1));
|
|
return detail::ConcatEvenFull(b16, b16); // lower half
|
|
}
|
|
template <class T, HWY_IF_T_SIZE(T, 4)>
|
|
HWY_INLINE svuint8_t BoolFromMask(svbool_t m) {
|
|
return U8FromU32(svdup_n_u32_z(m, 1));
|
|
}
|
|
template <class T, HWY_IF_T_SIZE(T, 8)>
|
|
HWY_INLINE svuint8_t BoolFromMask(svbool_t m) {
|
|
const ScalableTag<uint32_t> d32;
|
|
const svuint32_t b64 = BitCast(d32, svdup_n_u64_z(m, 1));
|
|
return U8FromU32(detail::ConcatEvenFull(b64, b64)); // lower half
|
|
}
|
|
|
|
// Compacts groups of 8 u8 into 8 contiguous bits in a 64-bit lane.
|
|
HWY_INLINE svuint64_t BitsFromBool(svuint8_t x) {
|
|
const ScalableTag<uint8_t> d8;
|
|
const ScalableTag<uint16_t> d16;
|
|
const ScalableTag<uint32_t> d32;
|
|
const ScalableTag<uint64_t> d64;
|
|
// TODO(janwas): could use SVE2 BDEP, but it's optional.
|
|
x = Or(x, BitCast(d8, ShiftRight<7>(BitCast(d16, x))));
|
|
x = Or(x, BitCast(d8, ShiftRight<14>(BitCast(d32, x))));
|
|
x = Or(x, BitCast(d8, ShiftRight<28>(BitCast(d64, x))));
|
|
return BitCast(d64, x);
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
// `p` points to at least 8 writable bytes.
|
|
// TODO(janwas): specialize for HWY_SVE_256
|
|
// TODO(janwas): with SVE2.1, use PMOV to store to vector, then StoreU
|
|
template <class D>
|
|
HWY_API size_t StoreMaskBits(D d, svbool_t m, uint8_t* bits) {
|
|
svuint64_t bits_in_u64 =
|
|
detail::BitsFromBool(detail::BoolFromMask<TFromD<D>>(m));
|
|
|
|
const size_t num_bits = Lanes(d);
|
|
const size_t num_bytes = (num_bits + 8 - 1) / 8; // Round up, see below
|
|
|
|
// Truncate each u64 to 8 bits and store to u8.
|
|
svst1b_u64(FirstN(ScalableTag<uint64_t>(), num_bytes), bits, bits_in_u64);
|
|
|
|
// Non-full byte, need to clear the undefined upper bits. Can happen for
|
|
// capped/fractional vectors or large T and small hardware vectors.
|
|
if (num_bits < 8) {
|
|
const int mask = static_cast<int>((1ull << num_bits) - 1);
|
|
bits[0] = static_cast<uint8_t>(bits[0] & mask);
|
|
}
|
|
// Else: we wrote full bytes because num_bits is a power of two >= 8.
|
|
|
|
return num_bytes;
|
|
}
|
|
|
|
// ------------------------------ CompressBits (LoadMaskBits)
|
|
template <class V, HWY_IF_NOT_T_SIZE_V(V, 1)>
|
|
HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) {
|
|
return Compress(v, LoadMaskBits(DFromV<V>(), bits));
|
|
}
|
|
|
|
// ------------------------------ CompressBitsStore (LoadMaskBits)
|
|
template <class D, HWY_IF_NOT_T_SIZE_D(D, 1)>
|
|
HWY_API size_t CompressBitsStore(VFromD<D> v, const uint8_t* HWY_RESTRICT bits,
|
|
D d, TFromD<D>* HWY_RESTRICT unaligned) {
|
|
return CompressStore(v, LoadMaskBits(d, bits), d, unaligned);
|
|
}
|
|
|
|
// ------------------------------ Expand (StoreMaskBits)
|
|
|
|
#ifdef HWY_NATIVE_EXPAND
|
|
#undef HWY_NATIVE_EXPAND
|
|
#else
|
|
#define HWY_NATIVE_EXPAND
|
|
#endif
|
|
|
|
namespace detail {
|
|
|
|
HWY_INLINE svuint8_t IndicesForExpandFromBits(uint64_t mask_bits) {
|
|
const CappedTag<uint8_t, 8> du8;
|
|
alignas(16) static constexpr uint8_t table[8 * 256] = {
|
|
// PrintExpand8x8Tables
|
|
128, 128, 128, 128, 128, 128, 128, 128, //
|
|
0, 128, 128, 128, 128, 128, 128, 128, //
|
|
128, 0, 128, 128, 128, 128, 128, 128, //
|
|
0, 1, 128, 128, 128, 128, 128, 128, //
|
|
128, 128, 0, 128, 128, 128, 128, 128, //
|
|
0, 128, 1, 128, 128, 128, 128, 128, //
|
|
128, 0, 1, 128, 128, 128, 128, 128, //
|
|
0, 1, 2, 128, 128, 128, 128, 128, //
|
|
128, 128, 128, 0, 128, 128, 128, 128, //
|
|
0, 128, 128, 1, 128, 128, 128, 128, //
|
|
128, 0, 128, 1, 128, 128, 128, 128, //
|
|
0, 1, 128, 2, 128, 128, 128, 128, //
|
|
128, 128, 0, 1, 128, 128, 128, 128, //
|
|
0, 128, 1, 2, 128, 128, 128, 128, //
|
|
128, 0, 1, 2, 128, 128, 128, 128, //
|
|
0, 1, 2, 3, 128, 128, 128, 128, //
|
|
128, 128, 128, 128, 0, 128, 128, 128, //
|
|
0, 128, 128, 128, 1, 128, 128, 128, //
|
|
128, 0, 128, 128, 1, 128, 128, 128, //
|
|
0, 1, 128, 128, 2, 128, 128, 128, //
|
|
128, 128, 0, 128, 1, 128, 128, 128, //
|
|
0, 128, 1, 128, 2, 128, 128, 128, //
|
|
128, 0, 1, 128, 2, 128, 128, 128, //
|
|
0, 1, 2, 128, 3, 128, 128, 128, //
|
|
128, 128, 128, 0, 1, 128, 128, 128, //
|
|
0, 128, 128, 1, 2, 128, 128, 128, //
|
|
128, 0, 128, 1, 2, 128, 128, 128, //
|
|
0, 1, 128, 2, 3, 128, 128, 128, //
|
|
128, 128, 0, 1, 2, 128, 128, 128, //
|
|
0, 128, 1, 2, 3, 128, 128, 128, //
|
|
128, 0, 1, 2, 3, 128, 128, 128, //
|
|
0, 1, 2, 3, 4, 128, 128, 128, //
|
|
128, 128, 128, 128, 128, 0, 128, 128, //
|
|
0, 128, 128, 128, 128, 1, 128, 128, //
|
|
128, 0, 128, 128, 128, 1, 128, 128, //
|
|
0, 1, 128, 128, 128, 2, 128, 128, //
|
|
128, 128, 0, 128, 128, 1, 128, 128, //
|
|
0, 128, 1, 128, 128, 2, 128, 128, //
|
|
128, 0, 1, 128, 128, 2, 128, 128, //
|
|
0, 1, 2, 128, 128, 3, 128, 128, //
|
|
128, 128, 128, 0, 128, 1, 128, 128, //
|
|
0, 128, 128, 1, 128, 2, 128, 128, //
|
|
128, 0, 128, 1, 128, 2, 128, 128, //
|
|
0, 1, 128, 2, 128, 3, 128, 128, //
|
|
128, 128, 0, 1, 128, 2, 128, 128, //
|
|
0, 128, 1, 2, 128, 3, 128, 128, //
|
|
128, 0, 1, 2, 128, 3, 128, 128, //
|
|
0, 1, 2, 3, 128, 4, 128, 128, //
|
|
128, 128, 128, 128, 0, 1, 128, 128, //
|
|
0, 128, 128, 128, 1, 2, 128, 128, //
|
|
128, 0, 128, 128, 1, 2, 128, 128, //
|
|
0, 1, 128, 128, 2, 3, 128, 128, //
|
|
128, 128, 0, 128, 1, 2, 128, 128, //
|
|
0, 128, 1, 128, 2, 3, 128, 128, //
|
|
128, 0, 1, 128, 2, 3, 128, 128, //
|
|
0, 1, 2, 128, 3, 4, 128, 128, //
|
|
128, 128, 128, 0, 1, 2, 128, 128, //
|
|
0, 128, 128, 1, 2, 3, 128, 128, //
|
|
128, 0, 128, 1, 2, 3, 128, 128, //
|
|
0, 1, 128, 2, 3, 4, 128, 128, //
|
|
128, 128, 0, 1, 2, 3, 128, 128, //
|
|
0, 128, 1, 2, 3, 4, 128, 128, //
|
|
128, 0, 1, 2, 3, 4, 128, 128, //
|
|
0, 1, 2, 3, 4, 5, 128, 128, //
|
|
128, 128, 128, 128, 128, 128, 0, 128, //
|
|
0, 128, 128, 128, 128, 128, 1, 128, //
|
|
128, 0, 128, 128, 128, 128, 1, 128, //
|
|
0, 1, 128, 128, 128, 128, 2, 128, //
|
|
128, 128, 0, 128, 128, 128, 1, 128, //
|
|
0, 128, 1, 128, 128, 128, 2, 128, //
|
|
128, 0, 1, 128, 128, 128, 2, 128, //
|
|
0, 1, 2, 128, 128, 128, 3, 128, //
|
|
128, 128, 128, 0, 128, 128, 1, 128, //
|
|
0, 128, 128, 1, 128, 128, 2, 128, //
|
|
128, 0, 128, 1, 128, 128, 2, 128, //
|
|
0, 1, 128, 2, 128, 128, 3, 128, //
|
|
128, 128, 0, 1, 128, 128, 2, 128, //
|
|
0, 128, 1, 2, 128, 128, 3, 128, //
|
|
128, 0, 1, 2, 128, 128, 3, 128, //
|
|
0, 1, 2, 3, 128, 128, 4, 128, //
|
|
128, 128, 128, 128, 0, 128, 1, 128, //
|
|
0, 128, 128, 128, 1, 128, 2, 128, //
|
|
128, 0, 128, 128, 1, 128, 2, 128, //
|
|
0, 1, 128, 128, 2, 128, 3, 128, //
|
|
128, 128, 0, 128, 1, 128, 2, 128, //
|
|
0, 128, 1, 128, 2, 128, 3, 128, //
|
|
128, 0, 1, 128, 2, 128, 3, 128, //
|
|
0, 1, 2, 128, 3, 128, 4, 128, //
|
|
128, 128, 128, 0, 1, 128, 2, 128, //
|
|
0, 128, 128, 1, 2, 128, 3, 128, //
|
|
128, 0, 128, 1, 2, 128, 3, 128, //
|
|
0, 1, 128, 2, 3, 128, 4, 128, //
|
|
128, 128, 0, 1, 2, 128, 3, 128, //
|
|
0, 128, 1, 2, 3, 128, 4, 128, //
|
|
128, 0, 1, 2, 3, 128, 4, 128, //
|
|
0, 1, 2, 3, 4, 128, 5, 128, //
|
|
128, 128, 128, 128, 128, 0, 1, 128, //
|
|
0, 128, 128, 128, 128, 1, 2, 128, //
|
|
128, 0, 128, 128, 128, 1, 2, 128, //
|
|
0, 1, 128, 128, 128, 2, 3, 128, //
|
|
128, 128, 0, 128, 128, 1, 2, 128, //
|
|
0, 128, 1, 128, 128, 2, 3, 128, //
|
|
128, 0, 1, 128, 128, 2, 3, 128, //
|
|
0, 1, 2, 128, 128, 3, 4, 128, //
|
|
128, 128, 128, 0, 128, 1, 2, 128, //
|
|
0, 128, 128, 1, 128, 2, 3, 128, //
|
|
128, 0, 128, 1, 128, 2, 3, 128, //
|
|
0, 1, 128, 2, 128, 3, 4, 128, //
|
|
128, 128, 0, 1, 128, 2, 3, 128, //
|
|
0, 128, 1, 2, 128, 3, 4, 128, //
|
|
128, 0, 1, 2, 128, 3, 4, 128, //
|
|
0, 1, 2, 3, 128, 4, 5, 128, //
|
|
128, 128, 128, 128, 0, 1, 2, 128, //
|
|
0, 128, 128, 128, 1, 2, 3, 128, //
|
|
128, 0, 128, 128, 1, 2, 3, 128, //
|
|
0, 1, 128, 128, 2, 3, 4, 128, //
|
|
128, 128, 0, 128, 1, 2, 3, 128, //
|
|
0, 128, 1, 128, 2, 3, 4, 128, //
|
|
128, 0, 1, 128, 2, 3, 4, 128, //
|
|
0, 1, 2, 128, 3, 4, 5, 128, //
|
|
128, 128, 128, 0, 1, 2, 3, 128, //
|
|
0, 128, 128, 1, 2, 3, 4, 128, //
|
|
128, 0, 128, 1, 2, 3, 4, 128, //
|
|
0, 1, 128, 2, 3, 4, 5, 128, //
|
|
128, 128, 0, 1, 2, 3, 4, 128, //
|
|
0, 128, 1, 2, 3, 4, 5, 128, //
|
|
128, 0, 1, 2, 3, 4, 5, 128, //
|
|
0, 1, 2, 3, 4, 5, 6, 128, //
|
|
128, 128, 128, 128, 128, 128, 128, 0, //
|
|
0, 128, 128, 128, 128, 128, 128, 1, //
|
|
128, 0, 128, 128, 128, 128, 128, 1, //
|
|
0, 1, 128, 128, 128, 128, 128, 2, //
|
|
128, 128, 0, 128, 128, 128, 128, 1, //
|
|
0, 128, 1, 128, 128, 128, 128, 2, //
|
|
128, 0, 1, 128, 128, 128, 128, 2, //
|
|
0, 1, 2, 128, 128, 128, 128, 3, //
|
|
128, 128, 128, 0, 128, 128, 128, 1, //
|
|
0, 128, 128, 1, 128, 128, 128, 2, //
|
|
128, 0, 128, 1, 128, 128, 128, 2, //
|
|
0, 1, 128, 2, 128, 128, 128, 3, //
|
|
128, 128, 0, 1, 128, 128, 128, 2, //
|
|
0, 128, 1, 2, 128, 128, 128, 3, //
|
|
128, 0, 1, 2, 128, 128, 128, 3, //
|
|
0, 1, 2, 3, 128, 128, 128, 4, //
|
|
128, 128, 128, 128, 0, 128, 128, 1, //
|
|
0, 128, 128, 128, 1, 128, 128, 2, //
|
|
128, 0, 128, 128, 1, 128, 128, 2, //
|
|
0, 1, 128, 128, 2, 128, 128, 3, //
|
|
128, 128, 0, 128, 1, 128, 128, 2, //
|
|
0, 128, 1, 128, 2, 128, 128, 3, //
|
|
128, 0, 1, 128, 2, 128, 128, 3, //
|
|
0, 1, 2, 128, 3, 128, 128, 4, //
|
|
128, 128, 128, 0, 1, 128, 128, 2, //
|
|
0, 128, 128, 1, 2, 128, 128, 3, //
|
|
128, 0, 128, 1, 2, 128, 128, 3, //
|
|
0, 1, 128, 2, 3, 128, 128, 4, //
|
|
128, 128, 0, 1, 2, 128, 128, 3, //
|
|
0, 128, 1, 2, 3, 128, 128, 4, //
|
|
128, 0, 1, 2, 3, 128, 128, 4, //
|
|
0, 1, 2, 3, 4, 128, 128, 5, //
|
|
128, 128, 128, 128, 128, 0, 128, 1, //
|
|
0, 128, 128, 128, 128, 1, 128, 2, //
|
|
128, 0, 128, 128, 128, 1, 128, 2, //
|
|
0, 1, 128, 128, 128, 2, 128, 3, //
|
|
128, 128, 0, 128, 128, 1, 128, 2, //
|
|
0, 128, 1, 128, 128, 2, 128, 3, //
|
|
128, 0, 1, 128, 128, 2, 128, 3, //
|
|
0, 1, 2, 128, 128, 3, 128, 4, //
|
|
128, 128, 128, 0, 128, 1, 128, 2, //
|
|
0, 128, 128, 1, 128, 2, 128, 3, //
|
|
128, 0, 128, 1, 128, 2, 128, 3, //
|
|
0, 1, 128, 2, 128, 3, 128, 4, //
|
|
128, 128, 0, 1, 128, 2, 128, 3, //
|
|
0, 128, 1, 2, 128, 3, 128, 4, //
|
|
128, 0, 1, 2, 128, 3, 128, 4, //
|
|
0, 1, 2, 3, 128, 4, 128, 5, //
|
|
128, 128, 128, 128, 0, 1, 128, 2, //
|
|
0, 128, 128, 128, 1, 2, 128, 3, //
|
|
128, 0, 128, 128, 1, 2, 128, 3, //
|
|
0, 1, 128, 128, 2, 3, 128, 4, //
|
|
128, 128, 0, 128, 1, 2, 128, 3, //
|
|
0, 128, 1, 128, 2, 3, 128, 4, //
|
|
128, 0, 1, 128, 2, 3, 128, 4, //
|
|
0, 1, 2, 128, 3, 4, 128, 5, //
|
|
128, 128, 128, 0, 1, 2, 128, 3, //
|
|
0, 128, 128, 1, 2, 3, 128, 4, //
|
|
128, 0, 128, 1, 2, 3, 128, 4, //
|
|
0, 1, 128, 2, 3, 4, 128, 5, //
|
|
128, 128, 0, 1, 2, 3, 128, 4, //
|
|
0, 128, 1, 2, 3, 4, 128, 5, //
|
|
128, 0, 1, 2, 3, 4, 128, 5, //
|
|
0, 1, 2, 3, 4, 5, 128, 6, //
|
|
128, 128, 128, 128, 128, 128, 0, 1, //
|
|
0, 128, 128, 128, 128, 128, 1, 2, //
|
|
128, 0, 128, 128, 128, 128, 1, 2, //
|
|
0, 1, 128, 128, 128, 128, 2, 3, //
|
|
128, 128, 0, 128, 128, 128, 1, 2, //
|
|
0, 128, 1, 128, 128, 128, 2, 3, //
|
|
128, 0, 1, 128, 128, 128, 2, 3, //
|
|
0, 1, 2, 128, 128, 128, 3, 4, //
|
|
128, 128, 128, 0, 128, 128, 1, 2, //
|
|
0, 128, 128, 1, 128, 128, 2, 3, //
|
|
128, 0, 128, 1, 128, 128, 2, 3, //
|
|
0, 1, 128, 2, 128, 128, 3, 4, //
|
|
128, 128, 0, 1, 128, 128, 2, 3, //
|
|
0, 128, 1, 2, 128, 128, 3, 4, //
|
|
128, 0, 1, 2, 128, 128, 3, 4, //
|
|
0, 1, 2, 3, 128, 128, 4, 5, //
|
|
128, 128, 128, 128, 0, 128, 1, 2, //
|
|
0, 128, 128, 128, 1, 128, 2, 3, //
|
|
128, 0, 128, 128, 1, 128, 2, 3, //
|
|
0, 1, 128, 128, 2, 128, 3, 4, //
|
|
128, 128, 0, 128, 1, 128, 2, 3, //
|
|
0, 128, 1, 128, 2, 128, 3, 4, //
|
|
128, 0, 1, 128, 2, 128, 3, 4, //
|
|
0, 1, 2, 128, 3, 128, 4, 5, //
|
|
128, 128, 128, 0, 1, 128, 2, 3, //
|
|
0, 128, 128, 1, 2, 128, 3, 4, //
|
|
128, 0, 128, 1, 2, 128, 3, 4, //
|
|
0, 1, 128, 2, 3, 128, 4, 5, //
|
|
128, 128, 0, 1, 2, 128, 3, 4, //
|
|
0, 128, 1, 2, 3, 128, 4, 5, //
|
|
128, 0, 1, 2, 3, 128, 4, 5, //
|
|
0, 1, 2, 3, 4, 128, 5, 6, //
|
|
128, 128, 128, 128, 128, 0, 1, 2, //
|
|
0, 128, 128, 128, 128, 1, 2, 3, //
|
|
128, 0, 128, 128, 128, 1, 2, 3, //
|
|
0, 1, 128, 128, 128, 2, 3, 4, //
|
|
128, 128, 0, 128, 128, 1, 2, 3, //
|
|
0, 128, 1, 128, 128, 2, 3, 4, //
|
|
128, 0, 1, 128, 128, 2, 3, 4, //
|
|
0, 1, 2, 128, 128, 3, 4, 5, //
|
|
128, 128, 128, 0, 128, 1, 2, 3, //
|
|
0, 128, 128, 1, 128, 2, 3, 4, //
|
|
128, 0, 128, 1, 128, 2, 3, 4, //
|
|
0, 1, 128, 2, 128, 3, 4, 5, //
|
|
128, 128, 0, 1, 128, 2, 3, 4, //
|
|
0, 128, 1, 2, 128, 3, 4, 5, //
|
|
128, 0, 1, 2, 128, 3, 4, 5, //
|
|
0, 1, 2, 3, 128, 4, 5, 6, //
|
|
128, 128, 128, 128, 0, 1, 2, 3, //
|
|
0, 128, 128, 128, 1, 2, 3, 4, //
|
|
128, 0, 128, 128, 1, 2, 3, 4, //
|
|
0, 1, 128, 128, 2, 3, 4, 5, //
|
|
128, 128, 0, 128, 1, 2, 3, 4, //
|
|
0, 128, 1, 128, 2, 3, 4, 5, //
|
|
128, 0, 1, 128, 2, 3, 4, 5, //
|
|
0, 1, 2, 128, 3, 4, 5, 6, //
|
|
128, 128, 128, 0, 1, 2, 3, 4, //
|
|
0, 128, 128, 1, 2, 3, 4, 5, //
|
|
128, 0, 128, 1, 2, 3, 4, 5, //
|
|
0, 1, 128, 2, 3, 4, 5, 6, //
|
|
128, 128, 0, 1, 2, 3, 4, 5, //
|
|
0, 128, 1, 2, 3, 4, 5, 6, //
|
|
128, 0, 1, 2, 3, 4, 5, 6, //
|
|
0, 1, 2, 3, 4, 5, 6, 7};
|
|
return Load(du8, table + mask_bits * 8);
|
|
}
|
|
|
|
template <class D, HWY_IF_T_SIZE_D(D, 1)>
|
|
HWY_INLINE svuint8_t LaneIndicesFromByteIndices(D, svuint8_t idx) {
|
|
return idx;
|
|
}
|
|
template <class D, class DU = RebindToUnsigned<D>, HWY_IF_NOT_T_SIZE_D(D, 1)>
|
|
HWY_INLINE VFromD<DU> LaneIndicesFromByteIndices(D, svuint8_t idx) {
|
|
return PromoteTo(DU(), idx);
|
|
}
|
|
|
|
// General case when we don't know the vector size, 8 elements at a time.
|
|
template <class V>
|
|
HWY_INLINE V ExpandLoop(V v, svbool_t mask) {
|
|
const DFromV<V> d;
|
|
uint8_t mask_bytes[256 / 8];
|
|
StoreMaskBits(d, mask, mask_bytes);
|
|
|
|
// ShiftLeftLanes is expensive, so we're probably better off storing to memory
|
|
// and loading the final result.
|
|
alignas(16) TFromV<V> out[2 * MaxLanes(d)];
|
|
|
|
svbool_t next = svpfalse_b();
|
|
size_t input_consumed = 0;
|
|
const V iota = Iota(d, 0);
|
|
for (size_t i = 0; i < Lanes(d); i += 8) {
|
|
uint64_t mask_bits = mask_bytes[i / 8];
|
|
|
|
// We want to skip past the v lanes already consumed. There is no
|
|
// instruction for variable-shift-reg, but we can splice.
|
|
const V vH = detail::Splice(v, v, next);
|
|
input_consumed += PopCount(mask_bits);
|
|
next = detail::GeN(iota, static_cast<TFromV<V>>(input_consumed));
|
|
|
|
const auto idx = detail::LaneIndicesFromByteIndices(
|
|
d, detail::IndicesForExpandFromBits(mask_bits));
|
|
const V expand = TableLookupLanes(vH, idx);
|
|
StoreU(expand, d, out + i);
|
|
}
|
|
return LoadU(d, out);
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
template <class V, HWY_IF_T_SIZE_V(V, 1)>
|
|
HWY_API V Expand(V v, svbool_t mask) {
|
|
#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE
|
|
const DFromV<V> d;
|
|
uint8_t mask_bytes[256 / 8];
|
|
StoreMaskBits(d, mask, mask_bytes);
|
|
const uint64_t maskL = mask_bytes[0];
|
|
const uint64_t maskH = mask_bytes[1];
|
|
|
|
// We want to skip past the v bytes already consumed by expandL. There is no
|
|
// instruction for shift-reg by variable bytes, but we can splice. Instead of
|
|
// GeN, Not(FirstN()) would also work.
|
|
using T = TFromV<V>;
|
|
const T countL = static_cast<T>(PopCount(maskL));
|
|
const V vH = detail::Splice(v, v, detail::GeN(Iota(d, 0), countL));
|
|
|
|
const svuint8_t idxL = detail::IndicesForExpandFromBits(maskL);
|
|
const svuint8_t idxH = detail::IndicesForExpandFromBits(maskH);
|
|
return Combine(d, TableLookupLanes(vH, idxH), TableLookupLanes(v, idxL));
|
|
#else
|
|
return detail::ExpandLoop(v, mask);
|
|
#endif
|
|
}
|
|
|
|
template <class V, HWY_IF_T_SIZE_V(V, 2)>
|
|
HWY_API V Expand(V v, svbool_t mask) {
|
|
#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE // 16x8
|
|
const DFromV<V> d;
|
|
const RebindToUnsigned<decltype(d)> du16;
|
|
const Rebind<uint8_t, decltype(d)> du8;
|
|
// Convert mask into bitfield via horizontal sum (faster than ORV) of 8 bits.
|
|
// Pre-multiply by N so we can use it as an offset for Load.
|
|
const svuint16_t bits = Shl(Set(du16, 1), Iota(du16, 3));
|
|
const size_t offset = detail::SumOfLanesM(mask, bits);
|
|
|
|
// Storing as 8-bit reduces table size from 4 KiB to 2 KiB. We cannot apply
|
|
// the nibble trick used below because not all indices fit within one lane.
|
|
alignas(16) static constexpr uint8_t table[8 * 256] = {
|
|
// PrintExpand16x8LaneTables
|
|
255, 255, 255, 255, 255, 255, 255, 255, //
|
|
0, 255, 255, 255, 255, 255, 255, 255, //
|
|
255, 0, 255, 255, 255, 255, 255, 255, //
|
|
0, 1, 255, 255, 255, 255, 255, 255, //
|
|
255, 255, 0, 255, 255, 255, 255, 255, //
|
|
0, 255, 1, 255, 255, 255, 255, 255, //
|
|
255, 0, 1, 255, 255, 255, 255, 255, //
|
|
0, 1, 2, 255, 255, 255, 255, 255, //
|
|
255, 255, 255, 0, 255, 255, 255, 255, //
|
|
0, 255, 255, 1, 255, 255, 255, 255, //
|
|
255, 0, 255, 1, 255, 255, 255, 255, //
|
|
0, 1, 255, 2, 255, 255, 255, 255, //
|
|
255, 255, 0, 1, 255, 255, 255, 255, //
|
|
0, 255, 1, 2, 255, 255, 255, 255, //
|
|
255, 0, 1, 2, 255, 255, 255, 255, //
|
|
0, 1, 2, 3, 255, 255, 255, 255, //
|
|
255, 255, 255, 255, 0, 255, 255, 255, //
|
|
0, 255, 255, 255, 1, 255, 255, 255, //
|
|
255, 0, 255, 255, 1, 255, 255, 255, //
|
|
0, 1, 255, 255, 2, 255, 255, 255, //
|
|
255, 255, 0, 255, 1, 255, 255, 255, //
|
|
0, 255, 1, 255, 2, 255, 255, 255, //
|
|
255, 0, 1, 255, 2, 255, 255, 255, //
|
|
0, 1, 2, 255, 3, 255, 255, 255, //
|
|
255, 255, 255, 0, 1, 255, 255, 255, //
|
|
0, 255, 255, 1, 2, 255, 255, 255, //
|
|
255, 0, 255, 1, 2, 255, 255, 255, //
|
|
0, 1, 255, 2, 3, 255, 255, 255, //
|
|
255, 255, 0, 1, 2, 255, 255, 255, //
|
|
0, 255, 1, 2, 3, 255, 255, 255, //
|
|
255, 0, 1, 2, 3, 255, 255, 255, //
|
|
0, 1, 2, 3, 4, 255, 255, 255, //
|
|
255, 255, 255, 255, 255, 0, 255, 255, //
|
|
0, 255, 255, 255, 255, 1, 255, 255, //
|
|
255, 0, 255, 255, 255, 1, 255, 255, //
|
|
0, 1, 255, 255, 255, 2, 255, 255, //
|
|
255, 255, 0, 255, 255, 1, 255, 255, //
|
|
0, 255, 1, 255, 255, 2, 255, 255, //
|
|
255, 0, 1, 255, 255, 2, 255, 255, //
|
|
0, 1, 2, 255, 255, 3, 255, 255, //
|
|
255, 255, 255, 0, 255, 1, 255, 255, //
|
|
0, 255, 255, 1, 255, 2, 255, 255, //
|
|
255, 0, 255, 1, 255, 2, 255, 255, //
|
|
0, 1, 255, 2, 255, 3, 255, 255, //
|
|
255, 255, 0, 1, 255, 2, 255, 255, //
|
|
0, 255, 1, 2, 255, 3, 255, 255, //
|
|
255, 0, 1, 2, 255, 3, 255, 255, //
|
|
0, 1, 2, 3, 255, 4, 255, 255, //
|
|
255, 255, 255, 255, 0, 1, 255, 255, //
|
|
0, 255, 255, 255, 1, 2, 255, 255, //
|
|
255, 0, 255, 255, 1, 2, 255, 255, //
|
|
0, 1, 255, 255, 2, 3, 255, 255, //
|
|
255, 255, 0, 255, 1, 2, 255, 255, //
|
|
0, 255, 1, 255, 2, 3, 255, 255, //
|
|
255, 0, 1, 255, 2, 3, 255, 255, //
|
|
0, 1, 2, 255, 3, 4, 255, 255, //
|
|
255, 255, 255, 0, 1, 2, 255, 255, //
|
|
0, 255, 255, 1, 2, 3, 255, 255, //
|
|
255, 0, 255, 1, 2, 3, 255, 255, //
|
|
0, 1, 255, 2, 3, 4, 255, 255, //
|
|
255, 255, 0, 1, 2, 3, 255, 255, //
|
|
0, 255, 1, 2, 3, 4, 255, 255, //
|
|
255, 0, 1, 2, 3, 4, 255, 255, //
|
|
0, 1, 2, 3, 4, 5, 255, 255, //
|
|
255, 255, 255, 255, 255, 255, 0, 255, //
|
|
0, 255, 255, 255, 255, 255, 1, 255, //
|
|
255, 0, 255, 255, 255, 255, 1, 255, //
|
|
0, 1, 255, 255, 255, 255, 2, 255, //
|
|
255, 255, 0, 255, 255, 255, 1, 255, //
|
|
0, 255, 1, 255, 255, 255, 2, 255, //
|
|
255, 0, 1, 255, 255, 255, 2, 255, //
|
|
0, 1, 2, 255, 255, 255, 3, 255, //
|
|
255, 255, 255, 0, 255, 255, 1, 255, //
|
|
0, 255, 255, 1, 255, 255, 2, 255, //
|
|
255, 0, 255, 1, 255, 255, 2, 255, //
|
|
0, 1, 255, 2, 255, 255, 3, 255, //
|
|
255, 255, 0, 1, 255, 255, 2, 255, //
|
|
0, 255, 1, 2, 255, 255, 3, 255, //
|
|
255, 0, 1, 2, 255, 255, 3, 255, //
|
|
0, 1, 2, 3, 255, 255, 4, 255, //
|
|
255, 255, 255, 255, 0, 255, 1, 255, //
|
|
0, 255, 255, 255, 1, 255, 2, 255, //
|
|
255, 0, 255, 255, 1, 255, 2, 255, //
|
|
0, 1, 255, 255, 2, 255, 3, 255, //
|
|
255, 255, 0, 255, 1, 255, 2, 255, //
|
|
0, 255, 1, 255, 2, 255, 3, 255, //
|
|
255, 0, 1, 255, 2, 255, 3, 255, //
|
|
0, 1, 2, 255, 3, 255, 4, 255, //
|
|
255, 255, 255, 0, 1, 255, 2, 255, //
|
|
0, 255, 255, 1, 2, 255, 3, 255, //
|
|
255, 0, 255, 1, 2, 255, 3, 255, //
|
|
0, 1, 255, 2, 3, 255, 4, 255, //
|
|
255, 255, 0, 1, 2, 255, 3, 255, //
|
|
0, 255, 1, 2, 3, 255, 4, 255, //
|
|
255, 0, 1, 2, 3, 255, 4, 255, //
|
|
0, 1, 2, 3, 4, 255, 5, 255, //
|
|
255, 255, 255, 255, 255, 0, 1, 255, //
|
|
0, 255, 255, 255, 255, 1, 2, 255, //
|
|
255, 0, 255, 255, 255, 1, 2, 255, //
|
|
0, 1, 255, 255, 255, 2, 3, 255, //
|
|
255, 255, 0, 255, 255, 1, 2, 255, //
|
|
0, 255, 1, 255, 255, 2, 3, 255, //
|
|
255, 0, 1, 255, 255, 2, 3, 255, //
|
|
0, 1, 2, 255, 255, 3, 4, 255, //
|
|
255, 255, 255, 0, 255, 1, 2, 255, //
|
|
0, 255, 255, 1, 255, 2, 3, 255, //
|
|
255, 0, 255, 1, 255, 2, 3, 255, //
|
|
0, 1, 255, 2, 255, 3, 4, 255, //
|
|
255, 255, 0, 1, 255, 2, 3, 255, //
|
|
0, 255, 1, 2, 255, 3, 4, 255, //
|
|
255, 0, 1, 2, 255, 3, 4, 255, //
|
|
0, 1, 2, 3, 255, 4, 5, 255, //
|
|
255, 255, 255, 255, 0, 1, 2, 255, //
|
|
0, 255, 255, 255, 1, 2, 3, 255, //
|
|
255, 0, 255, 255, 1, 2, 3, 255, //
|
|
0, 1, 255, 255, 2, 3, 4, 255, //
|
|
255, 255, 0, 255, 1, 2, 3, 255, //
|
|
0, 255, 1, 255, 2, 3, 4, 255, //
|
|
255, 0, 1, 255, 2, 3, 4, 255, //
|
|
0, 1, 2, 255, 3, 4, 5, 255, //
|
|
255, 255, 255, 0, 1, 2, 3, 255, //
|
|
0, 255, 255, 1, 2, 3, 4, 255, //
|
|
255, 0, 255, 1, 2, 3, 4, 255, //
|
|
0, 1, 255, 2, 3, 4, 5, 255, //
|
|
255, 255, 0, 1, 2, 3, 4, 255, //
|
|
0, 255, 1, 2, 3, 4, 5, 255, //
|
|
255, 0, 1, 2, 3, 4, 5, 255, //
|
|
0, 1, 2, 3, 4, 5, 6, 255, //
|
|
255, 255, 255, 255, 255, 255, 255, 0, //
|
|
0, 255, 255, 255, 255, 255, 255, 1, //
|
|
255, 0, 255, 255, 255, 255, 255, 1, //
|
|
0, 1, 255, 255, 255, 255, 255, 2, //
|
|
255, 255, 0, 255, 255, 255, 255, 1, //
|
|
0, 255, 1, 255, 255, 255, 255, 2, //
|
|
255, 0, 1, 255, 255, 255, 255, 2, //
|
|
0, 1, 2, 255, 255, 255, 255, 3, //
|
|
255, 255, 255, 0, 255, 255, 255, 1, //
|
|
0, 255, 255, 1, 255, 255, 255, 2, //
|
|
255, 0, 255, 1, 255, 255, 255, 2, //
|
|
0, 1, 255, 2, 255, 255, 255, 3, //
|
|
255, 255, 0, 1, 255, 255, 255, 2, //
|
|
0, 255, 1, 2, 255, 255, 255, 3, //
|
|
255, 0, 1, 2, 255, 255, 255, 3, //
|
|
0, 1, 2, 3, 255, 255, 255, 4, //
|
|
255, 255, 255, 255, 0, 255, 255, 1, //
|
|
0, 255, 255, 255, 1, 255, 255, 2, //
|
|
255, 0, 255, 255, 1, 255, 255, 2, //
|
|
0, 1, 255, 255, 2, 255, 255, 3, //
|
|
255, 255, 0, 255, 1, 255, 255, 2, //
|
|
0, 255, 1, 255, 2, 255, 255, 3, //
|
|
255, 0, 1, 255, 2, 255, 255, 3, //
|
|
0, 1, 2, 255, 3, 255, 255, 4, //
|
|
255, 255, 255, 0, 1, 255, 255, 2, //
|
|
0, 255, 255, 1, 2, 255, 255, 3, //
|
|
255, 0, 255, 1, 2, 255, 255, 3, //
|
|
0, 1, 255, 2, 3, 255, 255, 4, //
|
|
255, 255, 0, 1, 2, 255, 255, 3, //
|
|
0, 255, 1, 2, 3, 255, 255, 4, //
|
|
255, 0, 1, 2, 3, 255, 255, 4, //
|
|
0, 1, 2, 3, 4, 255, 255, 5, //
|
|
255, 255, 255, 255, 255, 0, 255, 1, //
|
|
0, 255, 255, 255, 255, 1, 255, 2, //
|
|
255, 0, 255, 255, 255, 1, 255, 2, //
|
|
0, 1, 255, 255, 255, 2, 255, 3, //
|
|
255, 255, 0, 255, 255, 1, 255, 2, //
|
|
0, 255, 1, 255, 255, 2, 255, 3, //
|
|
255, 0, 1, 255, 255, 2, 255, 3, //
|
|
0, 1, 2, 255, 255, 3, 255, 4, //
|
|
255, 255, 255, 0, 255, 1, 255, 2, //
|
|
0, 255, 255, 1, 255, 2, 255, 3, //
|
|
255, 0, 255, 1, 255, 2, 255, 3, //
|
|
0, 1, 255, 2, 255, 3, 255, 4, //
|
|
255, 255, 0, 1, 255, 2, 255, 3, //
|
|
0, 255, 1, 2, 255, 3, 255, 4, //
|
|
255, 0, 1, 2, 255, 3, 255, 4, //
|
|
0, 1, 2, 3, 255, 4, 255, 5, //
|
|
255, 255, 255, 255, 0, 1, 255, 2, //
|
|
0, 255, 255, 255, 1, 2, 255, 3, //
|
|
255, 0, 255, 255, 1, 2, 255, 3, //
|
|
0, 1, 255, 255, 2, 3, 255, 4, //
|
|
255, 255, 0, 255, 1, 2, 255, 3, //
|
|
0, 255, 1, 255, 2, 3, 255, 4, //
|
|
255, 0, 1, 255, 2, 3, 255, 4, //
|
|
0, 1, 2, 255, 3, 4, 255, 5, //
|
|
255, 255, 255, 0, 1, 2, 255, 3, //
|
|
0, 255, 255, 1, 2, 3, 255, 4, //
|
|
255, 0, 255, 1, 2, 3, 255, 4, //
|
|
0, 1, 255, 2, 3, 4, 255, 5, //
|
|
255, 255, 0, 1, 2, 3, 255, 4, //
|
|
0, 255, 1, 2, 3, 4, 255, 5, //
|
|
255, 0, 1, 2, 3, 4, 255, 5, //
|
|
0, 1, 2, 3, 4, 5, 255, 6, //
|
|
255, 255, 255, 255, 255, 255, 0, 1, //
|
|
0, 255, 255, 255, 255, 255, 1, 2, //
|
|
255, 0, 255, 255, 255, 255, 1, 2, //
|
|
0, 1, 255, 255, 255, 255, 2, 3, //
|
|
255, 255, 0, 255, 255, 255, 1, 2, //
|
|
0, 255, 1, 255, 255, 255, 2, 3, //
|
|
255, 0, 1, 255, 255, 255, 2, 3, //
|
|
0, 1, 2, 255, 255, 255, 3, 4, //
|
|
255, 255, 255, 0, 255, 255, 1, 2, //
|
|
0, 255, 255, 1, 255, 255, 2, 3, //
|
|
255, 0, 255, 1, 255, 255, 2, 3, //
|
|
0, 1, 255, 2, 255, 255, 3, 4, //
|
|
255, 255, 0, 1, 255, 255, 2, 3, //
|
|
0, 255, 1, 2, 255, 255, 3, 4, //
|
|
255, 0, 1, 2, 255, 255, 3, 4, //
|
|
0, 1, 2, 3, 255, 255, 4, 5, //
|
|
255, 255, 255, 255, 0, 255, 1, 2, //
|
|
0, 255, 255, 255, 1, 255, 2, 3, //
|
|
255, 0, 255, 255, 1, 255, 2, 3, //
|
|
0, 1, 255, 255, 2, 255, 3, 4, //
|
|
255, 255, 0, 255, 1, 255, 2, 3, //
|
|
0, 255, 1, 255, 2, 255, 3, 4, //
|
|
255, 0, 1, 255, 2, 255, 3, 4, //
|
|
0, 1, 2, 255, 3, 255, 4, 5, //
|
|
255, 255, 255, 0, 1, 255, 2, 3, //
|
|
0, 255, 255, 1, 2, 255, 3, 4, //
|
|
255, 0, 255, 1, 2, 255, 3, 4, //
|
|
0, 1, 255, 2, 3, 255, 4, 5, //
|
|
255, 255, 0, 1, 2, 255, 3, 4, //
|
|
0, 255, 1, 2, 3, 255, 4, 5, //
|
|
255, 0, 1, 2, 3, 255, 4, 5, //
|
|
0, 1, 2, 3, 4, 255, 5, 6, //
|
|
255, 255, 255, 255, 255, 0, 1, 2, //
|
|
0, 255, 255, 255, 255, 1, 2, 3, //
|
|
255, 0, 255, 255, 255, 1, 2, 3, //
|
|
0, 1, 255, 255, 255, 2, 3, 4, //
|
|
255, 255, 0, 255, 255, 1, 2, 3, //
|
|
0, 255, 1, 255, 255, 2, 3, 4, //
|
|
255, 0, 1, 255, 255, 2, 3, 4, //
|
|
0, 1, 2, 255, 255, 3, 4, 5, //
|
|
255, 255, 255, 0, 255, 1, 2, 3, //
|
|
0, 255, 255, 1, 255, 2, 3, 4, //
|
|
255, 0, 255, 1, 255, 2, 3, 4, //
|
|
0, 1, 255, 2, 255, 3, 4, 5, //
|
|
255, 255, 0, 1, 255, 2, 3, 4, //
|
|
0, 255, 1, 2, 255, 3, 4, 5, //
|
|
255, 0, 1, 2, 255, 3, 4, 5, //
|
|
0, 1, 2, 3, 255, 4, 5, 6, //
|
|
255, 255, 255, 255, 0, 1, 2, 3, //
|
|
0, 255, 255, 255, 1, 2, 3, 4, //
|
|
255, 0, 255, 255, 1, 2, 3, 4, //
|
|
0, 1, 255, 255, 2, 3, 4, 5, //
|
|
255, 255, 0, 255, 1, 2, 3, 4, //
|
|
0, 255, 1, 255, 2, 3, 4, 5, //
|
|
255, 0, 1, 255, 2, 3, 4, 5, //
|
|
0, 1, 2, 255, 3, 4, 5, 6, //
|
|
255, 255, 255, 0, 1, 2, 3, 4, //
|
|
0, 255, 255, 1, 2, 3, 4, 5, //
|
|
255, 0, 255, 1, 2, 3, 4, 5, //
|
|
0, 1, 255, 2, 3, 4, 5, 6, //
|
|
255, 255, 0, 1, 2, 3, 4, 5, //
|
|
0, 255, 1, 2, 3, 4, 5, 6, //
|
|
255, 0, 1, 2, 3, 4, 5, 6, //
|
|
0, 1, 2, 3, 4, 5, 6, 7};
|
|
const svuint16_t indices = PromoteTo(du16, Load(du8, table + offset));
|
|
return TableLookupLanes(v, indices); // already zeros mask=false lanes
|
|
#else
|
|
return detail::ExpandLoop(v, mask);
|
|
#endif
|
|
}
|
|
|
|
template <class V, HWY_IF_T_SIZE_V(V, 4)>
|
|
HWY_API V Expand(V v, svbool_t mask) {
|
|
#if HWY_TARGET == HWY_SVE_256 || HWY_IDE // 32x8
|
|
const DFromV<V> d;
|
|
const RebindToUnsigned<decltype(d)> du32;
|
|
// Convert mask into bitfield via horizontal sum (faster than ORV).
|
|
const svuint32_t bits = Shl(Set(du32, 1), Iota(du32, 0));
|
|
const size_t code = detail::SumOfLanesM(mask, bits);
|
|
|
|
alignas(16) constexpr uint32_t packed_array[256] = {
|
|
// PrintExpand32x8.
|
|
0xffffffff, 0xfffffff0, 0xffffff0f, 0xffffff10, 0xfffff0ff, 0xfffff1f0,
|
|
0xfffff10f, 0xfffff210, 0xffff0fff, 0xffff1ff0, 0xffff1f0f, 0xffff2f10,
|
|
0xffff10ff, 0xffff21f0, 0xffff210f, 0xffff3210, 0xfff0ffff, 0xfff1fff0,
|
|
0xfff1ff0f, 0xfff2ff10, 0xfff1f0ff, 0xfff2f1f0, 0xfff2f10f, 0xfff3f210,
|
|
0xfff10fff, 0xfff21ff0, 0xfff21f0f, 0xfff32f10, 0xfff210ff, 0xfff321f0,
|
|
0xfff3210f, 0xfff43210, 0xff0fffff, 0xff1ffff0, 0xff1fff0f, 0xff2fff10,
|
|
0xff1ff0ff, 0xff2ff1f0, 0xff2ff10f, 0xff3ff210, 0xff1f0fff, 0xff2f1ff0,
|
|
0xff2f1f0f, 0xff3f2f10, 0xff2f10ff, 0xff3f21f0, 0xff3f210f, 0xff4f3210,
|
|
0xff10ffff, 0xff21fff0, 0xff21ff0f, 0xff32ff10, 0xff21f0ff, 0xff32f1f0,
|
|
0xff32f10f, 0xff43f210, 0xff210fff, 0xff321ff0, 0xff321f0f, 0xff432f10,
|
|
0xff3210ff, 0xff4321f0, 0xff43210f, 0xff543210, 0xf0ffffff, 0xf1fffff0,
|
|
0xf1ffff0f, 0xf2ffff10, 0xf1fff0ff, 0xf2fff1f0, 0xf2fff10f, 0xf3fff210,
|
|
0xf1ff0fff, 0xf2ff1ff0, 0xf2ff1f0f, 0xf3ff2f10, 0xf2ff10ff, 0xf3ff21f0,
|
|
0xf3ff210f, 0xf4ff3210, 0xf1f0ffff, 0xf2f1fff0, 0xf2f1ff0f, 0xf3f2ff10,
|
|
0xf2f1f0ff, 0xf3f2f1f0, 0xf3f2f10f, 0xf4f3f210, 0xf2f10fff, 0xf3f21ff0,
|
|
0xf3f21f0f, 0xf4f32f10, 0xf3f210ff, 0xf4f321f0, 0xf4f3210f, 0xf5f43210,
|
|
0xf10fffff, 0xf21ffff0, 0xf21fff0f, 0xf32fff10, 0xf21ff0ff, 0xf32ff1f0,
|
|
0xf32ff10f, 0xf43ff210, 0xf21f0fff, 0xf32f1ff0, 0xf32f1f0f, 0xf43f2f10,
|
|
0xf32f10ff, 0xf43f21f0, 0xf43f210f, 0xf54f3210, 0xf210ffff, 0xf321fff0,
|
|
0xf321ff0f, 0xf432ff10, 0xf321f0ff, 0xf432f1f0, 0xf432f10f, 0xf543f210,
|
|
0xf3210fff, 0xf4321ff0, 0xf4321f0f, 0xf5432f10, 0xf43210ff, 0xf54321f0,
|
|
0xf543210f, 0xf6543210, 0x0fffffff, 0x1ffffff0, 0x1fffff0f, 0x2fffff10,
|
|
0x1ffff0ff, 0x2ffff1f0, 0x2ffff10f, 0x3ffff210, 0x1fff0fff, 0x2fff1ff0,
|
|
0x2fff1f0f, 0x3fff2f10, 0x2fff10ff, 0x3fff21f0, 0x3fff210f, 0x4fff3210,
|
|
0x1ff0ffff, 0x2ff1fff0, 0x2ff1ff0f, 0x3ff2ff10, 0x2ff1f0ff, 0x3ff2f1f0,
|
|
0x3ff2f10f, 0x4ff3f210, 0x2ff10fff, 0x3ff21ff0, 0x3ff21f0f, 0x4ff32f10,
|
|
0x3ff210ff, 0x4ff321f0, 0x4ff3210f, 0x5ff43210, 0x1f0fffff, 0x2f1ffff0,
|
|
0x2f1fff0f, 0x3f2fff10, 0x2f1ff0ff, 0x3f2ff1f0, 0x3f2ff10f, 0x4f3ff210,
|
|
0x2f1f0fff, 0x3f2f1ff0, 0x3f2f1f0f, 0x4f3f2f10, 0x3f2f10ff, 0x4f3f21f0,
|
|
0x4f3f210f, 0x5f4f3210, 0x2f10ffff, 0x3f21fff0, 0x3f21ff0f, 0x4f32ff10,
|
|
0x3f21f0ff, 0x4f32f1f0, 0x4f32f10f, 0x5f43f210, 0x3f210fff, 0x4f321ff0,
|
|
0x4f321f0f, 0x5f432f10, 0x4f3210ff, 0x5f4321f0, 0x5f43210f, 0x6f543210,
|
|
0x10ffffff, 0x21fffff0, 0x21ffff0f, 0x32ffff10, 0x21fff0ff, 0x32fff1f0,
|
|
0x32fff10f, 0x43fff210, 0x21ff0fff, 0x32ff1ff0, 0x32ff1f0f, 0x43ff2f10,
|
|
0x32ff10ff, 0x43ff21f0, 0x43ff210f, 0x54ff3210, 0x21f0ffff, 0x32f1fff0,
|
|
0x32f1ff0f, 0x43f2ff10, 0x32f1f0ff, 0x43f2f1f0, 0x43f2f10f, 0x54f3f210,
|
|
0x32f10fff, 0x43f21ff0, 0x43f21f0f, 0x54f32f10, 0x43f210ff, 0x54f321f0,
|
|
0x54f3210f, 0x65f43210, 0x210fffff, 0x321ffff0, 0x321fff0f, 0x432fff10,
|
|
0x321ff0ff, 0x432ff1f0, 0x432ff10f, 0x543ff210, 0x321f0fff, 0x432f1ff0,
|
|
0x432f1f0f, 0x543f2f10, 0x432f10ff, 0x543f21f0, 0x543f210f, 0x654f3210,
|
|
0x3210ffff, 0x4321fff0, 0x4321ff0f, 0x5432ff10, 0x4321f0ff, 0x5432f1f0,
|
|
0x5432f10f, 0x6543f210, 0x43210fff, 0x54321ff0, 0x54321f0f, 0x65432f10,
|
|
0x543210ff, 0x654321f0, 0x6543210f, 0x76543210};
|
|
|
|
// For lane i, shift the i-th 4-bit index down and mask with 0xF because
|
|
// svtbl zeros outputs if the index is out of bounds.
|
|
const svuint32_t packed = Set(du32, packed_array[code]);
|
|
const svuint32_t indices = detail::AndN(Shr(packed, svindex_u32(0, 4)), 0xF);
|
|
return TableLookupLanes(v, indices); // already zeros mask=false lanes
|
|
#elif HWY_TARGET == HWY_SVE2_128 // 32x4
|
|
const DFromV<V> d;
|
|
const RebindToUnsigned<decltype(d)> du32;
|
|
// Convert mask into bitfield via horizontal sum (faster than ORV).
|
|
const svuint32_t bits = Shl(Set(du32, 1), Iota(du32, 0));
|
|
const size_t offset = detail::SumOfLanesM(mask, bits);
|
|
|
|
alignas(16) constexpr uint32_t packed_array[16] = {
|
|
// PrintExpand64x4Nibble - same for 32x4.
|
|
0x0000ffff, 0x0000fff0, 0x0000ff0f, 0x0000ff10, 0x0000f0ff, 0x0000f1f0,
|
|
0x0000f10f, 0x0000f210, 0x00000fff, 0x00001ff0, 0x00001f0f, 0x00002f10,
|
|
0x000010ff, 0x000021f0, 0x0000210f, 0x00003210};
|
|
|
|
// For lane i, shift the i-th 4-bit index down and mask with 0xF because
|
|
// svtbl zeros outputs if the index is out of bounds.
|
|
const svuint32_t packed = Set(du32, packed_array[offset]);
|
|
const svuint32_t indices = detail::AndN(Shr(packed, svindex_u32(0, 4)), 0xF);
|
|
return TableLookupLanes(v, indices); // already zeros mask=false lanes
|
|
#else
|
|
return detail::ExpandLoop(v, mask);
|
|
#endif
|
|
}
|
|
|
|
template <class V, HWY_IF_T_SIZE_V(V, 8)>
|
|
HWY_API V Expand(V v, svbool_t mask) {
|
|
#if HWY_TARGET == HWY_SVE_256 || HWY_IDE // 64x4
|
|
const DFromV<V> d;
|
|
const RebindToUnsigned<decltype(d)> du64;
|
|
|
|
// Convert mask into bitfield via horizontal sum (faster than ORV) of masked
|
|
// bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for
|
|
// SetTableIndices.
|
|
const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2));
|
|
const size_t offset = detail::SumOfLanesM(mask, bits);
|
|
|
|
alignas(16) static constexpr uint64_t table[4 * 16] = {
|
|
// PrintExpand64x4Tables - small enough to store uncompressed.
|
|
255, 255, 255, 255, 0, 255, 255, 255, 255, 0, 255, 255, 0, 1, 255, 255,
|
|
255, 255, 0, 255, 0, 255, 1, 255, 255, 0, 1, 255, 0, 1, 2, 255,
|
|
255, 255, 255, 0, 0, 255, 255, 1, 255, 0, 255, 1, 0, 1, 255, 2,
|
|
255, 255, 0, 1, 0, 255, 1, 2, 255, 0, 1, 2, 0, 1, 2, 3};
|
|
// This already zeros mask=false lanes.
|
|
return TableLookupLanes(v, SetTableIndices(d, table + offset));
|
|
#elif HWY_TARGET == HWY_SVE2_128 // 64x2
|
|
// Same as Compress, just zero out the mask=false lanes.
|
|
return IfThenElseZero(mask, Compress(v, mask));
|
|
#else
|
|
return detail::ExpandLoop(v, mask);
|
|
#endif
|
|
}
|
|
|
|
// ------------------------------ LoadExpand
|
|
|
|
template <class D>
|
|
HWY_API VFromD<D> LoadExpand(MFromD<D> mask, D d,
|
|
const TFromD<D>* HWY_RESTRICT unaligned) {
|
|
return Expand(LoadU(d, unaligned), mask);
|
|
}
|
|
|
|
// ------------------------------ MulEven (InterleaveEven)
|
|
|
|
#if HWY_SVE_HAVE_2
|
|
namespace detail {
|
|
#define HWY_SVE_MUL_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) \
|
|
NAME(HWY_SVE_V(BASE, HALF) a, HWY_SVE_V(BASE, HALF) b) { \
|
|
return sv##OP##_##CHAR##BITS(a, b); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_UI16(HWY_SVE_MUL_EVEN, MulEvenNative, mullb)
|
|
HWY_SVE_FOREACH_UI32(HWY_SVE_MUL_EVEN, MulEvenNative, mullb)
|
|
HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulEvenNative, mullb)
|
|
HWY_SVE_FOREACH_UI16(HWY_SVE_MUL_EVEN, MulOddNative, mullt)
|
|
HWY_SVE_FOREACH_UI32(HWY_SVE_MUL_EVEN, MulOddNative, mullt)
|
|
HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulOddNative, mullt)
|
|
#undef HWY_SVE_MUL_EVEN
|
|
} // namespace detail
|
|
#endif
|
|
|
|
template <class V, class DW = RepartitionToWide<DFromV<V>>,
|
|
HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))>
|
|
HWY_API VFromD<DW> MulEven(const V a, const V b) {
|
|
#if HWY_SVE_HAVE_2
|
|
return BitCast(DW(), detail::MulEvenNative(a, b));
|
|
#else
|
|
const auto lo = Mul(a, b);
|
|
const auto hi = MulHigh(a, b);
|
|
return BitCast(DW(), detail::InterleaveEven(lo, hi));
|
|
#endif
|
|
}
|
|
|
|
template <class V, class DW = RepartitionToWide<DFromV<V>>,
|
|
HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))>
|
|
HWY_API VFromD<DW> MulOdd(const V a, const V b) {
|
|
#if HWY_SVE_HAVE_2
|
|
return BitCast(DW(), detail::MulOddNative(a, b));
|
|
#else
|
|
const auto lo = Mul(a, b);
|
|
const auto hi = MulHigh(a, b);
|
|
return BitCast(DW(), detail::InterleaveOdd(lo, hi));
|
|
#endif
|
|
}
|
|
|
|
HWY_API svuint64_t MulEven(const svuint64_t a, const svuint64_t b) {
|
|
const auto lo = Mul(a, b);
|
|
const auto hi = MulHigh(a, b);
|
|
return detail::InterleaveEven(lo, hi);
|
|
}
|
|
|
|
HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) {
|
|
const auto lo = Mul(a, b);
|
|
const auto hi = MulHigh(a, b);
|
|
return detail::InterleaveOdd(lo, hi);
|
|
}
|
|
|
|
// ------------------------------ WidenMulPairwiseAdd
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svfloat32_t WidenMulPairwiseAdd(Simd<float, N, kPow2> df32, VBF16 a,
|
|
VBF16 b) {
|
|
#if HWY_SVE_HAVE_BFLOAT16
|
|
const svfloat32_t even = svbfmlalb_f32(Zero(df32), a, b);
|
|
return svbfmlalt_f32(even, a, b);
|
|
#else
|
|
const RebindToUnsigned<decltype(df32)> du32;
|
|
// Using shift/and instead of Zip leads to the odd/even order that
|
|
// RearrangeToOddPlusEven prefers.
|
|
using VU32 = VFromD<decltype(du32)>;
|
|
const VU32 odd = Set(du32, 0xFFFF0000u);
|
|
const VU32 ae = ShiftLeft<16>(BitCast(du32, a));
|
|
const VU32 ao = And(BitCast(du32, a), odd);
|
|
const VU32 be = ShiftLeft<16>(BitCast(du32, b));
|
|
const VU32 bo = And(BitCast(du32, b), odd);
|
|
return MulAdd(BitCast(df32, ae), BitCast(df32, be),
|
|
Mul(BitCast(df32, ao), BitCast(df32, bo)));
|
|
#endif // HWY_SVE_HAVE_BFLOAT16
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint32_t WidenMulPairwiseAdd(Simd<int32_t, N, kPow2> d32, svint16_t a,
|
|
svint16_t b) {
|
|
#if HWY_SVE_HAVE_2
|
|
(void)d32;
|
|
return svmlalt_s32(svmullb_s32(a, b), a, b);
|
|
#else
|
|
const svbool_t pg = detail::PTrue(d32);
|
|
// Shifting extracts the odd lanes as RearrangeToOddPlusEven prefers.
|
|
// Fortunately SVE has sign-extension for the even lanes.
|
|
const svint32_t ae = svexth_s32_x(pg, BitCast(d32, a));
|
|
const svint32_t be = svexth_s32_x(pg, BitCast(d32, b));
|
|
const svint32_t ao = ShiftRight<16>(BitCast(d32, a));
|
|
const svint32_t bo = ShiftRight<16>(BitCast(d32, b));
|
|
return svmla_s32_x(pg, svmul_s32_x(pg, ao, bo), ae, be);
|
|
#endif
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint32_t WidenMulPairwiseAdd(Simd<uint32_t, N, kPow2> d32,
|
|
svuint16_t a, svuint16_t b) {
|
|
#if HWY_SVE_HAVE_2
|
|
(void)d32;
|
|
return svmlalt_u32(svmullb_u32(a, b), a, b);
|
|
#else
|
|
const svbool_t pg = detail::PTrue(d32);
|
|
// Shifting extracts the odd lanes as RearrangeToOddPlusEven prefers.
|
|
// Fortunately SVE has sign-extension for the even lanes.
|
|
const svuint32_t ae = svexth_u32_x(pg, BitCast(d32, a));
|
|
const svuint32_t be = svexth_u32_x(pg, BitCast(d32, b));
|
|
const svuint32_t ao = ShiftRight<16>(BitCast(d32, a));
|
|
const svuint32_t bo = ShiftRight<16>(BitCast(d32, b));
|
|
return svmla_u32_x(pg, svmul_u32_x(pg, ao, bo), ae, be);
|
|
#endif
|
|
}
|
|
|
|
// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower)
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svfloat32_t ReorderWidenMulAccumulate(Simd<float, N, kPow2> df32,
|
|
VBF16 a, VBF16 b,
|
|
const svfloat32_t sum0,
|
|
svfloat32_t& sum1) {
|
|
#if HWY_SVE_HAVE_BFLOAT16
|
|
(void)df32;
|
|
sum1 = svbfmlalt_f32(sum1, a, b);
|
|
return svbfmlalb_f32(sum0, a, b);
|
|
#else
|
|
const RebindToUnsigned<decltype(df32)> du32;
|
|
// Using shift/and instead of Zip leads to the odd/even order that
|
|
// RearrangeToOddPlusEven prefers.
|
|
using VU32 = VFromD<decltype(du32)>;
|
|
const VU32 odd = Set(du32, 0xFFFF0000u);
|
|
const VU32 ae = ShiftLeft<16>(BitCast(du32, a));
|
|
const VU32 ao = And(BitCast(du32, a), odd);
|
|
const VU32 be = ShiftLeft<16>(BitCast(du32, b));
|
|
const VU32 bo = And(BitCast(du32, b), odd);
|
|
sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1);
|
|
return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0);
|
|
#endif // HWY_SVE_HAVE_BFLOAT16
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svint32_t ReorderWidenMulAccumulate(Simd<int32_t, N, kPow2> d32,
|
|
svint16_t a, svint16_t b,
|
|
const svint32_t sum0,
|
|
svint32_t& sum1) {
|
|
#if HWY_SVE_HAVE_2
|
|
(void)d32;
|
|
sum1 = svmlalt_s32(sum1, a, b);
|
|
return svmlalb_s32(sum0, a, b);
|
|
#else
|
|
const svbool_t pg = detail::PTrue(d32);
|
|
// Shifting extracts the odd lanes as RearrangeToOddPlusEven prefers.
|
|
// Fortunately SVE has sign-extension for the even lanes.
|
|
const svint32_t ae = svexth_s32_x(pg, BitCast(d32, a));
|
|
const svint32_t be = svexth_s32_x(pg, BitCast(d32, b));
|
|
const svint32_t ao = ShiftRight<16>(BitCast(d32, a));
|
|
const svint32_t bo = ShiftRight<16>(BitCast(d32, b));
|
|
sum1 = svmla_s32_x(pg, sum1, ao, bo);
|
|
return svmla_s32_x(pg, sum0, ae, be);
|
|
#endif
|
|
}
|
|
|
|
template <size_t N, int kPow2>
|
|
HWY_API svuint32_t ReorderWidenMulAccumulate(Simd<uint32_t, N, kPow2> d32,
|
|
svuint16_t a, svuint16_t b,
|
|
const svuint32_t sum0,
|
|
svuint32_t& sum1) {
|
|
#if HWY_SVE_HAVE_2
|
|
(void)d32;
|
|
sum1 = svmlalt_u32(sum1, a, b);
|
|
return svmlalb_u32(sum0, a, b);
|
|
#else
|
|
const svbool_t pg = detail::PTrue(d32);
|
|
// Shifting extracts the odd lanes as RearrangeToOddPlusEven prefers.
|
|
// Fortunately SVE has sign-extension for the even lanes.
|
|
const svuint32_t ae = svexth_u32_x(pg, BitCast(d32, a));
|
|
const svuint32_t be = svexth_u32_x(pg, BitCast(d32, b));
|
|
const svuint32_t ao = ShiftRight<16>(BitCast(d32, a));
|
|
const svuint32_t bo = ShiftRight<16>(BitCast(d32, b));
|
|
sum1 = svmla_u32_x(pg, sum1, ao, bo);
|
|
return svmla_u32_x(pg, sum0, ae, be);
|
|
#endif
|
|
}
|
|
|
|
// ------------------------------ RearrangeToOddPlusEven
|
|
template <class VW>
|
|
HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) {
|
|
// sum0 is the sum of bottom/even lanes and sum1 of top/odd lanes.
|
|
return Add(sum0, sum1);
|
|
}
|
|
|
|
// ------------------------------ SumOfMulQuadAccumulate
|
|
|
|
#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE
|
|
#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE
|
|
#else
|
|
#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE
|
|
#endif
|
|
|
|
template <class DI32, HWY_IF_I32_D(DI32)>
|
|
HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 /*di32*/, svint8_t a,
|
|
svint8_t b, svint32_t sum) {
|
|
return svdot_s32(sum, a, b);
|
|
}
|
|
|
|
#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE
|
|
#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE
|
|
#else
|
|
#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE
|
|
#endif
|
|
|
|
template <class DU32, HWY_IF_U32_D(DU32)>
|
|
HWY_API VFromD<DU32> SumOfMulQuadAccumulate(DU32 /*du32*/, svuint8_t a,
|
|
svuint8_t b, svuint32_t sum) {
|
|
return svdot_u32(sum, a, b);
|
|
}
|
|
|
|
#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE
|
|
#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE
|
|
#else
|
|
#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE
|
|
#endif
|
|
|
|
template <class DI32, HWY_IF_I32_D(DI32)>
|
|
HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 di32, svuint8_t a_u,
|
|
svint8_t b_i, svint32_t sum) {
|
|
// TODO: use svusdot_u32 on SVE targets that require support for both SVE2
|
|
// and SVE I8MM.
|
|
|
|
const RebindToUnsigned<decltype(di32)> du32;
|
|
const Repartition<uint8_t, decltype(di32)> du8;
|
|
|
|
const auto b_u = BitCast(du8, b_i);
|
|
const auto result_sum0 = svdot_u32(BitCast(du32, sum), a_u, b_u);
|
|
const auto result_sum1 =
|
|
ShiftLeft<8>(svdot_u32(Zero(du32), a_u, ShiftRight<7>(b_u)));
|
|
|
|
return BitCast(di32, Sub(result_sum0, result_sum1));
|
|
}
|
|
|
|
#ifdef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE
|
|
#undef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE
|
|
#else
|
|
#define HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE
|
|
#endif
|
|
|
|
template <class DI64, HWY_IF_I64_D(DI64)>
|
|
HWY_API VFromD<DI64> SumOfMulQuadAccumulate(DI64 /*di64*/, svint16_t a,
|
|
svint16_t b, svint64_t sum) {
|
|
return svdot_s64(sum, a, b);
|
|
}
|
|
|
|
#ifdef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE
|
|
#undef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE
|
|
#else
|
|
#define HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE
|
|
#endif
|
|
|
|
template <class DU64, HWY_IF_U64_D(DU64)>
|
|
HWY_API VFromD<DU64> SumOfMulQuadAccumulate(DU64 /*du64*/, svuint16_t a,
|
|
svuint16_t b, svuint64_t sum) {
|
|
return svdot_u64(sum, a, b);
|
|
}
|
|
|
|
// ------------------------------ AESRound / CLMul
|
|
|
|
#if defined(__ARM_FEATURE_SVE2_AES) || \
|
|
(HWY_SVE_HAVE_2 && HWY_HAVE_RUNTIME_DISPATCH)
|
|
|
|
// Per-target flag to prevent generic_ops-inl.h from defining AESRound.
|
|
#ifdef HWY_NATIVE_AES
|
|
#undef HWY_NATIVE_AES
|
|
#else
|
|
#define HWY_NATIVE_AES
|
|
#endif
|
|
|
|
HWY_API svuint8_t AESRound(svuint8_t state, svuint8_t round_key) {
|
|
// It is not clear whether E and MC fuse like they did on NEON.
|
|
return Xor(svaesmc_u8(svaese_u8(state, svdup_n_u8(0))), round_key);
|
|
}
|
|
|
|
HWY_API svuint8_t AESLastRound(svuint8_t state, svuint8_t round_key) {
|
|
return Xor(svaese_u8(state, svdup_n_u8(0)), round_key);
|
|
}
|
|
|
|
HWY_API svuint8_t AESInvMixColumns(svuint8_t state) {
|
|
return svaesimc_u8(state);
|
|
}
|
|
|
|
HWY_API svuint8_t AESRoundInv(svuint8_t state, svuint8_t round_key) {
|
|
return Xor(svaesimc_u8(svaesd_u8(state, svdup_n_u8(0))), round_key);
|
|
}
|
|
|
|
HWY_API svuint8_t AESLastRoundInv(svuint8_t state, svuint8_t round_key) {
|
|
return Xor(svaesd_u8(state, svdup_n_u8(0)), round_key);
|
|
}
|
|
|
|
template <uint8_t kRcon>
|
|
HWY_API svuint8_t AESKeyGenAssist(svuint8_t v) {
|
|
alignas(16) static constexpr uint8_t kRconXorMask[16] = {
|
|
0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0};
|
|
alignas(16) static constexpr uint8_t kRotWordShuffle[16] = {
|
|
0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12};
|
|
const DFromV<decltype(v)> d;
|
|
const Repartition<uint32_t, decltype(d)> du32;
|
|
const auto w13 = BitCast(d, DupOdd(BitCast(du32, v)));
|
|
const auto sub_word_result = AESLastRound(w13, LoadDup128(d, kRconXorMask));
|
|
return TableLookupBytes(sub_word_result, LoadDup128(d, kRotWordShuffle));
|
|
}
|
|
|
|
HWY_API svuint64_t CLMulLower(const svuint64_t a, const svuint64_t b) {
|
|
return svpmullb_pair(a, b);
|
|
}
|
|
|
|
HWY_API svuint64_t CLMulUpper(const svuint64_t a, const svuint64_t b) {
|
|
return svpmullt_pair(a, b);
|
|
}
|
|
|
|
#endif // __ARM_FEATURE_SVE2_AES
|
|
|
|
// ------------------------------ Lt128
|
|
|
|
namespace detail {
|
|
#define HWY_SVE_DUP(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
template <size_t N, int kPow2> \
|
|
HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, svbool_t m) { \
|
|
return sv##OP##_b##BITS(m, m); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupEvenB, trn1) // actually for bool
|
|
HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupOddB, trn2) // actually for bool
|
|
#undef HWY_SVE_DUP
|
|
|
|
#if HWY_TARGET == HWY_SVE_256 || HWY_IDE
|
|
template <class D>
|
|
HWY_INLINE svuint64_t Lt128Vec(D d, const svuint64_t a, const svuint64_t b) {
|
|
static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64");
|
|
const svbool_t eqHx = Eq(a, b); // only odd lanes used
|
|
// Convert to vector: more pipelines can execute vector TRN* instructions
|
|
// than the predicate version.
|
|
const svuint64_t ltHL = VecFromMask(d, Lt(a, b));
|
|
// Move into upper lane: ltL if the upper half is equal, otherwise ltH.
|
|
// Requires an extra IfThenElse because INSR, EXT, TRN2 are unpredicated.
|
|
const svuint64_t ltHx = IfThenElse(eqHx, DupEven(ltHL), ltHL);
|
|
// Duplicate upper lane into lower.
|
|
return DupOdd(ltHx);
|
|
}
|
|
#endif
|
|
} // namespace detail
|
|
|
|
template <class D>
|
|
HWY_INLINE svbool_t Lt128(D d, const svuint64_t a, const svuint64_t b) {
|
|
#if HWY_TARGET == HWY_SVE_256
|
|
return MaskFromVec(detail::Lt128Vec(d, a, b));
|
|
#else
|
|
static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64");
|
|
const svbool_t eqHx = Eq(a, b); // only odd lanes used
|
|
const svbool_t ltHL = Lt(a, b);
|
|
// Move into upper lane: ltL if the upper half is equal, otherwise ltH.
|
|
const svbool_t ltHx = svsel_b(eqHx, detail::DupEvenB(d, ltHL), ltHL);
|
|
// Duplicate upper lane into lower.
|
|
return detail::DupOddB(d, ltHx);
|
|
#endif // HWY_TARGET != HWY_SVE_256
|
|
}
|
|
|
|
// ------------------------------ Lt128Upper
|
|
|
|
template <class D>
|
|
HWY_INLINE svbool_t Lt128Upper(D d, svuint64_t a, svuint64_t b) {
|
|
static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64");
|
|
const svbool_t ltHL = Lt(a, b);
|
|
return detail::DupOddB(d, ltHL);
|
|
}
|
|
|
|
// ------------------------------ Eq128, Ne128
|
|
|
|
#if HWY_TARGET == HWY_SVE_256 || HWY_IDE
|
|
namespace detail {
|
|
|
|
template <class D>
|
|
HWY_INLINE svuint64_t Eq128Vec(D d, const svuint64_t a, const svuint64_t b) {
|
|
static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64");
|
|
// Convert to vector: more pipelines can execute vector TRN* instructions
|
|
// than the predicate version.
|
|
const svuint64_t eqHL = VecFromMask(d, Eq(a, b));
|
|
// Duplicate upper and lower.
|
|
const svuint64_t eqHH = DupOdd(eqHL);
|
|
const svuint64_t eqLL = DupEven(eqHL);
|
|
return And(eqLL, eqHH);
|
|
}
|
|
|
|
template <class D>
|
|
HWY_INLINE svuint64_t Ne128Vec(D d, const svuint64_t a, const svuint64_t b) {
|
|
static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64");
|
|
// Convert to vector: more pipelines can execute vector TRN* instructions
|
|
// than the predicate version.
|
|
const svuint64_t neHL = VecFromMask(d, Ne(a, b));
|
|
// Duplicate upper and lower.
|
|
const svuint64_t neHH = DupOdd(neHL);
|
|
const svuint64_t neLL = DupEven(neHL);
|
|
return Or(neLL, neHH);
|
|
}
|
|
|
|
} // namespace detail
|
|
#endif
|
|
|
|
template <class D>
|
|
HWY_INLINE svbool_t Eq128(D d, const svuint64_t a, const svuint64_t b) {
|
|
#if HWY_TARGET == HWY_SVE_256
|
|
return MaskFromVec(detail::Eq128Vec(d, a, b));
|
|
#else
|
|
static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64");
|
|
const svbool_t eqHL = Eq(a, b);
|
|
const svbool_t eqHH = detail::DupOddB(d, eqHL);
|
|
const svbool_t eqLL = detail::DupEvenB(d, eqHL);
|
|
return And(eqLL, eqHH);
|
|
#endif // HWY_TARGET != HWY_SVE_256
|
|
}
|
|
|
|
template <class D>
|
|
HWY_INLINE svbool_t Ne128(D d, const svuint64_t a, const svuint64_t b) {
|
|
#if HWY_TARGET == HWY_SVE_256
|
|
return MaskFromVec(detail::Ne128Vec(d, a, b));
|
|
#else
|
|
static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64");
|
|
const svbool_t neHL = Ne(a, b);
|
|
const svbool_t neHH = detail::DupOddB(d, neHL);
|
|
const svbool_t neLL = detail::DupEvenB(d, neHL);
|
|
return Or(neLL, neHH);
|
|
#endif // HWY_TARGET != HWY_SVE_256
|
|
}
|
|
|
|
// ------------------------------ Eq128Upper, Ne128Upper
|
|
|
|
template <class D>
|
|
HWY_INLINE svbool_t Eq128Upper(D d, svuint64_t a, svuint64_t b) {
|
|
static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64");
|
|
const svbool_t eqHL = Eq(a, b);
|
|
return detail::DupOddB(d, eqHL);
|
|
}
|
|
|
|
template <class D>
|
|
HWY_INLINE svbool_t Ne128Upper(D d, svuint64_t a, svuint64_t b) {
|
|
static_assert(IsSame<TFromD<D>, uint64_t>(), "D must be u64");
|
|
const svbool_t neHL = Ne(a, b);
|
|
return detail::DupOddB(d, neHL);
|
|
}
|
|
|
|
// ------------------------------ Min128, Max128 (Lt128)
|
|
|
|
template <class D>
|
|
HWY_INLINE svuint64_t Min128(D d, const svuint64_t a, const svuint64_t b) {
|
|
#if HWY_TARGET == HWY_SVE_256
|
|
return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b);
|
|
#else
|
|
return IfThenElse(Lt128(d, a, b), a, b);
|
|
#endif
|
|
}
|
|
|
|
template <class D>
|
|
HWY_INLINE svuint64_t Max128(D d, const svuint64_t a, const svuint64_t b) {
|
|
#if HWY_TARGET == HWY_SVE_256
|
|
return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b);
|
|
#else
|
|
return IfThenElse(Lt128(d, b, a), a, b);
|
|
#endif
|
|
}
|
|
|
|
template <class D>
|
|
HWY_INLINE svuint64_t Min128Upper(D d, const svuint64_t a, const svuint64_t b) {
|
|
return IfThenElse(Lt128Upper(d, a, b), a, b);
|
|
}
|
|
|
|
template <class D>
|
|
HWY_INLINE svuint64_t Max128Upper(D d, const svuint64_t a, const svuint64_t b) {
|
|
return IfThenElse(Lt128Upper(d, b, a), a, b);
|
|
}
|
|
|
|
// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex
|
|
|
|
#ifdef HWY_NATIVE_LEADING_ZERO_COUNT
|
|
#undef HWY_NATIVE_LEADING_ZERO_COUNT
|
|
#else
|
|
#define HWY_NATIVE_LEADING_ZERO_COUNT
|
|
#endif
|
|
|
|
#define HWY_SVE_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \
|
|
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
|
|
const DFromV<decltype(v)> d; \
|
|
return BitCast(d, sv##OP##_##CHAR##BITS##_x(detail::PTrue(d), v)); \
|
|
}
|
|
|
|
HWY_SVE_FOREACH_UI(HWY_SVE_LEADING_ZERO_COUNT, LeadingZeroCount, clz)
|
|
#undef HWY_SVE_LEADING_ZERO_COUNT
|
|
|
|
template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
|
|
HWY_API V TrailingZeroCount(V v) {
|
|
return LeadingZeroCount(ReverseBits(v));
|
|
}
|
|
|
|
template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
|
|
HWY_API V HighestSetBitIndex(V v) {
|
|
const DFromV<decltype(v)> d;
|
|
using T = TFromD<decltype(d)>;
|
|
return BitCast(d, Sub(Set(d, T{sizeof(T) * 8 - 1}), LeadingZeroCount(v)));
|
|
}
|
|
|
|
// ================================================== END MACROS
|
|
namespace detail { // for code folding
|
|
#undef HWY_SVE_ALL_PTRUE
|
|
#undef HWY_SVE_D
|
|
#undef HWY_SVE_FOREACH
|
|
#undef HWY_SVE_FOREACH_BF16
|
|
#undef HWY_SVE_FOREACH_F
|
|
#undef HWY_SVE_FOREACH_F16
|
|
#undef HWY_SVE_FOREACH_F32
|
|
#undef HWY_SVE_FOREACH_F64
|
|
#undef HWY_SVE_FOREACH_I
|
|
#undef HWY_SVE_FOREACH_I08
|
|
#undef HWY_SVE_FOREACH_I16
|
|
#undef HWY_SVE_FOREACH_I32
|
|
#undef HWY_SVE_FOREACH_I64
|
|
#undef HWY_SVE_FOREACH_IF
|
|
#undef HWY_SVE_FOREACH_U
|
|
#undef HWY_SVE_FOREACH_U08
|
|
#undef HWY_SVE_FOREACH_U16
|
|
#undef HWY_SVE_FOREACH_U32
|
|
#undef HWY_SVE_FOREACH_U64
|
|
#undef HWY_SVE_FOREACH_UI
|
|
#undef HWY_SVE_FOREACH_UI08
|
|
#undef HWY_SVE_FOREACH_UI16
|
|
#undef HWY_SVE_FOREACH_UI32
|
|
#undef HWY_SVE_FOREACH_UI64
|
|
#undef HWY_SVE_FOREACH_UIF3264
|
|
#undef HWY_SVE_HAVE_2
|
|
#undef HWY_SVE_PTRUE
|
|
#undef HWY_SVE_RETV_ARGPV
|
|
#undef HWY_SVE_RETV_ARGPVN
|
|
#undef HWY_SVE_RETV_ARGPVV
|
|
#undef HWY_SVE_RETV_ARGV
|
|
#undef HWY_SVE_RETV_ARGVN
|
|
#undef HWY_SVE_RETV_ARGVV
|
|
#undef HWY_SVE_RETV_ARGVVV
|
|
#undef HWY_SVE_T
|
|
#undef HWY_SVE_UNDEFINED
|
|
#undef HWY_SVE_V
|
|
|
|
} // namespace detail
|
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
|
} // namespace HWY_NAMESPACE
|
|
} // namespace hwy
|
|
HWY_AFTER_NAMESPACE();
|