blob: 563185185cf670a24ebd7b1de7293858ccc0d0a9 [file] [log] [blame]
// Copyright 2017 The Cobalt Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Note: This file is here for reference only, and is *NOT* part of Cobalt's
// build system. See adjacent README.md for build instructions.
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/ASTMatchers/ASTMatchers.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/Support/CommandLine.h"
#include <algorithm>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
using namespace clang::ast_matchers;
using namespace clang::tooling;
using namespace clang;
// My"Cool"String -> My\"Cool\"String
std::string EscapeJson(const std::string& string) {
std::ostringstream o;
for (auto c : string) {
switch (c) {
case '"':
o << "\\\"";
break;
case '\\':
o << "\\\\";
break;
case '\b':
o << "\\b";
break;
case '\f':
o << "\\f";
break;
case '\n':
o << "\\n";
break;
case '\r':
o << "\\r";
break;
case '\t':
o << "\\t";
break;
default:
if ('\x00' <= c && c <= '\x1f') {
o << "\\u" << std::hex << std::setw(4) << std::setfill('0')
<< static_cast<int>(c);
} else {
o << c;
}
}
}
return o.str();
}
enum JsonConversionType {
kNeedsEscape,
kAlreadyJson,
};
// Helper struct to make call sites of |ToJson| nice and fresh.
struct KeyValue {
KeyValue(const std::string& key, const std::string& value,
JsonConversionType conversion_type = kNeedsEscape) {
this->key = '"' + EscapeJson(key) + '"';
if (conversion_type == kNeedsEscape) {
this->value = '"' + EscapeJson(value) + '"';
} else if (conversion_type == kAlreadyJson) {
this->value = value;
} else {
assert(false);
}
}
KeyValue(const std::string& key, int value) {
this->key = '"' + EscapeJson(key) + '"';
this->value = std::to_string(value);
}
KeyValue(const std::string& key, unsigned value) {
this->key = '"' + EscapeJson(key) + '"';
this->value = std::to_string(value);
}
std::string key;
std::string value;
};
// {{"foo", "bar"}, {"baz", 5}} -> {"foo": "bar", "baz": 5}
std::string ToJson(std::vector<KeyValue> key_values) {
std::string result = "{";
bool needs_comma = false;
for (auto& key_value : key_values) {
if (needs_comma) result += ", ";
needs_comma = true;
result += key_value.key + ": " + key_value.value;
}
result += "}";
return result;
}
// Get an unsugared and unqualified string representation of
// |qual_type|.
std::string ToString(QualType qual_type) {
const Type* type = qual_type.getTypePtrOrNull();
if (!type) {
return "<INTERNALLY NULL>";
}
const Type* unqualified_desugared_type = type->getUnqualifiedDesugaredType();
CanQualType can_qual_type =
unqualified_desugared_type->getCanonicalTypeUnqualified();
return static_cast<QualType>(can_qual_type).getAsString();
}
// Return whether |base_decl| is a base class of |derived|.
bool IsBaseOf(CXXRecordDecl* base_decl, QualType derived) {
const Type* derived_type = derived.getTypePtrOrNull();
CXXRecordDecl* derived_decl = derived_type->getAsCXXRecordDecl();
if (!derived_decl) {
return false;
}
if (base_decl == derived_decl) {
return true;
}
for (auto& base : derived_decl->bases()) {
if (IsBaseOf(base_decl, base.getType())) {
return true;
}
}
return false;
}
// We need to keep track of types we've already visisted due to type cycles in
// templates. Consider, for example, "class Foo : public MyCrtp<Foo> {};".
using Visited = std::unordered_set<const Type*>;
// Utility functions to check if a type is roughly Traceable.
// "Is roughly Traceable" is defined as follows:
// 1. Traceable is roughly Traceable.
// 2. A type is roughly Traceable if any of its parent classes are roughly
// Traceable.
// 3. A type is roughly Traceable if any of its template arguments are roughly
// Traceable.
// 4. T* is roughly Traceable if T is roughly Traceable.
//
// Note that this definition is intentionally liberal in what it considers
// Traceable. False positives (such as the T* field in scoped_refptr) are
// expected to be handled by a white list that is implemented in a wrapper
// script (in order to facilitate modification of the white list without
// recompilation).
bool IsTraceable(QualType qual_type, Visited& visited);
bool IsTraceable(const RecordDecl* record_decl, Visited& visited);
bool IsTraceable(QualType qual_type) {
Visited visited;
return IsTraceable(qual_type, visited);
}
bool IsTraceable(QualType qual_type, Visited& visited) {
const Type* type = qual_type.getTypePtrOrNull();
if (!type) {
return false;
}
if (visited.count(type) == 1) {
return false;
}
visited.insert(type);
// 1. Traceable is roughly Traceable.
const Type* unqualified_desugared_type = type->getUnqualifiedDesugaredType();
CanQualType can_qual_type =
unqualified_desugared_type->getCanonicalTypeUnqualified();
if (static_cast<QualType>(can_qual_type).getAsString() ==
"class cobalt::script::Traceable") {
return true;
}
// 4. T* is roughly Traceable if T is roughly Traceable.
if (type->isPointerType()) {
if (IsTraceable(type->getPointeeType(), visited)) {
return true;
}
}
// 3. A type is roughly Traceable if any of its template arguments are
// roughly Traceable.
if (auto* tst = qual_type.getNonReferenceType()
->getAs<TemplateSpecializationType>()) {
for (auto& arg : *tst) {
if (arg.getKind() == TemplateArgument::Type) {
if (IsTraceable(arg.getAsType(), visited)) {
return true;
}
}
}
}
return IsTraceable(type->getAsCXXRecordDecl(), visited);
}
bool IsTraceable(const RecordDecl* record_decl,
std::unordered_set<const Type*>& visited) {
if (record_decl == nullptr) {
return false;
}
const CXXRecordDecl* cxx_record_decl =
dyn_cast<const CXXRecordDecl>(record_decl);
if (!cxx_record_decl || !cxx_record_decl->hasDefinition()) {
return false;
}
// 2. A type is roughly Traceable if any of its parent classes are roughly
// Traceable.
for (auto& base : cxx_record_decl->bases()) {
if (IsTraceable(base.getType(), visited)) {
return true;
}
}
return false;
}
// Search |stmt| (and its children) for a reference to a member
// expression named |target|.
// clang-format off
// Example tree:
// tracer->Trace(wrappable_);
// |-CXXMemberCallExpr 0x98b2b80 'void'
// | |-MemberExpr 0x98b2af8 '<bound member function type>' ->Trace 0x5d490b0
// | | `-ImplicitCastExpr 0x98b2ae0 'script::Tracer *' <LValueToRValue>
// | | `-DeclRefExpr 0x98b2ab8 'script::Tracer *' lvalue ParmVar 0x98b27e0 'tracer' 'script::Tracer *'
// | `-ImplicitCastExpr 0x98b2c78 'class cobalt::script::Traceable *' <DerivedToBase (Traceable)>
// | `-ImplicitCastExpr 0x98b2c60 'class cobalt::dom::DOMImplementation *' <UserDefinedConversion>
// | `-CXXMemberCallExpr 0x98b2c38 'class cobalt::dom::DOMImplementation *'
// | `-MemberExpr 0x98b2c00 '<bound member function type>' .operator cobalt::dom::DOMImplementation * 0x90f1ad8
// | `-ImplicitCastExpr 0x98b2be8 'const class scoped_refptr<class cobalt::dom::DOMImplementation>' lvalue <NoOp>
// | `-MemberExpr 0x98b2b48 'scoped_refptr<class cobalt::dom::DOMImplementation>':'class scoped_refptr<class cobalt::dom::DOMImplementation>' lvalue ->implementation_ 0x90f2be8
// | `-CXXThisExpr 0x98b2b30 'class cobalt::dom::Document *' this
// clang-format on
bool SearchForMemberExpr(Stmt* stmt, const std::string& target) {
if (MemberExpr* member_expr = dyn_cast<MemberExpr>(stmt)) {
if (member_expr->getMemberDecl()->getNameAsString() == target) {
return true;
}
}
for (auto& child : stmt->children()) {
if (SearchForMemberExpr(child, target)) {
return true;
}
}
return false;
}
// Attempt to get the |n|th child of |body| as a |T|. Returns nullptr if
// there are not enough children, or the child was not |dyn_cast|able to |T|.
template <typename T>
T* MaybeGetNthChildAs(Stmt* body, int n) {
auto children = body->children();
if (std::distance(children.begin(), children.end()) < n) {
return nullptr;
}
auto it = children.begin();
std::advance(it, n);
return dyn_cast<T>(*it);
}
// Check if |stmt| matches something like:
// clang-format off
// (this example is from inside void Document::TraceMembers(script::Tracer* tracer))
// Node::TraceMembers(tracer);
// Example tree:
// |-CXXMemberCallExpr 0x98b2a50 'void'
// | |-MemberExpr 0x98b29d8 '<bound member function type>' ->TraceMembers 0x7b979a0
// | | `-ImplicitCastExpr 0x98b2a80 'class cobalt::dom::Node *' <UncheckedDerivedToBase (Node)>
// | | `-CXXThisExpr 0x98b29c0 'class cobalt::dom::Document *' this
// | `-ImplicitCastExpr 0x98b2aa0 'script::Tracer *' <LValueToRValue>
// | `-DeclRefExpr 0x98b2a28 'script::Tracer *' lvalue ParmVar 0x98b27e0 'tracer' 'script::Tracer *'
// clang-format on
bool IsBaseClassTracerCall(Stmt* stmt) {
CXXMemberCallExpr* cxx_member_call_expr = dyn_cast<CXXMemberCallExpr>(stmt);
if (!cxx_member_call_expr) {
return false;
}
{
MemberExpr* member_expr =
MaybeGetNthChildAs<MemberExpr>(cxx_member_call_expr, 0);
if (!member_expr) {
return false;
}
{
ImplicitCastExpr* implicit_cast_expr =
MaybeGetNthChildAs<ImplicitCastExpr>(member_expr, 0);
if (!implicit_cast_expr) {
return false;
}
{
CXXThisExpr* cxx_this_expr =
MaybeGetNthChildAs<CXXThisExpr>(implicit_cast_expr, 0);
if (!cxx_this_expr) {
return false;
}
auto GetPointeeType = [](QualType qual_type) -> const Type* {
const Type* type = qual_type.getTypePtrOrNull();
if (!type) {
return nullptr;
}
if (!type->isPointerType()) {
return nullptr;
}
QualType pointee_qual_type = type->getPointeeType();
return pointee_qual_type.getTypePtrOrNull();
};
// Ensure that we are calling our base's TraceMembers.
const Type* this_type = GetPointeeType(cxx_this_expr->getType());
const Type* cast_type = GetPointeeType(implicit_cast_expr->getType());
if (!this_type || !cast_type) {
return false;
}
CXXRecordDecl* cast_decl = cast_type->getAsCXXRecordDecl();
CXXRecordDecl* this_decl = cast_type->getAsCXXRecordDecl();
if (!cast_decl || !this_decl) {
return false;
}
if (!IsBaseOf(cast_decl, cxx_this_expr->getType())) {
return false;
}
}
}
}
{
ImplicitCastExpr* implicit_cast_expr =
MaybeGetNthChildAs<ImplicitCastExpr>(cxx_member_call_expr, 1);
if (!implicit_cast_expr) {
return false;
}
// TODO: Check if this is of type script::Tracer*
}
return true;
}
class FieldPrinter : public MatchFinder::MatchCallback {
public:
// Iterate over each field declaration (FieldDecl) of each class. Check
// that the owning class's |TraceMembers| traces it.
void run(const MatchFinder::MatchResult& result) override {
const FieldDecl* field_decl =
result.Nodes.getNodeAs<clang::FieldDecl>("fieldDecl");
if (!field_decl) {
return;
}
QualType qual_type = field_decl->getType();
if (!IsTraceable(qual_type)) {
return;
}
const RecordDecl* parent = field_decl->getParent();
const CXXRecordDecl* cxx_parent = dyn_cast<const CXXRecordDecl>(parent);
if (!cxx_parent) {
return;
}
bool is_traced = false;
bool had_trace_members_declaration = false;
bool had_trace_members_body = false;
bool calls_base_trace_members = false;
auto methods = cxx_parent->methods();
auto trace_members_it =
std::find_if(methods.begin(), methods.end(), [](CXXMethodDecl* method) {
return method->getNameInfo().getName().getAsString() ==
"TraceMembers";
});
if (trace_members_it != methods.end()) {
auto method = *trace_members_it;
had_trace_members_declaration = true;
Stmt* body = method->getBody();
if (body) {
had_trace_members_body = true;
bool first_statement = true;
for (Stmt* stmt : body->children()) {
if (first_statement) {
first_statement = false;
calls_base_trace_members |= IsBaseClassTracerCall(stmt);
}
if (CXXMemberCallExpr* call = dyn_cast<CXXMemberCallExpr>(stmt)) {
// TODO: Check that this is actually a tracer call
int nargs = call->getNumArgs();
if (nargs == 0) {
continue;
}
Expr* arg = call->getArg(0);
is_traced |=
SearchForMemberExpr(arg, field_decl->getNameAsString());
}
}
}
}
std::string field_name = field_decl->getNameAsString();
std::string field_class = ToString(qual_type);
std::string field_class_friendly = qual_type.getAsString();
std::string parent_class =
cxx_parent->getCanonicalDecl()->getQualifiedNameAsString();
std::string parent_class_friendly = cxx_parent->getNameAsString();
// TODO: Maybe get some line/column numbers and a filename too?
if (!had_trace_members_declaration) {
std::cout << ToJson({
{"messageType", "needsTraceMembersDeclaration"},
{"fieldName", field_name},
{"fieldClass", field_class},
{"fieldClassFriendly", field_class_friendly},
{"parentClass", parent_class},
{"parentClassFriendly", parent_class_friendly},
})
<< '\n';
} else {
if (had_trace_members_body) {
if (!is_traced) {
std::cout << ToJson({
{"messageType", "needsTracerTraceField"},
{"fieldName", field_name},
{"fieldClass", field_class},
{"fieldClassFriendly", field_class_friendly},
{"parentClass", parent_class},
{"parentClassFriendly", parent_class_friendly},
})
<< '\n';
}
// Note that for a class that fails to call its base class's
// |TraceMembers|, this message will be printed for each field
// declaration it has. We expect the script wrapping us to de-dupe
// these messages.
if (!calls_base_trace_members) {
// We need to dump the base classes so our wrapper script can filter
// out classes with direct bases of script::Wrappable and
// script::Traceable.
std::string bases = "[";
bool first = true;
for (auto base : cxx_parent->bases()) {
if (!first) bases += ", ";
first = false;
bases += '"' + EscapeJson(ToString(base.getType())) + '"';
}
bases += "]";
std::cout << ToJson({
{"messageType", "needsCallBaseTraceMembers"},
{"parentClass", parent_class},
{"parentClassFriendly", parent_class_friendly},
{"baseNames", bases, kAlreadyJson},
})
<< '\n';
}
}
// Don't handle the case where there isn't a body. We will find it
// later in another translation unit, or else the code would not be
// able to link successfully.
}
}
};
int main(int argc, const char** argv) {
llvm::cl::OptionCategory my_tool_category("verify-trace-members options");
llvm::cl::extrahelp common_help(CommonOptionsParser::HelpMessage);
llvm::cl::extrahelp more_help("\nMore help...");
CommonOptionsParser options_parser(argc, argv, my_tool_category);
ClangTool tool(options_parser.getCompilations(),
options_parser.getSourcePathList());
FieldPrinter field_printer;
MatchFinder match_finder;
DeclarationMatcher field_matcher = fieldDecl().bind("fieldDecl");
match_finder.addMatcher(field_matcher, &field_printer);
return tool.run(newFrontendActionFactory(&match_finder).get());
}