blob: 7f581650b38939c1b2c91d1ebe24633909f0acc5 [file] [log] [blame]
// Copyright 2016 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/spdy/bidirectional_stream_spdy_impl.h"
#include <string>
#include "base/containers/span.h"
#include "base/macros.h"
#include "base/run_loop.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_piece.h"
#include "base/time/time.h"
#include "base/timer/mock_timer.h"
#include "net/base/load_timing_info.h"
#include "net/base/load_timing_info_test_util.h"
#include "net/base/net_errors.h"
#include "net/http/http_request_info.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_response_info.h"
#include "net/log/test_net_log.h"
#include "net/socket/socket_tag.h"
#include "net/socket/socket_test_util.h"
#include "net/spdy/spdy_session.h"
#include "net/spdy/spdy_test_util_common.h"
#include "net/test/cert_test_util.h"
#include "net/test/gtest_util.h"
#include "net/test/test_data_directory.h"
#include "net/test/test_with_scoped_task_environment.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using net::test::IsError;
using net::test::IsOk;
namespace net {
namespace {
const char kBodyData[] = "Body data";
const size_t kBodyDataSize = arraysize(kBodyData);
// Size of the buffer to be allocated for each read.
const size_t kReadBufferSize = 4096;
// Tests the load timing of a stream that's connected and is not the first
// request sent on a connection.
void TestLoadTimingReused(const LoadTimingInfo& load_timing_info) {
EXPECT_TRUE(load_timing_info.socket_reused);
EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id);
ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
}
// Tests the load timing of a stream that's connected and using a fresh
// connection.
void TestLoadTimingNotReused(const LoadTimingInfo& load_timing_info) {
EXPECT_FALSE(load_timing_info.socket_reused);
EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id);
ExpectConnectTimingHasTimes(
load_timing_info.connect_timing,
CONNECT_TIMING_HAS_SSL_TIMES | CONNECT_TIMING_HAS_DNS_TIMES);
ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
}
class TestDelegateBase : public BidirectionalStreamImpl::Delegate {
public:
TestDelegateBase(base::WeakPtr<SpdySession> session,
IOBuffer* read_buf,
int read_buf_len)
: stream_(std::make_unique<BidirectionalStreamSpdyImpl>(session,
NetLogSource())),
read_buf_(read_buf),
read_buf_len_(read_buf_len),
loop_(nullptr),
error_(OK),
bytes_read_(0),
on_data_read_count_(0),
on_data_sent_count_(0),
do_not_start_read_(false),
run_until_completion_(false),
not_expect_callback_(false),
on_failed_called_(false) {}
~TestDelegateBase() override = default;
void OnStreamReady(bool request_headers_sent) override {
CHECK(!on_failed_called_);
}
void OnHeadersReceived(
const spdy::SpdyHeaderBlock& response_headers) override {
CHECK(!on_failed_called_);
CHECK(!not_expect_callback_);
response_headers_ = response_headers.Clone();
if (!do_not_start_read_)
StartOrContinueReading();
}
void OnDataRead(int bytes_read) override {
CHECK(!on_failed_called_);
CHECK(!not_expect_callback_);
on_data_read_count_++;
CHECK_GE(bytes_read, OK);
bytes_read_ += bytes_read;
data_received_.append(read_buf_->data(), bytes_read);
if (!do_not_start_read_)
StartOrContinueReading();
}
void OnDataSent() override {
CHECK(!on_failed_called_);
CHECK(!not_expect_callback_);
on_data_sent_count_++;
}
void OnTrailersReceived(const spdy::SpdyHeaderBlock& trailers) override {
CHECK(!on_failed_called_);
trailers_ = trailers.Clone();
if (run_until_completion_)
loop_->Quit();
}
void OnFailed(int error) override {
CHECK(!on_failed_called_);
CHECK(!not_expect_callback_);
CHECK_NE(OK, error);
error_ = error;
on_failed_called_ = true;
if (run_until_completion_)
loop_->Quit();
}
void Start(const BidirectionalStreamRequestInfo* request,
const NetLogWithSource& net_log) {
stream_->Start(request, net_log,
/*send_request_headers_automatically=*/false, this,
std::make_unique<base::OneShotTimer>(),
TRAFFIC_ANNOTATION_FOR_TESTS);
not_expect_callback_ = false;
}
void SendData(IOBuffer* data, int length, bool end_of_stream) {
SendvData({data}, {length}, end_of_stream);
}
void SendvData(const std::vector<scoped_refptr<IOBuffer>>& data,
const std::vector<int>& length,
bool end_of_stream) {
not_expect_callback_ = true;
stream_->SendvData(data, length, end_of_stream);
not_expect_callback_ = false;
}
// Sets whether the delegate should wait until the completion of the stream.
void SetRunUntilCompletion(bool run_until_completion) {
run_until_completion_ = run_until_completion;
loop_ = std::make_unique<base::RunLoop>();
}
// Wait until the stream reaches completion.
void WaitUntilCompletion() { loop_->Run(); }
// Starts or continues read data from |stream_| until there is no more
// byte can be read synchronously.
void StartOrContinueReading() {
int rv = ReadData();
while (rv > 0) {
rv = ReadData();
}
if (run_until_completion_ && rv == 0)
loop_->Quit();
}
// Calls ReadData on the |stream_| and updates internal states.
int ReadData() {
int rv = stream_->ReadData(read_buf_.get(), read_buf_len_);
if (rv > 0) {
data_received_.append(read_buf_->data(), rv);
bytes_read_ += rv;
}
return rv;
}
NextProto GetProtocol() const { return stream_->GetProtocol(); }
int64_t GetTotalReceivedBytes() const {
return stream_->GetTotalReceivedBytes();
}
int64_t GetTotalSentBytes() const {
return stream_->GetTotalSentBytes();
}
bool GetLoadTimingInfo(LoadTimingInfo* load_timing_info) const {
return stream_->GetLoadTimingInfo(load_timing_info);
}
// Const getters for internal states.
const std::string& data_received() const { return data_received_; }
int bytes_read() const { return bytes_read_; }
int error() const { return error_; }
const spdy::SpdyHeaderBlock& response_headers() const {
return response_headers_;
}
const spdy::SpdyHeaderBlock& trailers() const { return trailers_; }
int on_data_read_count() const { return on_data_read_count_; }
int on_data_sent_count() const { return on_data_sent_count_; }
bool on_failed_called() const { return on_failed_called_; }
// Sets whether the delegate should automatically start reading.
void set_do_not_start_read(bool do_not_start_read) {
do_not_start_read_ = do_not_start_read;
}
private:
std::unique_ptr<BidirectionalStreamSpdyImpl> stream_;
scoped_refptr<IOBuffer> read_buf_;
int read_buf_len_;
std::string data_received_;
std::unique_ptr<base::RunLoop> loop_;
spdy::SpdyHeaderBlock response_headers_;
spdy::SpdyHeaderBlock trailers_;
int error_;
int bytes_read_;
int on_data_read_count_;
int on_data_sent_count_;
bool do_not_start_read_;
bool run_until_completion_;
bool not_expect_callback_;
bool on_failed_called_;
DISALLOW_COPY_AND_ASSIGN(TestDelegateBase);
};
} // namespace
class BidirectionalStreamSpdyImplTest : public testing::TestWithParam<bool>,
public WithScopedTaskEnvironment {
public:
BidirectionalStreamSpdyImplTest()
: default_url_(kDefaultUrl),
host_port_pair_(HostPortPair::FromURL(default_url_)),
key_(host_port_pair_,
ProxyServer::Direct(),
PRIVACY_MODE_DISABLED,
SocketTag()),
ssl_data_(SSLSocketDataProvider(ASYNC, OK)) {
ssl_data_.next_proto = kProtoHTTP2;
ssl_data_.ssl_info.cert =
ImportCertFromFile(GetTestCertsDirectory(), "ok_cert.pem");
}
protected:
void TearDown() override {
if (sequenced_data_) {
EXPECT_TRUE(sequenced_data_->AllReadDataConsumed());
EXPECT_TRUE(sequenced_data_->AllWriteDataConsumed());
}
}
// Initializes the session using SequencedSocketData.
void InitSession(base::span<const MockRead> reads,
base::span<const MockWrite> writes) {
ASSERT_TRUE(ssl_data_.ssl_info.cert.get());
session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl_data_);
sequenced_data_ = std::make_unique<SequencedSocketData>(reads, writes);
session_deps_.socket_factory->AddSocketDataProvider(sequenced_data_.get());
session_deps_.net_log = net_log_.bound().net_log();
http_session_ = SpdySessionDependencies::SpdyCreateSession(&session_deps_);
session_ = CreateSpdySession(http_session_.get(), key_, net_log_.bound());
}
BoundTestNetLog net_log_;
SpdyTestUtil spdy_util_;
SpdySessionDependencies session_deps_;
const GURL default_url_;
const HostPortPair host_port_pair_;
const SpdySessionKey key_;
std::unique_ptr<SequencedSocketData> sequenced_data_;
std::unique_ptr<HttpNetworkSession> http_session_;
base::WeakPtr<SpdySession> session_;
private:
SSLSocketDataProvider ssl_data_;
};
TEST_F(BidirectionalStreamSpdyImplTest, SimplePostRequest) {
spdy::SpdySerializedFrame req(spdy_util_.ConstructSpdyPost(
kDefaultUrl, 1, kBodyDataSize, LOW, nullptr, 0));
spdy::SpdySerializedFrame data_frame(spdy_util_.ConstructSpdyDataFrame(
1, base::StringPiece(kBodyData, kBodyDataSize), /*fin=*/true));
MockWrite writes[] = {
CreateMockWrite(req, 0), CreateMockWrite(data_frame, 3),
};
spdy::SpdySerializedFrame resp(spdy_util_.ConstructSpdyPostReply(nullptr, 0));
spdy::SpdySerializedFrame response_body_frame(
spdy_util_.ConstructSpdyDataFrame(1, /*fin=*/true));
MockRead reads[] = {
CreateMockRead(resp, 1),
MockRead(ASYNC, ERR_IO_PENDING, 2), // Force a pause.
CreateMockRead(response_body_frame, 4), MockRead(ASYNC, 0, 5),
};
InitSession(reads, writes);
BidirectionalStreamRequestInfo request_info;
request_info.method = "POST";
request_info.url = default_url_;
request_info.extra_headers.SetHeader(net::HttpRequestHeaders::kContentLength,
base::NumberToString(kBodyDataSize));
scoped_refptr<IOBuffer> read_buffer =
base::MakeRefCounted<IOBuffer>(kReadBufferSize);
auto delegate = std::make_unique<TestDelegateBase>(
session_, read_buffer.get(), kReadBufferSize);
delegate->SetRunUntilCompletion(true);
delegate->Start(&request_info, net_log_.bound());
sequenced_data_->RunUntilPaused();
scoped_refptr<StringIOBuffer> write_buffer =
base::MakeRefCounted<StringIOBuffer>(
std::string(kBodyData, kBodyDataSize));
delegate->SendData(write_buffer.get(), write_buffer->size(), true);
sequenced_data_->Resume();
base::RunLoop().RunUntilIdle();
delegate->WaitUntilCompletion();
LoadTimingInfo load_timing_info;
EXPECT_TRUE(delegate->GetLoadTimingInfo(&load_timing_info));
TestLoadTimingNotReused(load_timing_info);
EXPECT_EQ(1, delegate->on_data_read_count());
EXPECT_EQ(1, delegate->on_data_sent_count());
EXPECT_EQ(kProtoHTTP2, delegate->GetProtocol());
EXPECT_EQ(CountWriteBytes(writes), delegate->GetTotalSentBytes());
EXPECT_EQ(CountReadBytes(reads), delegate->GetTotalReceivedBytes());
}
TEST_F(BidirectionalStreamSpdyImplTest, LoadTimingTwoRequests) {
spdy::SpdySerializedFrame req(
spdy_util_.ConstructSpdyGet(nullptr, 0, /*stream_id=*/1, LOW));
spdy::SpdySerializedFrame req2(
spdy_util_.ConstructSpdyGet(nullptr, 0, /*stream_id=*/3, LOW));
MockWrite writes[] = {
CreateMockWrite(req, 0), CreateMockWrite(req2, 2),
};
spdy::SpdySerializedFrame resp(
spdy_util_.ConstructSpdyGetReply(nullptr, 0, /*stream_id=*/1));
spdy::SpdySerializedFrame resp2(
spdy_util_.ConstructSpdyGetReply(nullptr, 0, /*stream_id=*/3));
spdy::SpdySerializedFrame resp_body(
spdy_util_.ConstructSpdyDataFrame(/*stream_id=*/1, /*fin=*/true));
spdy::SpdySerializedFrame resp_body2(
spdy_util_.ConstructSpdyDataFrame(/*stream_id=*/3, /*fin=*/true));
MockRead reads[] = {CreateMockRead(resp, 1), CreateMockRead(resp_body, 3),
CreateMockRead(resp2, 4), CreateMockRead(resp_body2, 5),
MockRead(ASYNC, 0, 6)};
InitSession(reads, writes);
BidirectionalStreamRequestInfo request_info;
request_info.method = "GET";
request_info.url = default_url_;
request_info.end_stream_on_headers = true;
scoped_refptr<IOBuffer> read_buffer =
base::MakeRefCounted<IOBuffer>(kReadBufferSize);
scoped_refptr<IOBuffer> read_buffer2 =
base::MakeRefCounted<IOBuffer>(kReadBufferSize);
auto delegate = std::make_unique<TestDelegateBase>(
session_, read_buffer.get(), kReadBufferSize);
auto delegate2 = std::make_unique<TestDelegateBase>(
session_, read_buffer2.get(), kReadBufferSize);
delegate->SetRunUntilCompletion(true);
delegate2->SetRunUntilCompletion(true);
delegate->Start(&request_info, net_log_.bound());
delegate2->Start(&request_info, net_log_.bound());
base::RunLoop().RunUntilIdle();
delegate->WaitUntilCompletion();
delegate2->WaitUntilCompletion();
LoadTimingInfo load_timing_info;
EXPECT_TRUE(delegate->GetLoadTimingInfo(&load_timing_info));
TestLoadTimingNotReused(load_timing_info);
LoadTimingInfo load_timing_info2;
EXPECT_TRUE(delegate2->GetLoadTimingInfo(&load_timing_info2));
TestLoadTimingReused(load_timing_info2);
}
TEST_F(BidirectionalStreamSpdyImplTest, SendDataAfterStreamFailed) {
spdy::SpdySerializedFrame req(spdy_util_.ConstructSpdyPost(
kDefaultUrl, 1, kBodyDataSize * 3, LOW, nullptr, 0));
spdy::SpdySerializedFrame rst(
spdy_util_.ConstructSpdyRstStream(1, spdy::ERROR_CODE_PROTOCOL_ERROR));
MockWrite writes[] = {
CreateMockWrite(req, 0), CreateMockWrite(rst, 2),
};
const char* const kExtraHeaders[] = {"X-UpperCase", "yes"};
spdy::SpdySerializedFrame resp(
spdy_util_.ConstructSpdyGetReply(kExtraHeaders, 1, 1));
MockRead reads[] = {
CreateMockRead(resp, 1), MockRead(ASYNC, 0, 3),
};
InitSession(reads, writes);
BidirectionalStreamRequestInfo request_info;
request_info.method = "POST";
request_info.url = default_url_;
request_info.extra_headers.SetHeader(net::HttpRequestHeaders::kContentLength,
base::NumberToString(kBodyDataSize * 3));
scoped_refptr<IOBuffer> read_buffer =
base::MakeRefCounted<IOBuffer>(kReadBufferSize);
auto delegate = std::make_unique<TestDelegateBase>(
session_, read_buffer.get(), kReadBufferSize);
delegate->SetRunUntilCompletion(true);
delegate->Start(&request_info, net_log_.bound());
base::RunLoop().RunUntilIdle();
EXPECT_TRUE(delegate->on_failed_called());
// Try to send data after OnFailed(), should not get called back.
scoped_refptr<StringIOBuffer> buf =
base::MakeRefCounted<StringIOBuffer>("dummy");
delegate->SendData(buf.get(), buf->size(), false);
base::RunLoop().RunUntilIdle();
EXPECT_THAT(delegate->error(), IsError(ERR_SPDY_PROTOCOL_ERROR));
EXPECT_EQ(0, delegate->on_data_read_count());
EXPECT_EQ(0, delegate->on_data_sent_count());
EXPECT_EQ(kProtoHTTP2, delegate->GetProtocol());
// BidirectionalStreamSpdyStreamJob does not count the bytes sent for |rst|
// because it is sent after SpdyStream::Delegate::OnClose is called.
EXPECT_EQ(CountWriteBytes(base::make_span(writes, 1)),
delegate->GetTotalSentBytes());
EXPECT_EQ(0, delegate->GetTotalReceivedBytes());
}
INSTANTIATE_TEST_CASE_P(BidirectionalStreamSpdyImplTests,
BidirectionalStreamSpdyImplTest,
::testing::Bool());
// Tests that when received RST_STREAM with NO_ERROR, BidirectionalStream does
// not crash when processing pending writes. See crbug.com/650438.
TEST_P(BidirectionalStreamSpdyImplTest, RstWithNoErrorBeforeSendIsComplete) {
bool is_test_sendv = GetParam();
spdy::SpdySerializedFrame req(spdy_util_.ConstructSpdyPost(
kDefaultUrl, 1, kBodyDataSize * 3, LOW, nullptr, 0));
MockWrite writes[] = {CreateMockWrite(req, 0)};
spdy::SpdySerializedFrame resp(spdy_util_.ConstructSpdyPostReply(nullptr, 0));
spdy::SpdySerializedFrame rst(
spdy_util_.ConstructSpdyRstStream(1, spdy::ERROR_CODE_NO_ERROR));
MockRead reads[] = {CreateMockRead(resp, 1),
MockRead(ASYNC, ERR_IO_PENDING, 2), // Force a pause.
CreateMockRead(rst, 3), MockRead(ASYNC, 0, 4)};
InitSession(reads, writes);
BidirectionalStreamRequestInfo request_info;
request_info.method = "POST";
request_info.url = default_url_;
request_info.extra_headers.SetHeader(net::HttpRequestHeaders::kContentLength,
base::NumberToString(kBodyDataSize * 3));
scoped_refptr<IOBuffer> read_buffer =
base::MakeRefCounted<IOBuffer>(kReadBufferSize);
auto delegate = std::make_unique<TestDelegateBase>(
session_, read_buffer.get(), kReadBufferSize);
delegate->SetRunUntilCompletion(true);
delegate->Start(&request_info, net_log_.bound());
sequenced_data_->RunUntilPaused();
// Make a write pending before receiving RST_STREAM.
scoped_refptr<StringIOBuffer> write_buffer =
base::MakeRefCounted<StringIOBuffer>(
std::string(kBodyData, kBodyDataSize));
delegate->SendData(write_buffer.get(), write_buffer->size(), false);
sequenced_data_->Resume();
base::RunLoop().RunUntilIdle();
// Make sure OnClose() without an error completes any pending write().
EXPECT_EQ(1, delegate->on_data_sent_count());
EXPECT_FALSE(delegate->on_failed_called());
if (is_test_sendv) {
std::vector<scoped_refptr<IOBuffer>> three_buffers = {
write_buffer.get(), write_buffer.get(), write_buffer.get()};
std::vector<int> three_lengths = {
write_buffer->size(), write_buffer->size(), write_buffer->size()};
delegate->SendvData(three_buffers, three_lengths, /*end_of_stream=*/true);
base::RunLoop().RunUntilIdle();
} else {
for (size_t j = 0; j < 3; j++) {
delegate->SendData(write_buffer.get(), write_buffer->size(),
/*end_of_stream=*/j == 2);
base::RunLoop().RunUntilIdle();
}
}
delegate->WaitUntilCompletion();
LoadTimingInfo load_timing_info;
EXPECT_TRUE(delegate->GetLoadTimingInfo(&load_timing_info));
TestLoadTimingNotReused(load_timing_info);
EXPECT_THAT(delegate->error(), IsError(OK));
EXPECT_EQ(1, delegate->on_data_read_count());
EXPECT_EQ(is_test_sendv ? 2 : 4, delegate->on_data_sent_count());
EXPECT_EQ(kProtoHTTP2, delegate->GetProtocol());
EXPECT_EQ(CountWriteBytes(base::make_span(writes, 1)),
delegate->GetTotalSentBytes());
// Should not count RST stream.
EXPECT_EQ(CountReadBytes(base::make_span(reads).first(base::size(reads) - 2)),
delegate->GetTotalReceivedBytes());
// Now call SendData again should produce an error because end of stream
// flag has been written.
if (is_test_sendv) {
std::vector<scoped_refptr<IOBuffer>> buffer = {write_buffer.get()};
std::vector<int> buffer_size = {write_buffer->size()};
delegate->SendvData(buffer, buffer_size, true);
} else {
delegate->SendData(write_buffer.get(), write_buffer->size(), true);
}
base::RunLoop().RunUntilIdle();
EXPECT_THAT(delegate->error(), IsError(ERR_UNEXPECTED));
EXPECT_TRUE(delegate->on_failed_called());
EXPECT_EQ(is_test_sendv ? 2 : 4, delegate->on_data_sent_count());
}
} // namespace net