summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorknu <knu@b2dd03c8-39d4-4d8f-98ff-823fe69b080e>2018-11-24 08:38:35 +0000
committerknu <knu@b2dd03c8-39d4-4d8f-98ff-823fe69b080e>2018-11-24 08:38:35 +0000
commit045b0e54d884f2e67449caaaec670a76ed16f3d4 (patch)
treeb13fdfce8348eb8cbef757cf8c8c7eec215f749f
parentc0e20037f3fb16d75b55394e7bf0a1d3ef8b7b87 (diff)
Implement Enumerator#+ and Enumerable#chain [Feature #15144]
They return an Enumerator::Chain object which is a subclass of Enumerator, which represents a chain of enumerables that works as a single enumerator. ```ruby e = (1..3).chain([4, 5]) e.to_a #=> [1, 2, 3, 4, 5] e = (1..3).each + [4, 5] e.to_a #=> [1, 2, 3, 4, 5] ``` git-svn-id: svn+ssh://ci.ruby-lang.org/ruby/trunk@65949 b2dd03c8-39d4-4d8f-98ff-823fe69b080e
-rw-r--r--enumerator.c314
-rw-r--r--test/ruby/test_enumerator.rb116
2 files changed, 429 insertions, 1 deletions
diff --git a/enumerator.c b/enumerator.c
index 274583a3de..10b395f66f 100644
--- a/enumerator.c
+++ b/enumerator.c
@@ -12,6 +12,7 @@
************************************************/
+#include "ruby/ruby.h"
#include "internal.h"
#include "id.h"
@@ -161,6 +162,13 @@ struct proc_entry {
static VALUE generator_allocate(VALUE klass);
static VALUE generator_init(VALUE obj, VALUE proc);
+static VALUE rb_cEnumChain;
+
+struct enum_chain {
+ VALUE enums;
+ long pos;
+};
+
static VALUE rb_cArithSeq;
/*
@@ -2412,6 +2420,300 @@ stop_result(VALUE self)
}
/*
+ * Document-class: Enumerator::Chain
+ *
+ * Enumerator::Chain is a subclass of Enumerator, which represents a
+ * chain of enumerables that works as a single enumerator.
+ *
+ * This type of objects can be created by Enumerable#chain and
+ * Enumerator#+.
+ */
+
+static void
+enum_chain_mark(void *p)
+{
+ struct enum_chain *ptr = p;
+ rb_gc_mark(ptr->enums);
+}
+
+#define enum_chain_free RUBY_TYPED_DEFAULT_FREE
+
+static size_t
+enum_chain_memsize(const void *p)
+{
+ return sizeof(struct enum_chain);
+}
+
+static const rb_data_type_t enum_chain_data_type = {
+ "chain",
+ {
+ enum_chain_mark,
+ enum_chain_free,
+ enum_chain_memsize,
+ },
+ 0, 0, RUBY_TYPED_FREE_IMMEDIATELY
+};
+
+static struct enum_chain *
+enum_chain_ptr(VALUE obj)
+{
+ struct enum_chain *ptr;
+
+ TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
+ if (!ptr || ptr->enums == Qundef) {
+ rb_raise(rb_eArgError, "uninitialized chain");
+ }
+ return ptr;
+}
+
+/* :nodoc: */
+static VALUE
+enum_chain_allocate(VALUE klass)
+{
+ struct enum_chain *ptr;
+ VALUE obj;
+
+ obj = TypedData_Make_Struct(klass, struct enum_chain, &enum_chain_data_type, ptr);
+ ptr->enums = Qundef;
+ ptr->pos = -1;
+
+ return obj;
+}
+
+/*
+ * call-seq:
+ * Enumerator::Chain.new(*enums) -> enum
+ *
+ * Generates a new enumerator object that iterates over the elements
+ * of given enumerable objects in sequence.
+ *
+ * e = Enumerator::Chain.new(1..3, [4, 5])
+ * e.to_a #=> [1, 2, 3, 4, 5]
+ * e.size #=> 5
+ */
+static VALUE
+enum_chain_initialize(VALUE obj, VALUE enums)
+{
+ struct enum_chain *ptr;
+
+ rb_check_frozen(obj);
+ TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
+
+ if (!ptr) rb_raise(rb_eArgError, "unallocated chain");
+
+ ptr->enums = rb_obj_freeze(enums);
+ ptr->pos = -1;
+
+ return obj;
+}
+
+/* :nodoc: */
+static VALUE
+enum_chain_init_copy(VALUE obj, VALUE orig)
+{
+ struct enum_chain *ptr0, *ptr1;
+
+ if (!OBJ_INIT_COPY(obj, orig)) return obj;
+ ptr0 = enum_chain_ptr(orig);
+
+ TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr1);
+
+ if (!ptr1) rb_raise(rb_eArgError, "unallocated chain");
+
+ ptr1->enums = ptr0->enums;
+ ptr1->pos = ptr0->pos;
+
+ return obj;
+}
+
+static VALUE
+enum_chain_total_size(VALUE enums)
+{
+ VALUE total = INT2FIX(0);
+
+ RARRAY_PTR_USE(enums, ptr, {
+ long i;
+
+ for (i = 0; i < RARRAY_LEN(enums); i++) {
+ VALUE size = enum_size(ptr[i]);
+
+ if (NIL_P(size) || (RB_TYPE_P(size, T_FLOAT) && isinf(NUM2DBL(size)))) {
+ return size;
+ }
+ if (!RB_INTEGER_TYPE_P(size)) {
+ return Qnil;
+ }
+
+ total = rb_funcall(total, '+', 1, size);
+ }
+ });
+
+ return total;
+}
+
+/*
+ * call-seq:
+ * obj.size -> integer
+ *
+ * Returns the total size of the enumerator chain calculated by
+ * summing up the size of each enumerable in the chain. If any of the
+ * enumerables reports its size as nil or Float::INFINITY, that value
+ * is returned as the total size.
+ */
+static VALUE
+enum_chain_size(VALUE obj)
+{
+ return enum_chain_total_size(enum_chain_ptr(obj)->enums);
+}
+
+static VALUE
+enum_chain_enum_size(VALUE obj, VALUE args, VALUE eobj)
+{
+ return enum_chain_size(obj);
+}
+
+static VALUE
+enum_chain_yield_block(VALUE arg, VALUE block, int argc, VALUE *argv)
+{
+ return rb_funcallv(block, rb_intern("call"), argc, argv);
+}
+
+static VALUE
+enum_chain_enum_no_size(VALUE obj, VALUE args, VALUE eobj)
+{
+ return Qnil;
+}
+
+/*
+ * call-seq:
+ * obj.each(*args) { |...| ... } -> obj
+ * obj.each(*args) -> enumerator
+ *
+ * Iterates over the elements of the first enumerable by calling the
+ * "each" method on it with the given arguments, then proceeds to the
+ * following enumerables in sequence until all of the enumerables are
+ * exhausted.
+ *
+ * If no block is given, returns an enumerator.
+ */
+static VALUE
+enum_chain_each(int argc, VALUE *argv, VALUE obj)
+{
+ VALUE enums, block;
+ struct enum_chain *objptr;
+
+ RETURN_SIZED_ENUMERATOR(obj, argc, argv, argc > 0 ? enum_chain_enum_no_size : enum_chain_enum_size);
+
+ objptr = enum_chain_ptr(obj);
+ enums = objptr->enums;
+ block = rb_block_proc();
+
+ RARRAY_PTR_USE(enums, ptr, {
+ long i;
+
+ for (i = 0; i < RARRAY_LEN(enums); i++) {
+ objptr->pos = i;
+ rb_block_call(ptr[i], id_each, argc, argv, enum_chain_yield_block, block);
+ }
+ });
+
+ return obj;
+}
+
+/*
+ * call-seq:
+ * obj.rewind -> obj
+ *
+ * Rewinds the enumerator chain by calling the "rewind" method on each
+ * enumerable in reverse order. Each call is performed only if the
+ * enumerable responds to the method.
+ */
+static VALUE
+enum_chain_rewind(VALUE obj)
+{
+ struct enum_chain *objptr = enum_chain_ptr(obj);
+ VALUE enums = objptr->enums;
+
+ RARRAY_PTR_USE(enums, ptr, {
+ long i;
+
+ for (i = objptr->pos; 0 <= i && i < RARRAY_LEN(enums); objptr->pos = --i) {
+ rb_check_funcall(ptr[i], id_rewind, 0, 0);
+ }
+ });
+
+ return obj;
+}
+
+static VALUE
+inspect_enum_chain(VALUE obj, VALUE dummy, int recur)
+{
+ VALUE klass = rb_obj_class(obj);
+ struct enum_chain *ptr;
+
+ TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
+
+ if (!ptr || ptr->enums == Qundef) {
+ return rb_sprintf("#<%"PRIsVALUE": uninitialized>", rb_class_path(klass));
+ }
+
+ if (recur) {
+ return rb_sprintf("#<%"PRIsVALUE": ...>", rb_class_path(klass));
+ }
+
+ return rb_sprintf("#<%"PRIsVALUE": %+"PRIsVALUE">", rb_class_path(klass), ptr->enums);
+}
+
+/*
+ * call-seq:
+ * obj.inspect -> string
+ *
+ * Returns a printable version of the enumerator chain.
+ */
+static VALUE
+enum_chain_inspect(VALUE obj)
+{
+ return rb_exec_recursive(inspect_enum_chain, obj, 0);
+}
+
+/*
+ * call-seq:
+ * e.chain(*enums) -> enumerator
+ *
+ * Returns an enumerator object generated from this enumerator and
+ * given enumerables.
+ *
+ * e = (1..3).chain([4, 5])
+ * e.to_a #=> [1, 2, 3, 4, 5]
+ */
+static VALUE
+enum_chain(int argc, VALUE *argv, VALUE obj)
+{
+ VALUE enums = rb_ary_new_from_values(1, &obj);
+ rb_ary_cat(enums, argv, argc);
+
+ return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
+}
+
+/*
+ * call-seq:
+ * e + enum -> enumerator
+ *
+ * Returns an enumerator object generated from this enumerator and a
+ * given enumerable.
+ *
+ * e = (1..3).each + [4, 5]
+ * e.to_a #=> [1, 2, 3, 4, 5]
+ */
+static VALUE
+enumerator_plus(VALUE obj, VALUE eobj)
+{
+ VALUE enums = rb_ary_new_from_args(2, obj, eobj);
+
+ return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
+}
+
+/*
* Document-class: Enumerator::ArithmeticSequence
*
* Enumerator::ArithmeticSequence is a subclass of Enumerator,
@@ -2907,6 +3209,8 @@ InitVM_Enumerator(void)
rb_define_method(rb_cEnumerator, "rewind", enumerator_rewind, 0);
rb_define_method(rb_cEnumerator, "inspect", enumerator_inspect, 0);
rb_define_method(rb_cEnumerator, "size", enumerator_size, 0);
+ rb_define_method(rb_cEnumerator, "+", enumerator_plus, 1);
+ rb_define_method(rb_mEnumerable, "chain", enum_chain, -1);
/* Lazy */
rb_cLazy = rb_define_class_under(rb_cEnumerator, "Lazy", rb_cEnumerator);
@@ -2960,6 +3264,16 @@ InitVM_Enumerator(void)
rb_define_method(rb_cYielder, "yield", yielder_yield, -2);
rb_define_method(rb_cYielder, "<<", yielder_yield_push, 1);
+ /* Chain */
+ rb_cEnumChain = rb_define_class_under(rb_cEnumerator, "Chain", rb_cEnumerator);
+ rb_define_alloc_func(rb_cEnumChain, enum_chain_allocate);
+ rb_define_method(rb_cEnumChain, "initialize", enum_chain_initialize, -2);
+ rb_define_method(rb_cEnumChain, "initialize_copy", enum_chain_init_copy, 1);
+ rb_define_method(rb_cEnumChain, "each", enum_chain_each, -1);
+ rb_define_method(rb_cEnumChain, "size", enum_chain_size, 0);
+ rb_define_method(rb_cEnumChain, "rewind", enum_chain_rewind, 0);
+ rb_define_method(rb_cEnumChain, "inspect", enum_chain_inspect, 0);
+
/* ArithmeticSequence */
rb_cArithSeq = rb_define_class_under(rb_cEnumerator, "ArithmeticSequence", rb_cEnumerator);
rb_undef_alloc_func(rb_cArithSeq);
diff --git a/test/ruby/test_enumerator.rb b/test/ruby/test_enumerator.rb
index 0839c2c3dd..afd356105d 100644
--- a/test/ruby/test_enumerator.rb
+++ b/test/ruby/test_enumerator.rb
@@ -670,5 +670,119 @@ class TestEnumerator < Test::Unit::TestCase
assert_equal([0, 1], u.force)
assert_equal([0, 1], u.force)
end
-end
+ def test_enum_chain_and_plus
+ r = 1..5
+
+ e1 = r.chain()
+ assert_kind_of(Enumerator::Chain, e1)
+ assert_equal(5, e1.size)
+ ary = []
+ e1.each { |x| ary << x }
+ assert_equal([1, 2, 3, 4, 5], ary)
+
+ e2 = r.chain([6, 7, 8])
+ assert_kind_of(Enumerator::Chain, e2)
+ assert_equal(8, e2.size)
+ ary = []
+ e2.each { |x| ary << x }
+ assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
+
+ e3 = r.chain([6, 7], 8.step)
+ assert_kind_of(Enumerator::Chain, e3)
+ assert_equal(Float::INFINITY, e3.size)
+ ary = []
+ e3.take(10).each { |x| ary << x }
+ assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
+
+ # `a + b + c` should not return `Enumerator::Chain.new(a, b, c)`
+ # because it is expected that `(a + b).each` be called.
+ e4 = e2.dup
+ class << e4
+ attr_reader :each_is_called
+ def each
+ super
+ @each_is_called = true
+ end
+ end
+ e5 = e4 + 9.step
+ assert_kind_of(Enumerator::Chain, e5)
+ assert_equal(Float::INFINITY, e5.size)
+ ary = []
+ e5.take(10).each { |x| ary << x }
+ assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
+ assert_equal(true, e4.each_is_called)
+ end
+
+ def test_chained_enums
+ a = (1..5).each
+
+ e0 = Enumerator::Chain.new()
+ assert_kind_of(Enumerator::Chain, e0)
+ assert_equal(0, e0.size)
+ ary = []
+ e0.each { |x| ary << x }
+ assert_equal([], ary)
+
+ e1 = Enumerator::Chain.new(a)
+ assert_kind_of(Enumerator::Chain, e1)
+ assert_equal(5, e1.size)
+ ary = []
+ e1.each { |x| ary << x }
+ assert_equal([1, 2, 3, 4, 5], ary)
+
+ e2 = Enumerator::Chain.new(a, [6, 7, 8])
+ assert_kind_of(Enumerator::Chain, e2)
+ assert_equal(8, e2.size)
+ ary = []
+ e2.each { |x| ary << x }
+ assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
+
+ e3 = Enumerator::Chain.new(a, [6, 7], 8.step)
+ assert_kind_of(Enumerator::Chain, e3)
+ assert_equal(Float::INFINITY, e3.size)
+ ary = []
+ e3.take(10).each { |x| ary << x }
+ assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
+
+ e4 = Enumerator::Chain.new(a, Enumerator.new { |y| y << 6 << 7 << 8 })
+ assert_kind_of(Enumerator::Chain, e4)
+ assert_equal(nil, e4.size)
+ ary = []
+ e4.each { |x| ary << x }
+ assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
+
+ e5 = Enumerator::Chain.new(e1, e2)
+ assert_kind_of(Enumerator::Chain, e5)
+ assert_equal(13, e5.size)
+ ary = []
+ e5.each { |x| ary << x }
+ assert_equal([1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8], ary)
+
+ rewound = []
+ e1.define_singleton_method(:rewind) { rewound << object_id }
+ e2.define_singleton_method(:rewind) { rewound << object_id }
+ e5.rewind
+ assert_equal(rewound, [e2.object_id, e1.object_id])
+
+ rewound = []
+ a = [1]
+ e6 = Enumerator::Chain.new(a)
+ a.define_singleton_method(:rewind) { rewound << object_id }
+ e6.rewind
+ assert_equal(rewound, [])
+
+ assert_equal(
+ '#<Enumerator::Chain: [' +
+ '#<Enumerator::Chain: [' +
+ '#<Enumerator: 1..5:each>' +
+ ']>, ' +
+ '#<Enumerator::Chain: [' +
+ '#<Enumerator: 1..5:each>, ' +
+ '[6, 7, 8]' +
+ ']>' +
+ ']>',
+ e5.inspect
+ )
+ end
+end