|  | // Copyright 2014 The Chromium Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style license that can be | 
|  | // found in the LICENSE file. | 
|  |  | 
|  | #include "net/ssl/channel_id_service.h" | 
|  |  | 
|  | #include <algorithm> | 
|  | #include <limits> | 
|  | #include <memory> | 
|  | #include <utility> | 
|  |  | 
|  | #include "base/atomic_sequence_num.h" | 
|  | #include "base/bind.h" | 
|  | #include "base/bind_helpers.h" | 
|  | #include "base/callback_helpers.h" | 
|  | #include "base/compiler_specific.h" | 
|  | #include "base/location.h" | 
|  | #include "base/logging.h" | 
|  | #include "base/macros.h" | 
|  | #include "base/memory/ptr_util.h" | 
|  | #include "base/metrics/histogram_macros.h" | 
|  | #include "base/rand_util.h" | 
|  | #include "base/single_thread_task_runner.h" | 
|  | #include "base/task/post_task.h" | 
|  | #include "base/task_runner.h" | 
|  | #include "base/threading/thread_task_runner_handle.h" | 
|  | #include "crypto/ec_private_key.h" | 
|  | #include "net/base/net_errors.h" | 
|  | #include "net/base/registry_controlled_domains/registry_controlled_domain.h" | 
|  | #include "net/cert/x509_certificate.h" | 
|  | #include "net/cert/x509_util.h" | 
|  | #include "url/gurl.h" | 
|  |  | 
|  | namespace net { | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | base::AtomicSequenceNumber g_next_id; | 
|  |  | 
|  | // On success, returns a ChannelID object and sets |*error| to OK. | 
|  | // Otherwise, returns NULL, and |*error| will be set to a net error code. | 
|  | // |serial_number| is passed in because base::RandInt cannot be called from an | 
|  | // unjoined thread, due to relying on a non-leaked LazyInstance | 
|  | std::unique_ptr<ChannelIDStore::ChannelID> GenerateChannelID( | 
|  | const std::string& server_identifier, | 
|  | int* error) { | 
|  | std::unique_ptr<ChannelIDStore::ChannelID> result; | 
|  |  | 
|  | base::Time creation_time = base::Time::Now(); | 
|  | std::unique_ptr<crypto::ECPrivateKey> key(crypto::ECPrivateKey::Create()); | 
|  |  | 
|  | if (!key) { | 
|  | DLOG(ERROR) << "Unable to create channel ID key pair"; | 
|  | *error = ERR_KEY_GENERATION_FAILED; | 
|  | return result; | 
|  | } | 
|  |  | 
|  | result.reset(new ChannelIDStore::ChannelID(server_identifier, creation_time, | 
|  | std::move(key))); | 
|  | *error = OK; | 
|  | return result; | 
|  | } | 
|  |  | 
|  | }  // namespace | 
|  |  | 
|  | // ChannelIDServiceWorker takes care of the blocking process of performing key | 
|  | // generation. Will take care of deleting itself once Start() is called. | 
|  | class ChannelIDServiceWorker { | 
|  | public: | 
|  | typedef base::OnceCallback< | 
|  | void(const std::string&, int, std::unique_ptr<ChannelIDStore::ChannelID>)> | 
|  | WorkerDoneCallback; | 
|  |  | 
|  | ChannelIDServiceWorker(const std::string& server_identifier, | 
|  | WorkerDoneCallback callback) | 
|  | : server_identifier_(server_identifier), | 
|  | origin_task_runner_(base::ThreadTaskRunnerHandle::Get()), | 
|  | callback_(std::move(callback)) {} | 
|  |  | 
|  | // Starts the worker asynchronously. | 
|  | void Start(const scoped_refptr<base::TaskRunner>& task_runner) { | 
|  | DCHECK(origin_task_runner_->RunsTasksInCurrentSequence()); | 
|  |  | 
|  | auto callback = | 
|  | base::BindOnce(&ChannelIDServiceWorker::Run, base::Owned(this)); | 
|  |  | 
|  | if (task_runner) { | 
|  | task_runner->PostTask(FROM_HERE, std::move(callback)); | 
|  | } else { | 
|  | base::PostTaskWithTraits( | 
|  | FROM_HERE, | 
|  | {base::MayBlock(), base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN}, | 
|  | std::move(callback)); | 
|  | } | 
|  | } | 
|  |  | 
|  | private: | 
|  | void Run() { | 
|  | // Runs on a worker thread. | 
|  | int error = ERR_FAILED; | 
|  | std::unique_ptr<ChannelIDStore::ChannelID> channel_id = | 
|  | GenerateChannelID(server_identifier_, &error); | 
|  | origin_task_runner_->PostTask( | 
|  | FROM_HERE, base::BindOnce(std::move(callback_), server_identifier_, | 
|  | error, base::Passed(&channel_id))); | 
|  | } | 
|  |  | 
|  | const std::string server_identifier_; | 
|  | scoped_refptr<base::SequencedTaskRunner> origin_task_runner_; | 
|  | WorkerDoneCallback callback_; | 
|  |  | 
|  | DISALLOW_COPY_AND_ASSIGN(ChannelIDServiceWorker); | 
|  | }; | 
|  |  | 
|  | // A ChannelIDServiceJob is a one-to-one counterpart of an | 
|  | // ChannelIDServiceWorker. It lives only on the ChannelIDService's | 
|  | // origin task runner's thread. | 
|  | class ChannelIDServiceJob { | 
|  | public: | 
|  | ChannelIDServiceJob(bool create_if_missing) | 
|  | : create_if_missing_(create_if_missing) { | 
|  | } | 
|  |  | 
|  | ~ChannelIDServiceJob() { DCHECK(requests_.empty()); } | 
|  |  | 
|  | void AddRequest(ChannelIDService::Request* request, | 
|  | bool create_if_missing = false) { | 
|  | create_if_missing_ |= create_if_missing; | 
|  | requests_.push_back(request); | 
|  | } | 
|  |  | 
|  | void HandleResult(int error, std::unique_ptr<crypto::ECPrivateKey> key) { | 
|  | PostAll(error, std::move(key)); | 
|  | } | 
|  |  | 
|  | bool CreateIfMissing() const { return create_if_missing_; } | 
|  |  | 
|  | void CancelRequest(ChannelIDService::Request* req) { | 
|  | auto it = std::find(requests_.begin(), requests_.end(), req); | 
|  | if (it != requests_.end()) | 
|  | requests_.erase(it); | 
|  | } | 
|  |  | 
|  | private: | 
|  | void PostAll(int error, std::unique_ptr<crypto::ECPrivateKey> key) { | 
|  | std::vector<ChannelIDService::Request*> requests; | 
|  | requests_.swap(requests); | 
|  |  | 
|  | for (auto i = requests.begin(); i != requests.end(); i++) { | 
|  | std::unique_ptr<crypto::ECPrivateKey> key_copy; | 
|  | if (key) | 
|  | key_copy = key->Copy(); | 
|  | (*i)->Post(error, std::move(key_copy)); | 
|  | } | 
|  | } | 
|  |  | 
|  | std::vector<ChannelIDService::Request*> requests_; | 
|  | bool create_if_missing_; | 
|  | }; | 
|  |  | 
|  | ChannelIDService::Request::Request() : service_(NULL) { | 
|  | } | 
|  |  | 
|  | ChannelIDService::Request::~Request() { | 
|  | Cancel(); | 
|  | } | 
|  |  | 
|  | void ChannelIDService::Request::Cancel() { | 
|  | if (service_) { | 
|  | callback_.Reset(); | 
|  | job_->CancelRequest(this); | 
|  |  | 
|  | service_ = NULL; | 
|  | } | 
|  | } | 
|  |  | 
|  | void ChannelIDService::Request::RequestStarted( | 
|  | ChannelIDService* service, | 
|  | CompletionOnceCallback callback, | 
|  | std::unique_ptr<crypto::ECPrivateKey>* key, | 
|  | ChannelIDServiceJob* job) { | 
|  | DCHECK(service_ == NULL); | 
|  | service_ = service; | 
|  | callback_ = std::move(callback); | 
|  | key_ = key; | 
|  | job_ = job; | 
|  | } | 
|  |  | 
|  | void ChannelIDService::Request::Post( | 
|  | int error, | 
|  | std::unique_ptr<crypto::ECPrivateKey> key) { | 
|  | service_ = NULL; | 
|  | DCHECK(!callback_.is_null()); | 
|  | if (key) | 
|  | *key_ = std::move(key); | 
|  | // Running the callback might delete |this| (e.g. the callback cleans up | 
|  | // resources created for the request), so we can't touch any of our | 
|  | // members afterwards. Reset callback_ first. | 
|  | base::ResetAndReturn(&callback_).Run(error); | 
|  | } | 
|  |  | 
|  | ChannelIDService::ChannelIDService(ChannelIDStore* channel_id_store) | 
|  | : channel_id_store_(channel_id_store), | 
|  | id_(g_next_id.GetNext()), | 
|  | requests_(0), | 
|  | key_store_hits_(0), | 
|  | inflight_joins_(0), | 
|  | workers_created_(0), | 
|  | weak_ptr_factory_(this) {} | 
|  |  | 
|  | ChannelIDService::~ChannelIDService() { | 
|  | DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | 
|  | } | 
|  |  | 
|  | // static | 
|  | std::string ChannelIDService::GetDomainForHost(const std::string& host) { | 
|  | std::string domain = | 
|  | registry_controlled_domains::GetDomainAndRegistry( | 
|  | host, registry_controlled_domains::INCLUDE_PRIVATE_REGISTRIES); | 
|  | if (domain.empty()) | 
|  | return host; | 
|  | return domain; | 
|  | } | 
|  |  | 
|  | int ChannelIDService::GetOrCreateChannelID( | 
|  | const std::string& host, | 
|  | std::unique_ptr<crypto::ECPrivateKey>* key, | 
|  | CompletionOnceCallback callback, | 
|  | Request* out_req) { | 
|  | DVLOG(1) << __func__ << " " << host; | 
|  | DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | 
|  |  | 
|  | if (callback.is_null() || !key || host.empty()) { | 
|  | return ERR_INVALID_ARGUMENT; | 
|  | } | 
|  |  | 
|  | std::string domain = GetDomainForHost(host); | 
|  | if (domain.empty()) { | 
|  | return ERR_INVALID_ARGUMENT; | 
|  | } | 
|  |  | 
|  | requests_++; | 
|  |  | 
|  | // See if a request for the same domain is currently in flight. | 
|  | bool create_if_missing = true; | 
|  | if (JoinToInFlightRequest(domain, key, create_if_missing, &callback, | 
|  | out_req)) { | 
|  | return ERR_IO_PENDING; | 
|  | } | 
|  |  | 
|  | int err = LookupChannelID(domain, key, create_if_missing, &callback, out_req); | 
|  | if (err == ERR_FILE_NOT_FOUND) { | 
|  | // Sync lookup did not find a valid channel ID.  Start generating a new one. | 
|  | workers_created_++; | 
|  | ChannelIDServiceWorker* worker = new ChannelIDServiceWorker( | 
|  | domain, base::BindOnce(&ChannelIDService::GeneratedChannelID, | 
|  | weak_ptr_factory_.GetWeakPtr())); | 
|  | worker->Start(task_runner_); | 
|  |  | 
|  | // We are waiting for key generation.  Create a job & request to track it. | 
|  | ChannelIDServiceJob* job = new ChannelIDServiceJob(create_if_missing); | 
|  | inflight_[domain] = base::WrapUnique(job); | 
|  |  | 
|  | job->AddRequest(out_req); | 
|  | out_req->RequestStarted(this, std::move(callback), key, job); | 
|  | return ERR_IO_PENDING; | 
|  | } | 
|  |  | 
|  | return err; | 
|  | } | 
|  |  | 
|  | int ChannelIDService::GetChannelID(const std::string& host, | 
|  | std::unique_ptr<crypto::ECPrivateKey>* key, | 
|  | CompletionOnceCallback callback, | 
|  | Request* out_req) { | 
|  | DVLOG(1) << __func__ << " " << host; | 
|  | DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | 
|  |  | 
|  | if (callback.is_null() || !key || host.empty()) { | 
|  | return ERR_INVALID_ARGUMENT; | 
|  | } | 
|  |  | 
|  | std::string domain = GetDomainForHost(host); | 
|  | if (domain.empty()) { | 
|  | return ERR_INVALID_ARGUMENT; | 
|  | } | 
|  |  | 
|  | requests_++; | 
|  |  | 
|  | // See if a request for the same domain currently in flight. | 
|  | bool create_if_missing = false; | 
|  | if (JoinToInFlightRequest(domain, key, create_if_missing, &callback, | 
|  | out_req)) { | 
|  | return ERR_IO_PENDING; | 
|  | } | 
|  |  | 
|  | int err = LookupChannelID(domain, key, create_if_missing, &callback, out_req); | 
|  | return err; | 
|  | } | 
|  |  | 
|  | void ChannelIDService::GotChannelID(int err, | 
|  | const std::string& server_identifier, | 
|  | std::unique_ptr<crypto::ECPrivateKey> key) { | 
|  | DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | 
|  |  | 
|  | auto j = inflight_.find(server_identifier); | 
|  | if (j == inflight_.end()) { | 
|  | NOTREACHED(); | 
|  | return; | 
|  | } | 
|  |  | 
|  | if (err == OK) { | 
|  | // Async DB lookup found a valid channel ID. | 
|  | key_store_hits_++; | 
|  | // ChannelIDService::Request::Post will do the histograms and stuff. | 
|  | HandleResult(OK, server_identifier, std::move(key)); | 
|  | return; | 
|  | } | 
|  | // Async lookup failed or the channel ID was missing. Return the error | 
|  | // directly, unless the channel ID was missing and a request asked to create | 
|  | // one. | 
|  | if (err != ERR_FILE_NOT_FOUND || !j->second->CreateIfMissing()) { | 
|  | HandleResult(err, server_identifier, std::move(key)); | 
|  | return; | 
|  | } | 
|  | // At least one request asked to create a channel ID => start generating a new | 
|  | // one. | 
|  | workers_created_++; | 
|  | ChannelIDServiceWorker* worker = new ChannelIDServiceWorker( | 
|  | server_identifier, base::BindOnce(&ChannelIDService::GeneratedChannelID, | 
|  | weak_ptr_factory_.GetWeakPtr())); | 
|  | worker->Start(task_runner_); | 
|  | } | 
|  |  | 
|  | ChannelIDStore* ChannelIDService::GetChannelIDStore() { | 
|  | return channel_id_store_.get(); | 
|  | } | 
|  |  | 
|  | void ChannelIDService::GeneratedChannelID( | 
|  | const std::string& server_identifier, | 
|  | int error, | 
|  | std::unique_ptr<ChannelIDStore::ChannelID> channel_id) { | 
|  | DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | 
|  |  | 
|  | std::unique_ptr<crypto::ECPrivateKey> key; | 
|  | if (error == OK) { | 
|  | key = channel_id->key()->Copy(); | 
|  | channel_id_store_->SetChannelID(std::move(channel_id)); | 
|  | } | 
|  | HandleResult(error, server_identifier, std::move(key)); | 
|  | } | 
|  |  | 
|  | void ChannelIDService::HandleResult(int error, | 
|  | const std::string& server_identifier, | 
|  | std::unique_ptr<crypto::ECPrivateKey> key) { | 
|  | DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | 
|  |  | 
|  | auto j = inflight_.find(server_identifier); | 
|  | if (j == inflight_.end()) { | 
|  | NOTREACHED(); | 
|  | return; | 
|  | } | 
|  | std::unique_ptr<ChannelIDServiceJob> job = std::move(j->second); | 
|  | inflight_.erase(j); | 
|  |  | 
|  | job->HandleResult(error, std::move(key)); | 
|  | } | 
|  |  | 
|  | bool ChannelIDService::JoinToInFlightRequest( | 
|  | const std::string& domain, | 
|  | std::unique_ptr<crypto::ECPrivateKey>* key, | 
|  | bool create_if_missing, | 
|  | CompletionOnceCallback* callback, | 
|  | Request* out_req) { | 
|  | auto j = inflight_.find(domain); | 
|  | if (j == inflight_.end()) | 
|  | return false; | 
|  |  | 
|  | // A request for the same domain is in flight already. We'll attach our | 
|  | // callback, but we'll also mark it as requiring a channel ID if one's mising. | 
|  | ChannelIDServiceJob* job = j->second.get(); | 
|  | inflight_joins_++; | 
|  |  | 
|  | job->AddRequest(out_req, create_if_missing); | 
|  | out_req->RequestStarted(this, std::move(*callback), key, job); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | int ChannelIDService::LookupChannelID( | 
|  | const std::string& domain, | 
|  | std::unique_ptr<crypto::ECPrivateKey>* key, | 
|  | bool create_if_missing, | 
|  | CompletionOnceCallback* callback, | 
|  | Request* out_req) { | 
|  | // Check if a channel ID key already exists for this domain. | 
|  | int err = channel_id_store_->GetChannelID( | 
|  | domain, key, | 
|  | base::BindOnce(&ChannelIDService::GotChannelID, | 
|  | weak_ptr_factory_.GetWeakPtr())); | 
|  |  | 
|  | if (err == OK) { | 
|  | // Sync lookup found a valid channel ID. | 
|  | DVLOG(1) << "Channel ID store had valid key for " << domain; | 
|  | key_store_hits_++; | 
|  | return OK; | 
|  | } | 
|  |  | 
|  | if (err == ERR_IO_PENDING) { | 
|  | // We are waiting for async DB lookup.  Create a job & request to track it. | 
|  | ChannelIDServiceJob* job = new ChannelIDServiceJob(create_if_missing); | 
|  | inflight_[domain] = base::WrapUnique(job); | 
|  |  | 
|  | job->AddRequest(out_req); | 
|  | out_req->RequestStarted(this, std::move(*callback), key, job); | 
|  | return ERR_IO_PENDING; | 
|  | } | 
|  |  | 
|  | return err; | 
|  | } | 
|  |  | 
|  | int ChannelIDService::channel_id_count() { | 
|  | return channel_id_store_->GetChannelIDCount(); | 
|  | } | 
|  |  | 
|  | }  // namespace net |