summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenta Murata <mrkn@mrkn.jp>2020-12-02 09:42:05 +0900
committerKenta Murata <mrkn@mrkn.jp>2020-12-02 09:44:03 +0900
commit82dc0c6aa380ea736adcd5ea54ec8f77a9269007 (patch)
tree5b7e1f3501299bc2a83c5a0459ebfd56856f71b9
parent7172272c4ca290b0b8d5bed4dd9de84eb1561303 (diff)
memory_view.c: Check availability in rb_memory_view_get
-rw-r--r--ext/-test-/memory_view/memory_view.c6
-rw-r--r--memory_view.c4
-rw-r--r--test/ruby/test_memory_view.rb13
3 files changed, 22 insertions, 1 deletions
diff --git a/ext/-test-/memory_view/memory_view.c b/ext/-test-/memory_view/memory_view.c
index 7f1f007ba4..f7c5090087 100644
--- a/ext/-test-/memory_view/memory_view.c
+++ b/ext/-test-/memory_view/memory_view.c
@@ -35,7 +35,8 @@ exportable_string_get_memory_view(VALUE obj, rb_memory_view_t *view, int flags)
static int
exportable_string_memory_view_available_p(VALUE obj)
{
- return Qtrue;
+ VALUE str = rb_ivar_get(obj, id_str);
+ return !NIL_P(str);
}
static const rb_memory_view_entry_t exportable_string_memory_view_entry = {
@@ -232,6 +233,9 @@ memory_view_ref_count_while_exporting(VALUE mod, VALUE obj, VALUE n)
static VALUE
expstr_initialize(VALUE obj, VALUE s)
{
+ if (!NIL_P(s)) {
+ Check_Type(s, T_STRING);
+ }
rb_ivar_set(obj, id_str, s);
return Qnil;
}
diff --git a/memory_view.c b/memory_view.c
index aade3a4aaf..4a4245abbf 100644
--- a/memory_view.c
+++ b/memory_view.c
@@ -592,6 +592,10 @@ rb_memory_view_get(VALUE obj, rb_memory_view_t* view, int flags)
VALUE klass = CLASS_OF(obj);
const rb_memory_view_entry_t *entry = lookup_memory_view_entry(klass);
if (entry) {
+ if (!(*entry->available_p_func)(obj)) {
+ return 0;
+ }
+
int rv = (*entry->get_func)(obj, view, flags);
if (rv) {
register_exported_object(view->obj);
diff --git a/test/ruby/test_memory_view.rb b/test/ruby/test_memory_view.rb
index 668d738974..2432f713d1 100644
--- a/test/ruby/test_memory_view.rb
+++ b/test/ruby/test_memory_view.rb
@@ -197,6 +197,13 @@ class TestMemoryView < Test::Unit::TestCase
assert_equal(expected_result, members)
end
+ def test_rb_memory_view_available_p
+ es = MemoryViewTestUtils::ExportableString.new("ruby")
+ assert_equal(true, MemoryViewTestUtils.available?(es))
+ es = MemoryViewTestUtils::ExportableString.new(nil)
+ assert_equal(false, MemoryViewTestUtils.available?(es))
+ end
+
def test_ref_count_with_exported_object
es = MemoryViewTestUtils::ExportableString.new("ruby")
assert_equal(1, MemoryViewTestUtils.ref_count_while_exporting(es, 1))
@@ -223,6 +230,12 @@ class TestMemoryView < Test::Unit::TestCase
memory_view_info)
end
+ def test_rb_memory_view_get_with_memory_view_unavailable_object
+ es = MemoryViewTestUtils::ExportableString.new(nil)
+ memory_view_info = MemoryViewTestUtils.get_memory_view_info(es)
+ assert_nil(memory_view_info)
+ end
+
def test_rb_memory_view_fill_contiguous_strides
row_major_strides = MemoryViewTestUtils.fill_contiguous_strides(3, 8, [2, 3, 4], true)
assert_equal([96, 32, 8],