// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "media/learning/impl/extra_trees_trainer.h"

#include <set>

#include "base/bind.h"
#include "base/check_op.h"
#include "media/learning/impl/voting_ensemble.h"

namespace media {
namespace learning {

ExtraTreesTrainer::ExtraTreesTrainer() = default;

ExtraTreesTrainer::~ExtraTreesTrainer() = default;

void ExtraTreesTrainer::Train(const LearningTask& task,
                              const TrainingData& training_data,
                              TrainedModelCB model_cb) {
  // Make sure that there is no training in progress.
  DCHECK_EQ(trees_.size(), 0u);
  DCHECK_EQ(converter_.get(), nullptr);

  task_ = task;
  trees_.reserve(task.rf_number_of_trees);

  // Instantiate our tree trainer if we haven't already.  We do this now only
  // so that we can send it our rng, mostly for tests.
  // TODO(liberato): We should always take the rng in the ctor, rather than
  // via SetRngForTesting.  Then we can do this earlier.
  if (!tree_trainer_)
    tree_trainer_ = std::make_unique<RandomTreeTrainer>(rng());

  // We've modified RandomTree to handle nominals, so we don't need to do one-
  // hot conversion normally.  It's slow.  However, the changes to RandomTree
  // are only approximately the same thing.
  if (task_.use_one_hot_conversion) {
    converter_ = std::make_unique<OneHotConverter>(task, training_data);
    converted_training_data_ = converter_->Convert(training_data);
    task_ = converter_->converted_task();
  } else {
    converted_training_data_ = training_data;
  }

  // Start training.  Send in nullptr to start the process.
  OnRandomTreeModel(std::move(model_cb), nullptr);
}

void ExtraTreesTrainer::OnRandomTreeModel(TrainedModelCB model_cb,
                                          std::unique_ptr<Model> model) {
  // Allow a null Model to make it easy to start training.
  if (model)
    trees_.push_back(std::move(model));

  // If this is the last tree, then return the finished model.
  if (trees_.size() == task_.rf_number_of_trees) {
    std::unique_ptr<Model> finished_model =
        std::make_unique<VotingEnsemble>(std::move(trees_));
    // If we have a converter, then wrap everything in a ConvertingModel.
    if (converter_) {
      finished_model = std::make_unique<ConvertingModel>(
          std::move(converter_), std::move(finished_model));
    }

    std::move(model_cb).Run(std::move(finished_model));
    return;
  }

  // Train the next tree.
  auto cb = base::BindOnce(&ExtraTreesTrainer::OnRandomTreeModel, AsWeakPtr(),
                           std::move(model_cb));
  tree_trainer_->Train(task_, converted_training_data_, std::move(cb));
}

}  // namespace learning
}  // namespace media
