blob: 9d7beecef00e4e14520712f57c716c3a4de88057 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <utility>
#include "base/bind.h"
#include "base/callback_helpers.h"
#include "base/macros.h"
#include "base/memory/ptr_util.h"
#include "base/test/task_environment.h"
#include "base/threading/thread.h"
#include "media/learning/mojo/public/cpp/mojo_learning_task_controller.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class MojoLearningTaskControllerTest : public ::testing::Test {
public:
// Impl of a mojom::LearningTaskController that remembers call arguments.
class FakeMojoLearningTaskController : public mojom::LearningTaskController {
public:
void BeginObservation(
const base::UnguessableToken& id,
const FeatureVector& features,
const absl::optional<TargetValue>& default_target) override {
begin_args_.id_ = id;
begin_args_.features_ = features;
begin_args_.default_target_ = default_target;
}
void CompleteObservation(const base::UnguessableToken& id,
const ObservationCompletion& completion) override {
complete_args_.id_ = id;
complete_args_.completion_ = completion;
}
void CancelObservation(const base::UnguessableToken& id) override {
cancel_args_.id_ = id;
}
void UpdateDefaultTarget(
const base::UnguessableToken& id,
const absl::optional<TargetValue>& default_target) override {
update_default_args_.id_ = id;
update_default_args_.default_target_ = default_target;
}
void PredictDistribution(const FeatureVector& features,
PredictDistributionCallback callback) override {
predict_args_.features_ = features;
predict_args_.callback_ = std::move(callback);
}
struct {
base::UnguessableToken id_;
FeatureVector features_;
absl::optional<TargetValue> default_target_;
} begin_args_;
struct {
base::UnguessableToken id_;
ObservationCompletion completion_;
} complete_args_;
struct {
base::UnguessableToken id_;
} cancel_args_;
struct {
base::UnguessableToken id_;
absl::optional<TargetValue> default_target_;
} update_default_args_;
struct {
FeatureVector features_;
PredictDistributionCallback callback_;
} predict_args_;
};
public:
MojoLearningTaskControllerTest()
: learning_controller_receiver_(&fake_learning_controller_) {}
~MojoLearningTaskControllerTest() override = default;
void SetUp() override {
// Create a LearningTask.
task_.name = "MyLearningTask";
// Tell |learning_controller_| to forward to the fake learner impl.
mojo::Remote<media::learning::mojom::LearningTaskController> remote(
learning_controller_receiver_.BindNewPipeAndPassRemote());
learning_controller_ =
std::make_unique<MojoLearningTaskController>(task_, std::move(remote));
}
// Mojo stuff.
base::test::TaskEnvironment task_environment_;
LearningTask task_;
FakeMojoLearningTaskController fake_learning_controller_;
mojo::Receiver<mojom::LearningTaskController> learning_controller_receiver_;
// The learner under test.
std::unique_ptr<MojoLearningTaskController> learning_controller_;
};
TEST_F(MojoLearningTaskControllerTest, GetLearningTask) {
EXPECT_EQ(learning_controller_->GetLearningTask().name, task_.name);
}
TEST_F(MojoLearningTaskControllerTest, BeginWithoutDefaultTarget) {
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
learning_controller_->BeginObservation(id, features, absl::nullopt,
absl::nullopt);
task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.begin_args_.id_);
EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
EXPECT_FALSE(fake_learning_controller_.begin_args_.default_target_);
}
TEST_F(MojoLearningTaskControllerTest, BeginWithDefaultTarget) {
base::UnguessableToken id = base::UnguessableToken::Create();
TargetValue default_target(987);
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
learning_controller_->BeginObservation(id, features, default_target,
absl::nullopt);
task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.begin_args_.id_);
EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
EXPECT_EQ(default_target,
fake_learning_controller_.begin_args_.default_target_);
}
TEST_F(MojoLearningTaskControllerTest, UpdateDefaultTargetToValue) {
// Test if we can update the default target to a non-nullopt.
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
learning_controller_->BeginObservation(id, features, absl::nullopt,
absl::nullopt);
TargetValue default_target(987);
learning_controller_->UpdateDefaultTarget(id, default_target);
task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.update_default_args_.id_);
EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
EXPECT_EQ(default_target,
fake_learning_controller_.update_default_args_.default_target_);
}
TEST_F(MojoLearningTaskControllerTest, UpdateDefaultTargetToNoValue) {
// Test if we can update the default target to nullopt.
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
TargetValue default_target(987);
learning_controller_->BeginObservation(id, features, default_target,
absl::nullopt);
learning_controller_->UpdateDefaultTarget(id, absl::nullopt);
task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.update_default_args_.id_);
EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
EXPECT_EQ(absl::nullopt,
fake_learning_controller_.update_default_args_.default_target_);
}
TEST_F(MojoLearningTaskControllerTest, Complete) {
base::UnguessableToken id = base::UnguessableToken::Create();
ObservationCompletion completion(TargetValue(1234));
learning_controller_->CompleteObservation(id, completion);
task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.complete_args_.id_);
EXPECT_EQ(completion.target_value,
fake_learning_controller_.complete_args_.completion_.target_value);
}
TEST_F(MojoLearningTaskControllerTest, Cancel) {
base::UnguessableToken id = base::UnguessableToken::Create();
learning_controller_->CancelObservation(id);
task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.cancel_args_.id_);
}
TEST_F(MojoLearningTaskControllerTest, PredictDistribution) {
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
TargetHistogram observed_prediction;
learning_controller_->PredictDistribution(
features, base::BindOnce(
[](TargetHistogram* test_storage,
const absl::optional<TargetHistogram>& predicted) {
*test_storage = *predicted;
},
&observed_prediction));
task_environment_.RunUntilIdle();
EXPECT_EQ(features, fake_learning_controller_.predict_args_.features_);
EXPECT_FALSE(fake_learning_controller_.predict_args_.callback_.is_null());
TargetHistogram expected_prediction;
expected_prediction[TargetValue(1)] = 1.0;
expected_prediction[TargetValue(2)] = 2.0;
expected_prediction[TargetValue(3)] = 3.0;
std::move(fake_learning_controller_.predict_args_.callback_)
.Run(expected_prediction);
task_environment_.RunUntilIdle();
EXPECT_EQ(observed_prediction, expected_prediction);
}
} // namespace learning
} // namespace media