From 4c53dc970bf82e4c5fb237be4b2404bcb07496d2 Mon Sep 17 00:00:00 2001 From: Samuel Williams Date: Thu, 11 Feb 2021 19:17:54 +1300 Subject: Add hook for `Timeout.timeout`. --- include/ruby/fiber/scheduler.h | 2 ++ lib/timeout.rb | 8 +++++++- scheduler.c | 8 ++++++++ test/fiber/scheduler.rb | 28 ++++++++++++++++++++++++---- test/fiber/test_timeout.rb | 30 ++++++++++++++++++++++++++++++ 5 files changed, 71 insertions(+), 5 deletions(-) create mode 100644 test/fiber/test_timeout.rb diff --git a/include/ruby/fiber/scheduler.h b/include/ruby/fiber/scheduler.h index 36d1726b19..213a8b8a1f 100644 --- a/include/ruby/fiber/scheduler.h +++ b/include/ruby/fiber/scheduler.h @@ -22,6 +22,8 @@ VALUE rb_fiber_scheduler_make_timeout(struct timeval *timeout); VALUE rb_fiber_scheduler_close(VALUE scheduler); +VALUE rb_fiber_scheduler_timeout_raise(VALUE scheduler, VALUE duration); + VALUE rb_fiber_scheduler_kernel_sleep(VALUE scheduler, VALUE duration); VALUE rb_fiber_scheduler_kernel_sleepv(VALUE scheduler, int argc, VALUE * argv); diff --git a/lib/timeout.rb b/lib/timeout.rb index 9026ad51d6..43f26f5869 100644 --- a/lib/timeout.rb +++ b/lib/timeout.rb @@ -76,9 +76,15 @@ module Timeout # Note that this is both a method of module Timeout, so you can include # Timeout into your classes so they have a #timeout method, as well as # a module method, so you can call it directly as Timeout.timeout(). - def timeout(sec, klass = nil, message = nil) #:yield: +sec+ + def timeout(sec, klass = nil, message = nil, &block) #:yield: +sec+ return yield(sec) if sec == nil or sec.zero? + message ||= "execution expired".freeze + + if scheduler = Fiber.scheduler and scheduler.respond_to?(:timeout_raise) + return scheduler.timeout_raise(sec, klass || Error, message, &block) + end + from = "from #{caller_locations(1, 1)[0]}" if $DEBUG e = Error bl = proc do |exception| diff --git a/scheduler.c b/scheduler.c index f2b1b00fa1..3403eb1801 100644 --- a/scheduler.c +++ b/scheduler.c @@ -17,6 +17,7 @@ static ID id_close; static ID id_block; static ID id_unblock; +static ID id_timeout_raise; static ID id_kernel_sleep; static ID id_process_wait; @@ -32,6 +33,7 @@ Init_Fiber_Scheduler(void) id_block = rb_intern_const("block"); id_unblock = rb_intern_const("unblock"); + id_timeout_raise = rb_intern_const("timeout_raise"); id_kernel_sleep = rb_intern_const("kernel_sleep"); id_process_wait = rb_intern_const("process_wait"); @@ -108,6 +110,12 @@ rb_fiber_scheduler_make_timeout(struct timeval *timeout) return Qnil; } +VALUE +rb_fiber_scheduler_timeout_raise(VALUE scheduler, VALUE timeout) +{ + return rb_funcall(scheduler, id_timeout_raise, 1, timeout); +} + VALUE rb_fiber_scheduler_kernel_sleep(VALUE scheduler, VALUE timeout) { diff --git a/test/fiber/scheduler.rb b/test/fiber/scheduler.rb index f2fb304e19..8a8585fcbe 100644 --- a/test/fiber/scheduler.rb +++ b/test/fiber/scheduler.rb @@ -81,10 +81,12 @@ class Scheduler waiting, @waiting = @waiting, {} waiting.each do |fiber, timeout| - if timeout <= time - fiber.resume - else - @waiting[fiber] = timeout + if fiber.alive? + if timeout <= time + fiber.resume + else + @waiting[fiber] = timeout + end end end end @@ -127,6 +129,24 @@ class Scheduler Process.clock_gettime(Process::CLOCK_MONOTONIC) end + def timeout_raise(duration, klass, message, &block) + fiber = Fiber.current + + self.fiber do + sleep(duration) + + if fiber&.alive? + fiber.raise(klass, message) + end + end + + begin + yield(duration) + ensure + fiber = nil + end + end + def process_wait(pid, flags) # $stderr.puts [__method__, pid, flags, Fiber.current].inspect diff --git a/test/fiber/test_timeout.rb b/test/fiber/test_timeout.rb new file mode 100644 index 0000000000..b974aa0e35 --- /dev/null +++ b/test/fiber/test_timeout.rb @@ -0,0 +1,30 @@ +# frozen_string_literal: true +require 'test/unit' +require_relative 'scheduler' + +require 'timeout' + +class TestFiberTimeout < Test::Unit::TestCase + def test_timeout_raise + error = nil + + thread = Thread.new do + scheduler = Scheduler.new + Fiber.set_scheduler scheduler + + Fiber.schedule do + begin + Timeout.timeout(0.01) do + sleep(1) + end + rescue + error = $! + end + end + end + + thread.join + + assert_kind_of(Timeout::Error, error) + end +end -- cgit v1.2.3