diff options
-rw-r--r-- | struct.c | 59 | ||||
-rw-r--r-- | test/ruby/test_data.rb | 59 |
2 files changed, 118 insertions, 0 deletions
@@ -1834,6 +1834,63 @@ rb_data_init_copy(VALUE copy, VALUE s) /* * call-seq: + * with(**kwargs) -> instance + * + * Returns a shallow copy of +self+ --- the instance variables of + * +self+ are copied, but not the objects they reference. + * + * If the method is supplied any keyword arguments, the copy will + * be created with the respective field values updated to use the + * supplied keyword argument values. Note that it is an error to + * supply a keyword that the Data class does not have as a member. + * + * Point = Data.define(:x, :y) + * + * origin = Point.new(x: 0, y: 0) + * + * up = origin.with(x: 1) + * right = origin.with(y: 1) + * up_and_right = up.with(y: 1) + * + * p origin # #<data Point x=0, y=0> + * p up # #<data Point x=1, y=0> + * p right # #<data Point x=0, y=1> + * p up_and_right # #<data Point x=1, y=1> + * + * out = origin.with(z: 1) # ArgumentError: unknown keyword: :z + * some_point = origin.with(1, 2) # ArgumentError: expected keyword arguments, got positional arguments + * + */ + +static VALUE +rb_data_with(int argc, const VALUE *argv, VALUE self) +{ + VALUE kwargs; + rb_scan_args(argc, argv, "0:", &kwargs); + if (NIL_P(kwargs)) { + return self; + } + + VALUE copy = rb_obj_alloc(rb_obj_class(self)); + rb_struct_init_copy(copy, self); + + struct struct_hash_set_arg arg; + arg.self = copy; + arg.unknown_keywords = Qnil; + rb_hash_foreach(kwargs, struct_hash_set_i, (VALUE)&arg); + // Freeze early before potentially raising, so that we don't leave an + // unfrozen copy on the heap, which could get exposed via ObjectSpace. + RB_OBJ_FREEZE_RAW(copy); + + if (arg.unknown_keywords != Qnil) { + rb_exc_raise(rb_keyword_error_new("unknown", arg.unknown_keywords)); + } + + return copy; +} + +/* + * call-seq: * inspect -> string * to_s -> string * @@ -2205,6 +2262,8 @@ InitVM_Struct(void) rb_define_method(rb_cData, "deconstruct", rb_data_deconstruct, 0); rb_define_method(rb_cData, "deconstruct_keys", rb_data_deconstruct_keys, 1); + + rb_define_method(rb_cData, "with", rb_data_with, -1); } #undef rb_intern diff --git a/test/ruby/test_data.rb b/test/ruby/test_data.rb index 4d28da6061..3cafb365ed 100644 --- a/test/ruby/test_data.rb +++ b/test/ruby/test_data.rb @@ -158,6 +158,65 @@ class TestData < Test::Unit::TestCase assert_not_operator(o1, :eql?, o3) end + def test_with + klass = Data.define(:foo, :bar) + source = klass.new(foo: 1, bar: 2) + + # Simple + test = source.with + assert_equal(source.object_id, test.object_id) + + # Changes + test = source.with(foo: 10) + + assert_equal(1, source.foo) + assert_equal(2, source.bar) + assert_equal(source, klass.new(foo: 1, bar: 2)) + + assert_equal(10, test.foo) + assert_equal(2, test.bar) + assert_equal(test, klass.new(foo: 10, bar: 2)) + + test = source.with(foo: 10, bar: 20) + + assert_equal(1, source.foo) + assert_equal(2, source.bar) + assert_equal(source, klass.new(foo: 1, bar: 2)) + + assert_equal(10, test.foo) + assert_equal(20, test.bar) + assert_equal(test, klass.new(foo: 10, bar: 20)) + + # Keyword splat + changes = { foo: 10, bar: 20 } + test = source.with(**changes) + + assert_equal(1, source.foo) + assert_equal(2, source.bar) + assert_equal(source, klass.new(foo: 1, bar: 2)) + + assert_equal(10, test.foo) + assert_equal(20, test.bar) + assert_equal(test, klass.new(foo: 10, bar: 20)) + + # Wrong protocol + assert_raise_with_message(ArgumentError, "wrong number of arguments (given 1, expected 0)") do + source.with(10) + end + assert_raise_with_message(ArgumentError, "unknown keywords: :baz, :quux") do + source.with(foo: 1, bar: 2, baz: 3, quux: 4) + end + assert_raise_with_message(ArgumentError, "wrong number of arguments (given 1, expected 0)") do + source.with(1, bar: 2) + end + assert_raise_with_message(ArgumentError, "wrong number of arguments (given 2, expected 0)") do + source.with(1, 2) + end + assert_raise_with_message(ArgumentError, "wrong number of arguments (given 1, expected 0)") do + source.with({ bar: 2 }) + end + end + def test_memberless klass = Data.define |