blob: 991ea3147d1c407f0b72adbd1aaf3b896342500b [file] [log] [blame]
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/websockets/websocket_handshake_handler.h"
#include "base/base64.h"
#include "base/md5.h"
#include "base/sha1.h"
#include "base/string_number_conversions.h"
#include "base/string_piece.h"
#include "base/string_util.h"
#include "base/stringprintf.h"
#include "googleurl/src/gurl.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_util.h"
namespace {
const size_t kRequestKey3Size = 8U;
const size_t kResponseKeySize = 16U;
// First version that introduced new WebSocket handshake which does not
// require sending "key3" or "response key" data after headers.
const int kMinVersionOfHybiNewHandshake = 4;
// Used when we calculate the value of Sec-WebSocket-Accept.
const char* const kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
void ParseHandshakeHeader(
const char* handshake_message, int len,
std::string* status_line,
std::string* headers) {
size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n");
if (i == base::StringPiece::npos) {
*status_line = std::string(handshake_message, len);
*headers = "";
return;
}
// |status_line| includes \r\n.
*status_line = std::string(handshake_message, i + 2);
int header_len = len - (i + 2) - 2;
if (header_len > 0) {
// |handshake_message| includes tailing \r\n\r\n.
// |headers| doesn't include 2nd \r\n.
*headers = std::string(handshake_message + i + 2, header_len);
} else {
*headers = "";
}
}
void FetchHeaders(const std::string& headers,
const char* const headers_to_get[],
size_t headers_to_get_len,
std::vector<std::string>* values) {
net::HttpUtil::HeadersIterator iter(headers.begin(), headers.end(), "\r\n");
while (iter.GetNext()) {
for (size_t i = 0; i < headers_to_get_len; i++) {
if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
headers_to_get[i])) {
values->push_back(iter.values());
}
}
}
}
bool GetHeaderName(std::string::const_iterator line_begin,
std::string::const_iterator line_end,
std::string::const_iterator* name_begin,
std::string::const_iterator* name_end) {
std::string::const_iterator colon = std::find(line_begin, line_end, ':');
if (colon == line_end) {
return false;
}
*name_begin = line_begin;
*name_end = colon;
if (*name_begin == *name_end || net::HttpUtil::IsLWS(**name_begin))
return false;
net::HttpUtil::TrimLWS(name_begin, name_end);
return true;
}
// Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that
// is, lines that are not formatted as "<name>: <value>\r\n".
std::string FilterHeaders(
const std::string& headers,
const char* const headers_to_remove[],
size_t headers_to_remove_len) {
std::string filtered_headers;
StringTokenizer lines(headers.begin(), headers.end(), "\r\n");
while (lines.GetNext()) {
std::string::const_iterator line_begin = lines.token_begin();
std::string::const_iterator line_end = lines.token_end();
std::string::const_iterator name_begin;
std::string::const_iterator name_end;
bool should_remove = false;
if (GetHeaderName(line_begin, line_end, &name_begin, &name_end)) {
for (size_t i = 0; i < headers_to_remove_len; ++i) {
if (LowerCaseEqualsASCII(name_begin, name_end, headers_to_remove[i])) {
should_remove = true;
break;
}
}
}
if (!should_remove) {
filtered_headers.append(line_begin, line_end);
filtered_headers.append("\r\n");
}
}
return filtered_headers;
}
// Gets a key number from |key| and appends the number to |challenge|.
// The key number (/part_N/) is extracted as step 4.-8. in
// 5.2. Sending the server's opening handshake of
// http://www.ietf.org/id/draft-ietf-hybi-thewebsocketprotocol-00.txt
void GetKeyNumber(const std::string& key, std::string* challenge) {
uint32 key_number = 0;
uint32 spaces = 0;
for (size_t i = 0; i < key.size(); ++i) {
if (isdigit(key[i])) {
// key_number should not overflow. (it comes from
// WebCore/websockets/WebSocketHandshake.cpp).
key_number = key_number * 10 + key[i] - '0';
} else if (key[i] == ' ') {
++spaces;
}
}
// spaces should not be zero in valid handshake request.
if (spaces == 0)
return;
key_number /= spaces;
char part[4];
for (int i = 0; i < 4; i++) {
part[3 - i] = key_number & 0xFF;
key_number >>= 8;
}
challenge->append(part, 4);
}
int GetVersionFromRequest(const std::string& request_headers) {
std::vector<std::string> values;
const char* const headers_to_get[2] = { "sec-websocket-version",
"sec-websocket-draft" };
FetchHeaders(request_headers, headers_to_get, 2, &values);
DCHECK_LE(values.size(), 1U);
if (values.empty())
return 0;
int version;
bool conversion_success = base::StringToInt(values[0], &version);
DCHECK(conversion_success);
DCHECK_GE(version, 1);
return version;
}
} // anonymous namespace
namespace net {
WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler()
: original_length_(0),
raw_length_(0),
protocol_version_(-1) {}
bool WebSocketHandshakeRequestHandler::ParseRequest(
const char* data, int length) {
DCHECK_GT(length, 0);
std::string input(data, length);
int input_header_length =
HttpUtil::LocateEndOfHeaders(input.data(), input.size(), 0);
if (input_header_length <= 0)
return false;
ParseHandshakeHeader(input.data(),
input_header_length,
&status_line_,
&headers_);
// WebSocket protocol drafts hixie-76 (hybi-00), hybi-01, 02 and 03 require
// the clients to send key3 after the handshake request header fields.
// Hybi-04 and later drafts, on the other hand, no longer have key3
// in the handshake format.
protocol_version_ = GetVersionFromRequest(headers_);
DCHECK_GE(protocol_version_, 0);
if (protocol_version_ >= kMinVersionOfHybiNewHandshake) {
key3_ = "";
original_length_ = input_header_length;
return true;
}
if (input_header_length + kRequestKey3Size > input.size())
return false;
// Assumes WebKit doesn't send any data after handshake request message
// until handshake is finished.
// Thus, |key3_| is part of handshake message, and not in part
// of WebSocket frame stream.
DCHECK_EQ(kRequestKey3Size, input.size() - input_header_length);
key3_ = std::string(input.data() + input_header_length,
input.size() - input_header_length);
original_length_ = input.size();
return true;
}
size_t WebSocketHandshakeRequestHandler::original_length() const {
return original_length_;
}
void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing(
const std::string& name, const std::string& value) {
DCHECK(!headers_.empty());
HttpUtil::AppendHeaderIfMissing(name.c_str(), value, &headers_);
}
void WebSocketHandshakeRequestHandler::RemoveHeaders(
const char* const headers_to_remove[],
size_t headers_to_remove_len) {
DCHECK(!headers_.empty());
headers_ = FilterHeaders(
headers_, headers_to_remove, headers_to_remove_len);
}
HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo(
const GURL& url, std::string* challenge) {
HttpRequestInfo request_info;
request_info.url = url;
size_t method_end = base::StringPiece(status_line_).find_first_of(" ");
if (method_end != base::StringPiece::npos)
request_info.method = std::string(status_line_.data(), method_end);
request_info.extra_headers.Clear();
request_info.extra_headers.AddHeadersFromString(headers_);
request_info.extra_headers.RemoveHeader("Upgrade");
request_info.extra_headers.RemoveHeader("Connection");
if (protocol_version_ >= kMinVersionOfHybiNewHandshake) {
std::string key;
bool header_present =
request_info.extra_headers.GetHeader("Sec-WebSocket-Key", &key);
DCHECK(header_present);
request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key");
*challenge = key;
} else {
challenge->clear();
std::string key;
bool header_present =
request_info.extra_headers.GetHeader("Sec-WebSocket-Key1", &key);
DCHECK(header_present);
request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key1");
GetKeyNumber(key, challenge);
header_present =
request_info.extra_headers.GetHeader("Sec-WebSocket-Key2", &key);
DCHECK(header_present);
request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key2");
GetKeyNumber(key, challenge);
challenge->append(key3_);
}
return request_info;
}
bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock(
const GURL& url,
SpdyHeaderBlock* headers,
std::string* challenge,
int spdy_protocol_version) {
// Construct opening handshake request headers as a SPDY header block.
// For details, see WebSocket Layering over SPDY/3 Draft 8.
if (spdy_protocol_version <= 2) {
(*headers)["path"] = url.path();
(*headers)["version"] =
base::StringPrintf("%s%d", "WebSocket/", protocol_version_);
(*headers)["scheme"] = url.scheme();
} else {
(*headers)[":path"] = url.path();
(*headers)[":version"] =
base::StringPrintf("%s%d", "WebSocket/", protocol_version_);
(*headers)[":scheme"] = url.scheme();
}
HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n");
while (iter.GetNext()) {
if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), "upgrade") ||
LowerCaseEqualsASCII(iter.name_begin(),
iter.name_end(),
"connection") ||
LowerCaseEqualsASCII(iter.name_begin(),
iter.name_end(),
"sec-websocket-version")) {
// These headers must be ignored.
continue;
} else if (LowerCaseEqualsASCII(iter.name_begin(),
iter.name_end(),
"sec-websocket-key")) {
*challenge = iter.values();
// Sec-WebSocket-Key is not sent to a server.
continue;
} else if (LowerCaseEqualsASCII(iter.name_begin(),
iter.name_end(),
"host") ||
LowerCaseEqualsASCII(iter.name_begin(),
iter.name_end(),
"origin") ||
LowerCaseEqualsASCII(iter.name_begin(),
iter.name_end(),
"sec-websocket-protocol") ||
LowerCaseEqualsASCII(iter.name_begin(),
iter.name_end(),
"sec-websocket-extensions")) {
// TODO(toyoshim): Some WebSocket extensions may not be compatible with
// SPDY. We should omit them from a Sec-WebSocket-Extension header.
std::string name;
if (spdy_protocol_version <= 2)
name = StringToLowerASCII(iter.name());
else
name = ":" + StringToLowerASCII(iter.name());
(*headers)[name] = iter.values();
continue;
}
// Others should be sent out to |headers|.
std::string name = StringToLowerASCII(iter.name());
SpdyHeaderBlock::iterator found = headers->find(name);
if (found == headers->end()) {
(*headers)[name] = iter.values();
} else {
// For now, websocket doesn't use multiple headers, but follows to http.
found->second.append(1, '\0'); // +=() doesn't append 0's
found->second.append(iter.values());
}
}
return true;
}
std::string WebSocketHandshakeRequestHandler::GetRawRequest() {
DCHECK(!status_line_.empty());
DCHECK(!headers_.empty());
// The following works on both hybi-04 and older handshake,
// because |key3_| is guaranteed to be empty if the handshake was hybi-04's.
std::string raw_request = status_line_ + headers_ + "\r\n" + key3_;
raw_length_ = raw_request.size();
return raw_request;
}
size_t WebSocketHandshakeRequestHandler::raw_length() const {
DCHECK_GT(raw_length_, 0);
return raw_length_;
}
int WebSocketHandshakeRequestHandler::protocol_version() const {
DCHECK_GE(protocol_version_, 0);
return protocol_version_;
}
WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler()
: original_header_length_(0),
protocol_version_(0) {}
WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {}
int WebSocketHandshakeResponseHandler::protocol_version() const {
DCHECK_GE(protocol_version_, 0);
return protocol_version_;
}
void WebSocketHandshakeResponseHandler::set_protocol_version(
int protocol_version) {
DCHECK_GE(protocol_version, 0);
protocol_version_ = protocol_version;
}
size_t WebSocketHandshakeResponseHandler::ParseRawResponse(
const char* data, int length) {
DCHECK_GT(length, 0);
if (HasResponse()) {
DCHECK(!status_line_.empty());
// headers_ might be empty for wrong response from server.
return 0;
}
size_t old_original_length = original_.size();
original_.append(data, length);
// TODO(ukai): fail fast when response gives wrong status code.
original_header_length_ = HttpUtil::LocateEndOfHeaders(
original_.data(), original_.size(), 0);
if (!HasResponse())
return length;
ParseHandshakeHeader(original_.data(),
original_header_length_,
&status_line_,
&headers_);
int header_size = status_line_.size() + headers_.size();
DCHECK_GE(original_header_length_, header_size);
header_separator_ = std::string(original_.data() + header_size,
original_header_length_ - header_size);
key_ = std::string(original_.data() + original_header_length_,
GetResponseKeySize());
return original_header_length_ + GetResponseKeySize() - old_original_length;
}
bool WebSocketHandshakeResponseHandler::HasResponse() const {
return original_header_length_ > 0 &&
original_header_length_ + GetResponseKeySize() <= original_.size();
}
bool WebSocketHandshakeResponseHandler::ParseResponseInfo(
const HttpResponseInfo& response_info,
const std::string& challenge) {
if (!response_info.headers.get())
return false;
std::string response_message;
response_message = response_info.headers->GetStatusLine();
response_message += "\r\n";
if (protocol_version_ >= kMinVersionOfHybiNewHandshake)
response_message += "Upgrade: websocket\r\n";
else
response_message += "Upgrade: WebSocket\r\n";
response_message += "Connection: Upgrade\r\n";
if (protocol_version_ >= kMinVersionOfHybiNewHandshake) {
std::string hash = base::SHA1HashString(challenge + kWebSocketGuid);
std::string websocket_accept;
bool encode_success = base::Base64Encode(hash, &websocket_accept);
DCHECK(encode_success);
response_message += "Sec-WebSocket-Accept: " + websocket_accept + "\r\n";
}
void* iter = NULL;
std::string name;
std::string value;
while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) {
response_message += name + ": " + value + "\r\n";
}
response_message += "\r\n";
if (protocol_version_ < kMinVersionOfHybiNewHandshake) {
base::MD5Digest digest;
base::MD5Sum(challenge.data(), challenge.size(), &digest);
const char* digest_data = reinterpret_cast<char*>(digest.a);
response_message.append(digest_data, sizeof(digest.a));
}
return ParseRawResponse(response_message.data(),
response_message.size()) == response_message.size();
}
bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock(
const SpdyHeaderBlock& headers,
const std::string& challenge,
int spdy_protocol_version) {
SpdyHeaderBlock::const_iterator status;
if (spdy_protocol_version <= 2)
status = headers.find("status");
else
status = headers.find(":status");
if (status == headers.end())
return false;
std::string response_message;
response_message =
base::StringPrintf("%s%s\r\n", "HTTP/1.1 ", status->second.c_str());
response_message += "Upgrade: websocket\r\n";
response_message += "Connection: Upgrade\r\n";
std::string hash = base::SHA1HashString(challenge + kWebSocketGuid);
std::string websocket_accept;
bool encode_success = base::Base64Encode(hash, &websocket_accept);
DCHECK(encode_success);
response_message += "Sec-WebSocket-Accept: " + websocket_accept + "\r\n";
for (SpdyHeaderBlock::const_iterator iter = headers.begin();
iter != headers.end();
++iter) {
// For each value, if the server sends a NUL-separated list of values,
// we separate that back out into individual headers for each value
// in the list.
if ((spdy_protocol_version <= 2 &&
LowerCaseEqualsASCII(iter->first, "status")) ||
(spdy_protocol_version >= 3 &&
LowerCaseEqualsASCII(iter->first, ":status"))) {
// The status value is already handled as the first line of
// |response_message|. Just skip here.
continue;
}
const std::string& value = iter->second;
size_t start = 0;
size_t end = 0;
do {
end = value.find('\0', start);
std::string tval;
if (end != std::string::npos)
tval = value.substr(start, (end - start));
else
tval = value.substr(start);
if (spdy_protocol_version >= 3 &&
(LowerCaseEqualsASCII(iter->first, ":sec-websocket-protocol") ||
LowerCaseEqualsASCII(iter->first, ":sec-websocket-extensions")))
response_message += iter->first.substr(1) + ": " + tval + "\r\n";
else
response_message += iter->first + ": " + tval + "\r\n";
start = end + 1;
} while (end != std::string::npos);
}
response_message += "\r\n";
return ParseRawResponse(response_message.data(),
response_message.size()) == response_message.size();
}
void WebSocketHandshakeResponseHandler::GetHeaders(
const char* const headers_to_get[],
size_t headers_to_get_len,
std::vector<std::string>* values) {
DCHECK(HasResponse());
DCHECK(!status_line_.empty());
// headers_ might be empty for wrong response from server.
if (headers_.empty())
return;
FetchHeaders(headers_, headers_to_get, headers_to_get_len, values);
}
void WebSocketHandshakeResponseHandler::RemoveHeaders(
const char* const headers_to_remove[],
size_t headers_to_remove_len) {
DCHECK(HasResponse());
DCHECK(!status_line_.empty());
// headers_ might be empty for wrong response from server.
if (headers_.empty())
return;
headers_ = FilterHeaders(headers_, headers_to_remove, headers_to_remove_len);
}
std::string WebSocketHandshakeResponseHandler::GetRawResponse() const {
DCHECK(HasResponse());
return std::string(original_.data(),
original_header_length_ + GetResponseKeySize());
}
std::string WebSocketHandshakeResponseHandler::GetResponse() {
DCHECK(HasResponse());
DCHECK(!status_line_.empty());
// headers_ might be empty for wrong response from server.
DCHECK_EQ(GetResponseKeySize(), key_.size());
return status_line_ + headers_ + header_separator_ + key_;
}
size_t WebSocketHandshakeResponseHandler::GetResponseKeySize() const {
if (protocol_version_ >= kMinVersionOfHybiNewHandshake)
return 0;
return kResponseKeySize;
}
} // namespace net