initial SSE implementation of memcmp for x64 platforms

This commit is contained in:
Nic Barker 2025-01-31 10:05:32 +13:00
parent 85acb86dbc
commit b58bdd1a1d

66
clay.h
View File

@ -16,7 +16,10 @@
#include <stdlib.h>
#include <string.h>
#ifdef __aarch64__
// SIMD includes on supported platforms
#ifdef __x86_64__
#include <emmintrin.h>
#elif __aarch64__
#include <arm_neon.h>
#endif
@ -1409,34 +1412,15 @@ void Clay__CloseElement(void) {
}
bool Clay__MemCmp(const char *s1, const char *s2, int32_t length);
#ifdef __x86_64__
bool Clay__MemCmp(const char *s1, const char *s2, int32_t length) {
for (int32_t i = 0; i < length; i++) {
if (s1[i] != s2[i]) {
return false;
}
}
return true;
}
#elif defined(__aarch64__)
bool Clay__MemCmp(const char *s1, const char *s2, int32_t length) {
while (length >= 16) {
uint8x16_t v1 = vld1q_u8((const uint8_t *)s1);
uint8x16_t v2 = vld1q_u8((const uint8_t *)s2);
__m128i v1 = _mm_loadu_si128((const __m128i *)s1);
__m128i v2 = _mm_loadu_si128((const __m128i *)s2);
// Compare vectors
uint8x16_t cmp = vceqq_u8(v1, v2);
uint64_t mask = vgetq_lane_u64(vreinterpretq_u64_u8(cmp), 0) &
vgetq_lane_u64(vreinterpretq_u64_u8(cmp), 1);
if (mask != UINT64_MAX) { // If there's a difference
for (int32_t i = 0; i < 16; i++) {
if (s1[i] != s2[i]) {
if (_mm_movemask_epi8(_mm_cmpeq_epi8(v1, v2)) != 0xFFFF) { // If any byte differs
return false;
}
}
}
s1 += 16;
s2 += 16;
@ -1452,6 +1436,42 @@ bool Clay__MemCmp(const char *s1, const char *s2, int32_t length);
s2++;
}
return true;
}
#elif defined(__aarch64__)
bool Clay__MemCmp(const char *s1, const char *s2, int32_t length) {
while (length >= 16) {
uint8x16_t v1 = vld1q_u8((const uint8_t *)s1);
uint8x16_t v2 = vld1q_u8((const uint8_t *)s2);
// Compare vectors
if (vminvq_u32(vceqq_u8(v1, v2)) != 0xFFFFFFFF) { // If there's a difference
return false;
}
s1 += 16;
s2 += 16;
length -= 16;
}
// Handle remaining bytes
while (length--) {
if (*s1 != *s2) {
return false;
}
s1++;
s2++;
}
return true;
}
#else
bool Clay__MemCmp(const char *s1, const char *s2, int32_t length) {
for (int32_t i = 0; i < length; i++) {
if (s1[i] != s2[i]) {
return false;
}
}
return true;
}
#endif