diff options
Diffstat (limited to 'test/net')
-rw-r--r-- | test/net/smtp/test_smtp.rb | 2 | ||||
-rw-r--r-- | test/net/smtp/test_sslcontext.rb | 128 | ||||
-rw-r--r-- | test/net/smtp/test_starttls.rb | 121 |
3 files changed, 250 insertions, 1 deletions
diff --git a/test/net/smtp/test_smtp.rb b/test/net/smtp/test_smtp.rb index fccf137cdc..af30bb7221 100644 --- a/test/net/smtp/test_smtp.rb +++ b/test/net/smtp/test_smtp.rb @@ -137,7 +137,7 @@ module Net smtp = Net::SMTP.new("localhost", servers[0].local_address.ip_port) smtp.enable_tls smtp.open_timeout = 1 - smtp.start do + smtp.start(tls_verify: false) do end ensure sock.close if sock diff --git a/test/net/smtp/test_sslcontext.rb b/test/net/smtp/test_sslcontext.rb new file mode 100644 index 0000000000..f3f3b347ad --- /dev/null +++ b/test/net/smtp/test_sslcontext.rb @@ -0,0 +1,128 @@ +require 'net/smtp' +require 'test/unit' + +module Net + class TestSSLContext < Test::Unit::TestCase + class MySMTP < SMTP + attr_reader :__ssl_context, :__tls_hostname + + def initialize(socket) + @fake_socket = socket + super("smtp.example.com") + end + + def tcp_socket(*) + @fake_socket + end + + def ssl_socket_connect(*) + end + + def tlsconnect(*) + super + @fake_socket + end + + def ssl_socket(socket, context) + @__ssl_context = context + s = super + hostname = @__tls_hostname = '' + s.define_singleton_method(:post_connection_check){ |name| hostname.replace(name) } + s + end + end + + def teardown + @server_thread&.exit + @server_socket&.close + @client_socket&.close + end + + def start_smtpd(starttls) + @server_socket, @client_socket = UNIXSocket.pair + @starttls_executed = false + @server_thread = Thread.new(@server_socket) do |s| + s.puts "220 fakeserver\r\n" + while cmd = s.gets&.chomp + case cmd + when /\AEHLO / + s.puts "250-fakeserver\r\n" + s.puts "250-STARTTLS\r\n" if starttls + s.puts "250 8BITMIME\r\n" + when /\ASTARTTLS/ + @starttls_executed = true + s.puts "220 2.0.0 Ready to start TLS\r\n" + else + raise "unsupported command: #{cmd}" + end + end + end + @client_socket + end + + def test_default + smtp = MySMTP.new(start_smtpd(true)) + smtp.start + assert_equal(OpenSSL::SSL::VERIFY_PEER, smtp.__ssl_context.verify_mode) + end + + def test_enable_tls + smtp = MySMTP.new(start_smtpd(true)) + context = OpenSSL::SSL::SSLContext.new + smtp.enable_tls(context) + smtp.start + assert_equal(context, smtp.__ssl_context) + end + + def test_enable_tls_before_disable_starttls + smtp = MySMTP.new(start_smtpd(true)) + context = OpenSSL::SSL::SSLContext.new + smtp.enable_tls(context) + smtp.disable_starttls + smtp.start + assert_equal(context, smtp.__ssl_context) + end + + def test_enable_starttls + smtp = MySMTP.new(start_smtpd(true)) + context = OpenSSL::SSL::SSLContext.new + smtp.enable_starttls(context) + smtp.start + assert_equal(context, smtp.__ssl_context) + end + + def test_enable_starttls_before_disable_tls + smtp = MySMTP.new(start_smtpd(true)) + context = OpenSSL::SSL::SSLContext.new + smtp.enable_starttls(context) + smtp.disable_tls + smtp.start + assert_equal(context, smtp.__ssl_context) + end + + def test_start_with_tls_verify_true + smtp = MySMTP.new(start_smtpd(true)) + smtp.start(tls_verify: true) + assert_equal(OpenSSL::SSL::VERIFY_PEER, smtp.__ssl_context.verify_mode) + end + + def test_start_with_tls_verify_false + smtp = MySMTP.new(start_smtpd(true)) + smtp.start(tls_verify: false) + assert_equal(OpenSSL::SSL::VERIFY_NONE, smtp.__ssl_context.verify_mode) + end + + def test_start_with_tls_hostname + smtp = MySMTP.new(start_smtpd(true)) + smtp.start(tls_hostname: "localhost") + assert_equal("localhost", smtp.__tls_hostname) + end + + def test_start_without_tls_hostname + smtp = MySMTP.new(start_smtpd(true)) + smtp.start + assert_equal("smtp.example.com", smtp.__tls_hostname) + end + + end +end diff --git a/test/net/smtp/test_starttls.rb b/test/net/smtp/test_starttls.rb new file mode 100644 index 0000000000..98835c952a --- /dev/null +++ b/test/net/smtp/test_starttls.rb @@ -0,0 +1,121 @@ +require 'net/smtp' +require 'test/unit' + +module Net + class TestStarttls < Test::Unit::TestCase + class MySMTP < SMTP + def initialize(socket) + @fake_socket = socket + super("smtp.example.com") + end + + def tcp_socket(*) + @fake_socket + end + + def tlsconnect(*) + @fake_socket + end + end + + def teardown + @server_thread&.exit + @server_socket&.close + @client_socket&.close + end + + def start_smtpd(starttls) + @server_socket, @client_socket = UNIXSocket.pair + @starttls_executed = false + @server_thread = Thread.new(@server_socket) do |s| + s.puts "220 fakeserver\r\n" + while cmd = s.gets&.chomp + case cmd + when /\AEHLO / + s.puts "250-fakeserver\r\n" + s.puts "250-STARTTLS\r\n" if starttls + s.puts "250 8BITMIME\r\n" + when /\ASTARTTLS/ + @starttls_executed = true + s.puts "220 2.0.0 Ready to start TLS\r\n" + else + raise "unsupported command: #{cmd}" + end + end + end + @client_socket + end + + def test_default_with_starttls_capable + smtp = MySMTP.new(start_smtpd(true)) + smtp.start + assert(@starttls_executed) + end + + def test_default_without_starttls_capable + smtp = MySMTP.new(start_smtpd(false)) + smtp.start + assert(!@starttls_executed) + end + + def test_enable_starttls_with_starttls_capable + smtp = MySMTP.new(start_smtpd(true)) + smtp.enable_starttls + smtp.start + assert(@starttls_executed) + end + + def test_enable_starttls_without_starttls_capable + smtp = MySMTP.new(start_smtpd(false)) + smtp.enable_starttls + err = assert_raise(Net::SMTPUnsupportedCommand) { smtp.start } + assert_equal("STARTTLS is not supported on this server", err.message) + end + + def test_enable_starttls_auto_with_starttls_capable + smtp = MySMTP.new(start_smtpd(true)) + smtp.enable_starttls_auto + smtp.start + assert(@starttls_executed) + end + + def test_tls_with_starttls_capable + smtp = MySMTP.new(start_smtpd(true)) + smtp.enable_tls + smtp.start + assert(!@starttls_executed) + end + + def test_tls_without_starttls_capable + smtp = MySMTP.new(start_smtpd(false)) + smtp.enable_tls + end + + def test_disable_starttls + smtp = MySMTP.new(start_smtpd(true)) + smtp.disable_starttls + smtp.start + assert(!@starttls_executed) + end + + def test_enable_tls_and_enable_starttls + smtp = MySMTP.new(start_smtpd(true)) + smtp.enable_tls + err = assert_raise(ArgumentError) { smtp.enable_starttls } + assert_equal("SMTPS and STARTTLS is exclusive", err.message) + end + + def test_enable_tls_and_enable_starttls_auto + smtp = MySMTP.new(start_smtpd(true)) + smtp.enable_tls + err = assert_raise(ArgumentError) { smtp.enable_starttls_auto } + assert_equal("SMTPS and STARTTLS is exclusive", err.message) + end + + def test_enable_starttls_and_enable_starttls_auto + smtp = MySMTP.new(start_smtpd(true)) + smtp.enable_starttls + assert_nothing_raised { smtp.enable_starttls_auto } + end + end +end |