summaryrefslogtreecommitdiff
path: root/ext/monitor/monitor.c
blob: 43a18f58af29e24d77c1f343eafdb0deb834ebcd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#include "ruby/ruby.h"

/* Thread::Monitor */

struct rb_monitor {
    long count;
    const VALUE owner;
    const VALUE mutex;
};

static void
monitor_mark(void *ptr)
{
    struct rb_monitor *mc = ptr;
    rb_gc_mark(mc->owner);
    rb_gc_mark(mc->mutex);
}

static size_t
monitor_memsize(const void *ptr)
{
    return sizeof(struct rb_monitor);
}

static const rb_data_type_t monitor_data_type = {
    "monitor",
    {monitor_mark, RUBY_TYPED_DEFAULT_FREE, monitor_memsize,},
    0, 0, RUBY_TYPED_FREE_IMMEDIATELY|RUBY_TYPED_WB_PROTECTED
};

static VALUE
monitor_alloc(VALUE klass)
{
    struct rb_monitor *mc;
    VALUE obj;

    obj = TypedData_Make_Struct(klass, struct rb_monitor, &monitor_data_type, mc);
    RB_OBJ_WRITE(obj, &mc->mutex, rb_mutex_new());
    RB_OBJ_WRITE(obj, &mc->owner, Qnil);
    mc->count = 0;

    return obj;
}

static struct rb_monitor *
monitor_ptr(VALUE monitor)
{
    struct rb_monitor *mc;
    TypedData_Get_Struct(monitor, struct rb_monitor, &monitor_data_type, mc);
    return mc;
}

static int
mc_owner_p(struct rb_monitor *mc)
{
    return mc->owner == rb_fiber_current();
}

static VALUE
monitor_try_enter(VALUE monitor)
{
    struct rb_monitor *mc = monitor_ptr(monitor);

    if (!mc_owner_p(mc)) {
        if (!rb_mutex_trylock(mc->mutex)) {
            return Qfalse;
        }
        RB_OBJ_WRITE(monitor, &mc->owner, rb_fiber_current());
        mc->count = 0;
    }
    mc->count += 1;
    return Qtrue;
}

static VALUE
monitor_enter(VALUE monitor)
{
    struct rb_monitor *mc = monitor_ptr(monitor);
    if (!mc_owner_p(mc)) {
        rb_mutex_lock(mc->mutex);
        RB_OBJ_WRITE(monitor, &mc->owner, rb_fiber_current());
        mc->count = 0;
    }
    mc->count++;
    return Qnil;
}

static VALUE
monitor_check_owner(VALUE monitor)
{
    struct rb_monitor *mc = monitor_ptr(monitor);
    if (!mc_owner_p(mc)) {
        rb_raise(rb_eThreadError, "current fiber not owner");
    }
    return Qnil;
}

static VALUE
monitor_exit(VALUE monitor)
{
    monitor_check_owner(monitor);

    struct rb_monitor *mc = monitor_ptr(monitor);

    if (mc->count <= 0) rb_bug("monitor_exit: count:%d\n", (int)mc->count);
    mc->count--;

    if (mc->count == 0) {
        RB_OBJ_WRITE(monitor, &mc->owner, Qnil);
        rb_mutex_unlock(mc->mutex);
    }
    return Qnil;
}

static VALUE
monitor_locked_p(VALUE monitor)
{
    struct rb_monitor *mc = monitor_ptr(monitor);
    return rb_mutex_locked_p(mc->mutex);
}

static VALUE
monitor_owned_p(VALUE monitor)
{
    struct rb_monitor *mc = monitor_ptr(monitor);
    return (rb_mutex_locked_p(mc->mutex) && mc_owner_p(mc)) ? Qtrue : Qfalse;
}

static VALUE
monitor_exit_for_cond(VALUE monitor)
{
    struct rb_monitor *mc = monitor_ptr(monitor);
    long cnt = mc->count;
    RB_OBJ_WRITE(monitor, &mc->owner, Qnil);
    mc->count = 0;
    return LONG2NUM(cnt);
}

struct wait_for_cond_data {
    VALUE monitor;
    VALUE cond;
    VALUE timeout;
    VALUE count;
};

static VALUE
monitor_wait_for_cond_body(VALUE v)
{
    struct wait_for_cond_data *data = (struct wait_for_cond_data *)v;
    struct rb_monitor *mc = monitor_ptr(data->monitor);
    // cond.wait(monitor.mutex, timeout)
    rb_funcall(data->cond, rb_intern("wait"), 2, mc->mutex, data->timeout);
    return Qtrue;
}

static VALUE
monitor_enter_for_cond(VALUE v)
{
    // assert(rb_mutex_owned_p(mc->mutex) == Qtrue)
    // but rb_mutex_owned_p is not exported...

    struct wait_for_cond_data *data = (struct wait_for_cond_data *)v;
    struct rb_monitor *mc = monitor_ptr(data->monitor);
    RB_OBJ_WRITE(data->monitor, &mc->owner, rb_fiber_current());
    mc->count = NUM2LONG(data->count);
    return Qnil;
}

static VALUE
monitor_wait_for_cond(VALUE monitor, VALUE cond, VALUE timeout)
{
    VALUE count = monitor_exit_for_cond(monitor);
    struct wait_for_cond_data data = {
        monitor,
        cond,
        timeout,
        count,
    };

    return rb_ensure(monitor_wait_for_cond_body, (VALUE)&data,
                     monitor_enter_for_cond, (VALUE)&data);
}

static VALUE
monitor_sync_body(VALUE monitor)
{
    return rb_yield_values(0);
}

static VALUE
monitor_sync_ensure(VALUE monitor)
{
    return monitor_exit(monitor);
}

static VALUE
monitor_synchronize(VALUE monitor)
{
    monitor_enter(monitor);
    return rb_ensure(monitor_sync_body, monitor, monitor_sync_ensure, monitor);
}

void
Init_monitor(void)
{
#if HAVE_RB_EXT_RACTOR_SAFE
    rb_ext_ractor_safe(true);
#endif

    VALUE rb_cMonitor = rb_define_class("Monitor", rb_cObject);
    rb_define_alloc_func(rb_cMonitor, monitor_alloc);

    rb_define_method(rb_cMonitor, "try_enter", monitor_try_enter, 0);
    rb_define_method(rb_cMonitor, "enter", monitor_enter, 0);
    rb_define_method(rb_cMonitor, "exit", monitor_exit, 0);
    rb_define_method(rb_cMonitor, "synchronize", monitor_synchronize, 0);

    /* internal methods for MonitorMixin */
    rb_define_method(rb_cMonitor, "mon_locked?", monitor_locked_p, 0);
    rb_define_method(rb_cMonitor, "mon_check_owner", monitor_check_owner, 0);
    rb_define_method(rb_cMonitor, "mon_owned?", monitor_owned_p, 0);

    /* internal methods for MonitorMixin::ConditionVariable */
    rb_define_method(rb_cMonitor, "wait_for_cond", monitor_wait_for_cond, 2);
}