diff options
Diffstat (limited to 'prism/templates/lib/prism/serialize.rb.erb')
-rw-r--r-- | prism/templates/lib/prism/serialize.rb.erb | 404 |
1 files changed, 404 insertions, 0 deletions
diff --git a/prism/templates/lib/prism/serialize.rb.erb b/prism/templates/lib/prism/serialize.rb.erb new file mode 100644 index 0000000000..756821cf7d --- /dev/null +++ b/prism/templates/lib/prism/serialize.rb.erb @@ -0,0 +1,404 @@ +require "stringio" +require_relative "polyfill/unpack1" + +module Prism + # A module responsible for deserializing parse results. + module Serialize + # The major version of prism that we are expecting to find in the serialized + # strings. + MAJOR_VERSION = 0 + + # The minor version of prism that we are expecting to find in the serialized + # strings. + MINOR_VERSION = 29 + + # The patch version of prism that we are expecting to find in the serialized + # strings. + PATCH_VERSION = 0 + + # Deserialize the AST represented by the given string into a parse result. + def self.load(input, serialized) + input = input.dup + source = Source.for(input) + loader = Loader.new(source, serialized) + result = loader.load_result + + input.force_encoding(loader.encoding) + result + end + + # Deserialize the tokens represented by the given string into a parse + # result. + def self.load_tokens(source, serialized) + Loader.new(source, serialized).load_tokens_result + end + + class Loader # :nodoc: + if RUBY_ENGINE == "truffleruby" + # StringIO is synchronized and that adds a high overhead on TruffleRuby. + class FastStringIO # :nodoc: + attr_accessor :pos + + def initialize(string) + @string = string + @pos = 0 + end + + def getbyte + byte = @string.getbyte(@pos) + @pos += 1 + byte + end + + def read(n) + slice = @string.byteslice(@pos, n) + @pos += n + slice + end + + def eof? + @pos >= @string.bytesize + end + end + else + FastStringIO = ::StringIO + end + private_constant :FastStringIO + + attr_reader :encoding, :input, :serialized, :io + attr_reader :constant_pool_offset, :constant_pool, :source + attr_reader :start_line + + def initialize(source, serialized) + @encoding = Encoding::UTF_8 + + @input = source.source.dup + raise unless serialized.encoding == Encoding::BINARY + @serialized = serialized + @io = FastStringIO.new(serialized) + + @constant_pool_offset = nil + @constant_pool = nil + + @source = source + define_load_node_lambdas unless RUBY_ENGINE == "ruby" + end + + def load_header + raise "Invalid serialization" if io.read(5) != "PRISM" + raise "Invalid serialization" if io.read(3).unpack("C3") != [MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION] + only_semantic_fields = io.getbyte + unless only_semantic_fields == 0 + raise "Invalid serialization (location fields must be included but are not)" + end + end + + def load_encoding + @encoding = Encoding.find(io.read(load_varuint)) + @input = input.force_encoding(@encoding).freeze + @encoding + end + + def load_start_line + source.instance_variable_set :@start_line, load_varsint + end + + def load_line_offsets + source.instance_variable_set :@offsets, Array.new(load_varuint) { load_varuint } + end + + def load_comments + Array.new(load_varuint) do + case load_varuint + when 0 then InlineComment.new(load_location_object) + when 1 then EmbDocComment.new(load_location_object) + end + end + end + + DIAGNOSTIC_TYPES = [ + <%- errors.each do |error| -%> + <%= error.name.downcase.to_sym.inspect %>, + <%- end -%> + <%- warnings.each do |warning| -%> + <%= warning.name.downcase.to_sym.inspect %>, + <%- end -%> + ].freeze + + private_constant :DIAGNOSTIC_TYPES + + def load_metadata + comments = load_comments + magic_comments = Array.new(load_varuint) { MagicComment.new(load_location_object, load_location_object) } + data_loc = load_optional_location_object + errors = Array.new(load_varuint) { ParseError.new(DIAGNOSTIC_TYPES[load_varuint], load_embedded_string, load_location_object, load_error_level) } + warnings = Array.new(load_varuint) { ParseWarning.new(DIAGNOSTIC_TYPES[load_varuint], load_embedded_string, load_location_object, load_warning_level) } + [comments, magic_comments, data_loc, errors, warnings] + end + + def load_tokens + tokens = [] + while type = TOKEN_TYPES.fetch(load_varuint) + start = load_varuint + length = load_varuint + lex_state = load_varuint + location = Location.new(@source, start, length) + tokens << [Token.new(source, type, location.slice, location), lex_state] + end + + tokens + end + + def load_tokens_result + tokens = load_tokens + encoding = load_encoding + load_start_line + load_line_offsets + comments, magic_comments, data_loc, errors, warnings = load_metadata + tokens.each { |token,| token.value.force_encoding(encoding) } + + raise "Expected to consume all bytes while deserializing" unless @io.eof? + LexResult.new(tokens, comments, magic_comments, data_loc, errors, warnings, @source) + end + + def load_nodes + load_header + load_encoding + load_start_line + load_line_offsets + + comments, magic_comments, data_loc, errors, warnings = load_metadata + + @constant_pool_offset = load_uint32 + @constant_pool = Array.new(load_varuint, nil) + + [load_node, comments, magic_comments, data_loc, errors, warnings] + end + + def load_result + node, comments, magic_comments, data_loc, errors, warnings = load_nodes + ParseResult.new(node, comments, magic_comments, data_loc, errors, warnings, @source) + end + + private + + # variable-length integer using https://en.wikipedia.org/wiki/LEB128 + # This is also what protobuf uses: https://protobuf.dev/programming-guides/encoding/#varints + def load_varuint + n = io.getbyte + if n < 128 + n + else + n -= 128 + shift = 0 + while (b = io.getbyte) >= 128 + n += (b - 128) << (shift += 7) + end + n + (b << (shift + 7)) + end + end + + def load_varsint + n = load_varuint + (n >> 1) ^ (-(n & 1)) + end + + def load_integer + negative = io.getbyte != 0 + length = load_varuint + + value = 0 + length.times { |index| value |= (load_varuint << (index * 32)) } + + value = -value if negative + value + end + + def load_double + io.read(8).unpack1("D") + end + + def load_uint32 + io.read(4).unpack1("L") + end + + def load_optional_node + if io.getbyte != 0 + io.pos -= 1 + load_node + end + end + + def load_embedded_string + io.read(load_varuint).force_encoding(encoding) + end + + def load_string + type = io.getbyte + case type + when 1 + input.byteslice(load_varuint, load_varuint).force_encoding(encoding) + when 2 + load_embedded_string + else + raise "Unknown serialized string type: #{type}" + end + end + + def load_location + (load_varuint << 32) | load_varuint + end + + def load_location_object + Location.new(source, load_varuint, load_varuint) + end + + def load_optional_location + load_location if io.getbyte != 0 + end + + def load_optional_location_object + load_location_object if io.getbyte != 0 + end + + def load_constant(index) + constant = constant_pool[index] + + unless constant + offset = constant_pool_offset + index * 8 + start = @serialized.unpack1("L", offset: offset) + length = @serialized.unpack1("L", offset: offset + 4) + + constant = + if start.nobits?(1 << 31) + input.byteslice(start, length).force_encoding(@encoding).to_sym + else + @serialized.byteslice(start & ((1 << 31) - 1), length).force_encoding(@encoding).to_sym + end + + constant_pool[index] = constant + end + + constant + end + + def load_required_constant + load_constant(load_varuint - 1) + end + + def load_optional_constant + index = load_varuint + load_constant(index - 1) if index != 0 + end + + def load_error_level + level = io.getbyte + + case level + when 0 + :syntax + when 1 + :argument + when 2 + :load + else + raise "Unknown level: #{level}" + end + end + + def load_warning_level + level = io.getbyte + + case level + when 0 + :default + when 1 + :verbose + else + raise "Unknown level: #{level}" + end + end + + if RUBY_ENGINE == 'ruby' + def load_node + type = io.getbyte + location = load_location + + case type + <%- nodes.each_with_index do |node, index| -%> + when <%= index + 1 %> then + <%- if node.needs_serialized_length? -%> + load_uint32 + <%- end -%> + <%= node.name %>.new( + source, <%= (node.fields.map { |field| + case field + when Prism::Template::NodeField then "load_node" + when Prism::Template::OptionalNodeField then "load_optional_node" + when Prism::Template::StringField then "load_string" + when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node }" + when Prism::Template::ConstantField then "load_required_constant" + when Prism::Template::OptionalConstantField then "load_optional_constant" + when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_required_constant }" + when Prism::Template::LocationField then "load_location" + when Prism::Template::OptionalLocationField then "load_optional_location" + when Prism::Template::UInt8Field then "io.getbyte" + when Prism::Template::UInt32Field, Prism::Template::FlagsField then "load_varuint" + when Prism::Template::IntegerField then "load_integer" + when Prism::Template::DoubleField then "load_double" + else raise + end + } + ["location"]).join(", ") -%>) + <%- end -%> + end + end + else + def load_node + type = io.getbyte + @load_node_lambdas[type].call + end + + def define_load_node_lambdas + @load_node_lambdas = [ + nil, + <%- nodes.each do |node| -%> + -> { + location = load_location + <%- if node.needs_serialized_length? -%> + load_uint32 + <%- end -%> + <%= node.name %>.new( + source, <%= (node.fields.map { |field| + case field + when Prism::Template::NodeField then "load_node" + when Prism::Template::OptionalNodeField then "load_optional_node" + when Prism::Template::StringField then "load_string" + when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node }" + when Prism::Template::ConstantField then "load_required_constant" + when Prism::Template::OptionalConstantField then "load_optional_constant" + when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_required_constant }" + when Prism::Template::LocationField then "load_location" + when Prism::Template::OptionalLocationField then "load_optional_location" + when Prism::Template::UInt8Field then "io.getbyte" + when Prism::Template::UInt32Field, Prism::Template::FlagsField then "load_varuint" + when Prism::Template::IntegerField then "load_integer" + when Prism::Template::DoubleField then "load_double" + else raise + end + } + ["location"]).join(", ") -%>) + }, + <%- end -%> + ] + end + end + end + + # The token types that can be indexed by their enum values. + TOKEN_TYPES = [ + nil, + <%- tokens.each do |token| -%> + <%= token.name.to_sym.inspect %>, + <%- end -%> + ] + end +end |