summaryrefslogtreecommitdiff
path: root/enumerator.c
diff options
context:
space:
mode:
authorknu <knu@b2dd03c8-39d4-4d8f-98ff-823fe69b080e>2018-11-24 08:38:35 +0000
committerknu <knu@b2dd03c8-39d4-4d8f-98ff-823fe69b080e>2018-11-24 08:38:35 +0000
commit045b0e54d884f2e67449caaaec670a76ed16f3d4 (patch)
treeb13fdfce8348eb8cbef757cf8c8c7eec215f749f /enumerator.c
parentc0e20037f3fb16d75b55394e7bf0a1d3ef8b7b87 (diff)
Implement Enumerator#+ and Enumerable#chain [Feature #15144]
They return an Enumerator::Chain object which is a subclass of Enumerator, which represents a chain of enumerables that works as a single enumerator. ```ruby e = (1..3).chain([4, 5]) e.to_a #=> [1, 2, 3, 4, 5] e = (1..3).each + [4, 5] e.to_a #=> [1, 2, 3, 4, 5] ``` git-svn-id: svn+ssh://ci.ruby-lang.org/ruby/trunk@65949 b2dd03c8-39d4-4d8f-98ff-823fe69b080e
Diffstat (limited to 'enumerator.c')
-rw-r--r--enumerator.c314
1 files changed, 314 insertions, 0 deletions
diff --git a/enumerator.c b/enumerator.c
index 274583a3de..10b395f66f 100644
--- a/enumerator.c
+++ b/enumerator.c
@@ -12,6 +12,7 @@
************************************************/
+#include "ruby/ruby.h"
#include "internal.h"
#include "id.h"
@@ -161,6 +162,13 @@ struct proc_entry {
static VALUE generator_allocate(VALUE klass);
static VALUE generator_init(VALUE obj, VALUE proc);
+static VALUE rb_cEnumChain;
+
+struct enum_chain {
+ VALUE enums;
+ long pos;
+};
+
static VALUE rb_cArithSeq;
/*
@@ -2412,6 +2420,300 @@ stop_result(VALUE self)
}
/*
+ * Document-class: Enumerator::Chain
+ *
+ * Enumerator::Chain is a subclass of Enumerator, which represents a
+ * chain of enumerables that works as a single enumerator.
+ *
+ * This type of objects can be created by Enumerable#chain and
+ * Enumerator#+.
+ */
+
+static void
+enum_chain_mark(void *p)
+{
+ struct enum_chain *ptr = p;
+ rb_gc_mark(ptr->enums);
+}
+
+#define enum_chain_free RUBY_TYPED_DEFAULT_FREE
+
+static size_t
+enum_chain_memsize(const void *p)
+{
+ return sizeof(struct enum_chain);
+}
+
+static const rb_data_type_t enum_chain_data_type = {
+ "chain",
+ {
+ enum_chain_mark,
+ enum_chain_free,
+ enum_chain_memsize,
+ },
+ 0, 0, RUBY_TYPED_FREE_IMMEDIATELY
+};
+
+static struct enum_chain *
+enum_chain_ptr(VALUE obj)
+{
+ struct enum_chain *ptr;
+
+ TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
+ if (!ptr || ptr->enums == Qundef) {
+ rb_raise(rb_eArgError, "uninitialized chain");
+ }
+ return ptr;
+}
+
+/* :nodoc: */
+static VALUE
+enum_chain_allocate(VALUE klass)
+{
+ struct enum_chain *ptr;
+ VALUE obj;
+
+ obj = TypedData_Make_Struct(klass, struct enum_chain, &enum_chain_data_type, ptr);
+ ptr->enums = Qundef;
+ ptr->pos = -1;
+
+ return obj;
+}
+
+/*
+ * call-seq:
+ * Enumerator::Chain.new(*enums) -> enum
+ *
+ * Generates a new enumerator object that iterates over the elements
+ * of given enumerable objects in sequence.
+ *
+ * e = Enumerator::Chain.new(1..3, [4, 5])
+ * e.to_a #=> [1, 2, 3, 4, 5]
+ * e.size #=> 5
+ */
+static VALUE
+enum_chain_initialize(VALUE obj, VALUE enums)
+{
+ struct enum_chain *ptr;
+
+ rb_check_frozen(obj);
+ TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
+
+ if (!ptr) rb_raise(rb_eArgError, "unallocated chain");
+
+ ptr->enums = rb_obj_freeze(enums);
+ ptr->pos = -1;
+
+ return obj;
+}
+
+/* :nodoc: */
+static VALUE
+enum_chain_init_copy(VALUE obj, VALUE orig)
+{
+ struct enum_chain *ptr0, *ptr1;
+
+ if (!OBJ_INIT_COPY(obj, orig)) return obj;
+ ptr0 = enum_chain_ptr(orig);
+
+ TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr1);
+
+ if (!ptr1) rb_raise(rb_eArgError, "unallocated chain");
+
+ ptr1->enums = ptr0->enums;
+ ptr1->pos = ptr0->pos;
+
+ return obj;
+}
+
+static VALUE
+enum_chain_total_size(VALUE enums)
+{
+ VALUE total = INT2FIX(0);
+
+ RARRAY_PTR_USE(enums, ptr, {
+ long i;
+
+ for (i = 0; i < RARRAY_LEN(enums); i++) {
+ VALUE size = enum_size(ptr[i]);
+
+ if (NIL_P(size) || (RB_TYPE_P(size, T_FLOAT) && isinf(NUM2DBL(size)))) {
+ return size;
+ }
+ if (!RB_INTEGER_TYPE_P(size)) {
+ return Qnil;
+ }
+
+ total = rb_funcall(total, '+', 1, size);
+ }
+ });
+
+ return total;
+}
+
+/*
+ * call-seq:
+ * obj.size -> integer
+ *
+ * Returns the total size of the enumerator chain calculated by
+ * summing up the size of each enumerable in the chain. If any of the
+ * enumerables reports its size as nil or Float::INFINITY, that value
+ * is returned as the total size.
+ */
+static VALUE
+enum_chain_size(VALUE obj)
+{
+ return enum_chain_total_size(enum_chain_ptr(obj)->enums);
+}
+
+static VALUE
+enum_chain_enum_size(VALUE obj, VALUE args, VALUE eobj)
+{
+ return enum_chain_size(obj);
+}
+
+static VALUE
+enum_chain_yield_block(VALUE arg, VALUE block, int argc, VALUE *argv)
+{
+ return rb_funcallv(block, rb_intern("call"), argc, argv);
+}
+
+static VALUE
+enum_chain_enum_no_size(VALUE obj, VALUE args, VALUE eobj)
+{
+ return Qnil;
+}
+
+/*
+ * call-seq:
+ * obj.each(*args) { |...| ... } -> obj
+ * obj.each(*args) -> enumerator
+ *
+ * Iterates over the elements of the first enumerable by calling the
+ * "each" method on it with the given arguments, then proceeds to the
+ * following enumerables in sequence until all of the enumerables are
+ * exhausted.
+ *
+ * If no block is given, returns an enumerator.
+ */
+static VALUE
+enum_chain_each(int argc, VALUE *argv, VALUE obj)
+{
+ VALUE enums, block;
+ struct enum_chain *objptr;
+
+ RETURN_SIZED_ENUMERATOR(obj, argc, argv, argc > 0 ? enum_chain_enum_no_size : enum_chain_enum_size);
+
+ objptr = enum_chain_ptr(obj);
+ enums = objptr->enums;
+ block = rb_block_proc();
+
+ RARRAY_PTR_USE(enums, ptr, {
+ long i;
+
+ for (i = 0; i < RARRAY_LEN(enums); i++) {
+ objptr->pos = i;
+ rb_block_call(ptr[i], id_each, argc, argv, enum_chain_yield_block, block);
+ }
+ });
+
+ return obj;
+}
+
+/*
+ * call-seq:
+ * obj.rewind -> obj
+ *
+ * Rewinds the enumerator chain by calling the "rewind" method on each
+ * enumerable in reverse order. Each call is performed only if the
+ * enumerable responds to the method.
+ */
+static VALUE
+enum_chain_rewind(VALUE obj)
+{
+ struct enum_chain *objptr = enum_chain_ptr(obj);
+ VALUE enums = objptr->enums;
+
+ RARRAY_PTR_USE(enums, ptr, {
+ long i;
+
+ for (i = objptr->pos; 0 <= i && i < RARRAY_LEN(enums); objptr->pos = --i) {
+ rb_check_funcall(ptr[i], id_rewind, 0, 0);
+ }
+ });
+
+ return obj;
+}
+
+static VALUE
+inspect_enum_chain(VALUE obj, VALUE dummy, int recur)
+{
+ VALUE klass = rb_obj_class(obj);
+ struct enum_chain *ptr;
+
+ TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
+
+ if (!ptr || ptr->enums == Qundef) {
+ return rb_sprintf("#<%"PRIsVALUE": uninitialized>", rb_class_path(klass));
+ }
+
+ if (recur) {
+ return rb_sprintf("#<%"PRIsVALUE": ...>", rb_class_path(klass));
+ }
+
+ return rb_sprintf("#<%"PRIsVALUE": %+"PRIsVALUE">", rb_class_path(klass), ptr->enums);
+}
+
+/*
+ * call-seq:
+ * obj.inspect -> string
+ *
+ * Returns a printable version of the enumerator chain.
+ */
+static VALUE
+enum_chain_inspect(VALUE obj)
+{
+ return rb_exec_recursive(inspect_enum_chain, obj, 0);
+}
+
+/*
+ * call-seq:
+ * e.chain(*enums) -> enumerator
+ *
+ * Returns an enumerator object generated from this enumerator and
+ * given enumerables.
+ *
+ * e = (1..3).chain([4, 5])
+ * e.to_a #=> [1, 2, 3, 4, 5]
+ */
+static VALUE
+enum_chain(int argc, VALUE *argv, VALUE obj)
+{
+ VALUE enums = rb_ary_new_from_values(1, &obj);
+ rb_ary_cat(enums, argv, argc);
+
+ return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
+}
+
+/*
+ * call-seq:
+ * e + enum -> enumerator
+ *
+ * Returns an enumerator object generated from this enumerator and a
+ * given enumerable.
+ *
+ * e = (1..3).each + [4, 5]
+ * e.to_a #=> [1, 2, 3, 4, 5]
+ */
+static VALUE
+enumerator_plus(VALUE obj, VALUE eobj)
+{
+ VALUE enums = rb_ary_new_from_args(2, obj, eobj);
+
+ return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
+}
+
+/*
* Document-class: Enumerator::ArithmeticSequence
*
* Enumerator::ArithmeticSequence is a subclass of Enumerator,
@@ -2907,6 +3209,8 @@ InitVM_Enumerator(void)
rb_define_method(rb_cEnumerator, "rewind", enumerator_rewind, 0);
rb_define_method(rb_cEnumerator, "inspect", enumerator_inspect, 0);
rb_define_method(rb_cEnumerator, "size", enumerator_size, 0);
+ rb_define_method(rb_cEnumerator, "+", enumerator_plus, 1);
+ rb_define_method(rb_mEnumerable, "chain", enum_chain, -1);
/* Lazy */
rb_cLazy = rb_define_class_under(rb_cEnumerator, "Lazy", rb_cEnumerator);
@@ -2960,6 +3264,16 @@ InitVM_Enumerator(void)
rb_define_method(rb_cYielder, "yield", yielder_yield, -2);
rb_define_method(rb_cYielder, "<<", yielder_yield_push, 1);
+ /* Chain */
+ rb_cEnumChain = rb_define_class_under(rb_cEnumerator, "Chain", rb_cEnumerator);
+ rb_define_alloc_func(rb_cEnumChain, enum_chain_allocate);
+ rb_define_method(rb_cEnumChain, "initialize", enum_chain_initialize, -2);
+ rb_define_method(rb_cEnumChain, "initialize_copy", enum_chain_init_copy, 1);
+ rb_define_method(rb_cEnumChain, "each", enum_chain_each, -1);
+ rb_define_method(rb_cEnumChain, "size", enum_chain_size, 0);
+ rb_define_method(rb_cEnumChain, "rewind", enum_chain_rewind, 0);
+ rb_define_method(rb_cEnumChain, "inspect", enum_chain_inspect, 0);
+
/* ArithmeticSequence */
rb_cArithSeq = rb_define_class_under(rb_cEnumerator, "ArithmeticSequence", rb_cEnumerator);
rb_undef_alloc_func(rb_cArithSeq);