summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc-Andre Lafortune <github@marc-andre.ca>2021-01-12 17:20:50 -0500
committerMarc-Andre Lafortune <github@marc-andre.ca>2021-01-12 23:29:39 -0500
commitf48edc28dda3df962f289fd373c06a8dfeda3dc0 (patch)
tree2ea822f834a7e2cfa34d40bb3f33a135029f7420
parent75212f2fc6571bd9cab0381fbd0bde81e1b3159c (diff)
Fix method protection for modules in the ancestry chain.
[Fixes ruby/ostruct#23]
-rw-r--r--lib/ostruct.rb10
-rw-r--r--test/ostruct/test_ostruct.rb34
2 files changed, 43 insertions, 1 deletions
diff --git a/lib/ostruct.rb b/lib/ostruct.rb
index e255a5e704..e00e281d6d 100644
--- a/lib/ostruct.rb
+++ b/lib/ostruct.rb
@@ -223,7 +223,15 @@ class OpenStruct
elsif name.end_with?('!')
true
else
- method!(name).owner < OpenStruct
+ owner = method!(name).owner
+ if owner.class == ::Class
+ owner < ::OpenStruct
+ else
+ self.class.ancestors.any? do |mod|
+ return false if mod == ::OpenStruct
+ mod == owner
+ end
+ end
end
end
diff --git a/test/ostruct/test_ostruct.rb b/test/ostruct/test_ostruct.rb
index 30ea3d571c..d41eae9332 100644
--- a/test/ostruct/test_ostruct.rb
+++ b/test/ostruct/test_ostruct.rb
@@ -290,6 +290,40 @@ class TC_OpenStruct < Test::Unit::TestCase
assert_equal('hello', o.to_s)
end
+ def test_override_submodule
+ m = Module.new {
+ def foo; :protect_me; end
+ private def bar; :protect_me; end
+ def inspect; 'protect me'; end
+ }
+ m2 = Module.new {
+ def added_to_all_open_struct; :override_me; end
+ }
+ OpenStruct.class_eval do
+ include m2
+ # prepend case tbd
+ def added_to_all_open_struct_2; :override_me; end
+ end
+ c = Class.new(OpenStruct) { include m }
+ o = c.new(
+ foo: 1, bar: 2, inspect: '3', # in subclass: protected
+ table!: 4, # bang method: protected
+ each_pair: 5, to_s: 'hello', # others: not protected
+ # including those added by the user:
+ added_to_all_open_struct: 6, added_to_all_open_struct_2: 7,
+ )
+ # protected:
+ assert_equal(:protect_me, o.foo)
+ assert_equal(:protect_me, o.send(:bar))
+ assert_equal('protect me', o.inspect)
+ assert_not_equal(4, o.send(:table!))
+ # not protected:
+ assert_equal(5, o.each_pair)
+ assert_equal('hello', o.to_s)
+ assert_equal(6, o.added_to_all_open_struct)
+ assert_equal(7, o.added_to_all_open_struct_2)
+ end
+
def test_mistaken_subclass
sub = Class.new(OpenStruct) do
def [](k)