summaryrefslogtreecommitdiff
path: root/array.c
diff options
context:
space:
mode:
authorDylan Thacker-Smith <dylan.smith@shopify.com>2019-09-27 12:24:25 -0400
committerNobuyoshi Nakada <nobu@ruby-lang.org>2019-09-28 01:24:24 +0900
commita1fda16b238f24cf55814ecc18f716cbfff8dd91 (patch)
treebcd7856a7f56f0a9629998abb757c7f1daf17ea7 /array.c
parent869e2dd8c8efc1e7a043c9eee82d97c47befbcc7 (diff)
Optimize Array#flatten and flatten! for already flattened arrays (#2495)
* Optimize Array#flatten and flatten! for already flattened arrays * Add benchmark for Array#flatten and Array#flatten! [Bug #16119]
Diffstat (limited to 'array.c')
-rw-r--r--array.c43
1 files changed, 33 insertions, 10 deletions
diff --git a/array.c b/array.c
index 825d9f7126..37456147b1 100644
--- a/array.c
+++ b/array.c
@@ -5122,21 +5122,43 @@ rb_ary_count(int argc, VALUE *argv, VALUE ary)
}
static VALUE
-flatten(VALUE ary, int level, int *modified)
+flatten(VALUE ary, int level)
{
- long i = 0;
+ long i;
VALUE stack, result, tmp, elt, vmemo;
st_table *memo;
st_data_t id;
- stack = ary_new(0, ARY_DEFAULT_SIZE);
+ for (i = 0; i < RARRAY_LEN(ary); i++) {
+ elt = RARRAY_AREF(ary, i);
+ tmp = rb_check_array_type(elt);
+ if (!NIL_P(tmp)) {
+ break;
+ }
+ }
+ if (i == RARRAY_LEN(ary)) {
+ return ary;
+ } else if (tmp == ary) {
+ rb_raise(rb_eArgError, "tried to flatten recursive array");
+ }
+
result = ary_new(0, RARRAY_LEN(ary));
+ ary_memcpy(result, 0, i, RARRAY_CONST_PTR_TRANSIENT(ary));
+ ARY_SET_LEN(result, i);
+
+ stack = ary_new(0, ARY_DEFAULT_SIZE);
+ rb_ary_push(stack, ary);
+ rb_ary_push(stack, LONG2NUM(i + 1));
+
vmemo = rb_hash_new();
RBASIC_CLEAR_CLASS(vmemo);
memo = st_init_numtable();
rb_hash_st_table_set(vmemo, memo);
st_insert(memo, (st_data_t)ary, (st_data_t)Qtrue);
- *modified = 0;
+ st_insert(memo, (st_data_t)tmp, (st_data_t)Qtrue);
+
+ ary = tmp;
+ i = 0;
while (1) {
while (i < RARRAY_LEN(ary)) {
@@ -5155,7 +5177,6 @@ flatten(VALUE ary, int level, int *modified)
rb_ary_push(result, elt);
}
else {
- *modified = 1;
id = (st_data_t)tmp;
if (st_lookup(memo, id, 0)) {
st_clear(memo);
@@ -5215,9 +5236,8 @@ rb_ary_flatten_bang(int argc, VALUE *argv, VALUE ary)
if (!NIL_P(lv)) level = NUM2INT(lv);
if (level == 0) return Qnil;
- result = flatten(ary, level, &mod);
- if (mod == 0) {
- ary_discard(result);
+ result = flatten(ary, level);
+ if (result == ary) {
return Qnil;
}
if (!(mod = ARY_EMBED_P(result))) rb_obj_freeze(result);
@@ -5252,7 +5272,7 @@ rb_ary_flatten_bang(int argc, VALUE *argv, VALUE ary)
static VALUE
rb_ary_flatten(int argc, VALUE *argv, VALUE ary)
{
- int mod = 0, level = -1;
+ int level = -1;
VALUE result;
if (rb_check_arity(argc, 0, 1) && !NIL_P(argv[0])) {
@@ -5260,7 +5280,10 @@ rb_ary_flatten(int argc, VALUE *argv, VALUE ary)
if (level == 0) return ary_make_shared_copy(ary);
}
- result = flatten(ary, level, &mod);
+ result = flatten(ary, level);
+ if (result == ary) {
+ result = ary_make_shared_copy(ary);
+ }
OBJ_INFECT(result, ary);
return result;