| // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "net/dns/dns_test_util.h" |
| |
| #include <string> |
| |
| #include "base/big_endian.h" |
| #include "base/bind.h" |
| #include "base/location.h" |
| #include "base/memory/weak_ptr.h" |
| #include "base/single_thread_task_runner.h" |
| #include "base/sys_byteorder.h" |
| #include "base/threading/thread_task_runner_handle.h" |
| #include "net/base/io_buffer.h" |
| #include "net/base/net_errors.h" |
| #include "net/dns/address_sorter.h" |
| #include "net/dns/dns_query.h" |
| #include "net/dns/dns_response.h" |
| #include "net/dns/dns_transaction.h" |
| #include "net/dns/dns_util.h" |
| #include "starboard/memory.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| namespace net { |
| namespace { |
| |
| class MockAddressSorter : public AddressSorter { |
| public: |
| ~MockAddressSorter() override = default; |
| void Sort(const AddressList& list, CallbackType callback) const override { |
| // Do nothing. |
| std::move(callback).Run(true, list); |
| } |
| }; |
| |
| // A DnsTransaction which uses MockDnsClientRuleList to determine the response. |
| class MockTransaction : public DnsTransaction, |
| public base::SupportsWeakPtr<MockTransaction> { |
| public: |
| MockTransaction(const MockDnsClientRuleList& rules, |
| const std::string& hostname, |
| uint16_t qtype, |
| DnsTransactionFactory::CallbackType callback) |
| : result_(MockDnsClientRule::FAIL), |
| hostname_(hostname), |
| qtype_(qtype), |
| callback_(std::move(callback)), |
| started_(false), |
| delayed_(false) { |
| // Find the relevant rule which matches |qtype| and prefix of |hostname|. |
| for (size_t i = 0; i < rules.size(); ++i) { |
| const std::string& prefix = rules[i].prefix; |
| if ((rules[i].qtype == qtype) && |
| (hostname.size() >= prefix.size()) && |
| (hostname.compare(0, prefix.size(), prefix) == 0)) { |
| result_ = rules[i].result; |
| delayed_ = rules[i].delay; |
| |
| // Fill in an IP address for the result if one was not specified. |
| if (result_.ip.empty() && result_.type == MockDnsClientRule::OK) { |
| result_.ip = qtype_ == dns_protocol::kTypeA |
| ? IPAddress::IPv4Localhost() |
| : IPAddress::IPv6Localhost(); |
| } |
| break; |
| } |
| } |
| } |
| |
| const std::string& GetHostname() const override { return hostname_; } |
| |
| uint16_t GetType() const override { return qtype_; } |
| |
| void Start() override { |
| EXPECT_FALSE(started_); |
| started_ = true; |
| if (delayed_) |
| return; |
| // Using WeakPtr to cleanly cancel when transaction is destroyed. |
| base::ThreadTaskRunnerHandle::Get()->PostTask( |
| FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr())); |
| } |
| |
| void FinishDelayedTransaction() { |
| EXPECT_TRUE(delayed_); |
| delayed_ = false; |
| Finish(); |
| } |
| |
| bool delayed() const { return delayed_; } |
| |
| private: |
| void Finish() { |
| switch (result_.type) { |
| case MockDnsClientRule::NODOMAIN: |
| case MockDnsClientRule::EMPTY: |
| case MockDnsClientRule::OK: { |
| std::string qname; |
| DNSDomainFromDot(hostname_, &qname); |
| DnsQuery query(0, qname, qtype_); |
| |
| DnsResponse response; |
| char* buffer = response.io_buffer()->data(); |
| int nbytes = query.io_buffer()->size(); |
| memcpy(buffer, query.io_buffer()->data(), nbytes); |
| dns_protocol::Header* header = |
| reinterpret_cast<dns_protocol::Header*>(buffer); |
| header->flags |= dns_protocol::kFlagResponse; |
| if (MockDnsClientRule::NODOMAIN == result_.type) { |
| header->flags |= base::HostToNet16(dns_protocol::kRcodeNXDOMAIN); |
| } |
| |
| const uint16_t kPointerToQueryName = |
| static_cast<uint16_t>(0xc000 | sizeof(*header)); |
| |
| const uint32_t kTTL = 86400; // One day. |
| |
| // Size of RDATA which is a IPv4 or IPv6 address. |
| if (MockDnsClientRule::OK == result_.type) |
| EXPECT_TRUE(result_.ip.IsValid()); |
| size_t rdata_size = result_.ip.size(); |
| |
| if (MockDnsClientRule::OK == result_.type) { |
| // Write the answer using the expected IP address. |
| // 12 is the sum of sizes of the compressed name reference, TYPE, |
| // CLASS, TTL and RDLENGTH. |
| size_t answer_size = 12 + rdata_size; |
| header->ancount = base::HostToNet16(1); |
| base::BigEndianWriter writer(buffer + nbytes, answer_size); |
| writer.WriteU16(kPointerToQueryName); |
| writer.WriteU16(qtype_); |
| writer.WriteU16(dns_protocol::kClassIN); |
| writer.WriteU32(kTTL); |
| writer.WriteU16(static_cast<uint16_t>(rdata_size)); |
| writer.WriteBytes(result_.ip.bytes().data(), rdata_size); |
| nbytes += answer_size; |
| } else { |
| // authority will be 12 bytes. |
| size_t authority_size = 12; |
| header->ancount = base::HostToNet16(0); |
| header->nscount = base::HostToNet16(1); |
| base::BigEndianWriter writer(buffer + nbytes, authority_size); |
| writer.WriteU16(kPointerToQueryName); |
| writer.WriteU16(dns_protocol::kTypeSOA); |
| writer.WriteU16(dns_protocol::kClassIN); |
| writer.WriteU32(kTTL); |
| writer.WriteU16(static_cast<uint16_t>(rdata_size)); |
| nbytes += authority_size; |
| } |
| EXPECT_TRUE(response.InitParse(nbytes, query)); |
| if (MockDnsClientRule::NODOMAIN == result_.type) { |
| std::move(callback_).Run(this, ERR_NAME_NOT_RESOLVED, &response); |
| } else { |
| std::move(callback_).Run(this, OK, &response); |
| } |
| } break; |
| case MockDnsClientRule::FAIL: |
| std::move(callback_).Run(this, ERR_NAME_NOT_RESOLVED, NULL); |
| break; |
| case MockDnsClientRule::TIMEOUT: |
| std::move(callback_).Run(this, ERR_DNS_TIMED_OUT, NULL); |
| break; |
| default: |
| NOTREACHED(); |
| break; |
| } |
| } |
| |
| void SetRequestContext(URLRequestContext*) override {} |
| void SetRequestPriority(RequestPriority priority) override {} |
| |
| MockDnsClientRule::Result result_; |
| const std::string hostname_; |
| const uint16_t qtype_; |
| DnsTransactionFactory::CallbackType callback_; |
| bool started_; |
| bool delayed_; |
| }; |
| |
| } // namespace |
| |
| // A DnsTransactionFactory which creates MockTransaction. |
| class MockTransactionFactory : public DnsTransactionFactory { |
| public: |
| explicit MockTransactionFactory(const MockDnsClientRuleList& rules) |
| : rules_(rules) {} |
| |
| ~MockTransactionFactory() override = default; |
| |
| std::unique_ptr<DnsTransaction> CreateTransaction( |
| const std::string& hostname, |
| uint16_t qtype, |
| DnsTransactionFactory::CallbackType callback, |
| const NetLogWithSource&) override { |
| std::unique_ptr<MockTransaction> transaction = |
| std::make_unique<MockTransaction>(rules_, hostname, qtype, |
| std::move(callback)); |
| if (transaction->delayed()) |
| delayed_transactions_.push_back(transaction->AsWeakPtr()); |
| #if defined(STARBOARD) |
| return std::move(transaction); |
| #else |
| return transaction; |
| #endif |
| } |
| |
| void AddEDNSOption(const OptRecordRdata::Opt& opt) override { |
| NOTREACHED() << "Not implemented"; |
| } |
| |
| void CompleteDelayedTransactions() { |
| DelayedTransactionList old_delayed_transactions; |
| old_delayed_transactions.swap(delayed_transactions_); |
| for (auto it = old_delayed_transactions.begin(); |
| it != old_delayed_transactions.end(); ++it) { |
| if (it->get()) |
| (*it)->FinishDelayedTransaction(); |
| } |
| } |
| |
| private: |
| typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList; |
| |
| MockDnsClientRuleList rules_; |
| DelayedTransactionList delayed_transactions_; |
| }; |
| |
| MockDnsClient::MockDnsClient(const DnsConfig& config, |
| const MockDnsClientRuleList& rules) |
| : config_(config), |
| factory_(new MockTransactionFactory(rules)), |
| address_sorter_(new MockAddressSorter()) { |
| } |
| |
| MockDnsClient::~MockDnsClient() = default; |
| |
| void MockDnsClient::SetConfig(const DnsConfig& config) { |
| config_ = config; |
| } |
| |
| const DnsConfig* MockDnsClient::GetConfig() const { |
| return config_.IsValid() ? &config_ : NULL; |
| } |
| |
| DnsTransactionFactory* MockDnsClient::GetTransactionFactory() { |
| return config_.IsValid() ? factory_.get() : NULL; |
| } |
| |
| AddressSorter* MockDnsClient::GetAddressSorter() { |
| return address_sorter_.get(); |
| } |
| |
| void MockDnsClient::CompleteDelayedTransactions() { |
| factory_->CompleteDelayedTransactions(); |
| } |
| |
| } // namespace net |