summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc-Andre Lafortune <github@marc-andre.ca>2020-07-29 16:59:06 -0400
committerMarc-Andre Lafortune <github@marc-andre.ca>2020-07-30 09:53:42 -0400
commit1b1ea7b3bc9484e6e59d716fce2965a2f39d1e3d (patch)
tree8450c5350a8773ff0cbda3c5fbfedc84a063f031
parent2bd1f827f14e06575e128a5e4928cee79592e61b (diff)
Fix Array#flatten for recursive array when given positive depth [Bug #17092]
-rw-r--r--array.c44
-rw-r--r--test/ruby/test_array.rb14
2 files changed, 37 insertions, 21 deletions
diff --git a/array.c b/array.c
index 345ee327f9..31acc579fa 100644
--- a/array.c
+++ b/array.c
@@ -6943,8 +6943,6 @@ flatten(VALUE ary, int level)
}
if (i == RARRAY_LEN(ary)) {
return ary;
- } else if (tmp == ary) {
- rb_raise(rb_eArgError, "tried to flatten recursive array");
}
result = ary_new(0, RARRAY_LEN(ary));
@@ -6955,12 +6953,14 @@ flatten(VALUE ary, int level)
rb_ary_push(stack, ary);
rb_ary_push(stack, LONG2NUM(i + 1));
- vmemo = rb_hash_new();
- RBASIC_CLEAR_CLASS(vmemo);
- memo = st_init_numtable();
- rb_hash_st_table_set(vmemo, memo);
- st_insert(memo, (st_data_t)ary, (st_data_t)Qtrue);
- st_insert(memo, (st_data_t)tmp, (st_data_t)Qtrue);
+ if (level < 0) {
+ vmemo = rb_hash_new();
+ RBASIC_CLEAR_CLASS(vmemo);
+ memo = st_init_numtable();
+ rb_hash_st_table_set(vmemo, memo);
+ st_insert(memo, (st_data_t)ary, (st_data_t)Qtrue);
+ st_insert(memo, (st_data_t)tmp, (st_data_t)Qtrue);
+ }
ary = tmp;
i = 0;
@@ -6974,20 +6974,24 @@ flatten(VALUE ary, int level)
}
tmp = rb_check_array_type(elt);
if (RBASIC(result)->klass) {
- RB_GC_GUARD(vmemo);
- st_clear(memo);
+ if (level < 0) {
+ RB_GC_GUARD(vmemo);
+ st_clear(memo);
+ }
rb_raise(rb_eRuntimeError, "flatten reentered");
}
if (NIL_P(tmp)) {
rb_ary_push(result, elt);
}
else {
- id = (st_data_t)tmp;
- if (st_is_member(memo, id)) {
- st_clear(memo);
- rb_raise(rb_eArgError, "tried to flatten recursive array");
+ if (level < 0) {
+ id = (st_data_t)tmp;
+ if (st_is_member(memo, id)) {
+ st_clear(memo);
+ rb_raise(rb_eArgError, "tried to flatten recursive array");
+ }
+ st_insert(memo, id, (st_data_t)Qtrue);
}
- st_insert(memo, id, (st_data_t)Qtrue);
rb_ary_push(stack, ary);
rb_ary_push(stack, LONG2NUM(i));
ary = tmp;
@@ -6997,14 +7001,18 @@ flatten(VALUE ary, int level)
if (RARRAY_LEN(stack) == 0) {
break;
}
- id = (st_data_t)ary;
- st_delete(memo, &id, 0);
+ if (level < 0) {
+ id = (st_data_t)ary;
+ st_delete(memo, &id, 0);
+ }
tmp = rb_ary_pop(stack);
i = NUM2LONG(tmp);
ary = rb_ary_pop(stack);
}
- st_clear(memo);
+ if (level < 0) {
+ st_clear(memo);
+ }
RBASIC_SET_CLASS(result, rb_obj_class(ary));
return result;
diff --git a/test/ruby/test_array.rb b/test/ruby/test_array.rb
index 46de8e08fc..9e36e74e71 100644
--- a/test/ruby/test_array.rb
+++ b/test/ruby/test_array.rb
@@ -886,6 +886,17 @@ class TestArray < Test::Unit::TestCase
assert_raise(NoMethodError, bug12738) { a.flatten.m }
end
+ def test_flatten_recursive
+ a = []
+ a << a
+ assert_raise(ArgumentError) { a.flatten }
+ b = [1]; c = [2, b]; b << c
+ assert_raise(ArgumentError) { b.flatten }
+
+ assert_equal([1, 2, b], b.flatten(1))
+ assert_equal([1, 2, 1, 2, 1, c], b.flatten(4))
+ end
+
def test_flatten!
a1 = @cls[ 1, 2, 3]
a2 = @cls[ 5, 6 ]
@@ -2649,9 +2660,6 @@ class TestArray < Test::Unit::TestCase
def test_flatten_error
a = []
- a << a
- assert_raise(ArgumentError) { a.flatten }
-
f = [].freeze
assert_raise(ArgumentError) { a.flatten!(1, 2) }
assert_raise(TypeError) { a.flatten!(:foo) }