summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ChangeLog5
-rw-r--r--io.c74
-rw-r--r--test/ruby/test_io.rb26
3 files changed, 97 insertions, 8 deletions
diff --git a/ChangeLog b/ChangeLog
index 0627ef6669..492aee213d 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,8 @@
+Sun Apr 20 04:45:13 2008 Tanaka Akira <akr@fsij.org>
+
+ * io.c (copy_stream_body): use readpartial and write method for
+ non-IOs such as StringIO and ARGF.
+
Fri Apr 18 20:57:33 2008 Yusuke Endoh <mame@tsg.ne.jp>
* test/ruby/test_array.rb: add tests to achieve over 95% test coverage
diff --git a/io.c b/io.c
index f3a7c597dc..fa8246cf08 100644
--- a/io.c
+++ b/io.c
@@ -124,7 +124,7 @@ VALUE rb_default_rs;
static VALUE argf;
-static ID id_write, id_read, id_getc, id_flush, id_encode;
+static ID id_write, id_read, id_getc, id_flush, id_encode, id_readpartial;
struct timeval rb_time_interval(VALUE);
@@ -6250,10 +6250,11 @@ rb_io_s_read(int argc, VALUE *argv, VALUE io)
struct copy_stream_struct {
VALUE src;
VALUE dst;
+ off_t copy_length; /* (off_t)-1 if not specified */
+ off_t src_offset; /* (off_t)-1 if not specified */
+
int src_fd;
int dst_fd;
- off_t copy_length;
- off_t src_offset;
int close_src;
int close_dst;
off_t total;
@@ -6567,6 +6568,49 @@ finish:
}
static VALUE
+copy_stream_fallback_body(VALUE arg)
+{
+ struct copy_stream_struct *stp = (struct copy_stream_struct *)arg;
+ const int buflen = 16*1024;
+ VALUE n;
+ VALUE buf = rb_str_buf_new(buflen);
+ if (stp->copy_length == (off_t)-1) {
+ while (1) {
+ rb_funcall(stp->src, id_readpartial,
+ 2, INT2FIX(buflen), buf);
+ n = rb_io_write(stp->dst, buf);
+ stp->total += NUM2LONG(n);
+ }
+ }
+ else {
+ long rest = stp->copy_length;
+ while (0 < rest) {
+ long l = buflen < rest ? buflen : rest;
+ long numwrote;
+ rb_funcall(stp->src, id_readpartial,
+ 2, INT2FIX(l), buf);
+ n = rb_io_write(stp->dst, buf);
+ numwrote = NUM2LONG(n);
+ stp->total += numwrote;
+ rest -= numwrote;
+ }
+ }
+ return Qnil;
+}
+
+static VALUE
+copy_stream_fallback(struct copy_stream_struct *stp)
+{
+ if (stp->src_offset != (off_t)-1) {
+ rb_raise(rb_eArgError, "cannot specify src_offset");
+ }
+ rb_rescue2(copy_stream_fallback_body, (VALUE)stp,
+ (VALUE (*) (ANYARGS))0, (VALUE)0,
+ rb_eEOFError, (VALUE)0);
+ return Qnil;
+}
+
+static VALUE
copy_stream_body(VALUE arg)
{
struct copy_stream_struct *stp = (struct copy_stream_struct *)arg;
@@ -6577,6 +6621,21 @@ copy_stream_body(VALUE arg)
stp->th = GET_THREAD();
+ stp->total = 0;
+
+ if (stp->src == argf ||
+ stp->dst == argf ||
+ !(TYPE(stp->src) == T_FILE ||
+ rb_respond_to(stp->src, rb_intern("to_io")) ||
+ TYPE(stp->src) == T_STRING ||
+ rb_respond_to(stp->src, rb_intern("to_path"))) ||
+ !(TYPE(stp->dst) == T_FILE ||
+ rb_respond_to(stp->dst, rb_intern("to_io")) ||
+ TYPE(stp->dst) == T_STRING ||
+ rb_respond_to(stp->dst, rb_intern("to_path")))) {
+ return copy_stream_fallback(stp);
+ }
+
src_io = rb_check_convert_type(stp->src, T_FILE, "IO", "to_io");
if (!NIL_P(src_io)) {
GetOpenFile(src_io, src_fptr);
@@ -6616,8 +6675,6 @@ copy_stream_body(VALUE arg)
}
stp->dst_fd = dst_fd;
- stp->total = 0;
-
if (src_fptr && dst_fptr && src_fptr->rbuf_len && dst_fptr->wbuf_len) {
long len = src_fptr->rbuf_len;
VALUE str;
@@ -6708,6 +6765,9 @@ rb_io_s_copy_stream(int argc, VALUE *argv, VALUE io)
rb_scan_args(argc, argv, "22", &src, &dst, &length, &src_offset);
+ st.src = src;
+ st.dst = dst;
+
if (NIL_P(length))
st.copy_length = (off_t)-1;
else
@@ -6718,9 +6778,6 @@ rb_io_s_copy_stream(int argc, VALUE *argv, VALUE io)
else
st.src_offset = NUM2OFFT(src_offset);
- st.src = src;
- st.dst = dst;
-
rb_ensure(copy_stream_body, (VALUE)&st, copy_stream_finalize, (VALUE)&st);
return OFFT2NUM(st.total);
@@ -7344,6 +7401,7 @@ Init_IO(void)
id_getc = rb_intern("getc");
id_flush = rb_intern("flush");
id_encode = rb_intern("encode");
+ id_readpartial = rb_intern("readpartial");
rb_define_global_function("syscall", rb_f_syscall, -1);
diff --git a/test/ruby/test_io.rb b/test/ruby/test_io.rb
index 0cb8a775e2..d2292446fd 100644
--- a/test/ruby/test_io.rb
+++ b/test/ruby/test_io.rb
@@ -2,6 +2,7 @@ require 'test/unit'
require 'tmpdir'
require 'io/nonblock'
require 'socket'
+require 'stringio'
class TestIO < Test::Unit::TestCase
def test_gets_rs
@@ -393,8 +394,33 @@ class TestIO < Test::Unit::TestCase
result = t.value
assert_equal(megacontent, result)
}
+ }
+ end
+
+ def test_copy_stream_strio
+ src = StringIO.new("abcd")
+ dst = StringIO.new
+ ret = IO.copy_stream(src, dst)
+ assert_equal(4, ret)
+ assert_equal("abcd", dst.string)
+ assert_equal(4, src.pos)
+ end
+ def test_copy_stream_strio_len
+ src = StringIO.new("abcd")
+ dst = StringIO.new
+ ret = IO.copy_stream(src, dst, 3)
+ assert_equal(3, ret)
+ assert_equal("abc", dst.string)
+ assert_equal(3, src.pos)
+ end
+ def test_copy_stream_strio_off
+ src = StringIO.new("abcd")
+ dst = StringIO.new
+ assert_raise(ArgumentError) {
+ IO.copy_stream(src, dst, 3, 1)
}
end
+
end