diff --git a/clay.h b/clay.h index d2810db..e04ed2f 100644 --- a/clay.h +++ b/clay.h @@ -16,7 +16,10 @@ #include #include -#ifdef __aarch64__ +// SIMD includes on supported platforms +#ifdef __x86_64__ +#include +#elif __aarch64__ #include #endif @@ -1409,33 +1412,14 @@ void Clay__CloseElement(void) { } bool Clay__MemCmp(const char *s1, const char *s2, int32_t length); - #ifdef __x86_64__ - bool Clay__MemCmp(const char *s1, const char *s2, int32_t length) { - for (int32_t i = 0; i < length; i++) { - if (s1[i] != s2[i]) { - return false; - } - } - return true; - } -#elif defined(__aarch64__) bool Clay__MemCmp(const char *s1, const char *s2, int32_t length) { while (length >= 16) { - uint8x16_t v1 = vld1q_u8((const uint8_t *)s1); - uint8x16_t v2 = vld1q_u8((const uint8_t *)s2); + __m128i v1 = _mm_loadu_si128((const __m128i *)s1); + __m128i v2 = _mm_loadu_si128((const __m128i *)s2); - // Compare vectors - uint8x16_t cmp = vceqq_u8(v1, v2); - uint64_t mask = vgetq_lane_u64(vreinterpretq_u64_u8(cmp), 0) & - vgetq_lane_u64(vreinterpretq_u64_u8(cmp), 1); - - if (mask != UINT64_MAX) { // If there's a difference - for (int32_t i = 0; i < 16; i++) { - if (s1[i] != s2[i]) { - return false; - } - } + if (_mm_movemask_epi8(_mm_cmpeq_epi8(v1, v2)) != 0xFFFF) { // If any byte differs + return false; } s1 += 16; @@ -1452,6 +1436,42 @@ bool Clay__MemCmp(const char *s1, const char *s2, int32_t length); s2++; } + return true; + } +#elif defined(__aarch64__) + bool Clay__MemCmp(const char *s1, const char *s2, int32_t length) { + while (length >= 16) { + uint8x16_t v1 = vld1q_u8((const uint8_t *)s1); + uint8x16_t v2 = vld1q_u8((const uint8_t *)s2); + + // Compare vectors + if (vminvq_u32(vceqq_u8(v1, v2)) != 0xFFFFFFFF) { // If there's a difference + return false; + } + + s1 += 16; + s2 += 16; + length -= 16; + } + + // Handle remaining bytes + while (length--) { + if (*s1 != *s2) { + return false; + } + s1++; + s2++; + } + + return true; + } +#else + bool Clay__MemCmp(const char *s1, const char *s2, int32_t length) { + for (int32_t i = 0; i < length; i++) { + if (s1[i] != s2[i]) { + return false; + } + } return true; } #endif