summaryrefslogtreecommitdiff
path: root/lib/csv/writer.rb
blob: d49115fccf8ae54692989e2a8b5a2eb4dc4e5752 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# frozen_string_literal: true

require_relative "match_p"
require_relative "row"

using CSV::MatchP if CSV.const_defined?(:MatchP)

class CSV
  # Note: Don't use this class directly. This is an internal class.
  class Writer
    #
    # A CSV::Writer receives an output, prepares the header, format and output.
    # It allows us to write new rows in the object and rewind it.
    #
    attr_reader :lineno
    attr_reader :headers

    def initialize(output, options)
      @output = output
      @options = options
      @lineno = 0
      @fields_converter = nil
      prepare
      if @options[:write_headers] and @headers
        self << @headers
      end
      @fields_converter = @options[:fields_converter]
    end

    #
    # Adds a new row
    #
    def <<(row)
      case row
      when Row
        row = row.fields
      when Hash
        row = @headers.collect {|header| row[header]}
      end

      @headers ||= row if @use_headers
      @lineno += 1

      row = @fields_converter.convert(row, nil, lineno) if @fields_converter

      i = -1
      converted_row = row.collect do |field|
        i += 1
        quote(field, i)
      end
      line = converted_row.join(@column_separator) + @row_separator
      if @output_encoding
        line = line.encode(@output_encoding)
      end
      @output << line

      self
    end

    #
    # Winds back to the beginning
    #
    def rewind
      @lineno = 0
      @headers = nil if @options[:headers].nil?
    end

    private
    def prepare
      @encoding = @options[:encoding]

      prepare_header
      prepare_format
      prepare_output
    end

    def prepare_header
      headers = @options[:headers]
      case headers
      when Array
        @headers = headers
        @use_headers = true
      when String
        @headers = CSV.parse_line(headers,
                                  col_sep: @options[:column_separator],
                                  row_sep: @options[:row_separator],
                                  quote_char: @options[:quote_character])
        @use_headers = true
      when true
        @headers = nil
        @use_headers = true
      else
        @headers = nil
        @use_headers = false
      end
      return unless @headers

      converter = @options[:header_fields_converter]
      @headers = converter.convert(@headers, nil, 0)
      @headers.each do |header|
        header.freeze if header.is_a?(String)
      end
    end

    def prepare_force_quotes_fields(force_quotes)
      @force_quotes_fields = {}
      force_quotes.each do |name_or_index|
        case name_or_index
        when Integer
          index = name_or_index
          @force_quotes_fields[index] = true
        when String, Symbol
          name = name_or_index.to_s
          if @headers.nil?
            message = ":headers is required when you use field name " +
                      "in :force_quotes: " +
                      "#{name_or_index.inspect}: #{force_quotes.inspect}"
            raise ArgumentError, message
          end
          index = @headers.index(name)
          next if index.nil?
          @force_quotes_fields[index] = true
        else
          message = ":force_quotes element must be " +
                    "field index or field name: " +
                    "#{name_or_index.inspect}: #{force_quotes.inspect}"
          raise ArgumentError, message
        end
      end
    end

    def prepare_format
      @column_separator = @options[:column_separator].to_s.encode(@encoding)
      row_separator = @options[:row_separator]
      if row_separator == :auto
        @row_separator = $INPUT_RECORD_SEPARATOR.encode(@encoding)
      else
        @row_separator = row_separator.to_s.encode(@encoding)
      end
      @quote_character = @options[:quote_character]
      force_quotes = @options[:force_quotes]
      if force_quotes.is_a?(Array)
        prepare_force_quotes_fields(force_quotes)
        @force_quotes = false
      elsif force_quotes
        @force_quotes_fields = nil
        @force_quotes = true
      else
        @force_quotes_fields = nil
        @force_quotes = false
      end
      unless @force_quotes
        @quotable_pattern =
          Regexp.new("[\r\n".encode(@encoding) +
                     Regexp.escape(@column_separator) +
                     Regexp.escape(@quote_character.encode(@encoding)) +
                     "]".encode(@encoding))
      end
      @quote_empty = @options.fetch(:quote_empty, true)
    end

    def prepare_output
      @output_encoding = nil
      return unless @output.is_a?(StringIO)

      output_encoding = @output.internal_encoding || @output.external_encoding
      if @encoding != output_encoding
        if @options[:force_encoding]
          @output_encoding = output_encoding
        else
          compatible_encoding = Encoding.compatible?(@encoding, output_encoding)
          if compatible_encoding
            @output.set_encoding(compatible_encoding)
            @output.seek(0, IO::SEEK_END)
          end
        end
      end
    end

    def quote_field(field)
      field = String(field)
      encoded_quote_character = @quote_character.encode(field.encoding)
      encoded_quote_character +
        field.gsub(encoded_quote_character,
                   encoded_quote_character * 2) +
        encoded_quote_character
    end

    def quote(field, i)
      if @force_quotes
        quote_field(field)
      elsif @force_quotes_fields and @force_quotes_fields[i]
        quote_field(field)
      else
        if field.nil?  # represent +nil+ fields as empty unquoted fields
          ""
        else
          field = String(field)  # Stringify fields
          # represent empty fields as empty quoted fields
          if (@quote_empty and field.empty?) or (field.valid_encoding? and @quotable_pattern.match?(field))
            quote_field(field)
          else
            field  # unquoted field
          end
        end
      end
    end
  end
end