blob: 3562009032091c6f66810abc872b6162d2a3b844 [file] [log] [blame]
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "dispatch.h"
#include <cassert>
#include "cbor.h"
#include "error_support.h"
#include "find_by_first.h"
#include "frontend_channel.h"
#include "protocol_core.h"
namespace v8_crdtp {
// =============================================================================
// DispatchResponse - Error status and chaining / fall through
// =============================================================================
// static
DispatchResponse DispatchResponse::Success() {
DispatchResponse result;
result.code_ = DispatchCode::SUCCESS;
return result;
}
// static
DispatchResponse DispatchResponse::FallThrough() {
DispatchResponse result;
result.code_ = DispatchCode::FALL_THROUGH;
return result;
}
// static
DispatchResponse DispatchResponse::ParseError(std::string message) {
DispatchResponse result;
result.code_ = DispatchCode::PARSE_ERROR;
result.message_ = std::move(message);
return result;
}
// static
DispatchResponse DispatchResponse::InvalidRequest(std::string message) {
DispatchResponse result;
result.code_ = DispatchCode::INVALID_REQUEST;
result.message_ = std::move(message);
return result;
}
// static
DispatchResponse DispatchResponse::MethodNotFound(std::string message) {
DispatchResponse result;
result.code_ = DispatchCode::METHOD_NOT_FOUND;
result.message_ = std::move(message);
return result;
}
// static
DispatchResponse DispatchResponse::InvalidParams(std::string message) {
DispatchResponse result;
result.code_ = DispatchCode::INVALID_PARAMS;
result.message_ = std::move(message);
return result;
}
// static
DispatchResponse DispatchResponse::InternalError() {
DispatchResponse result;
result.code_ = DispatchCode::INTERNAL_ERROR;
result.message_ = "Internal error";
return result;
}
// static
DispatchResponse DispatchResponse::ServerError(std::string message) {
DispatchResponse result;
result.code_ = DispatchCode::SERVER_ERROR;
result.message_ = std::move(message);
return result;
}
// =============================================================================
// Dispatchable - a shallow parser for CBOR encoded DevTools messages
// =============================================================================
namespace {
constexpr size_t kEncodedEnvelopeHeaderSize = 1 + 1 + sizeof(uint32_t);
} // namespace
Dispatchable::Dispatchable(span<uint8_t> serialized) : serialized_(serialized) {
Status s = cbor::CheckCBORMessage(serialized);
if (!s.ok()) {
status_ = {Error::MESSAGE_MUST_BE_AN_OBJECT, s.pos};
return;
}
cbor::CBORTokenizer tokenizer(serialized);
if (tokenizer.TokenTag() == cbor::CBORTokenTag::ERROR_VALUE) {
status_ = tokenizer.Status();
return;
}
// We checked for the envelope start byte above, so the tokenizer
// must agree here, since it's not an error.
assert(tokenizer.TokenTag() == cbor::CBORTokenTag::ENVELOPE);
// Before we enter the envelope, we save the position that we
// expect to see after we're done parsing the envelope contents.
// This way we can compare and produce an error if the contents
// didn't fit exactly into the envelope length.
const size_t pos_past_envelope = tokenizer.Status().pos +
kEncodedEnvelopeHeaderSize +
tokenizer.GetEnvelopeContents().size();
tokenizer.EnterEnvelope();
if (tokenizer.TokenTag() == cbor::CBORTokenTag::ERROR_VALUE) {
status_ = tokenizer.Status();
return;
}
if (tokenizer.TokenTag() != cbor::CBORTokenTag::MAP_START) {
status_ = {Error::MESSAGE_MUST_BE_AN_OBJECT, tokenizer.Status().pos};
return;
}
assert(tokenizer.TokenTag() == cbor::CBORTokenTag::MAP_START);
tokenizer.Next(); // Now we should be pointed at the map key.
while (tokenizer.TokenTag() != cbor::CBORTokenTag::STOP) {
switch (tokenizer.TokenTag()) {
case cbor::CBORTokenTag::DONE:
status_ =
Status{Error::CBOR_UNEXPECTED_EOF_IN_MAP, tokenizer.Status().pos};
return;
case cbor::CBORTokenTag::ERROR_VALUE:
status_ = tokenizer.Status();
return;
case cbor::CBORTokenTag::STRING8:
if (!MaybeParseProperty(&tokenizer))
return;
break;
default:
// We require the top-level keys to be UTF8 (US-ASCII in practice).
status_ = Status{Error::CBOR_INVALID_MAP_KEY, tokenizer.Status().pos};
return;
}
}
tokenizer.Next();
if (!has_call_id_) {
status_ = Status{Error::MESSAGE_MUST_HAVE_INTEGER_ID_PROPERTY,
tokenizer.Status().pos};
return;
}
if (method_.empty()) {
status_ = Status{Error::MESSAGE_MUST_HAVE_STRING_METHOD_PROPERTY,
tokenizer.Status().pos};
return;
}
// The contents of the envelope parsed OK, now check that we're at
// the expected position.
if (pos_past_envelope != tokenizer.Status().pos) {
status_ = Status{Error::CBOR_ENVELOPE_CONTENTS_LENGTH_MISMATCH,
tokenizer.Status().pos};
return;
}
if (tokenizer.TokenTag() != cbor::CBORTokenTag::DONE) {
status_ = Status{Error::CBOR_TRAILING_JUNK, tokenizer.Status().pos};
return;
}
}
bool Dispatchable::ok() const {
return status_.ok();
}
DispatchResponse Dispatchable::DispatchError() const {
// TODO(johannes): Replace with DCHECK / similar?
if (status_.ok())
return DispatchResponse::Success();
if (status_.IsMessageError())
return DispatchResponse::InvalidRequest(status_.Message());
return DispatchResponse::ParseError(status_.ToASCIIString());
}
bool Dispatchable::MaybeParseProperty(cbor::CBORTokenizer* tokenizer) {
span<uint8_t> property_name = tokenizer->GetString8();
if (SpanEquals(SpanFrom("id"), property_name))
return MaybeParseCallId(tokenizer);
if (SpanEquals(SpanFrom("method"), property_name))
return MaybeParseMethod(tokenizer);
if (SpanEquals(SpanFrom("params"), property_name))
return MaybeParseParams(tokenizer);
if (SpanEquals(SpanFrom("sessionId"), property_name))
return MaybeParseSessionId(tokenizer);
status_ =
Status{Error::MESSAGE_HAS_UNKNOWN_PROPERTY, tokenizer->Status().pos};
return false;
}
bool Dispatchable::MaybeParseCallId(cbor::CBORTokenizer* tokenizer) {
if (has_call_id_) {
status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
return false;
}
tokenizer->Next();
if (tokenizer->TokenTag() != cbor::CBORTokenTag::INT32) {
status_ = Status{Error::MESSAGE_MUST_HAVE_INTEGER_ID_PROPERTY,
tokenizer->Status().pos};
return false;
}
call_id_ = tokenizer->GetInt32();
has_call_id_ = true;
tokenizer->Next();
return true;
}
bool Dispatchable::MaybeParseMethod(cbor::CBORTokenizer* tokenizer) {
if (!method_.empty()) {
status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
return false;
}
tokenizer->Next();
if (tokenizer->TokenTag() != cbor::CBORTokenTag::STRING8) {
status_ = Status{Error::MESSAGE_MUST_HAVE_STRING_METHOD_PROPERTY,
tokenizer->Status().pos};
return false;
}
method_ = tokenizer->GetString8();
tokenizer->Next();
return true;
}
bool Dispatchable::MaybeParseParams(cbor::CBORTokenizer* tokenizer) {
if (params_seen_) {
status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
return false;
}
params_seen_ = true;
tokenizer->Next();
if (tokenizer->TokenTag() == cbor::CBORTokenTag::NULL_VALUE) {
tokenizer->Next();
return true;
}
if (tokenizer->TokenTag() != cbor::CBORTokenTag::ENVELOPE) {
status_ = Status{Error::MESSAGE_MAY_HAVE_OBJECT_PARAMS_PROPERTY,
tokenizer->Status().pos};
return false;
}
params_ = tokenizer->GetEnvelope();
tokenizer->Next();
return true;
}
bool Dispatchable::MaybeParseSessionId(cbor::CBORTokenizer* tokenizer) {
if (!session_id_.empty()) {
status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
return false;
}
tokenizer->Next();
if (tokenizer->TokenTag() != cbor::CBORTokenTag::STRING8) {
status_ = Status{Error::MESSAGE_MAY_HAVE_STRING_SESSION_ID_PROPERTY,
tokenizer->Status().pos};
return false;
}
session_id_ = tokenizer->GetString8();
tokenizer->Next();
return true;
}
namespace {
class ProtocolError : public Serializable {
public:
explicit ProtocolError(DispatchResponse dispatch_response)
: dispatch_response_(std::move(dispatch_response)) {}
void AppendSerialized(std::vector<uint8_t>* out) const override {
Status status;
std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
encoder->HandleMapBegin();
if (has_call_id_) {
encoder->HandleString8(SpanFrom("id"));
encoder->HandleInt32(call_id_);
}
encoder->HandleString8(SpanFrom("error"));
encoder->HandleMapBegin();
encoder->HandleString8(SpanFrom("code"));
encoder->HandleInt32(static_cast<int32_t>(dispatch_response_.Code()));
encoder->HandleString8(SpanFrom("message"));
encoder->HandleString8(SpanFrom(dispatch_response_.Message()));
if (!data_.empty()) {
encoder->HandleString8(SpanFrom("data"));
encoder->HandleString8(SpanFrom(data_));
}
encoder->HandleMapEnd();
encoder->HandleMapEnd();
assert(status.ok());
}
void SetCallId(int call_id) {
has_call_id_ = true;
call_id_ = call_id;
}
void SetData(std::string data) { data_ = std::move(data); }
private:
const DispatchResponse dispatch_response_;
std::string data_;
int call_id_ = 0;
bool has_call_id_ = false;
};
} // namespace
// =============================================================================
// Helpers for creating protocol cresponses and notifications.
// =============================================================================
std::unique_ptr<Serializable> CreateErrorResponse(
int call_id,
DispatchResponse dispatch_response,
const ErrorSupport* errors) {
auto protocol_error =
std::make_unique<ProtocolError>(std::move(dispatch_response));
protocol_error->SetCallId(call_id);
if (errors && !errors->Errors().empty()) {
protocol_error->SetData(
std::string(errors->Errors().begin(), errors->Errors().end()));
}
return protocol_error;
}
std::unique_ptr<Serializable> CreateErrorResponse(
int call_id,
DispatchResponse dispatch_response,
const DeserializerState& state) {
auto protocol_error =
std::make_unique<ProtocolError>(std::move(dispatch_response));
protocol_error->SetCallId(call_id);
// TODO(caseq): should we plumb the call name here?
protocol_error->SetData(state.ErrorMessage(MakeSpan("params")));
return protocol_error;
}
std::unique_ptr<Serializable> CreateErrorNotification(
DispatchResponse dispatch_response) {
return std::make_unique<ProtocolError>(std::move(dispatch_response));
}
namespace {
class Response : public Serializable {
public:
Response(int call_id, std::unique_ptr<Serializable> params)
: call_id_(call_id), params_(std::move(params)) {}
void AppendSerialized(std::vector<uint8_t>* out) const override {
Status status;
std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
encoder->HandleMapBegin();
encoder->HandleString8(SpanFrom("id"));
encoder->HandleInt32(call_id_);
encoder->HandleString8(SpanFrom("result"));
if (params_) {
params_->AppendSerialized(out);
} else {
encoder->HandleMapBegin();
encoder->HandleMapEnd();
}
encoder->HandleMapEnd();
assert(status.ok());
}
private:
const int call_id_;
std::unique_ptr<Serializable> params_;
};
class Notification : public Serializable {
public:
Notification(const char* method, std::unique_ptr<Serializable> params)
: method_(method), params_(std::move(params)) {}
void AppendSerialized(std::vector<uint8_t>* out) const override {
Status status;
std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
encoder->HandleMapBegin();
encoder->HandleString8(SpanFrom("method"));
encoder->HandleString8(SpanFrom(method_));
encoder->HandleString8(SpanFrom("params"));
if (params_) {
params_->AppendSerialized(out);
} else {
encoder->HandleMapBegin();
encoder->HandleMapEnd();
}
encoder->HandleMapEnd();
assert(status.ok());
}
private:
const char* method_;
std::unique_ptr<Serializable> params_;
};
} // namespace
std::unique_ptr<Serializable> CreateResponse(
int call_id,
std::unique_ptr<Serializable> params) {
return std::make_unique<Response>(call_id, std::move(params));
}
std::unique_ptr<Serializable> CreateNotification(
const char* method,
std::unique_ptr<Serializable> params) {
return std::make_unique<Notification>(method, std::move(params));
}
// =============================================================================
// DomainDispatcher - Dispatching betwen protocol methods within a domain.
// =============================================================================
DomainDispatcher::WeakPtr::WeakPtr(DomainDispatcher* dispatcher)
: dispatcher_(dispatcher) {}
DomainDispatcher::WeakPtr::~WeakPtr() {
if (dispatcher_)
dispatcher_->weak_ptrs_.erase(this);
}
DomainDispatcher::Callback::~Callback() = default;
void DomainDispatcher::Callback::dispose() {
backend_impl_ = nullptr;
}
DomainDispatcher::Callback::Callback(
std::unique_ptr<DomainDispatcher::WeakPtr> backend_impl,
int call_id,
span<uint8_t> method,
span<uint8_t> message)
: backend_impl_(std::move(backend_impl)),
call_id_(call_id),
method_(method),
message_(message.begin(), message.end()) {}
void DomainDispatcher::Callback::sendIfActive(
std::unique_ptr<Serializable> partialMessage,
const DispatchResponse& response) {
if (!backend_impl_ || !backend_impl_->get())
return;
backend_impl_->get()->sendResponse(call_id_, response,
std::move(partialMessage));
backend_impl_ = nullptr;
}
void DomainDispatcher::Callback::fallThroughIfActive() {
if (!backend_impl_ || !backend_impl_->get())
return;
backend_impl_->get()->channel()->FallThrough(call_id_, method_,
SpanFrom(message_));
backend_impl_ = nullptr;
}
DomainDispatcher::DomainDispatcher(FrontendChannel* frontendChannel)
: frontend_channel_(frontendChannel) {}
DomainDispatcher::~DomainDispatcher() {
clearFrontend();
}
void DomainDispatcher::sendResponse(int call_id,
const DispatchResponse& response,
std::unique_ptr<Serializable> result) {
if (!frontend_channel_)
return;
std::unique_ptr<Serializable> serializable;
if (response.IsError()) {
serializable = CreateErrorResponse(call_id, response);
} else {
serializable = CreateResponse(call_id, std::move(result));
}
frontend_channel_->SendProtocolResponse(call_id, std::move(serializable));
}
bool DomainDispatcher::MaybeReportInvalidParams(
const Dispatchable& dispatchable,
const ErrorSupport& errors) {
if (errors.Errors().empty())
return false;
if (frontend_channel_) {
frontend_channel_->SendProtocolResponse(
dispatchable.CallId(),
CreateErrorResponse(
dispatchable.CallId(),
DispatchResponse::InvalidParams("Invalid parameters"), &errors));
}
return true;
}
bool DomainDispatcher::MaybeReportInvalidParams(
const Dispatchable& dispatchable,
const DeserializerState& state) {
if (state.status().ok())
return false;
if (frontend_channel_) {
frontend_channel_->SendProtocolResponse(
dispatchable.CallId(),
CreateErrorResponse(
dispatchable.CallId(),
DispatchResponse::InvalidParams("Invalid parameters"), state));
}
return true;
}
void DomainDispatcher::clearFrontend() {
frontend_channel_ = nullptr;
for (auto& weak : weak_ptrs_)
weak->dispose();
weak_ptrs_.clear();
}
std::unique_ptr<DomainDispatcher::WeakPtr> DomainDispatcher::weakPtr() {
auto weak = std::make_unique<DomainDispatcher::WeakPtr>(this);
weak_ptrs_.insert(weak.get());
return weak;
}
// =============================================================================
// UberDispatcher - dispatches between domains (backends).
// =============================================================================
UberDispatcher::DispatchResult::DispatchResult(bool method_found,
std::function<void()> runnable)
: method_found_(method_found), runnable_(runnable) {}
void UberDispatcher::DispatchResult::Run() {
if (!runnable_)
return;
runnable_();
runnable_ = nullptr;
}
UberDispatcher::UberDispatcher(FrontendChannel* frontend_channel)
: frontend_channel_(frontend_channel) {
assert(frontend_channel);
}
UberDispatcher::~UberDispatcher() = default;
constexpr size_t kNotFound = std::numeric_limits<size_t>::max();
namespace {
size_t DotIdx(span<uint8_t> method) {
const void* p = memchr(method.data(), '.', method.size());
return p ? reinterpret_cast<const uint8_t*>(p) - method.data() : kNotFound;
}
} // namespace
UberDispatcher::DispatchResult UberDispatcher::Dispatch(
const Dispatchable& dispatchable) const {
span<uint8_t> method = FindByFirst(redirects_, dispatchable.Method(),
/*default_value=*/dispatchable.Method());
size_t dot_idx = DotIdx(method);
if (dot_idx != kNotFound) {
span<uint8_t> domain = method.subspan(0, dot_idx);
span<uint8_t> command = method.subspan(dot_idx + 1);
DomainDispatcher* dispatcher = FindByFirst(dispatchers_, domain);
if (dispatcher) {
std::function<void(const Dispatchable&)> dispatched =
dispatcher->Dispatch(command);
if (dispatched) {
return DispatchResult(
true, [dispatchable, dispatched = std::move(dispatched)]() {
dispatched(dispatchable);
});
}
}
}
return DispatchResult(false, [this, dispatchable]() {
frontend_channel_->SendProtocolResponse(
dispatchable.CallId(),
CreateErrorResponse(dispatchable.CallId(),
DispatchResponse::MethodNotFound(
"'" +
std::string(dispatchable.Method().begin(),
dispatchable.Method().end()) +
"' wasn't found")));
});
}
template <typename T>
struct FirstLessThan {
bool operator()(const std::pair<span<uint8_t>, T>& left,
const std::pair<span<uint8_t>, T>& right) {
return SpanLessThan(left.first, right.first);
}
};
void UberDispatcher::WireBackend(
span<uint8_t> domain,
const std::vector<std::pair<span<uint8_t>, span<uint8_t>>>&
sorted_redirects,
std::unique_ptr<DomainDispatcher> dispatcher) {
auto it = redirects_.insert(redirects_.end(), sorted_redirects.begin(),
sorted_redirects.end());
std::inplace_merge(redirects_.begin(), it, redirects_.end(),
FirstLessThan<span<uint8_t>>());
auto jt = dispatchers_.insert(dispatchers_.end(),
std::make_pair(domain, std::move(dispatcher)));
std::inplace_merge(dispatchers_.begin(), jt, dispatchers_.end(),
FirstLessThan<std::unique_ptr<DomainDispatcher>>());
}
} // namespace v8_crdtp