summaryrefslogtreecommitdiff
path: root/lib/drb/ssl.rb
blob: 58d6b7d1e0bd537d759decb73268ed4771e8a970 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
require 'socket'
require 'openssl'
require 'drb/drb'
require 'singleton'

module DRb

  class DRbSSLSocket < DRbTCPSocket

    class SSLConfig

      DEFAULT = {
	:SSLCertificate       => nil,
	:SSLPrivateKey        => nil,
	:SSLClientCA          => nil,
	:SSLCACertificatePath => nil,
	:SSLCACertificateFile => nil,
	:SSLVerifyMode        => ::OpenSSL::SSL::VERIFY_NONE, 
	:SSLVerifyDepth       => nil,
	:SSLVerifyCallback    => nil,   # custom verification
        :SSLCertificateStore  => nil,
	# Must specify if you use auto generated certificate.
	:SSLCertName          => nil,   # e.g. [["CN","fqdn.example.com"]]
	:SSLCertComment       => "Generated by Ruby/OpenSSL"
      }

      def initialize(config)
	@config  = config
        @cert    = config[:SSLCertificate]
        @pkey    = config[:SSLPrivateKey]
        @ssl_ctx = nil
      end

      def [](key); 
	@config[key] || DEFAULT[key]
      end

      def connect(tcp)
	ssl = ::OpenSSL::SSL::SSLSocket.new(tcp, @ssl_ctx)
	ssl.sync = true
	ssl.connect
	ssl
      end
      
      def accept(tcp)
	ssl = OpenSSL::SSL::SSLSocket.new(tcp, @ssl_ctx)
	ssl.sync = true
	ssl.accept
	ssl
      end
      
      def setup_certificate
        if @cert && @pkey
          return
        end

	rsa = OpenSSL::PKey::RSA.new(512){|p, n|
	  next unless self[:verbose]
	  case p
	  when 0; $stderr.putc "."  # BN_generate_prime
	  when 1; $stderr.putc "+"  # BN_generate_prime
	  when 2; $stderr.putc "*"  # searching good prime,
	                            # n = #of try,
                          	    # but also data from BN_generate_prime
	  when 3; $stderr.putc "\n" # found good prime, n==0 - p, n==1 - q,
                         	    # but also data from BN_generate_prime
	  else;   $stderr.putc "*"  # BN_generate_prime
	  end
	}

	cert = OpenSSL::X509::Certificate.new
	cert.version = 3
	cert.serial = 0
	name = OpenSSL::X509::Name.new(self[:SSLCertName])
	cert.subject = name
	cert.issuer = name
	cert.not_before = Time.now
	cert.not_after = Time.now + (365*24*60*60)
	cert.public_key = rsa.public_key
	
	ef = OpenSSL::X509::ExtensionFactory.new(nil,cert)
	cert.extensions = [
	  ef.create_extension("basicConstraints","CA:FALSE"),
	  ef.create_extension("subjectKeyIdentifier", "hash") ]
	ef.issuer_certificate = cert
	cert.add_extension(ef.create_extension("authorityKeyIdentifier",
					       "keyid:always,issuer:always"))
	if comment = self[:SSLCertComment]
	  cert.add_extension(ef.create_extension("nsComment", comment))
	end
	cert.sign(rsa, OpenSSL::Digest::SHA1.new)
	
	@cert = cert
        @pkey = rsa
      end

      def setup_ssl_context
        ctx = ::OpenSSL::SSL::SSLContext.new
        ctx.cert            = @cert
        ctx.key             = @pkey
	ctx.client_ca       = self[:SSLClientCA]
	ctx.ca_path         = self[:SSLCACertificatePath]
	ctx.ca_file         = self[:SSLCACertificateFile]
	ctx.verify_mode     = self[:SSLVerifyMode]
	ctx.verify_depth    = self[:SSLVerifyDepth]
	ctx.verify_callback = self[:SSLVerifyCallback]
        ctx.cert_store      = self[:SSLCertificateStore]
        @ssl_ctx = ctx
      end
    end

    def self.parse_uri(uri)
      if uri =~ /^drbssl:\/\/(.*?):(\d+)(\?(.*))?$/
	host = $1
	port = $2.to_i
	option = $4
	[host, port, option]
      else
	raise(DRbBadScheme, uri) unless uri =~ /^drbssl:/
	raise(DRbBadURI, 'can\'t parse uri:' + uri)
      end
    end

    def self.open(uri, config)
      host, port, option = parse_uri(uri)
      host.untaint
      port.untaint
      soc = TCPSocket.open(host, port)
      ssl_conf = SSLConfig::new(config)
      ssl_conf.setup_ssl_context
      ssl = ssl_conf.connect(soc)
      self.new(uri, ssl, ssl_conf, true)
    end

    def self.open_server(uri, config)
      uri = 'drbssl://:0' unless uri
      host, port, opt = parse_uri(uri)
      if host.size == 0
        host = getservername
        soc = open_server_inaddr_any(host, port)
      else
	soc = TCPServer.open(host, port)
      end
      port = soc.addr[1] if port == 0
      @uri = "drbssl://#{host}:#{port}"
      
      ssl_conf = SSLConfig.new(config)
      ssl_conf.setup_certificate
      ssl_conf.setup_ssl_context
      self.new(@uri, soc, ssl_conf, false)
    end

    def self.uri_option(uri, config)
      host, port, option = parse_uri(uri)
      return "drbssl://#{host}:#{port}", option
    end

    def initialize(uri, soc, config, is_established)
      @ssl = is_established ? soc : nil
      super(uri, soc.to_io, config)
    end
    
    def stream; @ssl; end

    def close
      if @ssl
	@ssl.close
	@ssl = nil
      end
      super
    end
      
    def accept
      begin
      while true
	soc = @socket.accept
	break if (@acl ? @acl.allow_socket?(soc) : true) 
	soc.close
      end
      ssl = @config.accept(soc)
      self.class.new(uri, ssl, @config, true)
      rescue OpenSSL::SSL::SSLError
	warn("#{__FILE__}:#{__LINE__}: warning: #{$!.message} (#{$!.class})") if @config[:verbose]
	retry
      end
    end
  end
  
  DRbProtocol.add_protocol(DRbSSLSocket)
end