1/* strstr optimized with 512-bit AVX-512 instructions
2 Copyright (C) 2022-2024 Free Software Foundation, Inc.
3 This file is part of the GNU C Library.
4
5 The GNU C Library is free software; you can redistribute it and/or
6 modify it under the terms of the GNU Lesser General Public
7 License as published by the Free Software Foundation; either
8 version 2.1 of the License, or (at your option) any later version.
9
10 The GNU C Library is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13 Lesser General Public License for more details.
14
15 You should have received a copy of the GNU Lesser General Public
16 License along with the GNU C Library; if not, see
17 <https://www.gnu.org/licenses/>. */
18
19#include <immintrin.h>
20#include <inttypes.h>
21#include <stdbool.h>
22#include <string.h>
23
24#define FULL_MMASK64 0xffffffffffffffff
25#define ONE_64BIT 0x1ull
26#define ZMM_SIZE_IN_BYTES 64
27#define PAGESIZE 4096
28
29#define cvtmask64_u64(...) (uint64_t) (__VA_ARGS__)
30#define kshiftri_mask64(x, y) ((x) >> (y))
31#define kand_mask64(x, y) ((x) & (y))
32
33/*
34 Returns the index of the first edge within the needle, returns 0 if no edge
35 is found. Example: 'ab' is the first edge in 'aaaaaaaaaabaarddg'
36 */
37static inline size_t
38find_edge_in_needle (const char *ned)
39{
40 size_t ind = 0;
41 while (ned[ind + 1] != '\0')
42 {
43 if (ned[ind] != ned[ind + 1])
44 return ind;
45 else
46 ind = ind + 1;
47 }
48 return 0;
49}
50
51/*
52 Compare needle with haystack byte by byte at specified location
53 */
54static inline bool
55verify_string_match (const char *hay, const size_t hay_index, const char *ned,
56 size_t ind)
57{
58 while (ned[ind] != '\0')
59 {
60 if (ned[ind] != hay[hay_index + ind])
61 return false;
62 ind = ind + 1;
63 }
64 return true;
65}
66
67/*
68 Compare needle with haystack at specified location. The first 64 bytes are
69 compared using a ZMM register.
70 */
71static inline bool
72verify_string_match_avx512 (const char *hay, const size_t hay_index,
73 const char *ned, const __mmask64 ned_mask,
74 const __m512i ned_zmm)
75{
76 /* check first 64 bytes using zmm and then scalar */
77 __m512i hay_zmm = _mm512_loadu_si512 (P: hay + hay_index); // safe to do so
78 __mmask64 match = _mm512_mask_cmpneq_epi8_mask (ned_mask, hay_zmm, ned_zmm);
79 if (match != 0x0) // failed the first few chars
80 return false;
81 else if (ned_mask == FULL_MMASK64)
82 return verify_string_match (hay, hay_index, ned, ZMM_SIZE_IN_BYTES);
83 return true;
84}
85
86char *
87__strstr_avx512 (const char *haystack, const char *ned)
88{
89 char first = ned[0];
90 if (first == '\0')
91 return (char *)haystack;
92 if (ned[1] == '\0')
93 return (char *)strchr (haystack, ned[0]);
94
95 size_t edge = find_edge_in_needle (ned);
96
97 /* ensure haystack is as long as the pos of edge in needle */
98 for (int ii = 0; ii < edge; ++ii)
99 {
100 if (haystack[ii] == '\0')
101 return NULL;
102 }
103
104 /*
105 Load 64 bytes of the needle and save it to a zmm register
106 Read one cache line at a time to avoid loading across a page boundary
107 */
108 __mmask64 ned_load_mask = _bzhi_u64 (
109 FULL_MMASK64, Y: 64 - ((uintptr_t) (ned) & 63));
110 __m512i ned_zmm = _mm512_maskz_loadu_epi8 (U: ned_load_mask, P: ned);
111 __mmask64 ned_nullmask
112 = _mm512_mask_testn_epi8_mask (U: ned_load_mask, A: ned_zmm, B: ned_zmm);
113
114 if (__glibc_unlikely (ned_nullmask == 0x0))
115 {
116 ned_zmm = _mm512_loadu_si512 (P: ned);
117 ned_nullmask = _mm512_testn_epi8_mask (A: ned_zmm, B: ned_zmm);
118 ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT);
119 if (ned_nullmask != 0x0)
120 ned_load_mask = ned_load_mask >> 1;
121 }
122 else
123 {
124 ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT);
125 ned_load_mask = ned_load_mask >> 1;
126 }
127 const __m512i ned0 = _mm512_set1_epi8 (w: ned[edge]);
128 const __m512i ned1 = _mm512_set1_epi8 (w: ned[edge + 1]);
129
130 /*
131 Read the bytes of haystack in the current cache line
132 */
133 size_t hay_index = edge;
134 __mmask64 loadmask = _bzhi_u64 (
135 FULL_MMASK64, Y: 64 - ((uintptr_t) (haystack + hay_index) & 63));
136 /* First load is a partial cache line */
137 __m512i hay0 = _mm512_maskz_loadu_epi8 (U: loadmask, P: haystack + hay_index);
138 /* Search for NULL and compare only till null char */
139 uint64_t nullmask
140 = cvtmask64_u64 (_mm512_mask_testn_epi8_mask (loadmask, hay0, hay0));
141 uint64_t cmpmask = nullmask ^ (nullmask - ONE_64BIT);
142 cmpmask = cmpmask & cvtmask64_u64 (loadmask);
143 /* Search for the 2 characters of needle */
144 __mmask64 k0 = _mm512_cmpeq_epi8_mask (hay0, ned0);
145 __mmask64 k1 = _mm512_cmpeq_epi8_mask (hay0, ned1);
146 k1 = kshiftri_mask64 (k1, 1);
147 /* k2 masks tell us if both chars from needle match */
148 uint64_t k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask;
149 /* For every match, search for the entire needle for a full match */
150 while (k2)
151 {
152 uint64_t bitcount = _tzcnt_u64 (k2);
153 k2 = _blsr_u64 (k2);
154 size_t match_pos = hay_index + bitcount - edge;
155 if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1))
156 < PAGESIZE - 1 - ZMM_SIZE_IN_BYTES)
157 {
158 /*
159 * Use vector compare as long as you are not crossing a page
160 */
161 if (verify_string_match_avx512 (hay: haystack, hay_index: match_pos, ned,
162 ned_mask: ned_load_mask, ned_zmm))
163 return (char *)haystack + match_pos;
164 }
165 else
166 {
167 if (verify_string_match (hay: haystack, hay_index: match_pos, ned, ind: 0))
168 return (char *)haystack + match_pos;
169 }
170 }
171 /* We haven't checked for potential match at the last char yet */
172 haystack = (const char *)(((uintptr_t) (haystack + hay_index) | 63));
173 hay_index = 0;
174
175 /*
176 Loop over one cache line at a time to prevent reading over page
177 boundary
178 */
179 __m512i hay1;
180 while (nullmask == 0)
181 {
182 hay0 = _mm512_loadu_si512 (P: haystack + hay_index);
183 hay1 = _mm512_load_si512 (P: haystack + hay_index
184 + 1); // Always 64 byte aligned
185 nullmask = cvtmask64_u64 (_mm512_testn_epi8_mask (hay1, hay1));
186 /* Compare only till null char */
187 cmpmask = nullmask ^ (nullmask - ONE_64BIT);
188 k0 = _mm512_cmpeq_epi8_mask (hay0, ned0);
189 k1 = _mm512_cmpeq_epi8_mask (hay1, ned1);
190 /* k2 masks tell us if both chars from needle match */
191 k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask;
192 /* For every match, compare full strings for potential match */
193 while (k2)
194 {
195 uint64_t bitcount = _tzcnt_u64 (k2);
196 k2 = _blsr_u64 (k2);
197 size_t match_pos = hay_index + bitcount - edge;
198 if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1))
199 < PAGESIZE - 1 - ZMM_SIZE_IN_BYTES)
200 {
201 /*
202 * Use vector compare as long as you are not crossing a page
203 */
204 if (verify_string_match_avx512 (hay: haystack, hay_index: match_pos, ned,
205 ned_mask: ned_load_mask, ned_zmm))
206 return (char *)haystack + match_pos;
207 }
208 else
209 {
210 /* Compare byte by byte */
211 if (verify_string_match (hay: haystack, hay_index: match_pos, ned, ind: 0))
212 return (char *)haystack + match_pos;
213 }
214 }
215 hay_index += ZMM_SIZE_IN_BYTES;
216 }
217 return NULL;
218}
219

source code of glibc/sysdeps/x86_64/multiarch/strstr-avx512.c