summaryrefslogtreecommitdiff
path: root/prism/templates/lib/prism/serialize.rb.erb
blob: 36d5d5432df696a4b60a6c09c5bd59e0736e5c47 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
require "stringio"
require_relative "polyfill/string"

module Prism
  # A module responsible for deserializing parse results.
  module Serialize
    # The major version of prism that we are expecting to find in the serialized
    # strings.
    MAJOR_VERSION = 0

    # The minor version of prism that we are expecting to find in the serialized
    # strings.
    MINOR_VERSION = 25

    # The patch version of prism that we are expecting to find in the serialized
    # strings.
    PATCH_VERSION = 0

    # Deserialize the AST represented by the given string into a parse result.
    def self.load(input, serialized)
      input = input.dup
      source = Source.new(input)
      loader = Loader.new(source, serialized)
      result = loader.load_result

      input.force_encoding(loader.encoding)
      result
    end

    # Deserialize the tokens represented by the given string into a parse
    # result.
    def self.load_tokens(source, serialized)
      Loader.new(source, serialized).load_tokens_result
    end

    class Loader # :nodoc:
      if RUBY_ENGINE == "truffleruby"
        # StringIO is synchronized and that adds a high overhead on TruffleRuby.
        class FastStringIO # :nodoc:
          attr_accessor :pos

          def initialize(string)
            @string = string
            @pos = 0
          end

          def getbyte
            byte = @string.getbyte(@pos)
            @pos += 1
            byte
          end

          def read(n)
            slice = @string.byteslice(@pos, n)
            @pos += n
            slice
          end

          def eof?
            @pos >= @string.bytesize
          end
        end
      else
        FastStringIO = ::StringIO
      end
      private_constant :FastStringIO

      attr_reader :encoding, :input, :serialized, :io
      attr_reader :constant_pool_offset, :constant_pool, :source
      attr_reader :start_line

      def initialize(source, serialized)
        @encoding = Encoding::UTF_8

        @input = source.source.dup
        raise unless serialized.encoding == Encoding::BINARY
        @serialized = serialized
        @io = FastStringIO.new(serialized)

        @constant_pool_offset = nil
        @constant_pool = nil

        @source = source
        define_load_node_lambdas unless RUBY_ENGINE == "ruby"
      end

      def load_header
        raise "Invalid serialization" if io.read(5) != "PRISM"
        raise "Invalid serialization" if io.read(3).unpack("C3") != [MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION]
        only_semantic_fields = io.getbyte
        unless only_semantic_fields == 0
          raise "Invalid serialization (location fields must be included but are not)"
        end
      end

      def load_encoding
        @encoding = Encoding.find(io.read(load_varuint))
        @input = input.force_encoding(@encoding).freeze
        @encoding
      end

      def load_start_line
        source.instance_variable_set :@start_line, load_varsint
      end

      def load_line_offsets
        source.instance_variable_set :@offsets, Array.new(load_varuint) { load_varuint }
      end

      def load_comments
        Array.new(load_varuint) do
          case load_varuint
          when 0 then InlineComment.new(load_location_object)
          when 1 then EmbDocComment.new(load_location_object)
          end
        end
      end

      DIAGNOSTIC_TYPES = [
        <%- errors.each do |error| -%>
        <%= error.name.downcase.to_sym.inspect %>,
        <%- end -%>
        <%- warnings.each do |warning| -%>
        <%= warning.name.downcase.to_sym.inspect %>,
        <%- end -%>
      ].freeze

      private_constant :DIAGNOSTIC_TYPES

      def load_metadata
        comments = load_comments
        magic_comments = Array.new(load_varuint) { MagicComment.new(load_location_object, load_location_object) }
        data_loc = load_optional_location_object
        errors = Array.new(load_varuint) { ParseError.new(DIAGNOSTIC_TYPES[load_varuint], load_embedded_string, load_location_object, load_error_level) }
        warnings = Array.new(load_varuint) { ParseWarning.new(DIAGNOSTIC_TYPES[load_varuint], load_embedded_string, load_location_object, load_warning_level) }
        [comments, magic_comments, data_loc, errors, warnings]
      end

      def load_tokens
        tokens = []
        while type = TOKEN_TYPES.fetch(load_varuint)
          start = load_varuint
          length = load_varuint
          lex_state = load_varuint
          location = Location.new(@source, start, length)
          tokens << [Prism::Token.new(source, type, location.slice, location), lex_state]
        end

        tokens
      end

      def load_tokens_result
        tokens = load_tokens
        encoding = load_encoding
        load_start_line
        load_line_offsets
        comments, magic_comments, data_loc, errors, warnings = load_metadata
        tokens.each { |token,| token.value.force_encoding(encoding) }

        raise "Expected to consume all bytes while deserializing" unless @io.eof?
        Prism::ParseResult.new(tokens, comments, magic_comments, data_loc, errors, warnings, @source)
      end

      def load_nodes
        load_header
        load_encoding
        load_start_line
        load_line_offsets

        comments, magic_comments, data_loc, errors, warnings = load_metadata

        @constant_pool_offset = load_uint32
        @constant_pool = Array.new(load_varuint, nil)

        [load_node, comments, magic_comments, data_loc, errors, warnings]
      end

      def load_result
        node, comments, magic_comments, data_loc, errors, warnings = load_nodes
        Prism::ParseResult.new(node, comments, magic_comments, data_loc, errors, warnings, @source)
      end

      private

      # variable-length integer using https://en.wikipedia.org/wiki/LEB128
      # This is also what protobuf uses: https://protobuf.dev/programming-guides/encoding/#varints
      def load_varuint
        n = io.getbyte
        if n < 128
          n
        else
          n -= 128
          shift = 0
          while (b = io.getbyte) >= 128
            n += (b - 128) << (shift += 7)
          end
          n + (b << (shift + 7))
        end
      end

      def load_varsint
        n = load_varuint
        (n >> 1) ^ (-(n & 1))
      end

      def load_integer
        negative = io.getbyte != 0
        length = load_varuint

        value = 0
        length.times { |index| value |= (load_varuint << (index * 32)) }

        value = -value if negative
        value
      end

      def load_double
        io.read(8).unpack1("D")
      end

      def load_uint32
        io.read(4).unpack1("L")
      end

      def load_optional_node
        if io.getbyte != 0
          io.pos -= 1
          load_node
        end
      end

      def load_embedded_string
        io.read(load_varuint).force_encoding(encoding)
      end

      def load_string
        type = io.getbyte
        case type
        when 1
          input.byteslice(load_varuint, load_varuint).force_encoding(encoding)
        when 2
          load_embedded_string
        else
          raise "Unknown serialized string type: #{type}"
        end
      end

      def load_location
        (load_varuint << 32) | load_varuint
      end

      def load_location_object
        Location.new(source, load_varuint, load_varuint)
      end

      def load_optional_location
        load_location if io.getbyte != 0
      end

      def load_optional_location_object
        load_location_object if io.getbyte != 0
      end

      def load_constant(index)
        constant = constant_pool[index]

        unless constant
          offset = constant_pool_offset + index * 8
          start = @serialized.unpack1("L", offset: offset)
          length = @serialized.unpack1("L", offset: offset + 4)

          constant =
            if start.nobits?(1 << 31)
              input.byteslice(start, length).force_encoding(@encoding).to_sym
            else
              @serialized.byteslice(start & ((1 << 31) - 1), length).force_encoding(@encoding).to_sym
            end

          constant_pool[index] = constant
        end

        constant
      end

      def load_required_constant
        load_constant(load_varuint - 1)
      end

      def load_optional_constant
        index = load_varuint
        load_constant(index - 1) if index != 0
      end

      def load_error_level
        level = io.getbyte

        case level
        when 0
          :syntax
        when 1
          :argument
        when 2
          :load
        else
          raise "Unknown level: #{level}"
        end
      end

      def load_warning_level
        level = io.getbyte

        case level
        when 0
          :default
        when 1
          :verbose
        else
          raise "Unknown level: #{level}"
        end
      end

      if RUBY_ENGINE == 'ruby'
        def load_node
          type = io.getbyte
          location = load_location

          case type
          <%- nodes.each_with_index do |node, index| -%>
          when <%= index + 1 %> then
            <%- if node.needs_serialized_length? -%>
            load_uint32
            <%- end -%>
            <%= node.name %>.new(
              source, <%= (node.fields.map { |field|
              case field
              when Prism::Template::NodeField then "load_node"
              when Prism::Template::OptionalNodeField then "load_optional_node"
              when Prism::Template::StringField then "load_string"
              when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node }"
              when Prism::Template::ConstantField then "load_required_constant"
              when Prism::Template::OptionalConstantField then "load_optional_constant"
              when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_required_constant }"
              when Prism::Template::LocationField then "load_location"
              when Prism::Template::OptionalLocationField then "load_optional_location"
              when Prism::Template::UInt8Field then "io.getbyte"
              when Prism::Template::UInt32Field, Prism::Template::FlagsField then "load_varuint"
              when Prism::Template::IntegerField then "load_integer"
              when Prism::Template::DoubleField then "load_double"
              else raise
              end
            } + ["location"]).join(", ") -%>)
            <%- end -%>
          end
        end
      else
        def load_node
          type = io.getbyte
          @load_node_lambdas[type].call
        end

        def define_load_node_lambdas
          @load_node_lambdas = [
            nil,
            <%- nodes.each do |node| -%>
            -> {
              location = load_location
              <%- if node.needs_serialized_length? -%>
              load_uint32
              <%- end -%>
              <%= node.name %>.new(
                source, <%= (node.fields.map { |field|
                case field
                when Prism::Template::NodeField then "load_node"
                when Prism::Template::OptionalNodeField then "load_optional_node"
                when Prism::Template::StringField then "load_string"
                when Prism::Template::NodeListField then "Array.new(load_varuint) { load_node }"
                when Prism::Template::ConstantField then "load_required_constant"
                when Prism::Template::OptionalConstantField then "load_optional_constant"
                when Prism::Template::ConstantListField then "Array.new(load_varuint) { load_required_constant }"
                when Prism::Template::LocationField then "load_location"
                when Prism::Template::OptionalLocationField then "load_optional_location"
                when Prism::Template::UInt8Field then "io.getbyte"
                when Prism::Template::UInt32Field, Prism::Template::FlagsField then "load_varuint"
                when Prism::Template::IntegerField then "load_integer"
                when Prism::Template::DoubleField then "load_double"
                else raise
                end
              } + ["location"]).join(", ") -%>)
            },
            <%- end -%>
          ]
        end
      end
    end

    # The token types that can be indexed by their enum values.
    TOKEN_TYPES = [
      nil,
      <%- tokens.each do |token| -%>
      <%= token.name.to_sym.inspect %>,
      <%- end -%>
    ]
  end
end