| // Copyright (c) 2013 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "net/socket/tcp_client_socket.h" |
| |
| #include <utility> |
| |
| #include "base/callback_helpers.h" |
| #include "base/logging.h" |
| #include "base/memory/ptr_util.h" |
| #include "base/metrics/histogram_macros.h" |
| #include "base/time/time.h" |
| #include "net/base/io_buffer.h" |
| #include "net/base/ip_endpoint.h" |
| #include "net/base/net_errors.h" |
| #include "net/socket/socket_performance_watcher.h" |
| #include "net/traffic_annotation/network_traffic_annotation.h" |
| |
| namespace net { |
| |
| class NetLogWithSource; |
| |
| TCPClientSocket::TCPClientSocket( |
| const AddressList& addresses, |
| std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher, |
| net::NetLog* net_log, |
| const net::NetLogSource& source) |
| : TCPClientSocket( |
| std::make_unique<TCPSocket>(std::move(socket_performance_watcher), |
| net_log, |
| source), |
| addresses, |
| -1 /* current_address_index */, |
| nullptr /* bind_address */) {} |
| |
| TCPClientSocket::TCPClientSocket(std::unique_ptr<TCPSocket> connected_socket, |
| const IPEndPoint& peer_address) |
| : TCPClientSocket(std::move(connected_socket), |
| AddressList(peer_address), |
| 0 /* current_address_index */, |
| nullptr /* bind_address */) {} |
| |
| TCPClientSocket::~TCPClientSocket() { |
| Disconnect(); |
| } |
| |
| std::unique_ptr<TCPClientSocket> TCPClientSocket::CreateFromBoundSocket( |
| std::unique_ptr<TCPSocket> bound_socket, |
| const AddressList& addresses, |
| const IPEndPoint& bound_address) { |
| return base::WrapUnique(new TCPClientSocket( |
| std::move(bound_socket), addresses, -1 /* current_address_index */, |
| std::make_unique<IPEndPoint>(bound_address))); |
| } |
| |
| int TCPClientSocket::Bind(const IPEndPoint& address) { |
| if (current_address_index_ >= 0 || bind_address_) { |
| // Cannot bind the socket if we are already connected or connecting. |
| NOTREACHED(); |
| return ERR_UNEXPECTED; |
| } |
| |
| int result = OK; |
| if (!socket_->IsValid()) { |
| result = OpenSocket(address.GetFamily()); |
| if (result != OK) |
| return result; |
| } |
| |
| result = socket_->Bind(address); |
| if (result != OK) |
| return result; |
| |
| bind_address_.reset(new IPEndPoint(address)); |
| return OK; |
| } |
| |
| bool TCPClientSocket::SetKeepAlive(bool enable, int delay) { |
| return socket_->SetKeepAlive(enable, delay); |
| } |
| |
| bool TCPClientSocket::SetNoDelay(bool no_delay) { |
| return socket_->SetNoDelay(no_delay); |
| } |
| |
| void TCPClientSocket::SetBeforeConnectCallback( |
| const BeforeConnectCallback& before_connect_callback) { |
| DCHECK_EQ(CONNECT_STATE_NONE, next_connect_state_); |
| before_connect_callback_ = before_connect_callback; |
| } |
| |
| int TCPClientSocket::Connect(CompletionOnceCallback callback) { |
| DCHECK(!callback.is_null()); |
| |
| // If connecting or already connected, then just return OK. |
| if (socket_->IsValid() && current_address_index_ >= 0) |
| return OK; |
| |
| socket_->StartLoggingMultipleConnectAttempts(addresses_); |
| |
| // We will try to connect to each address in addresses_. Start with the |
| // first one in the list. |
| next_connect_state_ = CONNECT_STATE_CONNECT; |
| current_address_index_ = 0; |
| |
| int rv = DoConnectLoop(OK); |
| if (rv == ERR_IO_PENDING) { |
| connect_callback_ = std::move(callback); |
| } else { |
| socket_->EndLoggingMultipleConnectAttempts(rv); |
| } |
| |
| return rv; |
| } |
| |
| TCPClientSocket::TCPClientSocket(std::unique_ptr<TCPSocket> socket, |
| const AddressList& addresses, |
| int current_address_index, |
| std::unique_ptr<IPEndPoint> bind_address) |
| : socket_(std::move(socket)), |
| bind_address_(std::move(bind_address)), |
| addresses_(addresses), |
| current_address_index_(-1), |
| next_connect_state_(CONNECT_STATE_NONE), |
| previously_disconnected_(false), |
| total_received_bytes_(0), |
| was_ever_used_(false) { |
| DCHECK(socket_); |
| if (socket_->IsValid()) |
| socket_->SetDefaultOptionsForClient(); |
| } |
| |
| int TCPClientSocket::ReadCommon(IOBuffer* buf, |
| int buf_len, |
| CompletionOnceCallback callback, |
| bool read_if_ready) { |
| DCHECK(!callback.is_null()); |
| |
| // |socket_| is owned by |this| and the callback won't be run once |socket_| |
| // is gone/closed. Therefore, it is safe to use base::Unretained() here. |
| CompletionOnceCallback read_callback = |
| base::BindOnce(&TCPClientSocket::DidCompleteRead, base::Unretained(this), |
| std::move(callback)); |
| int result = |
| read_if_ready |
| ? socket_->ReadIfReady(buf, buf_len, std::move(read_callback)) |
| : socket_->Read(buf, buf_len, std::move(read_callback)); |
| if (result > 0) { |
| was_ever_used_ = true; |
| total_received_bytes_ += result; |
| } |
| |
| return result; |
| } |
| |
| int TCPClientSocket::DoConnectLoop(int result) { |
| DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE); |
| |
| int rv = result; |
| do { |
| ConnectState state = next_connect_state_; |
| next_connect_state_ = CONNECT_STATE_NONE; |
| switch (state) { |
| case CONNECT_STATE_CONNECT: |
| DCHECK_EQ(OK, rv); |
| rv = DoConnect(); |
| break; |
| case CONNECT_STATE_CONNECT_COMPLETE: |
| rv = DoConnectComplete(rv); |
| break; |
| default: |
| NOTREACHED() << "bad state " << state; |
| rv = ERR_UNEXPECTED; |
| break; |
| } |
| } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE); |
| |
| return rv; |
| } |
| |
| int TCPClientSocket::DoConnect() { |
| DCHECK_GE(current_address_index_, 0); |
| DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size())); |
| |
| const IPEndPoint& endpoint = addresses_[current_address_index_]; |
| |
| if (previously_disconnected_) { |
| was_ever_used_ = false; |
| connection_attempts_.clear(); |
| previously_disconnected_ = false; |
| } |
| |
| next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE; |
| |
| if (socket_->IsValid()) { |
| DCHECK(bind_address_); |
| } else { |
| int result = OpenSocket(endpoint.GetFamily()); |
| if (result != OK) |
| return result; |
| |
| if (bind_address_) { |
| result = socket_->Bind(*bind_address_); |
| if (result != OK) { |
| socket_->Close(); |
| return result; |
| } |
| } |
| } |
| |
| if (before_connect_callback_) { |
| int result = before_connect_callback_.Run(); |
| DCHECK_NE(ERR_IO_PENDING, result); |
| if (result != net::OK) |
| return result; |
| } |
| |
| // Notify |socket_performance_watcher_| only if the |socket_| is reused to |
| // connect to a different IP Address. |
| if (socket_->socket_performance_watcher() && current_address_index_ != 0) |
| socket_->socket_performance_watcher()->OnConnectionChanged(); |
| |
| // |socket_| is owned by this class and the callback won't be run once |
| // |socket_| is gone. Therefore, it is safe to use base::Unretained() here. |
| return socket_->Connect(endpoint, |
| base::Bind(&TCPClientSocket::DidCompleteConnect, |
| base::Unretained(this))); |
| } |
| |
| int TCPClientSocket::DoConnectComplete(int result) { |
| if (result == OK) |
| return OK; // Done! |
| |
| connection_attempts_.push_back( |
| ConnectionAttempt(addresses_[current_address_index_], result)); |
| |
| // Close whatever partially connected socket we currently have. |
| DoDisconnect(); |
| |
| // Try to fall back to the next address in the list. |
| if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) { |
| next_connect_state_ = CONNECT_STATE_CONNECT; |
| ++current_address_index_; |
| return OK; |
| } |
| |
| // Otherwise there is nothing to fall back to, so give up. |
| return result; |
| } |
| |
| void TCPClientSocket::Disconnect() { |
| DoDisconnect(); |
| current_address_index_ = -1; |
| bind_address_.reset(); |
| } |
| |
| void TCPClientSocket::DoDisconnect() { |
| total_received_bytes_ = 0; |
| EmitTCPMetricsHistogramsOnDisconnect(); |
| // If connecting or already connected, record that the socket has been |
| // disconnected. |
| previously_disconnected_ = socket_->IsValid() && current_address_index_ >= 0; |
| socket_->Close(); |
| } |
| |
| bool TCPClientSocket::IsConnected() const { |
| return socket_->IsConnected(); |
| } |
| |
| bool TCPClientSocket::IsConnectedAndIdle() const { |
| return socket_->IsConnectedAndIdle(); |
| } |
| |
| int TCPClientSocket::GetPeerAddress(IPEndPoint* address) const { |
| return socket_->GetPeerAddress(address); |
| } |
| |
| int TCPClientSocket::GetLocalAddress(IPEndPoint* address) const { |
| DCHECK(address); |
| |
| if (!socket_->IsValid()) { |
| if (bind_address_) { |
| *address = *bind_address_; |
| return OK; |
| } |
| return ERR_SOCKET_NOT_CONNECTED; |
| } |
| |
| return socket_->GetLocalAddress(address); |
| } |
| |
| const NetLogWithSource& TCPClientSocket::NetLog() const { |
| return socket_->net_log(); |
| } |
| |
| bool TCPClientSocket::WasEverUsed() const { |
| return was_ever_used_; |
| } |
| |
| void TCPClientSocket::EnableTCPFastOpenIfSupported() { |
| socket_->EnableTCPFastOpenIfSupported(); |
| } |
| |
| bool TCPClientSocket::WasAlpnNegotiated() const { |
| return false; |
| } |
| |
| NextProto TCPClientSocket::GetNegotiatedProtocol() const { |
| return kProtoUnknown; |
| } |
| |
| bool TCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { |
| return false; |
| } |
| |
| int TCPClientSocket::Read(IOBuffer* buf, |
| int buf_len, |
| CompletionOnceCallback callback) { |
| return ReadCommon(buf, buf_len, std::move(callback), /*read_if_ready=*/false); |
| } |
| |
| int TCPClientSocket::ReadIfReady(IOBuffer* buf, |
| int buf_len, |
| CompletionOnceCallback callback) { |
| return ReadCommon(buf, buf_len, std::move(callback), /*read_if_ready=*/true); |
| } |
| |
| int TCPClientSocket::CancelReadIfReady() { |
| return socket_->CancelReadIfReady(); |
| } |
| |
| int TCPClientSocket::Write( |
| IOBuffer* buf, |
| int buf_len, |
| CompletionOnceCallback callback, |
| const NetworkTrafficAnnotationTag& traffic_annotation) { |
| DCHECK(!callback.is_null()); |
| |
| // |socket_| is owned by this class and the callback won't be run once |
| // |socket_| is gone. Therefore, it is safe to use base::Unretained() here. |
| CompletionOnceCallback write_callback = |
| base::BindOnce(&TCPClientSocket::DidCompleteWrite, base::Unretained(this), |
| std::move(callback)); |
| int result = socket_->Write(buf, buf_len, std::move(write_callback), |
| traffic_annotation); |
| if (result > 0) |
| was_ever_used_ = true; |
| |
| return result; |
| } |
| |
| int TCPClientSocket::SetReceiveBufferSize(int32_t size) { |
| return socket_->SetReceiveBufferSize(size); |
| } |
| |
| int TCPClientSocket::SetSendBufferSize(int32_t size) { |
| return socket_->SetSendBufferSize(size); |
| } |
| |
| SocketDescriptor TCPClientSocket::SocketDescriptorForTesting() const { |
| return socket_->SocketDescriptorForTesting(); |
| } |
| |
| void TCPClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const { |
| *out = connection_attempts_; |
| } |
| |
| void TCPClientSocket::ClearConnectionAttempts() { |
| connection_attempts_.clear(); |
| } |
| |
| void TCPClientSocket::AddConnectionAttempts( |
| const ConnectionAttempts& attempts) { |
| connection_attempts_.insert(connection_attempts_.begin(), attempts.begin(), |
| attempts.end()); |
| } |
| |
| int64_t TCPClientSocket::GetTotalReceivedBytes() const { |
| return total_received_bytes_; |
| } |
| |
| void TCPClientSocket::ApplySocketTag(const SocketTag& tag) { |
| socket_->ApplySocketTag(tag); |
| } |
| |
| void TCPClientSocket::DidCompleteConnect(int result) { |
| DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE); |
| DCHECK_NE(result, ERR_IO_PENDING); |
| DCHECK(!connect_callback_.is_null()); |
| |
| result = DoConnectLoop(result); |
| if (result != ERR_IO_PENDING) { |
| socket_->EndLoggingMultipleConnectAttempts(result); |
| std::move(connect_callback_).Run(result); |
| } |
| } |
| |
| void TCPClientSocket::DidCompleteRead(CompletionOnceCallback callback, |
| int result) { |
| if (result > 0) |
| total_received_bytes_ += result; |
| |
| DidCompleteReadWrite(std::move(callback), result); |
| } |
| |
| void TCPClientSocket::DidCompleteWrite(CompletionOnceCallback callback, |
| int result) { |
| DidCompleteReadWrite(std::move(callback), result); |
| } |
| |
| void TCPClientSocket::DidCompleteReadWrite(CompletionOnceCallback callback, |
| int result) { |
| if (result > 0) |
| was_ever_used_ = true; |
| std::move(callback).Run(result); |
| } |
| |
| int TCPClientSocket::OpenSocket(AddressFamily family) { |
| DCHECK(!socket_->IsValid()); |
| |
| int result = socket_->Open(family); |
| if (result != OK) |
| return result; |
| |
| socket_->SetDefaultOptionsForClient(); |
| |
| return OK; |
| } |
| |
| void TCPClientSocket::EmitTCPMetricsHistogramsOnDisconnect() { |
| base::TimeDelta rtt; |
| if (socket_->GetEstimatedRoundTripTime(&rtt)) { |
| UMA_HISTOGRAM_CUSTOM_TIMES("Net.TcpRtt.AtDisconnect", rtt, |
| base::TimeDelta::FromMilliseconds(1), |
| base::TimeDelta::FromMinutes(10), 100); |
| } |
| } |
| |
| } // namespace net |