// Copyright 2017 The Cobalt Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Note: This file is here for reference only, and is *NOT* part of Cobalt's
// build system.  See adjacent README.md for build instructions.

#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/ASTMatchers/ASTMatchers.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/Support/CommandLine.h"

#include <algorithm>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>

using namespace clang::ast_matchers;
using namespace clang::tooling;
using namespace clang;

// My"Cool"String -> My\"Cool\"String
std::string EscapeJson(const std::string& string) {
  std::ostringstream o;
  for (auto c : string) {
    switch (c) {
      case '"':
        o << "\\\"";
        break;
      case '\\':
        o << "\\\\";
        break;
      case '\b':
        o << "\\b";
        break;
      case '\f':
        o << "\\f";
        break;
      case '\n':
        o << "\\n";
        break;
      case '\r':
        o << "\\r";
        break;
      case '\t':
        o << "\\t";
        break;
      default:
        if ('\x00' <= c && c <= '\x1f') {
          o << "\\u" << std::hex << std::setw(4) << std::setfill('0')
            << static_cast<int>(c);
        } else {
          o << c;
        }
    }
  }
  return o.str();
}

enum JsonConversionType {
  kNeedsEscape,
  kAlreadyJson,
};

// Helper struct to make call sites of |ToJson| nice and fresh.
struct KeyValue {
  KeyValue(const std::string& key, const std::string& value,
           JsonConversionType conversion_type = kNeedsEscape) {
    this->key = '"' + EscapeJson(key) + '"';
    if (conversion_type == kNeedsEscape) {
      this->value = '"' + EscapeJson(value) + '"';
    } else if (conversion_type == kAlreadyJson) {
      this->value = value;
    } else {
      assert(false);
    }
  }
  KeyValue(const std::string& key, int value) {
    this->key = '"' + EscapeJson(key) + '"';
    this->value = std::to_string(value);
  }
  KeyValue(const std::string& key, unsigned value) {
    this->key = '"' + EscapeJson(key) + '"';
    this->value = std::to_string(value);
  }

  std::string key;
  std::string value;
};

// {{"foo", "bar"}, {"baz", 5}} -> {"foo": "bar", "baz": 5}
std::string ToJson(std::vector<KeyValue> key_values) {
  std::string result = "{";
  bool needs_comma = false;
  for (auto& key_value : key_values) {
    if (needs_comma) result += ", ";
    needs_comma = true;
    result += key_value.key + ": " + key_value.value;
  }
  result += "}";
  return result;
}

// Get an unsugared and unqualified string representation of
// |qual_type|.
std::string ToString(QualType qual_type) {
  const Type* type = qual_type.getTypePtrOrNull();
  if (!type) {
    return "<INTERNALLY NULL>";
  }
  const Type* unqualified_desugared_type = type->getUnqualifiedDesugaredType();
  CanQualType can_qual_type =
      unqualified_desugared_type->getCanonicalTypeUnqualified();
  return static_cast<QualType>(can_qual_type).getAsString();
}

// Return whether |base_decl| is a base class of |derived|.
bool IsBaseOf(CXXRecordDecl* base_decl, QualType derived) {
  const Type* derived_type = derived.getTypePtrOrNull();
  CXXRecordDecl* derived_decl = derived_type->getAsCXXRecordDecl();
  if (!derived_decl) {
    return false;
  }
  if (base_decl == derived_decl) {
    return true;
  }
  for (auto& base : derived_decl->bases()) {
    if (IsBaseOf(base_decl, base.getType())) {
      return true;
    }
  }
  return false;
}

// We need to keep track of types we've already visited due to type cycles in
// templates. Consider, for example, "class Foo : public MyCrtp<Foo> {};".
using Visited = std::unordered_set<const Type*>;

// Utility functions to check if a type is roughly Traceable.
// "Is roughly Traceable" is defined as follows:
//   1. Traceable is roughly Traceable.
//   2. A type is roughly Traceable if any of its parent classes are roughly
//   Traceable.
//   3. A type is roughly Traceable if any of its template arguments are roughly
//   Traceable.
//   4. T* is roughly Traceable if T is roughly Traceable.
//
// Note that this definition is intentionally liberal in what it considers
// Traceable.  False positives (such as the T* field in scoped_refptr) are
// expected to be handled by a white list that is implemented in a wrapper
// script (in order to facilitate modification of the white list without
// recompilation).
bool IsTraceable(QualType qual_type, Visited& visited);
bool IsTraceable(const RecordDecl* record_decl, Visited& visited);
bool IsTraceable(QualType qual_type) {
  Visited visited;
  return IsTraceable(qual_type, visited);
}

bool IsTraceable(QualType qual_type, Visited& visited) {
  const Type* type = qual_type.getTypePtrOrNull();
  if (!type) {
    return false;
  }
  if (visited.count(type) == 1) {
    return false;
  }
  visited.insert(type);

  // 1. Traceable is roughly Traceable.
  const Type* unqualified_desugared_type = type->getUnqualifiedDesugaredType();
  CanQualType can_qual_type =
      unqualified_desugared_type->getCanonicalTypeUnqualified();
  if (static_cast<QualType>(can_qual_type).getAsString() ==
      "class cobalt::script::Traceable") {
    return true;
  }

  // 4. T* is roughly Traceable if T is roughly Traceable.
  if (type->isPointerType()) {
    if (IsTraceable(type->getPointeeType(), visited)) {
      return true;
    }
  }

  // 3. A type is roughly Traceable if any of its template arguments are
  // roughly Traceable.
  if (auto* tst = qual_type.getNonReferenceType()
                      ->getAs<TemplateSpecializationType>()) {
    for (auto& arg : *tst) {
      if (arg.getKind() == TemplateArgument::Type) {
        if (IsTraceable(arg.getAsType(), visited)) {
          return true;
        }
      }
    }
  }

  return IsTraceable(type->getAsCXXRecordDecl(), visited);
}

bool IsTraceable(const RecordDecl* record_decl,
                 std::unordered_set<const Type*>& visited) {
  if (record_decl == nullptr) {
    return false;
  }
  const CXXRecordDecl* cxx_record_decl =
      dyn_cast<const CXXRecordDecl>(record_decl);
  if (!cxx_record_decl || !cxx_record_decl->hasDefinition()) {
    return false;
  }
  // 2. A type is roughly Traceable if any of its parent classes are roughly
  // Traceable.
  for (auto& base : cxx_record_decl->bases()) {
    if (IsTraceable(base.getType(), visited)) {
      return true;
    }
  }
  return false;
}

// Search |stmt| (and its children) for a reference to a member
// expression named |target|.
// clang-format off
// Example tree:
// tracer->Trace(wrappable_);
//   |-CXXMemberCallExpr 0x98b2b80 'void'
//   | |-MemberExpr 0x98b2af8 '<bound member function type>' ->Trace 0x5d490b0
//   | | `-ImplicitCastExpr 0x98b2ae0 'script::Tracer *' <LValueToRValue>
//   | |   `-DeclRefExpr 0x98b2ab8 'script::Tracer *' lvalue ParmVar 0x98b27e0 'tracer' 'script::Tracer *'
//   | `-ImplicitCastExpr 0x98b2c78 'class cobalt::script::Traceable *' <DerivedToBase (Traceable)>
//   |   `-ImplicitCastExpr 0x98b2c60 'class cobalt::dom::DOMImplementation *' <UserDefinedConversion>
//   |     `-CXXMemberCallExpr 0x98b2c38 'class cobalt::dom::DOMImplementation *'
//   |       `-MemberExpr 0x98b2c00 '<bound member function type>' .operator cobalt::dom::DOMImplementation * 0x90f1ad8
//   |         `-ImplicitCastExpr 0x98b2be8 'const class scoped_refptr<class cobalt::dom::DOMImplementation>' lvalue <NoOp>
//   |           `-MemberExpr 0x98b2b48 'scoped_refptr<class cobalt::dom::DOMImplementation>':'class scoped_refptr<class cobalt::dom::DOMImplementation>' lvalue ->implementation_ 0x90f2be8
//   |             `-CXXThisExpr 0x98b2b30 'class cobalt::dom::Document *' this
// clang-format on
bool SearchForMemberExpr(Stmt* stmt, const std::string& target) {
  if (MemberExpr* member_expr = dyn_cast<MemberExpr>(stmt)) {
    if (member_expr->getMemberDecl()->getNameAsString() == target) {
      return true;
    }
  }
  for (auto& child : stmt->children()) {
    if (SearchForMemberExpr(child, target)) {
      return true;
    }
  }
  return false;
}

// Attempt to get the |n|th child of |body| as a |T|.  Returns nullptr if
// there are not enough children, or the child was not |dyn_cast|able to |T|.
template <typename T>
T* MaybeGetNthChildAs(Stmt* body, int n) {
  auto children = body->children();
  if (std::distance(children.begin(), children.end()) < n) {
    return nullptr;
  }
  auto it = children.begin();
  std::advance(it, n);
  return dyn_cast<T>(*it);
}

// Check if |stmt| matches something like:
// clang-format off
// (this example is from inside void Document::TraceMembers(script::Tracer* tracer))
//   Node::TraceMembers(tracer);
// Example tree:
//   |-CXXMemberCallExpr 0x98b2a50 'void'
//   | |-MemberExpr 0x98b29d8 '<bound member function type>' ->TraceMembers 0x7b979a0
//   | | `-ImplicitCastExpr 0x98b2a80 'class cobalt::dom::Node *' <UncheckedDerivedToBase (Node)>
//   | |   `-CXXThisExpr 0x98b29c0 'class cobalt::dom::Document *' this
//   | `-ImplicitCastExpr 0x98b2aa0 'script::Tracer *' <LValueToRValue>
//   |   `-DeclRefExpr 0x98b2a28 'script::Tracer *' lvalue ParmVar 0x98b27e0 'tracer' 'script::Tracer *'
// clang-format on
bool IsBaseClassTracerCall(Stmt* stmt) {
  CXXMemberCallExpr* cxx_member_call_expr = dyn_cast<CXXMemberCallExpr>(stmt);
  if (!cxx_member_call_expr) {
    return false;
  }

  {
    MemberExpr* member_expr =
        MaybeGetNthChildAs<MemberExpr>(cxx_member_call_expr, 0);
    if (!member_expr) {
      return false;
    }
    {
      ImplicitCastExpr* implicit_cast_expr =
          MaybeGetNthChildAs<ImplicitCastExpr>(member_expr, 0);
      if (!implicit_cast_expr) {
        return false;
      }
      {
        CXXThisExpr* cxx_this_expr =
            MaybeGetNthChildAs<CXXThisExpr>(implicit_cast_expr, 0);
        if (!cxx_this_expr) {
          return false;
        }

        auto GetPointeeType = [](QualType qual_type) -> const Type* {
          const Type* type = qual_type.getTypePtrOrNull();
          if (!type) {
            return nullptr;
          }

          if (!type->isPointerType()) {
            return nullptr;
          }

          QualType pointee_qual_type = type->getPointeeType();
          return pointee_qual_type.getTypePtrOrNull();
        };

        // Ensure that we are calling our base's TraceMembers.
        const Type* this_type = GetPointeeType(cxx_this_expr->getType());
        const Type* cast_type = GetPointeeType(implicit_cast_expr->getType());
        if (!this_type || !cast_type) {
          return false;
        }

        CXXRecordDecl* cast_decl = cast_type->getAsCXXRecordDecl();
        CXXRecordDecl* this_decl = cast_type->getAsCXXRecordDecl();

        if (!cast_decl || !this_decl) {
          return false;
        }

        if (!IsBaseOf(cast_decl, cxx_this_expr->getType())) {
          return false;
        }
      }
    }
  }

  {
    ImplicitCastExpr* implicit_cast_expr =
        MaybeGetNthChildAs<ImplicitCastExpr>(cxx_member_call_expr, 1);
    if (!implicit_cast_expr) {
      return false;
    }
    // TODO: Check if this is of type script::Tracer*
  }

  return true;
}

class FieldPrinter : public MatchFinder::MatchCallback {
 public:
  // Iterate over each field declaration (FieldDecl) of each class.  Check
  // that the owning class's |TraceMembers| traces it.
  void run(const MatchFinder::MatchResult& result) override {
    const FieldDecl* field_decl =
        result.Nodes.getNodeAs<clang::FieldDecl>("fieldDecl");
    if (!field_decl) {
      return;
    }

    QualType qual_type = field_decl->getType();
    if (!IsTraceable(qual_type)) {
      return;
    }

    const RecordDecl* parent = field_decl->getParent();
    const CXXRecordDecl* cxx_parent = dyn_cast<const CXXRecordDecl>(parent);
    if (!cxx_parent) {
      return;
    }

    bool is_traced = false;
    bool had_trace_members_declaration = false;
    bool had_trace_members_body = false;
    bool calls_base_trace_members = false;

    auto methods = cxx_parent->methods();
    auto trace_members_it =
        std::find_if(methods.begin(), methods.end(), [](CXXMethodDecl* method) {
          return method->getNameInfo().getName().getAsString() ==
                 "TraceMembers";
        });

    if (trace_members_it != methods.end()) {
      auto method = *trace_members_it;
      had_trace_members_declaration = true;
      Stmt* body = method->getBody();
      if (body) {
        had_trace_members_body = true;
        bool first_statement = true;
        for (Stmt* stmt : body->children()) {
          if (first_statement) {
            first_statement = false;
            calls_base_trace_members |= IsBaseClassTracerCall(stmt);
          }

          if (CXXMemberCallExpr* call = dyn_cast<CXXMemberCallExpr>(stmt)) {
            // TODO: Check that this is actually a tracer call
            int nargs = call->getNumArgs();
            if (nargs == 0) {
              continue;
            }
            Expr* arg = call->getArg(0);
            is_traced |=
                SearchForMemberExpr(arg, field_decl->getNameAsString());
          }
        }
      }
    }

    std::string field_name = field_decl->getNameAsString();
    std::string field_class = ToString(qual_type);
    std::string field_class_friendly = qual_type.getAsString();
    std::string parent_class =
        cxx_parent->getCanonicalDecl()->getQualifiedNameAsString();
    std::string parent_class_friendly = cxx_parent->getNameAsString();
    // TODO: Maybe get some line/column numbers and a filename too?

    if (!had_trace_members_declaration) {
      std::cout << ToJson({
                       {"messageType", "needsTraceMembersDeclaration"},
                       {"fieldName", field_name},
                       {"fieldClass", field_class},
                       {"fieldClassFriendly", field_class_friendly},
                       {"parentClass", parent_class},
                       {"parentClassFriendly", parent_class_friendly},
                   })
                << '\n';
    } else {
      if (had_trace_members_body) {
        if (!is_traced) {
          std::cout << ToJson({
                           {"messageType", "needsTracerTraceField"},
                           {"fieldName", field_name},
                           {"fieldClass", field_class},
                           {"fieldClassFriendly", field_class_friendly},
                           {"parentClass", parent_class},
                           {"parentClassFriendly", parent_class_friendly},
                       })
                    << '\n';
        }
        // Note that for a class that fails to call its base class's
        // |TraceMembers|, this message will be printed for each field
        // declaration it has.  We expect the script wrapping us to de-dupe
        // these messages.
        if (!calls_base_trace_members) {
          // We need to dump the base classes so our wrapper script can filter
          // out classes with direct bases of script::Wrappable and
          // script::Traceable.
          std::string bases = "[";
          bool first = true;
          for (auto base : cxx_parent->bases()) {
            if (!first) bases += ", ";
            first = false;
            bases += '"' + EscapeJson(ToString(base.getType())) + '"';
          }
          bases += "]";
          std::cout << ToJson({
                           {"messageType", "needsCallBaseTraceMembers"},
                           {"parentClass", parent_class},
                           {"parentClassFriendly", parent_class_friendly},
                           {"baseNames", bases, kAlreadyJson},
                       })
                    << '\n';
        }
      }
      // Don't handle the case where there isn't a body.  We will find it
      // later in another translation unit, or else the code would not be
      // able to link successfully.
    }
  }
};

int main(int argc, const char** argv) {
  llvm::cl::OptionCategory my_tool_category("verify-trace-members options");
  llvm::cl::extrahelp common_help(CommonOptionsParser::HelpMessage);
  llvm::cl::extrahelp more_help("\nMore help...");

  CommonOptionsParser options_parser(argc, argv, my_tool_category);
  ClangTool tool(options_parser.getCompilations(),
                 options_parser.getSourcePathList());

  FieldPrinter field_printer;
  MatchFinder match_finder;
  DeclarationMatcher field_matcher = fieldDecl().bind("fieldDecl");
  match_finder.addMatcher(field_matcher, &field_printer);

  return tool.run(newFrontendActionFactory(&match_finder).get());
}
