// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <memory>
#include <utility>

#include "base/bind.h"
#include "base/callback_helpers.h"
#include "base/macros.h"
#include "base/memory/ptr_util.h"
#include "base/test/task_environment.h"
#include "base/threading/thread.h"
#include "media/learning/mojo/public/cpp/mojo_learning_task_controller.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace media {
namespace learning {

class MojoLearningTaskControllerTest : public ::testing::Test {
 public:
  // Impl of a mojom::LearningTaskController that remembers call arguments.
  class FakeMojoLearningTaskController : public mojom::LearningTaskController {
   public:
    void BeginObservation(
        const base::UnguessableToken& id,
        const FeatureVector& features,
        const absl::optional<TargetValue>& default_target) override {
      begin_args_.id_ = id;
      begin_args_.features_ = features;
      begin_args_.default_target_ = default_target;
    }

    void CompleteObservation(const base::UnguessableToken& id,
                             const ObservationCompletion& completion) override {
      complete_args_.id_ = id;
      complete_args_.completion_ = completion;
    }

    void CancelObservation(const base::UnguessableToken& id) override {
      cancel_args_.id_ = id;
    }

    void UpdateDefaultTarget(
        const base::UnguessableToken& id,
        const absl::optional<TargetValue>& default_target) override {
      update_default_args_.id_ = id;
      update_default_args_.default_target_ = default_target;
    }

    void PredictDistribution(const FeatureVector& features,
                             PredictDistributionCallback callback) override {
      predict_args_.features_ = features;
      predict_args_.callback_ = std::move(callback);
    }

    struct {
      base::UnguessableToken id_;
      FeatureVector features_;
      absl::optional<TargetValue> default_target_;
    } begin_args_;

    struct {
      base::UnguessableToken id_;
      ObservationCompletion completion_;
    } complete_args_;

    struct {
      base::UnguessableToken id_;
    } cancel_args_;

    struct {
      base::UnguessableToken id_;
      absl::optional<TargetValue> default_target_;
    } update_default_args_;

    struct {
      FeatureVector features_;
      PredictDistributionCallback callback_;
    } predict_args_;
  };

 public:
  MojoLearningTaskControllerTest()
      : learning_controller_receiver_(&fake_learning_controller_) {}
  ~MojoLearningTaskControllerTest() override = default;

  void SetUp() override {
    // Create a LearningTask.
    task_.name = "MyLearningTask";

    // Tell |learning_controller_| to forward to the fake learner impl.
    mojo::Remote<media::learning::mojom::LearningTaskController> remote(
        learning_controller_receiver_.BindNewPipeAndPassRemote());
    learning_controller_ =
        std::make_unique<MojoLearningTaskController>(task_, std::move(remote));
  }

  // Mojo stuff.
  base::test::TaskEnvironment task_environment_;

  LearningTask task_;
  FakeMojoLearningTaskController fake_learning_controller_;
  mojo::Receiver<mojom::LearningTaskController> learning_controller_receiver_;

  // The learner under test.
  std::unique_ptr<MojoLearningTaskController> learning_controller_;
};

TEST_F(MojoLearningTaskControllerTest, GetLearningTask) {
  EXPECT_EQ(learning_controller_->GetLearningTask().name, task_.name);
}

TEST_F(MojoLearningTaskControllerTest, BeginWithoutDefaultTarget) {
  base::UnguessableToken id = base::UnguessableToken::Create();
  FeatureVector features = {FeatureValue(123), FeatureValue(456)};
  learning_controller_->BeginObservation(id, features, absl::nullopt,
                                         absl::nullopt);
  task_environment_.RunUntilIdle();
  EXPECT_EQ(id, fake_learning_controller_.begin_args_.id_);
  EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
  EXPECT_FALSE(fake_learning_controller_.begin_args_.default_target_);
}

TEST_F(MojoLearningTaskControllerTest, BeginWithDefaultTarget) {
  base::UnguessableToken id = base::UnguessableToken::Create();
  TargetValue default_target(987);
  FeatureVector features = {FeatureValue(123), FeatureValue(456)};
  learning_controller_->BeginObservation(id, features, default_target,
                                         absl::nullopt);
  task_environment_.RunUntilIdle();
  EXPECT_EQ(id, fake_learning_controller_.begin_args_.id_);
  EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
  EXPECT_EQ(default_target,
            fake_learning_controller_.begin_args_.default_target_);
}

TEST_F(MojoLearningTaskControllerTest, UpdateDefaultTargetToValue) {
  // Test if we can update the default target to a non-nullopt.
  base::UnguessableToken id = base::UnguessableToken::Create();
  FeatureVector features = {FeatureValue(123), FeatureValue(456)};
  learning_controller_->BeginObservation(id, features, absl::nullopt,
                                         absl::nullopt);
  TargetValue default_target(987);
  learning_controller_->UpdateDefaultTarget(id, default_target);
  task_environment_.RunUntilIdle();
  EXPECT_EQ(id, fake_learning_controller_.update_default_args_.id_);
  EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
  EXPECT_EQ(default_target,
            fake_learning_controller_.update_default_args_.default_target_);
}

TEST_F(MojoLearningTaskControllerTest, UpdateDefaultTargetToNoValue) {
  // Test if we can update the default target to nullopt.
  base::UnguessableToken id = base::UnguessableToken::Create();
  FeatureVector features = {FeatureValue(123), FeatureValue(456)};
  TargetValue default_target(987);
  learning_controller_->BeginObservation(id, features, default_target,
                                         absl::nullopt);
  learning_controller_->UpdateDefaultTarget(id, absl::nullopt);
  task_environment_.RunUntilIdle();
  EXPECT_EQ(id, fake_learning_controller_.update_default_args_.id_);
  EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
  EXPECT_EQ(absl::nullopt,
            fake_learning_controller_.update_default_args_.default_target_);
}

TEST_F(MojoLearningTaskControllerTest, Complete) {
  base::UnguessableToken id = base::UnguessableToken::Create();
  ObservationCompletion completion(TargetValue(1234));
  learning_controller_->CompleteObservation(id, completion);
  task_environment_.RunUntilIdle();
  EXPECT_EQ(id, fake_learning_controller_.complete_args_.id_);
  EXPECT_EQ(completion.target_value,
            fake_learning_controller_.complete_args_.completion_.target_value);
}

TEST_F(MojoLearningTaskControllerTest, Cancel) {
  base::UnguessableToken id = base::UnguessableToken::Create();
  learning_controller_->CancelObservation(id);
  task_environment_.RunUntilIdle();
  EXPECT_EQ(id, fake_learning_controller_.cancel_args_.id_);
}

TEST_F(MojoLearningTaskControllerTest, PredictDistribution) {
  FeatureVector features = {FeatureValue(123), FeatureValue(456)};

  TargetHistogram observed_prediction;
  learning_controller_->PredictDistribution(
      features, base::BindOnce(
                    [](TargetHistogram* test_storage,
                       const absl::optional<TargetHistogram>& predicted) {
                      *test_storage = *predicted;
                    },
                    &observed_prediction));
  task_environment_.RunUntilIdle();
  EXPECT_EQ(features, fake_learning_controller_.predict_args_.features_);
  EXPECT_FALSE(fake_learning_controller_.predict_args_.callback_.is_null());

  TargetHistogram expected_prediction;
  expected_prediction[TargetValue(1)] = 1.0;
  expected_prediction[TargetValue(2)] = 2.0;
  expected_prediction[TargetValue(3)] = 3.0;
  std::move(fake_learning_controller_.predict_args_.callback_)
      .Run(expected_prediction);
  task_environment_.RunUntilIdle();
  EXPECT_EQ(observed_prediction, expected_prediction);
}

}  // namespace learning
}  // namespace media
