| // Copyright 2021 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "net/test/embedded_test_server/connection_tracker.h" |
| |
| #include "base/containers/contains.h" |
| #include "base/run_loop.h" |
| #include "base/task/single_thread_task_runner.h" |
| #include "net/test/embedded_test_server/embedded_test_server.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| namespace { |
| |
| bool GetPort(const net::StreamSocket& connection, uint16_t* port) { |
| // Gets the remote port of the peer, since the local port will always be |
| // the port the test server is listening on. This isn't strictly correct - |
| // it's possible for multiple peers to connect with the same remote port |
| // but different remote IPs - but the tests here assume that connections |
| // to the test server (running on localhost) will always come from |
| // localhost, and thus the peer port is all that's needed to distinguish |
| // two connections. This also would be problematic if the OS reused ports, |
| // but that's not something to worry about for these tests. |
| net::IPEndPoint address; |
| int result = connection.GetPeerAddress(&address); |
| if (result != net::OK) |
| return false; |
| *port = address.port(); |
| return true; |
| } |
| |
| } // namespace |
| |
| namespace net::test_server { |
| |
| ConnectionTracker::ConnectionTracker(EmbeddedTestServer* test_server) |
| : connection_listener_(this) { |
| test_server->SetConnectionListener(&connection_listener_); |
| } |
| |
| ConnectionTracker::~ConnectionTracker() = default; |
| |
| void ConnectionTracker::AcceptedSocketWithPort(uint16_t port) { |
| num_connected_sockets_++; |
| sockets_[port] = SocketStatus::kAccepted; |
| CheckAccepted(); |
| } |
| |
| void ConnectionTracker::ReadFromSocketWithPort(uint16_t port) { |
| EXPECT_TRUE(base::Contains(sockets_, port)); |
| if (sockets_[port] == SocketStatus::kAccepted) |
| num_read_sockets_++; |
| sockets_[port] = SocketStatus::kReadFrom; |
| if (read_loop_) { |
| read_loop_->Quit(); |
| read_loop_ = nullptr; |
| } |
| } |
| |
| // Returns the number of sockets that were accepted by the server. |
| size_t ConnectionTracker::GetAcceptedSocketCount() const { |
| return num_connected_sockets_; |
| } |
| |
| // Returns the number of sockets that were read from by the server. |
| size_t ConnectionTracker::GetReadSocketCount() const { |
| return num_read_sockets_; |
| } |
| |
| void ConnectionTracker::WaitUntilConnectionRead() { |
| base::RunLoop run_loop; |
| read_loop_ = &run_loop; |
| read_loop_->Run(); |
| } |
| |
| // This will wait for exactly |num_connections| items in |sockets_|. This method |
| // expects the server will not accept more than |num_connections| connections. |
| // |num_connections| must be greater than 0. |
| void ConnectionTracker::WaitForAcceptedConnections(size_t num_connections) { |
| DCHECK(!num_accepted_connections_loop_); |
| DCHECK_GT(num_connections, 0u); |
| base::RunLoop run_loop; |
| EXPECT_GE(num_connections, num_connected_sockets_); |
| num_accepted_connections_loop_ = &run_loop; |
| num_accepted_connections_needed_ = num_connections; |
| CheckAccepted(); |
| // Note that the previous call to CheckAccepted can quit this run loop |
| // before this call, which will make this call a no-op. |
| run_loop.Run(); |
| EXPECT_EQ(num_connections, num_connected_sockets_); |
| } |
| |
| // Helper function to stop the waiting for sockets to be accepted for |
| // WaitForAcceptedConnections. |num_accepted_connections_loop_| spins |
| // until |num_accepted_connections_needed_| sockets are accepted by the test |
| // server. The values will be null/0 if the loop is not running. |
| void ConnectionTracker::CheckAccepted() { |
| // |num_accepted_connections_loop_| null implies |
| // |num_accepted_connections_needed_| == 0. |
| DCHECK(num_accepted_connections_loop_ || |
| num_accepted_connections_needed_ == 0); |
| if (!num_accepted_connections_loop_ || |
| num_accepted_connections_needed_ != num_connected_sockets_) { |
| return; |
| } |
| |
| num_accepted_connections_loop_->Quit(); |
| num_accepted_connections_needed_ = 0; |
| num_accepted_connections_loop_ = nullptr; |
| } |
| |
| void ConnectionTracker::ResetCounts() { |
| sockets_.clear(); |
| num_connected_sockets_ = 0; |
| num_read_sockets_ = 0; |
| } |
| |
| ConnectionTracker::ConnectionListener::ConnectionListener( |
| ConnectionTracker* tracker) |
| : task_runner_(base::SingleThreadTaskRunner::GetCurrentDefault()), |
| tracker_(tracker) {} |
| |
| ConnectionTracker::ConnectionListener::~ConnectionListener() = default; |
| |
| // Gets called from the EmbeddedTestServer thread to be notified that |
| // a connection was accepted. |
| std::unique_ptr<net::StreamSocket> |
| ConnectionTracker::ConnectionListener::AcceptedSocket( |
| std::unique_ptr<net::StreamSocket> connection) { |
| uint16_t port; |
| if (GetPort(*connection, &port)) { |
| task_runner_->PostTask( |
| FROM_HERE, base::BindOnce(&ConnectionTracker::AcceptedSocketWithPort, |
| base::Unretained(tracker_), port)); |
| } |
| return connection; |
| } |
| |
| // Gets called from the EmbeddedTestServer thread to be notified that |
| // a connection was read from. |
| void ConnectionTracker::ConnectionListener::ReadFromSocket( |
| const net::StreamSocket& connection, |
| int rv) { |
| // Don't log a read if no data was transferred. This case often happens if |
| // the sockets of the test server are being flushed and disconnected. |
| if (rv <= 0) |
| return; |
| uint16_t port; |
| if (GetPort(connection, &port)) { |
| task_runner_->PostTask( |
| FROM_HERE, base::BindOnce(&ConnectionTracker::ReadFromSocketWithPort, |
| base::Unretained(tracker_), port)); |
| } |
| } |
| |
| } // namespace net::test_server |