summaryrefslogtreecommitdiff
path: root/lib/prism/translation/ripper.rb
diff options
context:
space:
mode:
authorKevin Newton <kddnewton@gmail.com>2024-03-13 08:47:41 -0400
committergit <svn-admin@ruby-lang.org>2024-03-13 13:52:13 +0000
commitd1eaa97ec3cdbe38605379fc87a55987d6802dc7 (patch)
treeb7efdc26e943bf04664ce4c3818a29747651247d /lib/prism/translation/ripper.rb
parent3f8ef7ff7c09e67a48eff33804060803b9f11119 (diff)
[ruby/prism] Track parentheses in patterns
https://github.com/ruby/prism/commit/62db99f156
Diffstat (limited to 'lib/prism/translation/ripper.rb')
-rw-r--r--lib/prism/translation/ripper.rb20
1 files changed, 15 insertions, 5 deletions
diff --git a/lib/prism/translation/ripper.rb b/lib/prism/translation/ripper.rb
index 94156d4988..9f269f9eb8 100644
--- a/lib/prism/translation/ripper.rb
+++ b/lib/prism/translation/ripper.rb
@@ -583,13 +583,23 @@ module Prism
# foo => bar | baz
# ^^^^^^^^^
def visit_alternation_pattern_node(node)
- left = visit(node.left)
- right = visit(node.right)
+ left = visit_pattern_node(node.left)
+ right = visit_pattern_node(node.right)
bounds(node.location)
on_binary(left, :|, right)
end
+ # Visit a pattern within a pattern match. This is used to bypass the
+ # parenthesis node that can be used to wrap patterns.
+ private def visit_pattern_node(node)
+ if node.is_a?(ParenthesesNode)
+ visit(node.body)
+ else
+ visit(node)
+ end
+ end
+
# a and b
# ^^^^^^^
def visit_and_node(node)
@@ -1952,7 +1962,7 @@ module Prism
# This is a special case where we're not going to call on_in directly
# because we don't have access to the consequent. Instead, we'll return
# the component parts and let the parent node handle it.
- pattern = visit(node.pattern)
+ pattern = visit_pattern_node(node.pattern)
statements =
if node.statements.nil?
bounds(node.location)
@@ -2389,7 +2399,7 @@ module Prism
# ^^^^^^^^^^
def visit_match_predicate_node(node)
value = visit(node.value)
- pattern = on_in(visit(node.pattern), nil, nil)
+ pattern = on_in(visit_pattern_node(node.pattern), nil, nil)
on_case(value, pattern)
end
@@ -2398,7 +2408,7 @@ module Prism
# ^^^^^^^^^^
def visit_match_required_node(node)
value = visit(node.value)
- pattern = on_in(visit(node.pattern), nil, nil)
+ pattern = on_in(visit_pattern_node(node.pattern), nil, nil)
on_case(value, pattern)
end