blob: 06dc980cd723a69a5699f463246684c54509029b [file] [log] [blame]
Andrew Top0d1858f2019-05-15 22:01:47 -07001// Copyright 2018 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/websockets/websocket_basic_handshake_stream.h"
6
7#include <string>
8#include <utility>
9#include <vector>
10
11#include "base/logging.h"
12#include "net/base/address_list.h"
13#include "net/base/ip_address.h"
14#include "net/base/ip_endpoint.h"
15#include "net/base/net_errors.h"
16#include "net/base/test_completion_callback.h"
17#include "net/http/http_request_info.h"
18#include "net/http/http_response_info.h"
19#include "net/log/net_log_with_source.h"
20#include "net/socket/client_socket_handle.h"
21#include "net/socket/socket_test_util.h"
22#include "net/socket/websocket_endpoint_lock_manager.h"
23#include "net/traffic_annotation/network_traffic_annotation.h"
24#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
25#include "net/websockets/websocket_test_util.h"
26#include "url/gurl.h"
27#include "url/origin.h"
28
29namespace net {
30namespace {
31
32TEST(WebSocketBasicHandshakeStreamTest, ConnectionClosedOnFailure) {
33 std::string request = WebSocketStandardRequest(
34 "/", "www.example.org",
35 url::Origin::Create(GURL("http://origin.example.org")), "", "");
36 std::string response =
37 "HTTP/1.1 404 Not Found\r\n"
38 "Content-Length: 0\r\n"
39 "\r\n";
40 MockWrite writes[] = {MockWrite(SYNCHRONOUS, 0, request.c_str())};
41 MockRead reads[] = {MockRead(SYNCHRONOUS, 1, response.c_str()),
42 MockRead(SYNCHRONOUS, ERR_IO_PENDING, 2)};
43 IPEndPoint end_point(IPAddress(127, 0, 0, 1), 80);
44 SequencedSocketData sequenced_socket_data(
45 MockConnect(SYNCHRONOUS, OK, end_point), reads, writes);
46 auto socket = std::make_unique<MockTCPClientSocket>(
47 AddressList(end_point), nullptr, &sequenced_socket_data);
48 const int connect_result = socket->Connect(CompletionOnceCallback());
49 EXPECT_EQ(connect_result, OK);
50 const MockTCPClientSocket* const socket_ptr = socket.get();
51 auto handle = std::make_unique<ClientSocketHandle>();
52 handle->SetSocket(std::move(socket));
53 DummyConnectDelegate delegate;
54 WebSocketEndpointLockManager endpoint_lock_manager;
55 TestWebSocketStreamRequestAPI stream_request_api;
56 std::vector<std::string> extensions = {
57 "permessage-deflate; client_max_window_bits"};
58 WebSocketBasicHandshakeStream basic_handshake_stream(
59 std::move(handle), &delegate, false, {}, extensions, &stream_request_api,
60 &endpoint_lock_manager);
61 basic_handshake_stream.SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
62 HttpRequestInfo request_info;
63 request_info.url = GURL("ws://www.example.com/");
64 request_info.method = "GET";
65 request_info.traffic_annotation =
66 MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
67 TestCompletionCallback callback1;
68 NetLogWithSource net_log;
69 const int result1 =
70 callback1.GetResult(basic_handshake_stream.InitializeStream(
71 &request_info, true, LOWEST, net_log, callback1.callback()));
72 EXPECT_EQ(result1, OK);
73
74 auto request_headers = WebSocketCommonTestHeaders();
75 HttpResponseInfo response_info;
76 TestCompletionCallback callback2;
77 const int result2 = callback2.GetResult(basic_handshake_stream.SendRequest(
78 request_headers, &response_info, callback2.callback()));
79 EXPECT_EQ(result2, OK);
80
81 TestCompletionCallback callback3;
82 const int result3 = callback3.GetResult(
83 basic_handshake_stream.ReadResponseHeaders(callback2.callback()));
84 EXPECT_EQ(result3, ERR_INVALID_RESPONSE);
85
86 EXPECT_FALSE(socket_ptr->IsConnected());
87}
88
89} // namespace
90} // namespace net