summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ChangeLog7
-rw-r--r--array.c64
-rw-r--r--test/ruby/test_array.rb56
3 files changed, 108 insertions, 19 deletions
diff --git a/ChangeLog b/ChangeLog
index c7853b63b8..d5474793aa 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,10 @@
+Fri Aug 27 07:57:34 2010 Nobuyoshi Nakada <nobu@ruby-lang.org>
+
+ * array.c (rb_ary_shuffle_bang): bail out from modification during
+ shuffle.
+
+ * array.c (rb_ary_sample): ditto.
+
Fri Aug 27 05:11:51 2010 Tanaka Akira <akr@fsij.org>
* ext/pathname/pathname.c (path_sysopen): Pathname#sysopen translated
diff --git a/array.c b/array.c
index 0e51bb0b98..8a828c506c 100644
--- a/array.c
+++ b/array.c
@@ -20,6 +20,8 @@
#endif
#include <assert.h>
+#define numberof(array) (int)(sizeof(array) / sizeof((array)[0]))
+
VALUE rb_cArray;
static ID id_cmp;
@@ -3748,8 +3750,8 @@ static VALUE sym_random;
static VALUE
rb_ary_shuffle_bang(int argc, VALUE *argv, VALUE ary)
{
- VALUE *ptr, opts, randgen = rb_cRandom;
- long i = RARRAY_LEN(ary);
+ VALUE *ptr, opts, *snap_ptr, randgen = rb_cRandom;
+ long i, snap_len;
if (OPTHASH_GIVEN_P(opts)) {
randgen = rb_hash_lookup2(opts, sym_random, randgen);
@@ -3758,10 +3760,17 @@ rb_ary_shuffle_bang(int argc, VALUE *argv, VALUE ary)
rb_raise(rb_eArgError, "wrong number of arguments (%d for 0)", argc);
}
rb_ary_modify(ary);
+ i = RARRAY_LEN(ary);
ptr = RARRAY_PTR(ary);
+ snap_len = i;
+ snap_ptr = ptr;
while (i) {
long j = RAND_UPTO(i);
- VALUE tmp = ptr[--i];
+ VALUE tmp;
+ if (snap_len != RARRAY_LEN(ary) || snap_ptr != RARRAY_PTR(ary)) {
+ rb_raise(rb_eRuntimeError, "modified during shuffle");
+ }
+ tmp = ptr[--i];
ptr[i] = ptr[j];
ptr[j] = tmp;
}
@@ -3814,37 +3823,54 @@ static VALUE
rb_ary_sample(int argc, VALUE *argv, VALUE ary)
{
VALUE nv, result, *ptr;
- VALUE opts, randgen = rb_cRandom;
+ VALUE opts, snap, randgen = rb_cRandom;
long n, len, i, j, k, idx[10];
+ double rnds[numberof(idx)];
- len = RARRAY_LEN(ary);
if (OPTHASH_GIVEN_P(opts)) {
randgen = rb_hash_lookup2(opts, sym_random, randgen);
}
+ ptr = RARRAY_PTR(ary);
+ len = RARRAY_LEN(ary);
if (argc == 0) {
if (len == 0) return Qnil;
- i = len == 1 ? 0 : RAND_UPTO(len);
+ if (len == 1) {
+ i = 0;
+ }
+ else {
+ double x = rb_random_real(randgen);
+ if ((len = RARRAY_LEN(ary)) == 0) return Qnil;
+ i = (long)(x * len);
+ }
return RARRAY_PTR(ary)[i];
}
rb_scan_args(argc, argv, "1", &nv);
n = NUM2LONG(nv);
if (n < 0) rb_raise(rb_eArgError, "negative sample number");
- ptr = RARRAY_PTR(ary);
+ if (n > len) n = len;
+ if (n <= numberof(idx)) {
+ for (i = 0; i < n; ++i) {
+ rnds[i] = rb_random_real(randgen);
+ }
+ }
len = RARRAY_LEN(ary);
+ ptr = RARRAY_PTR(ary);
if (n > len) n = len;
switch (n) {
- case 0: return rb_ary_new2(0);
+ case 0:
+ return rb_ary_new2(0);
case 1:
- return rb_ary_new4(1, &ptr[RAND_UPTO(len)]);
+ i = (long)(rnds[0] * len);
+ return rb_ary_new4(1, &ptr[i]);
case 2:
- i = RAND_UPTO(len);
- j = RAND_UPTO(len-1);
+ i = (long)(rnds[0] * len);
+ j = (long)(rnds[1] * (len-1));
if (j >= i) j++;
return rb_ary_new3(2, ptr[i], ptr[j]);
case 3:
- i = RAND_UPTO(len);
- j = RAND_UPTO(len-1);
- k = RAND_UPTO(len-2);
+ i = (long)(rnds[0] * len);
+ j = (long)(rnds[1] * (len-1));
+ k = (long)(rnds[2] * (len-2));
{
long l = j, g = i;
if (j >= i) l = i, g = ++j;
@@ -3852,12 +3878,12 @@ rb_ary_sample(int argc, VALUE *argv, VALUE ary)
}
return rb_ary_new3(3, ptr[i], ptr[j], ptr[k]);
}
- if ((size_t)n < sizeof(idx)/sizeof(idx[0])) {
+ if (n <= numberof(idx)) {
VALUE *ptr_result;
- long sorted[sizeof(idx)/sizeof(idx[0])];
- sorted[0] = idx[0] = RAND_UPTO(len);
+ long sorted[numberof(idx)];
+ sorted[0] = idx[0] = (long)(rnds[0] * len);
for (i=1; i<n; i++) {
- k = RAND_UPTO(--len);
+ k = (long)(rnds[i] * --len);
for (j = 0; j < i; ++j) {
if (k < sorted[j]) break;
++k;
@@ -3874,6 +3900,7 @@ rb_ary_sample(int argc, VALUE *argv, VALUE ary)
else {
VALUE *ptr_result;
result = rb_ary_new4(len, ptr);
+ RBASIC(result)->klass = 0;
ptr_result = RARRAY_PTR(result);
RB_GC_GUARD(ary);
for (i=0; i<n; i++) {
@@ -3882,6 +3909,7 @@ rb_ary_sample(int argc, VALUE *argv, VALUE ary)
ptr_result[j] = ptr_result[i];
ptr_result[i] = nv;
}
+ RBASIC(result)->klass = rb_cArray;
}
ARY_SET_LEN(result, n);
diff --git a/test/ruby/test_array.rb b/test/ruby/test_array.rb
index 4c3aba0589..44f71d3495 100644
--- a/test/ruby/test_array.rb
+++ b/test/ruby/test_array.rb
@@ -1901,7 +1901,6 @@ class TestArray < Test::Unit::TestCase
end
def test_shuffle_random
- cc = nil
gen = proc do
10000000
end
@@ -1911,6 +1910,16 @@ class TestArray < Test::Unit::TestCase
assert_raise(RangeError) {
[*0..2].shuffle(random: gen)
}
+
+ ary = (0...10000).to_a
+ gen = proc do
+ ary.replace([])
+ 0.5
+ end
+ class << gen
+ alias rand call
+ end
+ assert_raise(RuntimeError) {ary.shuffle!(random: gen)}
end
def test_sample
@@ -1951,6 +1960,51 @@ class TestArray < Test::Unit::TestCase
end
end
+ def test_sample_random
+ ary = (0...10000).to_a
+ assert_raise(ArgumentError) {ary.sample(1, 2, random: nil)}
+ gen0 = proc do
+ 0.5
+ end
+ class << gen0
+ alias rand call
+ end
+ gen1 = proc do
+ ary.replace([])
+ 0.5
+ end
+ class << gen1
+ alias rand call
+ end
+ assert_equal(5000, ary.sample(random: gen0))
+ assert_nil(ary.sample(random: gen1))
+ assert_equal([], ary)
+ ary = (0...10000).to_a
+ assert_equal([5000], ary.sample(1, random: gen0))
+ assert_equal([], ary.sample(1, random: gen1))
+ assert_equal([], ary)
+ ary = (0...10000).to_a
+ assert_equal([5000, 4999], ary.sample(2, random: gen0))
+ assert_equal([], ary.sample(2, random: gen1))
+ assert_equal([], ary)
+ ary = (0...10000).to_a
+ assert_equal([5000, 4999, 5001], ary.sample(3, random: gen0))
+ assert_equal([], ary.sample(3, random: gen1))
+ assert_equal([], ary)
+ ary = (0...10000).to_a
+ assert_equal([5000, 4999, 5001, 4998], ary.sample(4, random: gen0))
+ assert_equal([], ary.sample(4, random: gen1))
+ assert_equal([], ary)
+ ary = (0...10000).to_a
+ assert_equal([5000, 4999, 5001, 4998, 5002, 4997, 5003, 4996, 5004, 4995], ary.sample(10, random: gen0))
+ assert_equal([], ary.sample(10, random: gen1))
+ assert_equal([], ary)
+ ary = (0...10000).to_a
+ assert_equal([5000, 0, 5001, 2, 5002, 4, 5003, 6, 5004, 8, 5005], ary.sample(11, random: gen0))
+ ary.sample(11, random: gen1) # implementation detail, may change in the future
+ assert_equal([], ary)
+ end
+
def test_cycle
a = []
[0, 1, 2].cycle do |i|