|  | // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style license that can be | 
|  | // found in the LICENSE file. | 
|  |  | 
|  | #include "net/test/embedded_test_server/embedded_test_server.h" | 
|  |  | 
|  | #include <utility> | 
|  |  | 
|  | #include "base/bind.h" | 
|  | #include "base/files/file_path.h" | 
|  | #include "base/files/file_util.h" | 
|  | #include "base/location.h" | 
|  | #include "base/logging.h" | 
|  | #include "base/message_loop/message_loop.h" | 
|  | #include "base/message_loop/message_loop_current.h" | 
|  | #include "base/path_service.h" | 
|  | #include "base/process/process_metrics.h" | 
|  | #include "base/run_loop.h" | 
|  | #include "base/strings/string_util.h" | 
|  | #include "base/strings/stringprintf.h" | 
|  | #include "base/threading/thread_restrictions.h" | 
|  | #include "base/threading/thread_task_runner_handle.h" | 
|  | #include "crypto/rsa_private_key.h" | 
|  | #include "net/base/ip_endpoint.h" | 
|  | #include "net/base/net_errors.h" | 
|  | #include "net/base/port_util.h" | 
|  | #include "net/cert/pem_tokenizer.h" | 
|  | #include "net/cert/test_root_certs.h" | 
|  | #include "net/log/net_log_source.h" | 
|  | #include "net/socket/ssl_server_socket.h" | 
|  | #include "net/socket/stream_socket.h" | 
|  | #include "net/socket/tcp_server_socket.h" | 
|  | #include "net/ssl/ssl_info.h" | 
|  | #include "net/ssl/ssl_server_config.h" | 
|  | #include "net/test/cert_test_util.h" | 
|  | #include "net/test/embedded_test_server/default_handlers.h" | 
|  | #include "net/test/embedded_test_server/embedded_test_server_connection_listener.h" | 
|  | #include "net/test/embedded_test_server/http_connection.h" | 
|  | #include "net/test/embedded_test_server/http_request.h" | 
|  | #include "net/test/embedded_test_server/http_response.h" | 
|  | #include "net/test/embedded_test_server/request_handler_util.h" | 
|  | #include "net/test/test_data_directory.h" | 
|  |  | 
|  | namespace net { | 
|  | namespace test_server { | 
|  |  | 
|  | EmbeddedTestServer::EmbeddedTestServer() : EmbeddedTestServer(TYPE_HTTP) {} | 
|  |  | 
|  | EmbeddedTestServer::EmbeddedTestServer(Type type) | 
|  | : is_using_ssl_(type == TYPE_HTTPS), | 
|  | connection_listener_(nullptr), | 
|  | port_(0), | 
|  | cert_(CERT_OK), | 
|  | weak_factory_(this) { | 
|  | DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | 
|  |  | 
|  | if (!is_using_ssl_) | 
|  | return; | 
|  | RegisterTestCerts(); | 
|  | } | 
|  |  | 
|  | EmbeddedTestServer::~EmbeddedTestServer() { | 
|  | DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | 
|  |  | 
|  | if (Started() && !ShutdownAndWaitUntilComplete()) { | 
|  | LOG(ERROR) << "EmbeddedTestServer failed to shut down."; | 
|  | } | 
|  |  | 
|  | { | 
|  | // Thread::Join induced by test code should cause an assert. | 
|  | base::ScopedAllowBlockingForTesting allow_blocking; | 
|  |  | 
|  | io_thread_.reset(); | 
|  | } | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::RegisterTestCerts() { | 
|  | base::ScopedAllowBlockingForTesting allow_blocking; | 
|  | TestRootCerts* root_certs = TestRootCerts::GetInstance(); | 
|  | bool added_root_certs = root_certs->AddFromFile(GetRootCertPemPath()); | 
|  | DCHECK(added_root_certs) | 
|  | << "Failed to install root cert from EmbeddedTestServer"; | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::SetConnectionListener( | 
|  | EmbeddedTestServerConnectionListener* listener) { | 
|  | DCHECK(!io_thread_.get()) | 
|  | << "ConnectionListener must be set before starting the server."; | 
|  | connection_listener_ = listener; | 
|  | } | 
|  |  | 
|  | bool EmbeddedTestServer::Start(int port) { | 
|  | bool success = InitializeAndListen(port); | 
|  | if (!success) | 
|  | return false; | 
|  | StartAcceptingConnections(); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool EmbeddedTestServer::InitializeAndListen(int port) { | 
|  | DCHECK(!Started()); | 
|  |  | 
|  | const int max_tries = 5; | 
|  | int num_tries = 0; | 
|  | bool is_valid_port = false; | 
|  |  | 
|  | do { | 
|  | if (++num_tries > max_tries) { | 
|  | LOG(ERROR) << "Failed to listen on a valid port after " << max_tries | 
|  | << " attempts."; | 
|  | listen_socket_.reset(); | 
|  | return false; | 
|  | } | 
|  |  | 
|  | listen_socket_.reset(new TCPServerSocket(nullptr, NetLogSource())); | 
|  |  | 
|  | int result = | 
|  | listen_socket_->ListenWithAddressAndPort("127.0.0.1", port, 10); | 
|  | if (result) { | 
|  | LOG(ERROR) << "Listen failed: " << ErrorToString(result); | 
|  | listen_socket_.reset(); | 
|  | return false; | 
|  | } | 
|  |  | 
|  | result = listen_socket_->GetLocalAddress(&local_endpoint_); | 
|  | if (result != OK) { | 
|  | LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result); | 
|  | listen_socket_.reset(); | 
|  | return false; | 
|  | } | 
|  |  | 
|  | port_ = local_endpoint_.port(); | 
|  | is_valid_port |= net::IsPortAllowedForScheme( | 
|  | port_, is_using_ssl_ ? url::kHttpsScheme : url::kHttpScheme); | 
|  | } while (!is_valid_port); | 
|  |  | 
|  | if (is_using_ssl_) { | 
|  | base_url_ = GURL("https://" + local_endpoint_.ToString()); | 
|  | if (cert_ == CERT_MISMATCHED_NAME || cert_ == CERT_COMMON_NAME_IS_DOMAIN) { | 
|  | base_url_ = GURL( | 
|  | base::StringPrintf("https://localhost:%d", local_endpoint_.port())); | 
|  | } | 
|  | } else { | 
|  | base_url_ = GURL("http://" + local_endpoint_.ToString()); | 
|  | } | 
|  |  | 
|  | listen_socket_->DetachFromThread(); | 
|  |  | 
|  | if (is_using_ssl_) | 
|  | InitializeSSLServerContext(); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::InitializeSSLServerContext() { | 
|  | base::ScopedAllowBlockingForTesting allow_blocking; | 
|  | base::FilePath certs_dir(GetTestCertsDirectory()); | 
|  | std::string cert_name = GetCertificateName(); | 
|  |  | 
|  | base::FilePath key_path = certs_dir.AppendASCII(cert_name); | 
|  | std::string key_string; | 
|  | CHECK(base::ReadFileToString(key_path, &key_string)); | 
|  | std::vector<std::string> headers; | 
|  | headers.push_back("PRIVATE KEY"); | 
|  | PEMTokenizer pem_tokenizer(key_string, headers); | 
|  | pem_tokenizer.GetNext(); | 
|  | std::vector<uint8_t> key_vector; | 
|  | key_vector.assign(pem_tokenizer.data().begin(), pem_tokenizer.data().end()); | 
|  |  | 
|  | std::unique_ptr<crypto::RSAPrivateKey> server_key( | 
|  | crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); | 
|  | context_ = | 
|  | CreateSSLServerContext(GetCertificate().get(), *server_key, ssl_config_); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::StartAcceptingConnections() { | 
|  | DCHECK(!io_thread_.get()) | 
|  | << "Server must not be started while server is running"; | 
|  | base::Thread::Options thread_options; | 
|  | thread_options.message_loop_type = base::MessageLoop::TYPE_IO; | 
|  | io_thread_.reset(new base::Thread("EmbeddedTestServer IO Thread")); | 
|  | #if defined(STARBOARD) | 
|  | thread_options.stack_size = base::kUnitTestStackSize; | 
|  | #endif | 
|  | CHECK(io_thread_->StartWithOptions(thread_options)); | 
|  | CHECK(io_thread_->WaitUntilThreadStarted()); | 
|  |  | 
|  | io_thread_->task_runner()->PostTask( | 
|  | FROM_HERE, | 
|  | base::Bind(&EmbeddedTestServer::DoAcceptLoop, base::Unretained(this))); | 
|  | } | 
|  |  | 
|  | bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() { | 
|  | DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | 
|  |  | 
|  | return PostTaskToIOThreadAndWait(base::Bind( | 
|  | &EmbeddedTestServer::ShutdownOnIOThread, base::Unretained(this))); | 
|  | } | 
|  |  | 
|  | // static | 
|  | base::FilePath EmbeddedTestServer::GetRootCertPemPath() { | 
|  | return GetTestCertsDirectory().AppendASCII("root_ca_cert.pem"); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::ShutdownOnIOThread() { | 
|  | DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); | 
|  | weak_factory_.InvalidateWeakPtrs(); | 
|  | listen_socket_.reset(); | 
|  | connections_.clear(); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::HandleRequest(HttpConnection* connection, | 
|  | std::unique_ptr<HttpRequest> request) { | 
|  | DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); | 
|  | request->base_url = base_url_; | 
|  |  | 
|  | SSLInfo ssl_info; | 
|  | if (connection->socket_->GetSSLInfo(&ssl_info) && | 
|  | ssl_info.early_data_received) { | 
|  | request->headers["Early-Data"] = "1"; | 
|  | } | 
|  |  | 
|  | for (const auto& monitor : request_monitors_) | 
|  | monitor.Run(*request); | 
|  |  | 
|  | std::unique_ptr<HttpResponse> response; | 
|  |  | 
|  | for (const auto& handler : request_handlers_) { | 
|  | response = handler.Run(*request); | 
|  | if (response) | 
|  | break; | 
|  | } | 
|  |  | 
|  | if (!response) { | 
|  | for (const auto& handler : default_request_handlers_) { | 
|  | response = handler.Run(*request); | 
|  | if (response) | 
|  | break; | 
|  | } | 
|  | } | 
|  |  | 
|  | if (!response) { | 
|  | LOG(WARNING) << "Request not handled. Returning 404: " | 
|  | << request->relative_url; | 
|  | std::unique_ptr<BasicHttpResponse> not_found_response( | 
|  | new BasicHttpResponse); | 
|  | not_found_response->set_code(HTTP_NOT_FOUND); | 
|  | response = std::move(not_found_response); | 
|  | } | 
|  |  | 
|  | response->SendResponse( | 
|  | base::Bind(&HttpConnection::SendResponseBytes, connection->GetWeakPtr()), | 
|  | base::Bind(&EmbeddedTestServer::DidClose, weak_factory_.GetWeakPtr(), | 
|  | connection)); | 
|  | } | 
|  |  | 
|  | GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const { | 
|  | DCHECK(Started()) << "You must start the server first."; | 
|  | DCHECK(base::StartsWith(relative_url, "/", base::CompareCase::SENSITIVE)) | 
|  | << relative_url; | 
|  | return base_url_.Resolve(relative_url); | 
|  | } | 
|  |  | 
|  | GURL EmbeddedTestServer::GetURL( | 
|  | const std::string& hostname, | 
|  | const std::string& relative_url) const { | 
|  | GURL local_url = GetURL(relative_url); | 
|  | GURL::Replacements replace_host; | 
|  | replace_host.SetHostStr(hostname); | 
|  | return local_url.ReplaceComponents(replace_host); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert, | 
|  | const SSLServerConfig& ssl_config) { | 
|  | DCHECK(!Started()); | 
|  | cert_ = cert; | 
|  | ssl_config_ = ssl_config; | 
|  | } | 
|  |  | 
|  | bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const { | 
|  | *address_list = AddressList(local_endpoint_); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::ResetSSLConfigOnIOThread( | 
|  | ServerCertificate cert, | 
|  | const SSLServerConfig& ssl_config) { | 
|  | cert_ = cert; | 
|  | ssl_config_ = ssl_config; | 
|  | connections_.clear(); | 
|  | InitializeSSLServerContext(); | 
|  | } | 
|  |  | 
|  | bool EmbeddedTestServer::ResetSSLConfig(ServerCertificate cert, | 
|  | const SSLServerConfig& ssl_config) { | 
|  | return PostTaskToIOThreadAndWait( | 
|  | base::BindRepeating(&EmbeddedTestServer::ResetSSLConfigOnIOThread, | 
|  | base::Unretained(this), cert, ssl_config)); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert) { | 
|  | SetSSLConfig(cert, SSLServerConfig()); | 
|  | } | 
|  |  | 
|  | std::string EmbeddedTestServer::GetCertificateName() const { | 
|  | DCHECK(is_using_ssl_); | 
|  | switch (cert_) { | 
|  | case CERT_OK: | 
|  | case CERT_MISMATCHED_NAME: | 
|  | return "ok_cert.pem"; | 
|  | case CERT_COMMON_NAME_IS_DOMAIN: | 
|  | return "localhost_cert.pem"; | 
|  | case CERT_EXPIRED: | 
|  | return "expired_cert.pem"; | 
|  | case CERT_COMMON_NAME_ONLY: | 
|  | return "common_name_only.pem"; | 
|  | case CERT_SHA1_LEAF: | 
|  | return "sha1_leaf.pem"; | 
|  | } | 
|  |  | 
|  | return "ok_cert.pem"; | 
|  | } | 
|  |  | 
|  | scoped_refptr<X509Certificate> EmbeddedTestServer::GetCertificate() const { | 
|  | DCHECK(is_using_ssl_); | 
|  | base::FilePath certs_dir(GetTestCertsDirectory()); | 
|  |  | 
|  | base::ScopedAllowBlockingForTesting allow_blocking; | 
|  | return ImportCertFromFile(certs_dir, GetCertificateName()); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::ServeFilesFromDirectory( | 
|  | const base::FilePath& directory) { | 
|  | RegisterDefaultHandler(base::Bind(&HandleFileRequest, directory)); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::ServeFilesFromSourceDirectory( | 
|  | const std::string& relative) { | 
|  | base::FilePath test_data_dir; | 
|  | CHECK(base::PathService::Get(base::DIR_TEST_DATA, &test_data_dir)); | 
|  | ServeFilesFromDirectory(test_data_dir.AppendASCII(relative)); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::ServeFilesFromSourceDirectory( | 
|  | const base::FilePath& relative) { | 
|  | base::FilePath test_data_dir; | 
|  | CHECK(base::PathService::Get(base::DIR_TEST_DATA, &test_data_dir)); | 
|  | ServeFilesFromDirectory(test_data_dir.Append(relative)); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::AddDefaultHandlers(const base::FilePath& directory) { | 
|  | ServeFilesFromSourceDirectory(directory); | 
|  | RegisterDefaultHandlers(this); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::RegisterRequestHandler( | 
|  | const HandleRequestCallback& callback) { | 
|  | DCHECK(!io_thread_.get()) | 
|  | << "Handlers must be registered before starting the server."; | 
|  | request_handlers_.push_back(callback); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::RegisterRequestMonitor( | 
|  | const MonitorRequestCallback& callback) { | 
|  | DCHECK(!io_thread_.get()) | 
|  | << "Monitors must be registered before starting the server."; | 
|  | request_monitors_.push_back(callback); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::RegisterDefaultHandler( | 
|  | const HandleRequestCallback& callback) { | 
|  | DCHECK(!io_thread_.get()) | 
|  | << "Handlers must be registered before starting the server."; | 
|  | default_request_handlers_.push_back(callback); | 
|  | } | 
|  |  | 
|  | std::unique_ptr<StreamSocket> EmbeddedTestServer::DoSSLUpgrade( | 
|  | std::unique_ptr<StreamSocket> connection) { | 
|  | DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); | 
|  |  | 
|  | return context_->CreateSSLServerSocket(std::move(connection)); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::DoAcceptLoop() { | 
|  | while ( | 
|  | listen_socket_->Accept(&accepted_socket_, | 
|  | base::Bind(&EmbeddedTestServer::OnAcceptCompleted, | 
|  | base::Unretained(this))) == OK) { | 
|  | HandleAcceptResult(std::move(accepted_socket_)); | 
|  | } | 
|  | } | 
|  |  | 
|  | bool EmbeddedTestServer::FlushAllSocketsAndConnectionsOnUIThread() { | 
|  | return PostTaskToIOThreadAndWait( | 
|  | base::Bind(&EmbeddedTestServer::FlushAllSocketsAndConnections, | 
|  | base::Unretained(this))); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::FlushAllSocketsAndConnections() { | 
|  | connections_.clear(); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::OnAcceptCompleted(int rv) { | 
|  | DCHECK_NE(ERR_IO_PENDING, rv); | 
|  | HandleAcceptResult(std::move(accepted_socket_)); | 
|  | DoAcceptLoop(); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::OnHandshakeDone(HttpConnection* connection, int rv) { | 
|  | if (connection->socket_->IsConnected()) | 
|  | ReadData(connection); | 
|  | else | 
|  | DidClose(connection); | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::HandleAcceptResult( | 
|  | std::unique_ptr<StreamSocket> socket) { | 
|  | DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); | 
|  | if (connection_listener_) | 
|  | connection_listener_->AcceptedSocket(*socket); | 
|  |  | 
|  | if (is_using_ssl_) | 
|  | socket = DoSSLUpgrade(std::move(socket)); | 
|  |  | 
|  | std::unique_ptr<HttpConnection> http_connection_ptr = | 
|  | std::make_unique<HttpConnection>( | 
|  | std::move(socket), base::Bind(&EmbeddedTestServer::HandleRequest, | 
|  | base::Unretained(this))); | 
|  | HttpConnection* http_connection = http_connection_ptr.get(); | 
|  | connections_[http_connection->socket_.get()] = std::move(http_connection_ptr); | 
|  |  | 
|  | if (is_using_ssl_) { | 
|  | SSLServerSocket* ssl_socket = | 
|  | static_cast<SSLServerSocket*>(http_connection->socket_.get()); | 
|  | int rv = ssl_socket->Handshake( | 
|  | base::Bind(&EmbeddedTestServer::OnHandshakeDone, base::Unretained(this), | 
|  | http_connection)); | 
|  | if (rv != ERR_IO_PENDING) | 
|  | OnHandshakeDone(http_connection, rv); | 
|  | } else { | 
|  | ReadData(http_connection); | 
|  | } | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::ReadData(HttpConnection* connection) { | 
|  | while (true) { | 
|  | int rv = | 
|  | connection->ReadData(base::Bind(&EmbeddedTestServer::OnReadCompleted, | 
|  | base::Unretained(this), connection)); | 
|  | if (rv == ERR_IO_PENDING) | 
|  | return; | 
|  | if (!HandleReadResult(connection, rv)) | 
|  | return; | 
|  | } | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::OnReadCompleted(HttpConnection* connection, int rv) { | 
|  | DCHECK_NE(ERR_IO_PENDING, rv); | 
|  | if (HandleReadResult(connection, rv)) | 
|  | ReadData(connection); | 
|  | } | 
|  |  | 
|  | bool EmbeddedTestServer::HandleReadResult(HttpConnection* connection, int rv) { | 
|  | DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); | 
|  | if (connection_listener_) | 
|  | connection_listener_->ReadFromSocket(*connection->socket_, rv); | 
|  | if (rv <= 0) { | 
|  | DidClose(connection); | 
|  | return false; | 
|  | } | 
|  |  | 
|  | // Once a single complete request has been received, there is no further need | 
|  | // for the connection and it may be destroyed once the response has been sent. | 
|  | if (connection->ConsumeData(rv)) | 
|  | return false; | 
|  |  | 
|  | return true; | 
|  | } | 
|  |  | 
|  | void EmbeddedTestServer::DidClose(HttpConnection* connection) { | 
|  | DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); | 
|  | DCHECK(connection); | 
|  | DCHECK_EQ(1u, connections_.count(connection->socket_.get())); | 
|  |  | 
|  | connections_.erase(connection->socket_.get()); | 
|  | } | 
|  |  | 
|  | HttpConnection* EmbeddedTestServer::FindConnection(StreamSocket* socket) { | 
|  | DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); | 
|  |  | 
|  | auto it = connections_.find(socket); | 
|  | if (it == connections_.end()) { | 
|  | return nullptr; | 
|  | } | 
|  | return it->second.get(); | 
|  | } | 
|  |  | 
|  | bool EmbeddedTestServer::PostTaskToIOThreadAndWait( | 
|  | const base::Closure& closure) { | 
|  | // Note that PostTaskAndReply below requires | 
|  | // base::ThreadTaskRunnerHandle::Get() to return a task runner for posting | 
|  | // the reply task. However, in order to make EmbeddedTestServer universally | 
|  | // usable, it needs to cope with the situation where it's running on a thread | 
|  | // on which a message loop is not (yet) available or as has been destroyed | 
|  | // already. | 
|  | // | 
|  | // To handle this situation, create temporary message loop to support the | 
|  | // PostTaskAndReply operation if the current thread as no message loop. | 
|  | std::unique_ptr<base::MessageLoop> temporary_loop; | 
|  | if (!base::MessageLoopCurrent::Get()) | 
|  | temporary_loop.reset(new base::MessageLoop()); | 
|  |  | 
|  | base::RunLoop run_loop; | 
|  | if (!io_thread_->task_runner()->PostTaskAndReply(FROM_HERE, closure, | 
|  | run_loop.QuitClosure())) { | 
|  | return false; | 
|  | } | 
|  | run_loop.Run(); | 
|  |  | 
|  | return true; | 
|  | } | 
|  |  | 
|  | }  // namespace test_server | 
|  | }  // namespace net |