summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--NEWS5
-rw-r--r--array.c70
-rw-r--r--test/ruby/test_array.rb38
3 files changed, 113 insertions, 0 deletions
diff --git a/NEWS b/NEWS
index f4234d0fbb..052b7b9a0c 100644
--- a/NEWS
+++ b/NEWS
@@ -39,6 +39,11 @@ sufficient information, see the ChangeLog file or Redmine
* `Array`
+ * New methods:
+
+ * Added `Array#union` instance method.
+ [Feature #14097]
+
* Aliased methods:
* `Array#filter` is a new alias for `Array#select`.
diff --git a/array.c b/array.c
index 3c36bb66e6..ae3cc49e26 100644
--- a/array.c
+++ b/array.c
@@ -4322,6 +4322,75 @@ rb_ary_or(VALUE ary1, VALUE ary2)
/*
* call-seq:
+ * ary.union(other_ary1, other_ary2,...) -> ary
+ *
+ * Set Union --- Returns a new array by joining +other_ary+s with +self+,
+ * excluding any duplicates and preserving the order from the given arrays.
+ *
+ * It compares elements using their #hash and #eql? methods for efficiency.
+ *
+ * [ "a", "b", "c" ].union( [ "c", "d", "a" ] ) #=> [ "a", "b", "c", "d" ]
+ * [ "a" ].union( [ ["e", "b"], ["a", "c", "b"] ] ) #=> [ "a", "e", "b", "c" ]
+ * [ "a" ].union #=> [ "a" ]
+ *
+ * See also Array#|.
+ */
+
+static VALUE
+rb_ary_union_multi(int argc, VALUE *argv, VALUE ary)
+{
+ int i;
+ long j;
+ long sum;
+ VALUE hash, ary_union;
+
+ sum = RARRAY_LEN(ary);
+ for (i = 0; i < argc; i++){
+ argv[i] = to_ary(argv[i]);
+ sum += RARRAY_LEN(argv[i]);
+ }
+
+ if (sum <= SMALL_ARRAY_LEN) {
+ ary_union = rb_ary_new();
+
+ for (j = 0; j < RARRAY_LEN(ary); j++) {
+ VALUE elt = rb_ary_elt(ary, j);
+ if (rb_ary_includes_by_eql(ary_union, elt)) continue;
+ rb_ary_push(ary_union, elt);
+ }
+
+ for (i = 0; i < argc; i++) {
+ VALUE argv_i = argv[i];
+
+ for (j = 0; j < RARRAY_LEN(argv_i); j++) {
+ VALUE elt = rb_ary_elt(argv_i, j);
+ if (rb_ary_includes_by_eql(ary_union, elt)) continue;
+ rb_ary_push(ary_union, elt);
+ }
+ }
+ return ary_union;
+ }
+
+ hash = ary_make_hash(ary);
+
+ for (i = 0; i < argc; i++) {
+ VALUE argv_i = argv[i];
+
+ for (j = 0; j < RARRAY_LEN(argv_i); j++) {
+ VALUE elt = RARRAY_AREF(argv_i, j);
+ if (!st_update(RHASH_TBL_RAW(hash), (st_data_t)elt, ary_hash_orset, (st_data_t)elt)) {
+ RB_OBJ_WRITTEN(hash, Qundef, elt);
+ }
+ }
+ }
+
+ ary_union = rb_hash_values(hash);
+ ary_recycle_hash(hash);
+ return ary_union;
+}
+
+/*
+ * call-seq:
* ary.max -> obj
* ary.max {|a, b| block} -> obj
* ary.max(n) -> array
@@ -6296,6 +6365,7 @@ Init_Array(void)
rb_define_method(rb_cArray, "first", rb_ary_first, -1);
rb_define_method(rb_cArray, "last", rb_ary_last, -1);
rb_define_method(rb_cArray, "concat", rb_ary_concat_multi, -1);
+ rb_define_method(rb_cArray, "union", rb_ary_union_multi, -1);
rb_define_method(rb_cArray, "<<", rb_ary_push, 1);
rb_define_method(rb_cArray, "push", rb_ary_push_m, -1);
rb_define_alias(rb_cArray, "append", "push");
diff --git a/test/ruby/test_array.rb b/test/ruby/test_array.rb
index 9962745f1c..e1a77b73b1 100644
--- a/test/ruby/test_array.rb
+++ b/test/ruby/test_array.rb
@@ -1884,6 +1884,44 @@ class TestArray < Test::Unit::TestCase
assert_equal((1..128).to_a, b)
end
+ def test_union
+ assert_equal(@cls[], @cls[].union(@cls[]))
+ assert_equal(@cls[1], @cls[1].union(@cls[]))
+ assert_equal(@cls[1], @cls[].union(@cls[1]))
+ assert_equal(@cls[1], @cls[].union(@cls[], @cls[1]))
+ assert_equal(@cls[1], @cls[1].union(@cls[1]))
+ assert_equal(@cls[1], @cls[1].union(@cls[1], @cls[1], @cls[1]))
+
+ assert_equal(@cls[1,2], @cls[1].union(@cls[2]))
+ assert_equal(@cls[1,2], @cls[1, 1].union(@cls[2, 2]))
+ assert_equal(@cls[1,2], @cls[1, 2].union(@cls[1, 2]))
+ assert_equal(@cls[1,2], @cls[1, 1].union(@cls[1, 1], @cls[1, 2], @cls[2, 1], @cls[2, 2, 2]))
+
+ a = %w(a b c)
+ b = %w(a b c d e)
+ c = a.union(b)
+ assert_equal(c, b)
+ assert_not_same(c, b)
+ assert_equal(%w(a b c), a)
+ assert_equal(%w(a b c d e), b)
+ assert(a.none?(&:frozen?))
+ assert(b.none?(&:frozen?))
+ assert(c.none?(&:frozen?))
+ end
+
+ def test_union_big_array
+ assert_equal(@cls[1,2], (@cls[1]*64).union(@cls[2]*64))
+ assert_equal(@cls[1,2,3], (@cls[1, 2]*64).union(@cls[1, 2]*64, @cls[3]*60))
+
+ a = (1..64).to_a
+ b = (1..128).to_a
+ c = a | b
+ assert_equal(c, b)
+ assert_not_same(c, b)
+ assert_equal((1..64).to_a, a)
+ assert_equal((1..128).to_a, b)
+ end
+
def test_combination
a = @cls[]
assert_equal(1, a.combination(0).size)