// Copyright 2013 The Chromium Authors. All rights reserved.
/* Modifications: Copyright 2017 Google Inc. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "cobalt/websocket/web_socket_handshake_helper.h"

#include <set>
#include <string>

#include "base/logging.h"
#include "base/string_number_conversions.h"
#include "base/string_piece.h"
#include "base/string_util.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_status_code.h"
#include "net/websockets/websocket_extension.h"
#include "net/websockets/websocket_extension_parser.h"
#include "net/websockets/websocket_frame.h"
#include "net/websockets/websocket_handshake_challenge.h"
#include "net/websockets/websocket_handshake_constants.h"
#include "starboard/system.h"

namespace {
// Following enum and anonymous functions are adapted from Chromium net source,
// commit id: 7321c9e7ee80ef15b65c2f39646a5a2d22a9c950 in
// src/net/websockets/websocket_basic_handshake_stream.cc.

enum GetHeaderResult {
  GET_HEADER_OK,
  GET_HEADER_MISSING,
  GET_HEADER_MULTIPLE,
};

GetHeaderResult GetSingleHeaderValue(const net::HttpResponseHeaders* headers,
                                     const base::StringPiece& name,
                                     std::string* value) {
  void* iter = NULL;
  size_t num_values = 0;
  std::string temp_value;
  std::string name_string = name.as_string();
  while (headers->EnumerateHeader(&iter, name_string, &temp_value)) {
    if (++num_values > 1) return GET_HEADER_MULTIPLE;
    *value = temp_value;
  }
  return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING;
}

std::string MissingHeaderMessage(const std::string& header_name) {
  return std::string("'") + header_name + "' header is missing";
}

std::string MultipleHeaderValuesMessage(const std::string& header_name) {
  return std::string("'") + header_name +
         "' header must not appear more than once in a response";
}

bool ValidateHeaderHasSingleValue(GetHeaderResult result,
                                  const std::string& header_name,
                                  std::string* failure_message) {
  if (result == GET_HEADER_MISSING) {
    *failure_message = MissingHeaderMessage(header_name);
    return false;
  }
  if (result == GET_HEADER_MULTIPLE) {
    *failure_message = MultipleHeaderValuesMessage(header_name);
    return false;
  }
  DCHECK_EQ(result, GET_HEADER_OK);
  return true;
}

bool ValidateUpgrade(const net::HttpResponseHeaders* headers,
                     std::string* failure_message) {
  std::string value;
  GetHeaderResult result =
      GetSingleHeaderValue(headers, net::websockets::kUpgrade, &value);
  if (!ValidateHeaderHasSingleValue(result, net::websockets::kUpgrade,
                                    failure_message)) {
    return false;
  }

  if (!LowerCaseEqualsASCII(value, net::websockets::kWebSocketLowercase)) {
    *failure_message = "'Upgrade' header value is not 'WebSocket': " + value;
    return false;
  }
  return true;
}

bool ValidateSecWebSocketAccept(const net::HttpResponseHeaders* headers,
                                const std::string& expected,
                                std::string* failure_message) {
  std::string actual;
  GetHeaderResult result = GetSingleHeaderValue(
      headers, net::websockets::kSecWebSocketAccept, &actual);
  if (!ValidateHeaderHasSingleValue(
          result, net::websockets::kSecWebSocketAccept, failure_message)) {
    return false;
  }

  if (expected != actual) {
    *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value";
    return false;
  }
  return true;
}

bool ValidateConnection(const net::HttpResponseHeaders* headers,
                        std::string* failure_message) {
  // Connection header is permitted to contain other tokens.
  if (!headers->HasHeader(net::HttpRequestHeaders::kConnection)) {
    *failure_message =
        MissingHeaderMessage(net::HttpRequestHeaders::kConnection);
    return false;
  }
  if (!headers->HasHeaderValue(net::HttpRequestHeaders::kConnection,
                               net::websockets::kUpgrade)) {
    *failure_message = "'Connection' header value must contain 'Upgrade'";
    return false;
  }
  return true;
}

bool ValidateSubProtocol(
    const net::HttpResponseHeaders* headers,
    const std::vector<std::string>& requested_sub_protocols,
    std::string* sub_protocol, std::string* failure_message) {
  void* iter = NULL;
  std::string value;
  std::set<std::string> requested_set(requested_sub_protocols.begin(),
                                      requested_sub_protocols.end());
  int count = 0;
  bool has_multiple_protocols = false;
  bool has_invalid_protocol = false;

  while (!has_invalid_protocol || !has_multiple_protocols) {
    std::string temp_value;
    if (!headers->EnumerateHeader(&iter, net::websockets::kSecWebSocketProtocol,
                                  &temp_value))
      break;
    value = temp_value;
    if (requested_set.count(value) == 0) has_invalid_protocol = true;
    if (++count > 1) has_multiple_protocols = true;
  }

  if (has_multiple_protocols) {
    *failure_message =
        MultipleHeaderValuesMessage(net::websockets::kSecWebSocketProtocol);
    return false;
  } else if (count > 0 && requested_sub_protocols.size() == 0) {
    *failure_message = std::string(
                           "Response must not include 'Sec-WebSocket-Protocol' "
                           "header if not present in request: ") +
                       value;
    return false;
  } else if (has_invalid_protocol) {
    *failure_message = "'Sec-WebSocket-Protocol' header value '" + value +
                       "' in response does not match any of sent values";
    return false;
  } else if (requested_sub_protocols.size() > 0 && count == 0) {
    *failure_message =
        "Sent non-empty 'Sec-WebSocket-Protocol' header "
        "but no response was received";
    return false;
  }
  *sub_protocol = value;
  return true;
}

bool ValidateExtensions(const net::HttpResponseHeaders* headers,
                        std::string* failure_message) {
  void* iter = NULL;
  std::string header_value;
  while (headers->EnumerateHeader(
      &iter, net::websockets::kSecWebSocketExtensions, &header_value)) {
    net::WebSocketExtensionParser parser;
    if (!parser.Parse(header_value)) {
      *failure_message =
          "'Sec-WebSocket-Extensions' header value is "
          "rejected by the parser: " +
          header_value;
      return false;
    }

    const std::vector<net::WebSocketExtension>& extensions =
        parser.extensions();
    if (extensions.empty() == false) {
      *failure_message = "Cobalt does not support any websocket extensions";
      return false;
    }
  }
  return true;
}

cobalt::websocket::SecWebSocketKey GenerateRandomSecWebSocketKey() {
  using cobalt::websocket::SecWebSocketKey;
  SecWebSocketKey::SecWebSocketKeyBytes random_data;
  SbSystemGetRandomData(&random_data,
                        sizeof(SecWebSocketKey::SecWebSocketKeyBytes));
  cobalt::websocket::SecWebSocketKey key(random_data);
  return key;
}

}  // namespace

namespace cobalt {
namespace websocket {

void WebSocketHandshakeHelper::GenerateHandshakeRequest(
    const GURL& connect_url, const std::string& origin,
    const std::vector<std::string>& desired_sub_protocols,
    std::string* handshake_request) {
  DCHECK(handshake_request);
  GenerateSecWebSocketKey();

  int effective_port = connect_url.IntPort();
  std::string host_header(connect_url.host());
  if (effective_port != url_parse::PORT_UNSPECIFIED) {
    host_header += ":" + connect_url.port();
  }

  std::string& header_string(*handshake_request);
  header_string.clear();
  header_string.reserve(256);  // This avoids reallocations for most cases.

  // Note: Concatenating string literals and std::string objects are separated
  // to avoid creating unnecessary std::string objects.
  header_string += "GET ";
  header_string += connect_url.path();
  if (connect_url.has_query()) {
    header_string += "?";
    header_string += connect_url.query();
  }
  header_string += " HTTP/1.1\r\n";
  header_string += "Host:";
  header_string += host_header;
  header_string += "\r\n";
  header_string +=
      "Connection:Upgrade\r\n"
      "Pragma:no-cache\r\n"
      "Cache-Control:no-cache\r\n"
      "Upgrade:websocket\r\n"
      "Sec-WebSocket-Extensions:\r\n"
      "Sec-WebSocket-Version:13\r\n";
  header_string += "Origin:";
  header_string += origin;
  header_string += "\r\n";
  header_string += "Sec-WebSocket-Key:";
  header_string += sec_websocket_key_.GetKeyEncodedInBase64();
  header_string += "\r\n";
  header_string += "User-Agent:";
  header_string += user_agent_;
  header_string += "\r\n";

  if (!desired_sub_protocols.empty()) {
    header_string += "Sec-WebSocket-Protocol:";
    header_string += JoinString(desired_sub_protocols, ",");
    header_string += "\r\n";
  }

  header_string += "\r\n";

  requested_sub_protocols_ = desired_sub_protocols;

  const std::string& sec_websocket_key_base64(
      sec_websocket_key_.GetKeyEncodedInBase64());
  handshake_challenge_response_ =
      net::ComputeSecWebSocketAccept(sec_websocket_key_base64);
}

WebSocketHandshakeHelper::WebSocketHandshakeHelper(
    const base::StringPiece user_agent)
    : sec_websocket_key_generator_function_(&GenerateRandomSecWebSocketKey),
      user_agent_(user_agent.data(), user_agent.size()) {}

WebSocketHandshakeHelper::WebSocketHandshakeHelper(
    const base::StringPiece user_agent,
    SecWebSocketKeyGeneratorFunction sec_websocket_key_generator_function)
    : sec_websocket_key_generator_function_(
          sec_websocket_key_generator_function),
      user_agent_(user_agent.data(), user_agent.size()) {}

bool WebSocketHandshakeHelper::IsResponseValid(
    const net::HttpResponseHeaders& headers, std::string* failure_message) {
  DCHECK(failure_message);
  int response_code = headers.response_code();

  // Check response code first.
  if (response_code != net::HTTP_SWITCHING_PROTOCOLS) {
    *failure_message =
        "Invalid response code " + base::IntToString(response_code);
    return false;
  }

  if (!ValidateUpgrade(&headers, failure_message)) {
    return false;
  }
  if (!ValidateSecWebSocketAccept(&headers, handshake_challenge_response_,
                                  failure_message)) {
    return false;
  }
  if (!ValidateConnection(&headers, failure_message)) {
    return false;
  }
  if (!ValidateSubProtocol(&headers, requested_sub_protocols_,
                           &selected_subprotocol_, failure_message)) {
    return false;
  }
  // Cobalt does not support extensions, so we just make sure that none are
  // being selected.
  if (!ValidateExtensions(&headers, failure_message)) {
    return false;
  }

  failure_message->clear();
  return true;
}

void WebSocketHandshakeHelper::GenerateSecWebSocketKey() {
  sec_websocket_key_ = sec_websocket_key_generator_function_();
}

}  // namespace websocket
}  // namespace cobalt
