summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ext/strscan/strscan.c25
-rw-r--r--test/strscan/test_stringscanner.rb80
2 files changed, 89 insertions, 16 deletions
diff --git a/ext/strscan/strscan.c b/ext/strscan/strscan.c
index 606c44bc96..e272f92249 100644
--- a/ext/strscan/strscan.c
+++ b/ext/strscan/strscan.c
@@ -686,14 +686,6 @@ strscan_do_scan(VALUE self, VALUE pattern, int succptr, int getstr, int headonly
{
struct strscanner *p;
- if (headonly) {
- if (!RB_TYPE_P(pattern, T_REGEXP)) {
- StringValue(pattern);
- }
- }
- else {
- Check_Type(pattern, T_REGEXP);
- }
GET_SCANNER(self, p);
CLEAR_MATCH_STATUS(p);
@@ -714,14 +706,25 @@ strscan_do_scan(VALUE self, VALUE pattern, int succptr, int getstr, int headonly
}
}
else {
+ StringValue(pattern);
rb_enc_check(p->str, pattern);
if (S_RESTLEN(p) < RSTRING_LEN(pattern)) {
return Qnil;
}
- if (memcmp(CURPTR(p), RSTRING_PTR(pattern), RSTRING_LEN(pattern)) != 0) {
- return Qnil;
+
+ if (headonly) {
+ if (memcmp(CURPTR(p), RSTRING_PTR(pattern), RSTRING_LEN(pattern)) != 0) {
+ return Qnil;
+ }
+ set_registers(p, RSTRING_LEN(pattern));
+ } else {
+ long pos = rb_memsearch(RSTRING_PTR(pattern), RSTRING_LEN(pattern),
+ CURPTR(p), S_RESTLEN(p), rb_enc_get(pattern));
+ if (pos == -1) {
+ return Qnil;
+ }
+ set_registers(p, RSTRING_LEN(pattern) + pos);
}
- set_registers(p, RSTRING_LEN(pattern));
}
MATCHED(p);
diff --git a/test/strscan/test_stringscanner.rb b/test/strscan/test_stringscanner.rb
index 143cf7197d..9b7b7910d0 100644
--- a/test/strscan/test_stringscanner.rb
+++ b/test/strscan/test_stringscanner.rb
@@ -262,7 +262,7 @@ module StringScannerTests
end
def test_scan
- s = create_string_scanner('stra strb strc', true)
+ s = create_string_scanner("stra strb\0strc", true)
tmp = s.scan(/\w+/)
assert_equal 'stra', tmp
@@ -270,7 +270,7 @@ module StringScannerTests
assert_equal ' ', tmp
assert_equal 'strb', s.scan(/\w+/)
- assert_equal ' ', s.scan(/\s+/)
+ assert_equal "\u0000", s.scan(/\0/)
tmp = s.scan(/\w+/)
assert_equal 'strc', tmp
@@ -312,11 +312,14 @@ module StringScannerTests
end
def test_scan_string
- s = create_string_scanner('stra strb strc')
+ s = create_string_scanner("stra strb\0strc")
assert_equal 'str', s.scan('str')
assert_equal 'str', s[0]
assert_equal 3, s.pos
assert_equal 'a ', s.scan('a ')
+ assert_equal 'strb', s.scan('strb')
+ assert_equal "\u0000", s.scan("\0")
+ assert_equal 'strc', s.scan('strc')
str = 'stra strb strc'.dup
s = create_string_scanner(str, false)
@@ -668,13 +671,47 @@ module StringScannerTests
assert_equal(nil, s.exist?(/e/))
end
- def test_exist_p_string
+ def test_exist_p_invalid_argument
s = create_string_scanner("test string")
assert_raise(TypeError) do
- s.exist?(" ")
+ s.exist?(1)
end
end
+ def test_exist_p_string
+ omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby"
+ s = create_string_scanner("test string")
+ assert_equal(3, s.exist?("s"))
+ assert_equal(0, s.pos)
+ s.scan("test")
+ assert_equal(2, s.exist?("s"))
+ assert_equal(4, s.pos)
+ assert_equal(nil, s.exist?("e"))
+ end
+
+ def test_scan_until
+ s = create_string_scanner("Foo Bar\0Baz")
+ assert_equal("Foo", s.scan_until(/Foo/))
+ assert_equal(3, s.pos)
+ assert_equal(" Bar", s.scan_until(/Bar/))
+ assert_equal(7, s.pos)
+ assert_equal(nil, s.skip_until(/Qux/))
+ assert_equal("\u0000Baz", s.scan_until(/Baz/))
+ assert_equal(11, s.pos)
+ end
+
+ def test_scan_until_string
+ omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby"
+ s = create_string_scanner("Foo Bar\0Baz")
+ assert_equal("Foo", s.scan_until("Foo"))
+ assert_equal(3, s.pos)
+ assert_equal(" Bar", s.scan_until("Bar"))
+ assert_equal(7, s.pos)
+ assert_equal(nil, s.skip_until("Qux"))
+ assert_equal("\u0000Baz", s.scan_until("Baz"))
+ assert_equal(11, s.pos)
+ end
+
def test_skip_until
s = create_string_scanner("Foo Bar Baz")
assert_equal(3, s.skip_until(/Foo/))
@@ -684,6 +721,16 @@ module StringScannerTests
assert_equal(nil, s.skip_until(/Qux/))
end
+ def test_skip_until_string
+ omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby"
+ s = create_string_scanner("Foo Bar Baz")
+ assert_equal(3, s.skip_until("Foo"))
+ assert_equal(3, s.pos)
+ assert_equal(4, s.skip_until("Bar"))
+ assert_equal(7, s.pos)
+ assert_equal(nil, s.skip_until("Qux"))
+ end
+
def test_check_until
s = create_string_scanner("Foo Bar Baz")
assert_equal("Foo", s.check_until(/Foo/))
@@ -693,6 +740,16 @@ module StringScannerTests
assert_equal(nil, s.check_until(/Qux/))
end
+ def test_check_until_string
+ omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby"
+ s = create_string_scanner("Foo Bar Baz")
+ assert_equal("Foo", s.check_until("Foo"))
+ assert_equal(0, s.pos)
+ assert_equal("Foo Bar", s.check_until("Bar"))
+ assert_equal(0, s.pos)
+ assert_equal(nil, s.check_until("Qux"))
+ end
+
def test_search_full
s = create_string_scanner("Foo Bar Baz")
assert_equal(8, s.search_full(/Bar /, false, false))
@@ -705,6 +762,19 @@ module StringScannerTests
assert_equal(11, s.pos)
end
+ def test_search_full_string
+ omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby"
+ s = create_string_scanner("Foo Bar Baz")
+ assert_equal(8, s.search_full("Bar ", false, false))
+ assert_equal(0, s.pos)
+ assert_equal("Foo Bar ", s.search_full("Bar ", false, true))
+ assert_equal(0, s.pos)
+ assert_equal(8, s.search_full("Bar ", true, false))
+ assert_equal(8, s.pos)
+ assert_equal("Baz", s.search_full("az", true, true))
+ assert_equal(11, s.pos)
+ end
+
def test_peek
s = create_string_scanner("test string")
assert_equal("test st", s.peek(7))