summaryrefslogtreecommitdiff
path: root/prism/templates/lib/prism/serialize.rb.erb
diff options
context:
space:
mode:
Diffstat (limited to 'prism/templates/lib/prism/serialize.rb.erb')
-rw-r--r--prism/templates/lib/prism/serialize.rb.erb404
1 files changed, 404 insertions, 0 deletions
diff --git a/prism/templates/lib/prism/serialize.rb.erb b/prism/templates/lib/prism/serialize.rb.erb
new file mode 100644
index 0000000000..756821cf7d
--- /dev/null
+++ b/prism/templates/lib/prism/serialize.rb.erb
@@ -0,0 +1,404 @@
+require "stringio"
+require_relative "polyfill/unpack1"
+
+module Prism
+ # A module responsible for deserializing parse results.
+ module Serialize
+ # The major version of prism that we are expecting to find in the serialized
+ # strings.
+ MAJOR_VERSION = 0
+
+ # The minor version of prism that we are expecting to find in the serialized
+ # strings.
+ MINOR_VERSION = 29
+
+ # The patch version of prism that we are expecting to find in the serialized
+ # strings.
+ PATCH_VERSION = 0
+
+ # Deserialize the AST represented by the given string into a parse result.
+ def self.load(input, serialized)
+ input = input.dup
+ source = Source.for(input)
+ loader = Loader.new(source, serialized)
+ result = loader.load_result
+
+ input.force_encoding(loader.encoding)
+ result
+ end
+
+ # Deserialize the tokens represented by the given string into a parse
+ # result.
+ def self.load_tokens(source, serialized)
+ Loader.new(source, serialized).load_tokens_result
+ end
+
+ class Loader # :nodoc:
+ if RUBY_ENGINE == "truffleruby"
+ # StringIO is synchronized and that adds a high overhead on TruffleRuby.
+ class FastStringIO # :nodoc:
+ attr_accessor :pos
+
+ def initialize(string)
+ @string = string
+ @pos = 0
+ end
+
+ def getbyte
+ byte = @string.getbyte(@pos)
+ @pos += 1
+ byte
+ end
+
+ def read(n)
+ slice = @string.byteslice(@pos, n)
+ @pos += n
+ slice
+ end
+
+ def eof?
+ @pos >= @string.bytesize
+ end
+ end
+ else
+ FastStringIO = ::StringIO
+ end
+ private_constant :FastStringIO
+
+ attr_reader :encoding, :input, :serialized, :io
+ attr_reader :constant_pool_offset, :constant_pool, :source
+ attr_reader :start_line
+
+ def initialize(source, serialized)
+ @encoding = Encoding::UTF_8
+
+ @input = source.source.dup
+ raise unless serialized.encoding == Encoding::BINARY
+ @serialized = serialized
+ @io = FastStringIO.new(serialized)
+
+ @constant_pool_offset = nil
+ @constant_pool = nil
+
+ @source = source
+ define_load_node_lambdas unless RUBY_ENGINE == "ruby"
+ end
+
+ def load_header
+ raise "Invalid serialization" if io.read(5) != "PRISM"
+ raise "Invalid serialization" if io.read(3).unpack("C3") != [MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION]
+ only_semantic_fields = io.getbyte
+ unless only_semantic_fields == 0
+ raise "Invalid serialization (location fields must be included but are not)"
+ end
+ end
+
+ def load_encoding
+ @encoding = Encoding.find(io.read(load_varuint))
+ @input = input.force_encoding(@encoding).freeze
+ @encoding
+ end
+
+ def load_start_line
+ source.instance_variable_set :@start_line, load_varsint
+ end
+
+ def load_line_offsets
+ source.instance_variable_set :@offsets, Array.new(load_varuint) { load_varuint }
+ end
+
+ def load_comments
+ Array.new(load_varuint) do
+ case load_varuint
+ when 0 then InlineComment.new(load_location_object)
+ when 1 then EmbDocComment.new(load_location_object)
+ end
+ end
+ end
+
+ DIAGNOSTIC_TYPES = [
+ <%- errors.each do |error| -%>
+ <%= error.name.downcase.to_sym.inspect %>,
+ <%- end -%>
+ <%- warnings.each do |warning| -%>
+ <%= warning.name.downcase.to_sym.inspect %>,
+ <%- end -%>
+ ].freeze
+
+ private_constant :DIAGNOSTIC_TYPES
+
+ def load_metadata
+ comments = load_comments
+ magic_comments = Array.new(load_varuint) { MagicComment.new(load_location_object, load_location_object) }
+ data_loc = load_optional_location_object
+ errors = Array.new(load_varuint) { ParseError.new(DIAGNOSTIC_TYPES[load_varuint], load_embedded_string, load_location_object, load_error_level) }
+ warnings = Array.new(load_varuint) { ParseWarning.new(DIAGNOSTIC_TYPES[load_varuint], load_embedded_string, load_location_object, load_warning_level) }
+ [comments, magic_comments, data_loc, errors, warnings]
+ end
+
+ def load_tokens
+ tokens = []
+ while type = TOKEN_TYPES.fetch(load_varuint)
+ start = load_varuint
+ length = load_varuint
+ lex_state = load_varuint
+ location = Location.new(@source, start, length)
+ tokens << [Token.new(source, type, location.slice, location), lex_state]
+ end
+
+ tokens
+ end
+
+ def load_tokens_result
+ tokens = load_tokens
+ encoding = load_encoding
+ load_start_line
+ load_line_offsets
+ comments, magic_comments, data_loc, errors, warnings = load_metadata
+ tokens.each { |token,| token.value.force_encoding(encoding) }
+
+ raise "Expected to consume all bytes while deserializing" unless @io.eof?
+ LexResult.new(tokens, comments, magic_comments, data_loc, errors, warnings, @source)
+ end
+
+ def load_nodes
+ load_header
+ load_encoding
+ load_start_line
+ load_line_offsets
+
+ comments, magic_comments, data_loc, errors, warnings = load_metadata
+
+ @constant_pool_offset = load_uint32
+ @constant_pool = Array.new(load_varuint, nil)
+
+ [load_node, comments, magic_comments, data_loc, errors, warnings]
+ end
+
+ def load_result
+ node, comments, magic_comments, data_loc, errors, warnings = load_nodes
+ ParseResult.new(node, comments, magic_comments, data_loc, errors, warnings, @source)
+ end
+
+ private
+
+ # variable-length integer using https://en.wikipedia.org/wiki/LEB128
+ # This is also what protobuf uses: https://protobuf.dev/programming-guides/encoding/#varints
+ def load_varuint
+ n = io.getbyte
+ if n < 128
+ n
+ else
+ n -= 128
+ shift = 0
+ while (b = io.getbyte) >= 128
+ n += (b - 128) << (shift += 7)
+ end
+ n + (b << (shift + 7))
+ end
+ end
+
+ def load_varsint
+ n = load_varuint
+ (n >> 1) ^ (-(n & 1))
+ end
+
+ def load_integer
+ negative = io.getbyte != 0
+ length = load_varuint
+
+ value = 0
+ length.times { |index| value |= (load_varuint << (index * 32)) }
+
+ value = -value if negative
+ value
+ end
+
+ def load_double
+ io.read(8).unpack1("D")
+ end
+
+ def load_uint32
+ io.read(4).unpack1("L")
+ end
+
+ def load_optional_node
+ if io.getbyte != 0
+ io.pos -= 1
+ load_node
+ end
+ end
+
+ def load_embedded_string
+ io.read(load_varuint).force_encoding(encoding)
+ end
+
+ def load_string
+ type = io.getbyte
+ case type
+ when 1
+ input.byteslice(load_varuint, load_varuint).force_encoding(encoding)
+ when 2
+ load_embedded_string
+ else
+ raise "Unknown serialized string type: #{type}"
+ end
+ end
+
+ def load_location
+ (load_varuint << 32) | load_varuint
+ end
+
+ def load_location_object
+ Location.new(source, load_varuint, load_varuint)
+ end
+
+ def load_optional_location
+ load_location if io.getbyte != 0
+ end
+
+ def load_optional_location_object
+ load_location_object if io.getbyte != 0
+ end
+
+ def load_constant(index)
+ constant = constant_pool[index]
+
+ unless constant
+ offset = constant_pool_offset + index * 8
+ start = @serialized.unpack1("L", offset: offset)
+ length = @serialized.unpack1("L", offset: offset + 4)
+
+ constant =
+ if start.nobits?(1 << 31)
+ input.byteslice(start, length).force_encoding(@encoding).to_sym
+ else
+ @serialized.byteslice(start & ((1 << 31) - 1), length).force_encoding(@encoding).to_sym
+ end
+
+ constant_pool[index] = constant
+ end
+
+ constant
+ end
+
+ def load_required_constant
+ load_constant(load_varuint - 1)
+ end
+
+ def load_optional_constant
+ index = load_varuint
+ load_constant(index - 1) if index != 0
+ end
+
+ def load_error_level
+ level = io.getbyte
+
+ case level
+ when 0
+ :syntax
+ when 1
+ :argument
+ when 2
+ :load
+ else
+ raise "Unknown level: #{level}"
+ end
+ end
+
+ def load_warning_level
+ level = io.getbyte
+
+ case level
+ when 0
+ :default
+ when 1
+ :verbose
+ else
+ raise "Unknown level: #{level}"
+ end
+ end
+
+ if RUBY_ENGINE == 'ruby'
+ def load_node
+ type = io.getbyte
+ location = load_location
+
+ case type
+ <%- nodes.each_with_index do |node, index| -%>
+ when <%= index + 1 %> then
+ <%- if node.needs_serialized_length? -%>
+ load_uint32
+ <%- end -%>
+ <%= node.name %>.new(
+ source, <%= (node.fields.map { |field|
+ case field
+ when Prism::Template::NodeField then "load_node"
+ when Prism::Template::OptionalNodeField then "load_optional_node"
+ when Prism::Template::StringField then "load_string"
+ when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node }"
+ when Prism::Template::ConstantField then "load_required_constant"
+ when Prism::Template::OptionalConstantField then "load_optional_constant"
+ when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_required_constant }"
+ when Prism::Template::LocationField then "load_location"
+ when Prism::Template::OptionalLocationField then "load_optional_location"
+ when Prism::Template::UInt8Field then "io.getbyte"
+ when Prism::Template::UInt32Field, Prism::Template::FlagsField then "load_varuint"
+ when Prism::Template::IntegerField then "load_integer"
+ when Prism::Template::DoubleField then "load_double"
+ else raise
+ end
+ } + ["location"]).join(", ") -%>)
+ <%- end -%>
+ end
+ end
+ else
+ def load_node
+ type = io.getbyte
+ @load_node_lambdas[type].call
+ end
+
+ def define_load_node_lambdas
+ @load_node_lambdas = [
+ nil,
+ <%- nodes.each do |node| -%>
+ -> {
+ location = load_location
+ <%- if node.needs_serialized_length? -%>
+ load_uint32
+ <%- end -%>
+ <%= node.name %>.new(
+ source, <%= (node.fields.map { |field|
+ case field
+ when Prism::Template::NodeField then "load_node"
+ when Prism::Template::OptionalNodeField then "load_optional_node"
+ when Prism::Template::StringField then "load_string"
+ when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node }"
+ when Prism::Template::ConstantField then "load_required_constant"
+ when Prism::Template::OptionalConstantField then "load_optional_constant"
+ when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_required_constant }"
+ when Prism::Template::LocationField then "load_location"
+ when Prism::Template::OptionalLocationField then "load_optional_location"
+ when Prism::Template::UInt8Field then "io.getbyte"
+ when Prism::Template::UInt32Field, Prism::Template::FlagsField then "load_varuint"
+ when Prism::Template::IntegerField then "load_integer"
+ when Prism::Template::DoubleField then "load_double"
+ else raise
+ end
+ } + ["location"]).join(", ") -%>)
+ },
+ <%- end -%>
+ ]
+ end
+ end
+ end
+
+ # The token types that can be indexed by their enum values.
+ TOKEN_TYPES = [
+ nil,
+ <%- tokens.each do |token| -%>
+ <%= token.name.to_sym.inspect %>,
+ <%- end -%>
+ ]
+ end
+end