blob: 972492405ac06ff63fe8a30f86f4653f4502d294 [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/test/test_doh_server.h"
#include <string.h>
#include <memory>
#include "base/base64url.h"
#include "base/check.h"
#include "base/functional/bind.h"
#include "base/logging.h"
#include "base/memory/scoped_refptr.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_piece.h"
#include "base/synchronization/lock.h"
#include "net/base/io_buffer.h"
#include "net/base/url_util.h"
#include "net/dns/dns_names_util.h"
#include "net/dns/dns_query.h"
#include "net/dns/dns_response.h"
#include "net/dns/dns_test_util.h"
#include "net/dns/public/dns_protocol.h"
#include "net/http/http_status_code.h"
#include "net/test/embedded_test_server/embedded_test_server.h"
#include "net/test/embedded_test_server/http_request.h"
#include "net/test/embedded_test_server/http_response.h"
#include "url/gurl.h"
namespace net {
namespace {
const char kPath[] = "/dns-query";
std::unique_ptr<test_server::HttpResponse> MakeHttpErrorResponse(
HttpStatusCode status,
base::StringPiece error) {
auto response = std::make_unique<test_server::BasicHttpResponse>();
response->set_code(status);
response->set_content(std::string(error));
response->set_content_type("text/plain;charset=utf-8");
return response;
}
std::unique_ptr<test_server::HttpResponse> MakeHttpResponseFromDns(
const DnsResponse& dns_response) {
if (!dns_response.IsValid()) {
return MakeHttpErrorResponse(HTTP_INTERNAL_SERVER_ERROR,
"error making DNS response");
}
auto response = std::make_unique<test_server::BasicHttpResponse>();
response->set_code(HTTP_OK);
response->set_content(std::string(dns_response.io_buffer()->data(),
dns_response.io_buffer_size()));
response->set_content_type("application/dns-message");
return response;
}
} // namespace
TestDohServer::TestDohServer() {
server_.RegisterRequestHandler(base::BindRepeating(
&TestDohServer::HandleRequest, base::Unretained(this)));
}
TestDohServer::~TestDohServer() = default;
void TestDohServer::SetHostname(base::StringPiece name) {
DCHECK(!server_.Started());
hostname_ = std::string(name);
}
void TestDohServer::SetFailRequests(bool fail_requests) {
base::AutoLock lock(lock_);
fail_requests_ = fail_requests;
}
void TestDohServer::AddAddressRecord(base::StringPiece name,
const IPAddress& address,
base::TimeDelta ttl) {
AddRecord(BuildTestAddressRecord(std::string(name), address, ttl));
}
void TestDohServer::AddRecord(const DnsResourceRecord& record) {
base::AutoLock lock(lock_);
records_.insert(
std::make_pair(std::make_pair(record.name, record.type), record));
}
bool TestDohServer::Start() {
if (!InitializeAndListen()) {
return false;
}
StartAcceptingConnections();
return true;
}
bool TestDohServer::InitializeAndListen() {
if (hostname_) {
EmbeddedTestServer::ServerCertificateConfig cert_config;
cert_config.dns_names = {*hostname_};
server_.SetSSLConfig(cert_config);
} else {
// `CERT_OK` is valid for 127.0.0.1.
server_.SetSSLConfig(EmbeddedTestServer::CERT_OK);
}
return server_.InitializeAndListen();
}
void TestDohServer::StartAcceptingConnections() {
server_.StartAcceptingConnections();
}
bool TestDohServer::ShutdownAndWaitUntilComplete() {
return server_.ShutdownAndWaitUntilComplete();
}
std::string TestDohServer::GetTemplate() {
GURL url =
hostname_ ? server_.GetURL(*hostname_, kPath) : server_.GetURL(kPath);
return url.spec() + "{?dns}";
}
std::string TestDohServer::GetPostOnlyTemplate() {
GURL url =
hostname_ ? server_.GetURL(*hostname_, kPath) : server_.GetURL(kPath);
return url.spec();
}
int TestDohServer::QueriesServed() {
base::AutoLock lock(lock_);
return queries_served_;
}
std::unique_ptr<test_server::HttpResponse> TestDohServer::HandleRequest(
const test_server::HttpRequest& request) {
GURL request_url = request.GetURL();
if (request_url.path_piece() != kPath) {
return nullptr;
}
base::AutoLock lock(lock_);
queries_served_++;
if (fail_requests_) {
return MakeHttpErrorResponse(HTTP_NOT_FOUND, "failed request");
}
// See RFC 8484, Section 4.1.
std::string query;
if (request.method == test_server::METHOD_GET) {
std::string query_b64;
if (!GetValueForKeyInQuery(request_url, "dns", &query_b64) ||
!base::Base64UrlDecode(
query_b64, base::Base64UrlDecodePolicy::IGNORE_PADDING, &query)) {
return MakeHttpErrorResponse(HTTP_BAD_REQUEST,
"could not decode query string");
}
} else if (request.method == test_server::METHOD_POST) {
auto content_type = request.headers.find("content-type");
if (content_type == request.headers.end() ||
content_type->second != "application/dns-message") {
return MakeHttpErrorResponse(HTTP_BAD_REQUEST,
"unsupported content type");
}
query = request.content;
} else {
return MakeHttpErrorResponse(HTTP_BAD_REQUEST, "invalid method");
}
// Parse the DNS query.
auto query_buf = base::MakeRefCounted<IOBufferWithSize>(query.size());
memcpy(query_buf->data(), query.data(), query.size());
DnsQuery dns_query(std::move(query_buf));
if (!dns_query.Parse(query.size())) {
return MakeHttpErrorResponse(HTTP_BAD_REQUEST, "invalid DNS query");
}
absl::optional<std::string> name = dns_names_util::NetworkToDottedName(
dns_query.qname(), /*require_complete=*/true);
if (!name) {
DnsResponse response(dns_query.id(), /*is_authoritative=*/false,
/*answers=*/{}, /*authority_records=*/{},
/*additional_records=*/{}, dns_query,
dns_protocol::kRcodeFORMERR);
return MakeHttpResponseFromDns(response);
}
auto range = records_.equal_range(std::make_pair(*name, dns_query.qtype()));
std::vector<DnsResourceRecord> answers;
for (auto i = range.first; i != range.second; ++i) {
answers.push_back(i->second);
}
VLOG(1) << "Serving " << answers.size() << " records for " << *name
<< ", qtype " << dns_query.qtype();
// Note `answers` may be empty. NOERROR with no answers is how to express
// NODATA, so there is no need handle it specially.
//
// For now, this server does not support configuring additional records. When
// testing more complex HTTPS record cases, this will need to be extended.
//
// TODO(crbug.com/1251204): Add SOA records to test the default TTL.
DnsResponse response(dns_query.id(), /*is_authoritative=*/true,
/*answers=*/answers, /*authority_records=*/{},
/*additional_records=*/{}, dns_query);
return MakeHttpResponseFromDns(response);
}
} // namespace net