summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPhilip Mueller <mail@philip.in-aachen.net>2024-02-21 12:30:22 -0500
committergit <svn-admin@ruby-lang.org>2024-04-23 13:53:23 +0000
commitf7d1699f6714d8fe14ed92272584f68a79995e64 (patch)
tree0adbaa5b3b6f1c4e8db9f6fbeadf2b706e7a93c9
parent87b829aa942089c7614470184f02aedec9d72ec9 (diff)
[ruby/prism] Implement case equality on nodes
https://github.com/ruby/prism/commit/dc121e4fdf
-rw-r--r--prism/templates/lib/prism/node.rb.erb23
-rw-r--r--test/prism/ruby_api_test.rb15
2 files changed, 34 insertions, 4 deletions
diff --git a/prism/templates/lib/prism/node.rb.erb b/prism/templates/lib/prism/node.rb.erb
index 6b5a285315..f869a841c5 100644
--- a/prism/templates/lib/prism/node.rb.erb
+++ b/prism/templates/lib/prism/node.rb.erb
@@ -219,10 +219,10 @@ module Prism
def deconstruct_keys(keys)
{ <%= (node.fields.map { |field| "#{field.name}: #{field.name}" } + ["location: location"]).join(", ") %> }
end
-
<%- node.fields.each do |field| -%>
+
<%- if field.comment.nil? -%>
- # <%= "private " if field.is_a?(Prism::Template::FlagsField) %>attr_reader <%= field.name %>: <%= field.rbs_class %>
+ # <%= "protected " if field.is_a?(Prism::Template::FlagsField) %>attr_reader <%= field.name %>: <%= field.rbs_class %>
<%- else -%>
<%- field.each_comment_line do |line| -%>
#<%= line %>
@@ -248,9 +248,8 @@ module Prism
end
end
<%- else -%>
- attr_reader :<%= field.name -%><%= "\n private :#{field.name}" if field.is_a?(Prism::Template::FlagsField) %>
+ attr_reader :<%= field.name -%><%= "\n protected :#{field.name}" if field.is_a?(Prism::Template::FlagsField) %>
<%- end -%>
-
<%- end -%>
<%- node.fields.each do |field| -%>
<%- case field -%>
@@ -349,6 +348,22 @@ module Prism
def self.type
:<%= node.human %>
end
+
+ # Implements case-equality for the node. This is effectively == but without
+ # comparing the value of locations. Locations are checked only for presence.
+ def ===(other)
+ other.is_a?(<%= node.name %>)<%= " &&" if node.fields.any? %>
+ <%- node.fields.each_with_index do |field, index| -%>
+ <%- if field.is_a?(Prism::Template::LocationField) || field.is_a?(Prism::Template::OptionalLocationField) -%>
+ (<%= field.name %>.nil? == other.<%= field.name %>.nil?)<%= " &&" if index != node.fields.length - 1 %>
+ <%- elsif field.is_a?(Prism::Template::NodeListField) || field.is_a?(Prism::Template::ConstantListField) -%>
+ (<%= field.name %>.length == other.<%= field.name %>.length) &&
+ <%= field.name %>.zip(other.<%= field.name %>).all? { |left, right| left === right }<%= " &&" if index != node.fields.length - 1 %>
+ <%- else -%>
+ (<%= field.name %> === other.<%= field.name %>)<%= " &&" if index != node.fields.length - 1 %>
+ <%- end -%>
+ <%- end -%>
+ end
end
<%- end -%>
<%- flags.each_with_index do |flag, flag_index| -%>
diff --git a/test/prism/ruby_api_test.rb b/test/prism/ruby_api_test.rb
index 6418887147..bf493666d2 100644
--- a/test/prism/ruby_api_test.rb
+++ b/test/prism/ruby_api_test.rb
@@ -244,6 +244,21 @@ module Prism
assert_equal 16, base[parse_expression("0x1")]
end
+ def test_node_equality
+ assert_operator parse_expression("1"), :===, parse_expression("1")
+ assert_operator Prism.parse("1").value, :===, Prism.parse("1").value
+
+ complex_source = "class Something; @var = something.else { _1 }; end"
+ assert_operator parse_expression(complex_source), :===, parse_expression(complex_source)
+
+ refute_operator parse_expression("1"), :===, parse_expression("2")
+ refute_operator parse_expression("1"), :===, parse_expression("0x1")
+
+ complex_source_1 = "class Something; @var = something.else { _1 }; end"
+ complex_source_2 = "class Something; @var = something.else { _2 }; end"
+ refute_operator parse_expression(complex_source_1), :===, parse_expression(complex_source_2)
+ end
+
private
def parse_expression(source)