1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
|
#include "prism/internal/strpbrk.h"
#include "prism/compiler/accel.h"
#include "prism/compiler/inline.h"
#include "prism/compiler/unused.h"
#include "prism/internal/bit.h"
#include "prism/internal/diagnostic.h"
#include "prism/internal/encoding.h"
#include "prism/internal/parser.h"
#include <assert.h>
#include <stdbool.h>
#include <string.h>
/**
* Add an invalid multibyte character error to the parser.
*/
static PRISM_INLINE void
pm_strpbrk_invalid_multibyte_character(pm_parser_t *parser, uint32_t start, uint32_t length) {
pm_diagnostic_list_append_format(&parser->metadata_arena, &parser->error_list, start, length, PM_ERR_INVALID_MULTIBYTE_CHARACTER, parser->start[start]);
}
/**
* Set the explicit encoding for the parser to the current encoding.
*/
static PRISM_INLINE void
pm_strpbrk_explicit_encoding_set(pm_parser_t *parser, uint32_t start, uint32_t length) {
if (parser->explicit_encoding != NULL) {
if (parser->explicit_encoding == parser->encoding) {
// Okay, we already locked to this encoding.
} else if (parser->explicit_encoding == PM_ENCODING_UTF_8_ENTRY) {
// Not okay, we already found a Unicode escape sequence and this
// conflicts.
pm_diagnostic_list_append_format(&parser->metadata_arena, &parser->error_list, start, length, PM_ERR_MIXED_ENCODING, parser->encoding->name);
} else {
// Should not be anything else.
assert(false && "unreachable");
}
}
parser->explicit_encoding = parser->encoding;
}
/**
* Scan forward through ASCII bytes looking for a byte that is in the given
* charset. Returns true if a match was found, storing its offset in *index.
* Returns false if no match was found, storing the number of ASCII bytes
* consumed in *index (so the caller can skip past them).
*
* All charset characters must be ASCII (< 0x80). The scanner stops at non-ASCII
* bytes, returning control to the caller's encoding-aware loop.
*
* Up to three optimized implementations are selected at compile time, with a
* no-op fallback for unsupported platforms:
* 1. NEON — processes 16 bytes per iteration on aarch64.
* 2. SSSE3 — processes 16 bytes per iteration on x86-64.
* 3. SWAR — little-endian fallback, processes 8 bytes per iteration.
*/
#if defined(PRISM_HAS_NEON) || defined(PRISM_HAS_SSSE3) || defined(PRISM_HAS_SWAR)
/**
* Update the cached strpbrk lookup tables if the charset has changed. The
* parser caches the last charset's precomputed tables so that repeated calls
* with the same breakpoints (the common case during string/regex/list lexing)
* skip table construction entirely.
*
* Builds three structures:
* - low_lut/high_lut: nibble-based lookup tables for SIMD matching (NEON/SSSE3)
* - table: 256-bit bitmap for scalar fallback matching (all platforms)
*/
static PRISM_INLINE void
pm_strpbrk_cache_update(pm_parser_t *parser, const uint8_t *charset) {
// The cache key is the full charset buffer (PM_STRPBRK_CACHE_SIZE bytes).
// Since it is always NUL-padded, a fixed-size comparison covers both
// content and length.
if (memcmp(parser->strpbrk_cache.charset, charset, sizeof(parser->strpbrk_cache.charset)) == 0) return;
memset(parser->strpbrk_cache.low_lut, 0, sizeof(parser->strpbrk_cache.low_lut));
memset(parser->strpbrk_cache.high_lut, 0, sizeof(parser->strpbrk_cache.high_lut));
memset(parser->strpbrk_cache.table, 0, sizeof(parser->strpbrk_cache.table));
// Always include NUL in the tables. The slow path uses strchr, which
// always matches NUL (it finds the C string terminator), so NUL is
// effectively always a breakpoint. Replicating that here lets the fast
// scanner handle NUL at full speed instead of bailing to the slow path.
parser->strpbrk_cache.low_lut[0x00] |= (uint8_t) (1 << 0);
parser->strpbrk_cache.high_lut[0x00] = (uint8_t) (1 << 0);
parser->strpbrk_cache.table[0] |= (uint64_t) 1;
size_t charset_len = 0;
for (const uint8_t *c = charset; *c != '\0'; c++) {
parser->strpbrk_cache.low_lut[*c & 0x0F] |= (uint8_t) (1 << (*c >> 4));
parser->strpbrk_cache.high_lut[*c >> 4] = (uint8_t) (1 << (*c >> 4));
parser->strpbrk_cache.table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F);
charset_len++;
}
// Store the new charset key, NUL-padded to the full buffer size.
memcpy(parser->strpbrk_cache.charset, charset, charset_len + 1);
memset(parser->strpbrk_cache.charset + charset_len + 1, 0, sizeof(parser->strpbrk_cache.charset) - charset_len - 1);
}
#endif
#if defined(PRISM_HAS_NEON)
#include <arm_neon.h>
static PRISM_INLINE bool
scan_strpbrk_ascii(pm_parser_t *parser, const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) {
pm_strpbrk_cache_update(parser, charset);
uint8x16_t low_lut = vld1q_u8(parser->strpbrk_cache.low_lut);
uint8x16_t high_lut = vld1q_u8(parser->strpbrk_cache.high_lut);
uint8x16_t mask_0f = vdupq_n_u8(0x0F);
uint8x16_t mask_80 = vdupq_n_u8(0x80);
size_t idx = 0;
while (idx + 16 <= maximum) {
uint8x16_t v = vld1q_u8(source + idx);
// If any byte has the high bit set, we have non-ASCII data.
// Return to let the caller's encoding-aware loop handle it.
if (vmaxvq_u8(vandq_u8(v, mask_80)) != 0) break;
uint8x16_t lo_class = vqtbl1q_u8(low_lut, vandq_u8(v, mask_0f));
uint8x16_t hi_class = vqtbl1q_u8(high_lut, vshrq_n_u8(v, 4));
uint8x16_t matched = vtstq_u8(lo_class, hi_class);
if (vmaxvq_u8(matched) == 0) {
idx += 16;
continue;
}
// Find the position of the first matching byte.
uint64_t lo64 = vgetq_lane_u64(vreinterpretq_u64_u8(matched), 0);
if (lo64 != 0) {
*index = idx + pm_ctzll(lo64) / 8;
return true;
}
uint64_t hi64 = vgetq_lane_u64(vreinterpretq_u64_u8(matched), 1);
*index = idx + 8 + pm_ctzll(hi64) / 8;
return true;
}
// Scalar tail for remaining < 16 ASCII bytes.
while (idx < maximum && source[idx] < 0x80) {
uint8_t byte = source[idx];
if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
*index = idx;
return true;
}
idx++;
}
*index = idx;
return false;
}
#elif defined(PRISM_HAS_SSSE3)
#include <tmmintrin.h>
static PRISM_INLINE bool
scan_strpbrk_ascii(pm_parser_t *parser, const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) {
pm_strpbrk_cache_update(parser, charset);
__m128i low_lut = _mm_loadu_si128((const __m128i *) parser->strpbrk_cache.low_lut);
__m128i high_lut = _mm_loadu_si128((const __m128i *) parser->strpbrk_cache.high_lut);
__m128i mask_0f = _mm_set1_epi8(0x0F);
size_t idx = 0;
while (idx + 16 <= maximum) {
__m128i v = _mm_loadu_si128((const __m128i *) (source + idx));
// If any byte has the high bit set, stop.
if (_mm_movemask_epi8(v) != 0) break;
// Nibble-based classification using pshufb (SSSE3), same as NEON
// vqtbl1q_u8. A byte matches iff (low_lut[lo_nib] & high_lut[hi_nib]) != 0.
__m128i lo_class = _mm_shuffle_epi8(low_lut, _mm_and_si128(v, mask_0f));
__m128i hi_class = _mm_shuffle_epi8(high_lut, _mm_and_si128(_mm_srli_epi16(v, 4), mask_0f));
__m128i matched = _mm_and_si128(lo_class, hi_class);
// Check if any byte matched.
int mask = _mm_movemask_epi8(_mm_cmpeq_epi8(matched, _mm_setzero_si128()));
if (mask == 0xFFFF) {
// All bytes were zero — no match in this chunk.
idx += 16;
continue;
}
// Find the first matching byte (first non-zero in matched).
*index = idx + pm_ctzll((uint64_t) (~mask & 0xFFFF));
return true;
}
// Scalar tail.
while (idx < maximum && source[idx] < 0x80) {
uint8_t byte = source[idx];
if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
*index = idx;
return true;
}
idx++;
}
*index = idx;
return false;
}
#elif defined(PRISM_HAS_SWAR)
static PRISM_INLINE bool
scan_strpbrk_ascii(pm_parser_t *parser, const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) {
pm_strpbrk_cache_update(parser, charset);
static const uint64_t highs = 0x8080808080808080ULL;
size_t idx = 0;
while (idx + 8 <= maximum) {
uint64_t word;
memcpy(&word, source + idx, 8);
// Bail on any non-ASCII byte.
if (word & highs) break;
// Check each byte against the charset table.
for (size_t j = 0; j < 8; j++) {
uint8_t byte = source[idx + j];
if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
*index = idx + j;
return true;
}
}
idx += 8;
}
// Scalar tail.
while (idx < maximum && source[idx] < 0x80) {
uint8_t byte = source[idx];
if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
*index = idx;
return true;
}
idx++;
}
*index = idx;
return false;
}
#else
static PRISM_INLINE bool
scan_strpbrk_ascii(PRISM_UNUSED pm_parser_t *parser, PRISM_UNUSED const uint8_t *source, PRISM_UNUSED size_t maximum, PRISM_UNUSED const uint8_t *charset, size_t *index) {
*index = 0;
return false;
}
#endif
/**
* This is the default path.
*/
static PRISM_INLINE const uint8_t *
pm_strpbrk_utf8(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) {
while (index < maximum) {
if (strchr((const char *) charset, source[index]) != NULL) {
return source + index;
}
if (source[index] < 0x80) {
index++;
} else {
size_t width = pm_encoding_utf_8_char_width(source + index, (ptrdiff_t) (maximum - index));
if (width > 0) {
index += width;
} else if (!validate) {
index++;
} else {
// At this point we know we have an invalid multibyte character.
// We'll walk forward as far as we can until we find the next
// valid character so that we don't spam the user with a ton of
// the same kind of error.
const size_t start = index;
do {
index++;
} while (index < maximum && pm_encoding_utf_8_char_width(source + index, (ptrdiff_t) (maximum - index)) == 0);
pm_strpbrk_invalid_multibyte_character(parser, (uint32_t) ((source + start) - parser->start), (uint32_t) (index - start));
}
}
}
return NULL;
}
/**
* This is the path when the encoding is ASCII-8BIT.
*/
static PRISM_INLINE const uint8_t *
pm_strpbrk_ascii_8bit(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) {
while (index < maximum) {
if (strchr((const char *) charset, source[index]) != NULL) {
return source + index;
}
if (validate && source[index] >= 0x80) pm_strpbrk_explicit_encoding_set(parser, (uint32_t) (source - parser->start), 1);
index++;
}
return NULL;
}
/**
* This is the slow path that does care about the encoding.
*/
static PRISM_INLINE const uint8_t *
pm_strpbrk_multi_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) {
const pm_encoding_t *encoding = parser->encoding;
while (index < maximum) {
if (strchr((const char *) charset, source[index]) != NULL) {
return source + index;
}
if (source[index] < 0x80) {
index++;
} else {
size_t width = encoding->char_width(source + index, (ptrdiff_t) (maximum - index));
if (validate) pm_strpbrk_explicit_encoding_set(parser, (uint32_t) (source - parser->start), (uint32_t) width);
if (width > 0) {
index += width;
} else if (!validate) {
index++;
} else {
// At this point we know we have an invalid multibyte character.
// We'll walk forward as far as we can until we find the next
// valid character so that we don't spam the user with a ton of
// the same kind of error.
const size_t start = index;
do {
index++;
} while (index < maximum && encoding->char_width(source + index, (ptrdiff_t) (maximum - index)) == 0);
pm_strpbrk_invalid_multibyte_character(parser, (uint32_t) ((source + start) - parser->start), (uint32_t) (index - start));
}
}
}
return NULL;
}
/**
* This is the fast path that does not care about the encoding because we know
* the encoding only supports single-byte characters.
*/
static PRISM_INLINE const uint8_t *
pm_strpbrk_single_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) {
const pm_encoding_t *encoding = parser->encoding;
while (index < maximum) {
if (strchr((const char *) charset, source[index]) != NULL) {
return source + index;
}
if (source[index] < 0x80 || !validate) {
index++;
} else {
size_t width = encoding->char_width(source + index, (ptrdiff_t) (maximum - index));
pm_strpbrk_explicit_encoding_set(parser, (uint32_t) (source - parser->start), (uint32_t) width);
if (width > 0) {
index += width;
} else {
// At this point we know we have an invalid multibyte character.
// We'll walk forward as far as we can until we find the next
// valid character so that we don't spam the user with a ton of
// the same kind of error.
const size_t start = index;
do {
index++;
} while (index < maximum && encoding->char_width(source + index, (ptrdiff_t) (maximum - index)) == 0);
pm_strpbrk_invalid_multibyte_character(parser, (uint32_t) ((source + start) - parser->start), (uint32_t) (index - start));
}
}
}
return NULL;
}
/**
* Here we have rolled our own version of strpbrk. The standard library strpbrk
* has undefined behavior when the source string is not null-terminated. We want
* to support strings that are not null-terminated because pm_parse does not
* have the contract that the string is null-terminated. (This is desirable
* because it means the extension can call pm_parse with the result of a call to
* mmap).
*
* The standard library strpbrk also does not support passing a maximum length
* to search. We want to support this for the reason mentioned above, but we
* also don't want it to stop on null bytes. Ruby actually allows null bytes
* within strings, comments, regular expressions, etc. So we need to be able to
* skip past them.
*
* Finally, we want to support encodings wherein the charset could contain
* characters that are trailing bytes of multi-byte characters. For example, in
* Shift_JIS, the backslash character can be a trailing byte. In that case we
* need to take a slower path and iterate one multi-byte character at a time.
*/
const uint8_t *
pm_strpbrk(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, ptrdiff_t length, bool validate) {
if (length <= 0) return NULL;
size_t maximum = (size_t) length;
size_t index = 0;
if (scan_strpbrk_ascii(parser, source, maximum, charset, &index)) return source + index;
if (!parser->encoding_changed) {
return pm_strpbrk_utf8(parser, source, charset, index, maximum, validate);
} else if (parser->encoding == PM_ENCODING_ASCII_8BIT_ENTRY) {
return pm_strpbrk_ascii_8bit(parser, source, charset, index, maximum, validate);
} else if (parser->encoding->multibyte) {
return pm_strpbrk_multi_byte(parser, source, charset, index, maximum, validate);
} else {
return pm_strpbrk_single_byte(parser, source, charset, index, maximum, validate);
}
}
|