From 135a0d26a11cf0d0825e4f80e00e2b430555d831 Mon Sep 17 00:00:00 2001 From: Samuel Williams Date: Wed, 24 May 2023 10:17:35 +0900 Subject: Improvements to `IO::Buffer` `read`/`write`/`pread`/`pwrite`. (#7826) - Fix IO::Buffer `read`/`write` to use a minimum length. --- io_buffer.c | 247 +++++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 170 insertions(+), 77 deletions(-) (limited to 'io_buffer.c') diff --git a/io_buffer.c b/io_buffer.c index 04838f0033..7790b65876 100644 --- a/io_buffer.c +++ b/io_buffer.c @@ -991,17 +991,23 @@ io_buffer_readonly_p(VALUE self) return RBOOL(rb_io_buffer_readonly_p(self)); } -VALUE -rb_io_buffer_lock(VALUE self) +static void +io_buffer_lock(struct rb_io_buffer *buffer) { - struct rb_io_buffer *buffer = NULL; - TypedData_Get_Struct(self, struct rb_io_buffer, &rb_io_buffer_type, buffer); - if (buffer->flags & RB_IO_BUFFER_LOCKED) { rb_raise(rb_eIOBufferLockedError, "Buffer already locked!"); } buffer->flags |= RB_IO_BUFFER_LOCKED; +} + +VALUE +rb_io_buffer_lock(VALUE self) +{ + struct rb_io_buffer *buffer = NULL; + TypedData_Get_Struct(self, struct rb_io_buffer, &rb_io_buffer_type, buffer); + + io_buffer_lock(buffer); return self; } @@ -2433,18 +2439,120 @@ io_buffer_default_size(size_t page_size) return platform_agnostic_default_size; } +struct io_buffer_blocking_region_argument { + struct rb_io_buffer *buffer; + rb_blocking_function_t *function; + void *data; + int descriptor; +}; + +static VALUE +io_buffer_blocking_region_begin(VALUE _argument) +{ + struct io_buffer_blocking_region_argument *argument = (void*)_argument; + + return rb_thread_io_blocking_region(argument->function, argument->data, argument->descriptor); +} + +static VALUE +io_buffer_blocking_region_ensure(VALUE _argument) +{ + struct io_buffer_blocking_region_argument *argument = (void*)_argument; + + io_buffer_unlock(argument->buffer); + + return Qnil; +} + +static VALUE +io_buffer_blocking_region(struct rb_io_buffer *buffer, rb_blocking_function_t *function, void *data, int descriptor) +{ + struct io_buffer_blocking_region_argument argument = { + .buffer = buffer, + .function = function, + .data = data, + .descriptor = descriptor, + }; + + // If the buffer is already locked, we can skip the ensure (unlock): + if (buffer->flags & RB_IO_BUFFER_LOCKED) { + return io_buffer_blocking_region_begin((VALUE)&argument); + } + else { + // The buffer should be locked for the duration of the blocking region: + io_buffer_lock(buffer); + + return rb_ensure(io_buffer_blocking_region_begin, (VALUE)&argument, io_buffer_blocking_region_ensure, (VALUE)&argument); + } +} + +static inline struct rb_io_buffer * +io_buffer_extract_arguments(VALUE self, int argc, VALUE argv[], size_t *length, size_t *offset) +{ + struct rb_io_buffer *buffer = NULL; + TypedData_Get_Struct(self, struct rb_io_buffer, &rb_io_buffer_type, buffer); + + *offset = 0; + if (argc >= 2) { + if (rb_int_negative_p(argv[1])) { + rb_raise(rb_eArgError, "Offset can't be negative!"); + } + + *offset = NUM2SIZET(argv[1]); + } + + if (argc >= 1 && !NIL_P(argv[0])) { + if (rb_int_negative_p(argv[0])) { + rb_raise(rb_eArgError, "Length can't be negative!"); + } + + *length = NUM2SIZET(argv[0]); + } + else { + *length = buffer->size - *offset; + } + + return buffer; +} + struct io_buffer_read_internal_argument { int descriptor; - void *base; + + // The base pointer to read from: + char *base; + // The size of the buffer: size_t size; + + // The minimum number of bytes to read: + size_t length; }; static VALUE io_buffer_read_internal(void *_argument) { + size_t total = 0; struct io_buffer_read_internal_argument *argument = _argument; - ssize_t result = read(argument->descriptor, argument->base, argument->size); - return rb_fiber_scheduler_io_result(result, errno); + + while (true) { + ssize_t result = read(argument->descriptor, argument->base, argument->size); + + if (result < 0) { + return rb_fiber_scheduler_io_result(result, errno); + } + else if (result == 0) { + return rb_fiber_scheduler_io_result(total, 0); + } + else { + total += result; + + if (total >= argument->length) { + return rb_fiber_scheduler_io_result(total, 0); + } + + argument->base = argument->base + result; + argument->size = argument->size - result; + } + } } VALUE @@ -2475,10 +2583,11 @@ rb_io_buffer_read(VALUE self, VALUE io, size_t length, size_t offset) struct io_buffer_read_internal_argument argument = { .descriptor = descriptor, .base = base, - .size = length, + .size = size, + .length = length, }; - return rb_thread_io_blocking_region(io_buffer_read_internal, &argument, descriptor); + return io_buffer_blocking_region(buffer, io_buffer_read_internal, &argument, descriptor); } /* @@ -2508,23 +2617,12 @@ rb_io_buffer_read(VALUE self, VALUE io, size_t length, size_t offset) static VALUE io_buffer_read(int argc, VALUE *argv, VALUE self) { - rb_check_arity(argc, 2, 3); + rb_check_arity(argc, 1, 3); VALUE io = argv[0]; - if (rb_int_negative_p(argv[1])) { - rb_raise(rb_eArgError, "Length can't be negative!"); - } - size_t length = NUM2SIZET(argv[1]); - - size_t offset = 0; - if (argc >= 3) { - if (rb_int_negative_p(argv[2])) { - rb_raise(rb_eArgError, "Offset can't be negative!"); - } - - offset = NUM2SIZET(argv[2]); - } + size_t length, offset; + io_buffer_extract_arguments(self, argc-1, argv+1, &length, &offset); return rb_io_buffer_read(self, io, length, offset); } @@ -2597,7 +2695,7 @@ rb_io_buffer_pread(VALUE self, VALUE io, rb_off_t from, size_t length, size_t of .offset = from, }; - return rb_thread_io_blocking_region(io_buffer_pread_internal, &argument, descriptor); + return io_buffer_blocking_region(buffer, io_buffer_pread_internal, &argument, descriptor); } /* @@ -2629,41 +2727,55 @@ rb_io_buffer_pread(VALUE self, VALUE io, rb_off_t from, size_t length, size_t of static VALUE io_buffer_pread(int argc, VALUE *argv, VALUE self) { - rb_check_arity(argc, 3, 4); + rb_check_arity(argc, 2, 4); VALUE io = argv[0]; rb_off_t from = NUM2OFFT(argv[1]); - size_t length; - if (rb_int_negative_p(argv[2])) { - rb_raise(rb_eArgError, "Length can't be negative!"); - } - length = NUM2SIZET(argv[2]); - - size_t offset = 0; - if (argc >= 4) { - if (rb_int_negative_p(argv[3])) { - rb_raise(rb_eArgError, "Offset can't be negative!"); - } - - offset = NUM2SIZET(argv[3]); - } + size_t length, offset; + io_buffer_extract_arguments(self, argc-2, argv+2, &length, &offset); return rb_io_buffer_pread(self, io, from, length, offset); } struct io_buffer_write_internal_argument { int descriptor; - const void *base; + + // The base pointer to write from: + const char *base; + // The size of the buffer: size_t size; + + // The minimum length to write: + size_t length; }; static VALUE io_buffer_write_internal(void *_argument) { + size_t total = 0; struct io_buffer_write_internal_argument *argument = _argument; - ssize_t result = write(argument->descriptor, argument->base, argument->size); - return rb_fiber_scheduler_io_result(result, errno); + + while (true) { + ssize_t result = write(argument->descriptor, argument->base, argument->size); + + if (result < 0) { + return rb_fiber_scheduler_io_result(result, errno); + } + else if (result == 0) { + return rb_fiber_scheduler_io_result(total, 0); + } + else { + total += result; + + if (total >= argument->length) { + return rb_fiber_scheduler_io_result(total, 0); + } + + argument->base = argument->base + result; + argument->size = argument->size - result; + } + } } VALUE @@ -2694,18 +2806,22 @@ rb_io_buffer_write(VALUE self, VALUE io, size_t length, size_t offset) struct io_buffer_write_internal_argument argument = { .descriptor = descriptor, .base = base, - .size = length, + .size = size, + .length = length, }; - return rb_thread_io_blocking_region(io_buffer_write_internal, &argument, descriptor); + return io_buffer_blocking_region(buffer, io_buffer_write_internal, &argument, descriptor); } /* - * call-seq: write(io, length, [offset]) -> written length or -errno + * call-seq: write(io, [length, [offset]]) -> written length or -errno * - * Writes +length+ bytes from buffer into +io+, starting at + * Writes at least +length+ bytes from buffer into +io+, starting at * +offset+ in the buffer. If an error occurs, return -errno. * + * If +length+ is not given or nil, the whole buffer is written, minus + * the offset. If +length+ is zero, write will be called once. + * * If +offset+ is not given, the bytes are taken from the beginning * of the buffer. * @@ -2717,23 +2833,12 @@ rb_io_buffer_write(VALUE self, VALUE io, size_t length, size_t offset) static VALUE io_buffer_write(int argc, VALUE *argv, VALUE self) { - rb_check_arity(argc, 2, 3); + rb_check_arity(argc, 1, 3); VALUE io = argv[0]; - if (rb_int_negative_p(argv[1])) { - rb_raise(rb_eArgError, "Length can't be negative!"); - } - size_t length = NUM2SIZET(argv[1]); - - size_t offset = 0; - if (argc >= 3) { - if (rb_int_negative_p(argv[2])) { - rb_raise(rb_eArgError, "Offset can't be negative!"); - } - - offset = NUM2SIZET(argv[2]); - } + size_t length, offset; + io_buffer_extract_arguments(self, argc-1, argv+1, &length, &offset); return rb_io_buffer_write(self, io, length, offset); } @@ -2806,7 +2911,7 @@ rb_io_buffer_pwrite(VALUE self, VALUE io, rb_off_t from, size_t length, size_t o .offset = from, }; - return rb_thread_io_blocking_region(io_buffer_pwrite_internal, &argument, descriptor); + return io_buffer_blocking_region(buffer, io_buffer_pwrite_internal, &argument, descriptor); } /* @@ -2828,25 +2933,13 @@ rb_io_buffer_pwrite(VALUE self, VALUE io, rb_off_t from, size_t length, size_t o static VALUE io_buffer_pwrite(int argc, VALUE *argv, VALUE self) { - rb_check_arity(argc, 3, 4); + rb_check_arity(argc, 2, 4); VALUE io = argv[0]; rb_off_t from = NUM2OFFT(argv[1]); - size_t length; - if (rb_int_negative_p(argv[2])) { - rb_raise(rb_eArgError, "Length can't be negative!"); - } - length = NUM2SIZET(argv[2]); - - size_t offset = 0; - if (argc >= 4) { - if (rb_int_negative_p(argv[3])) { - rb_raise(rb_eArgError, "Offset can't be negative!"); - } - - offset = NUM2SIZET(argv[3]); - } + size_t length, offset; + io_buffer_extract_arguments(self, argc-2, argv+2, &length, &offset); return rb_io_buffer_pwrite(self, io, from, length, offset); } -- cgit v1.2.3