| // Copyright 2018 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include <memory> |
| #include <utility> |
| #include <vector> |
| |
| #include "base/bind.h" |
| #include "base/test/task_environment.h" |
| #include "base/threading/sequenced_task_runner_handle.h" |
| #include "media/learning/common/learning_task_controller.h" |
| #include "media/learning/impl/learning_session_impl.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| namespace media { |
| namespace learning { |
| |
| class LearningSessionImplTest : public testing::Test { |
| public: |
| class FakeLearningTaskController; |
| using ControllerVector = std::vector<FakeLearningTaskController*>; |
| using TaskRunnerVector = std::vector<base::SequencedTaskRunner*>; |
| |
| class FakeLearningTaskController : public LearningTaskController { |
| public: |
| // Send ControllerVector* as void*, else it complains that args can't be |
| // forwarded. Adding base::Unretained() doesn't help. |
| FakeLearningTaskController(void* controllers, |
| const LearningTask& task, |
| SequenceBoundFeatureProvider feature_provider) |
| : feature_provider_(std::move(feature_provider)) { |
| static_cast<ControllerVector*>(controllers)->push_back(this); |
| // As a complete hack, call the only public method on fp so that |
| // we can verify that it was given to us by the session. |
| if (!feature_provider_.is_null()) { |
| feature_provider_.AsyncCall(&FeatureProvider::AddFeatures) |
| .WithArgs(FeatureVector(), FeatureProvider::FeatureVectorCB()); |
| } |
| } |
| |
| void BeginObservation( |
| base::UnguessableToken id, |
| const FeatureVector& features, |
| const absl::optional<TargetValue>& default_target, |
| const absl::optional<ukm::SourceId>& source_id) override { |
| id_ = id; |
| observation_features_ = features; |
| default_target_ = default_target; |
| source_id_ = source_id; |
| } |
| |
| void CompleteObservation(base::UnguessableToken id, |
| const ObservationCompletion& completion) override { |
| EXPECT_EQ(id_, id); |
| example_.features = std::move(observation_features_); |
| example_.target_value = completion.target_value; |
| example_.weight = completion.weight; |
| } |
| |
| void CancelObservation(base::UnguessableToken id) override { |
| cancelled_id_ = id; |
| } |
| |
| void UpdateDefaultTarget( |
| base::UnguessableToken id, |
| const absl::optional<TargetValue>& default_target) override { |
| // Should not be called, since LearningTaskControllerImpl doesn't support |
| // default values. |
| updated_id_ = id; |
| } |
| |
| const LearningTask& GetLearningTask() override { |
| NOTREACHED(); |
| return LearningTask::Empty(); |
| } |
| |
| void PredictDistribution(const FeatureVector& features, |
| PredictionCB callback) override { |
| predict_features_ = features; |
| predict_cb_ = std::move(callback); |
| } |
| |
| SequenceBoundFeatureProvider feature_provider_; |
| base::UnguessableToken id_; |
| FeatureVector observation_features_; |
| FeatureVector predict_features_; |
| PredictionCB predict_cb_; |
| absl::optional<TargetValue> default_target_; |
| absl::optional<ukm::SourceId> source_id_; |
| LabelledExample example_; |
| |
| // Most recently cancelled id. |
| base::UnguessableToken cancelled_id_; |
| |
| // Id of most recently changed default target value. |
| absl::optional<base::UnguessableToken> updated_id_; |
| }; |
| |
| class FakeFeatureProvider : public FeatureProvider { |
| public: |
| FakeFeatureProvider(bool* flag_ptr) : flag_ptr_(flag_ptr) {} |
| |
| // Do nothing, except note that we were called. |
| void AddFeatures(FeatureVector features, |
| FeatureProvider::FeatureVectorCB cb) override { |
| *flag_ptr_ = true; |
| } |
| |
| bool* flag_ptr_ = nullptr; |
| }; |
| |
| LearningSessionImplTest() { |
| task_runner_ = base::SequencedTaskRunnerHandle::Get(); |
| session_ = std::make_unique<LearningSessionImpl>(task_runner_); |
| session_->SetTaskControllerFactoryCBForTesting(base::BindRepeating( |
| [](ControllerVector* controllers, TaskRunnerVector* task_runners, |
| scoped_refptr<base::SequencedTaskRunner> task_runner, |
| const LearningTask& task, |
| SequenceBoundFeatureProvider feature_provider) |
| -> base::SequenceBound<LearningTaskController> { |
| task_runners->push_back(task_runner.get()); |
| return base::SequenceBound<FakeLearningTaskController>( |
| task_runner, static_cast<void*>(controllers), task, |
| std::move(feature_provider)); |
| }, |
| &task_controllers_, &task_runners_)); |
| |
| task_0_.name = "task_0"; |
| task_1_.name = "task_1"; |
| } |
| |
| ~LearningSessionImplTest() override { |
| // To prevent a memory leak, reset the session. This will post destruction |
| // of other objects, so RunUntilIdle(). |
| session_.reset(); |
| task_environment_.RunUntilIdle(); |
| } |
| |
| base::test::TaskEnvironment task_environment_; |
| |
| scoped_refptr<base::SequencedTaskRunner> task_runner_; |
| |
| std::unique_ptr<LearningSessionImpl> session_; |
| |
| LearningTask task_0_; |
| LearningTask task_1_; |
| |
| ControllerVector task_controllers_; |
| TaskRunnerVector task_runners_; |
| }; |
| |
| TEST_F(LearningSessionImplTest, RegisteringTasksCreatesControllers) { |
| EXPECT_EQ(task_controllers_.size(), 0u); |
| EXPECT_EQ(task_runners_.size(), 0u); |
| |
| session_->RegisterTask(task_0_); |
| task_environment_.RunUntilIdle(); |
| EXPECT_EQ(task_controllers_.size(), 1u); |
| EXPECT_EQ(task_runners_.size(), 1u); |
| EXPECT_EQ(task_runners_[0], task_runner_.get()); |
| |
| session_->RegisterTask(task_1_); |
| task_environment_.RunUntilIdle(); |
| EXPECT_EQ(task_controllers_.size(), 2u); |
| EXPECT_EQ(task_runners_.size(), 2u); |
| EXPECT_EQ(task_runners_[1], task_runner_.get()); |
| |
| // Make sure controllers are being returned for the right tasks. |
| // Note: this test passes because LearningSessionController::GetController() |
| // returns a wrapper around a FakeLTC, instead of the FakeLTC itself. The |
| // wrapper internally built by LearningSessionImpl has a proper implementation |
| // of GetLearningTask(), whereas the FakeLTC does not. |
| std::unique_ptr<LearningTaskController> ltc_0 = |
| session_->GetController(task_0_.name); |
| EXPECT_EQ(ltc_0->GetLearningTask().name, task_0_.name); |
| |
| std::unique_ptr<LearningTaskController> ltc_1 = |
| session_->GetController(task_1_.name); |
| EXPECT_EQ(ltc_1->GetLearningTask().name, task_1_.name); |
| } |
| |
| TEST_F(LearningSessionImplTest, ExamplesAreForwardedToCorrectTask) { |
| session_->RegisterTask(task_0_); |
| session_->RegisterTask(task_1_); |
| |
| base::UnguessableToken id = base::UnguessableToken::Create(); |
| |
| LabelledExample example_0({FeatureValue(123), FeatureValue(456)}, |
| TargetValue(1234)); |
| std::unique_ptr<LearningTaskController> ltc_0 = |
| session_->GetController(task_0_.name); |
| ukm::SourceId source_id(123); |
| ltc_0->BeginObservation(id, example_0.features, absl::nullopt, source_id); |
| ltc_0->CompleteObservation( |
| id, ObservationCompletion(example_0.target_value, example_0.weight)); |
| |
| LabelledExample example_1({FeatureValue(321), FeatureValue(654)}, |
| TargetValue(4321)); |
| |
| std::unique_ptr<LearningTaskController> ltc_1 = |
| session_->GetController(task_1_.name); |
| ltc_1->BeginObservation(id, example_1.features); |
| ltc_1->CompleteObservation( |
| id, ObservationCompletion(example_1.target_value, example_1.weight)); |
| |
| task_environment_.RunUntilIdle(); |
| EXPECT_EQ(task_controllers_[0]->example_, example_0); |
| EXPECT_EQ(task_controllers_[0]->source_id_, source_id); |
| EXPECT_EQ(task_controllers_[1]->example_, example_1); |
| } |
| |
| TEST_F(LearningSessionImplTest, ControllerLifetimeScopedToSession) { |
| session_->RegisterTask(task_0_); |
| |
| std::unique_ptr<LearningTaskController> controller = |
| session_->GetController(task_0_.name); |
| |
| // Destroy the session. |controller| should still be usable, though it won't |
| // forward requests anymore. |
| session_.reset(); |
| task_environment_.RunUntilIdle(); |
| |
| // Should not crash. |
| controller->BeginObservation(base::UnguessableToken::Create(), |
| FeatureVector()); |
| } |
| |
| TEST_F(LearningSessionImplTest, FeatureProviderIsForwarded) { |
| // Verify that a FeatureProvider actually gets forwarded to the LTC. |
| bool flag = false; |
| session_->RegisterTask( |
| task_0_, base::SequenceBound<FakeFeatureProvider>(task_runner_, &flag)); |
| task_environment_.RunUntilIdle(); |
| // Registering the task should create a FakeLearningTaskController, which will |
| // call AddFeatures on the fake FeatureProvider. |
| EXPECT_TRUE(flag); |
| } |
| |
| TEST_F(LearningSessionImplTest, DestroyingControllerCancelsObservations) { |
| session_->RegisterTask(task_0_); |
| |
| std::unique_ptr<LearningTaskController> controller = |
| session_->GetController(task_0_.name); |
| task_environment_.RunUntilIdle(); |
| |
| // Start an observation and verify that it starts. |
| base::UnguessableToken id = base::UnguessableToken::Create(); |
| controller->BeginObservation(id, FeatureVector()); |
| task_environment_.RunUntilIdle(); |
| EXPECT_EQ(task_controllers_[0]->id_, id); |
| EXPECT_NE(task_controllers_[0]->cancelled_id_, id); |
| |
| // Should result in cancelling the observation. |
| controller.reset(); |
| task_environment_.RunUntilIdle(); |
| EXPECT_EQ(task_controllers_[0]->cancelled_id_, id); |
| } |
| |
| TEST_F(LearningSessionImplTest, |
| DestroyingControllerCompletesObservationsWithDefaultValues) { |
| // Also verifies that we don't send the default to the underlying controller, |
| // because LearningTaskControllerImpl doesn't support it. |
| session_->RegisterTask(task_0_); |
| |
| std::unique_ptr<LearningTaskController> controller = |
| session_->GetController(task_0_.name); |
| task_environment_.RunUntilIdle(); |
| |
| // Start an observation and verify that it doesn't forward the default target. |
| base::UnguessableToken id = base::UnguessableToken::Create(); |
| TargetValue default_target(123); |
| controller->BeginObservation(id, FeatureVector(), default_target); |
| task_environment_.RunUntilIdle(); |
| EXPECT_EQ(task_controllers_[0]->id_, id); |
| EXPECT_FALSE(task_controllers_[0]->default_target_); |
| |
| // Should complete the observation. |
| controller.reset(); |
| task_environment_.RunUntilIdle(); |
| EXPECT_EQ(task_controllers_[0]->example_.target_value, default_target); |
| } |
| |
| TEST_F(LearningSessionImplTest, ChangeDefaultTargetToValue) { |
| session_->RegisterTask(task_0_); |
| |
| std::unique_ptr<LearningTaskController> controller = |
| session_->GetController(task_0_.name); |
| task_environment_.RunUntilIdle(); |
| |
| // Start an observation without a default, then add one. |
| base::UnguessableToken id = base::UnguessableToken::Create(); |
| controller->BeginObservation(id, FeatureVector(), absl::nullopt); |
| TargetValue default_target(123); |
| controller->UpdateDefaultTarget(id, default_target); |
| task_environment_.RunUntilIdle(); |
| EXPECT_EQ(task_controllers_[0]->id_, id); |
| |
| // Should complete the observation. |
| controller.reset(); |
| task_environment_.RunUntilIdle(); |
| EXPECT_EQ(task_controllers_[0]->example_.target_value, default_target); |
| |
| // Shouldn't notify the underlying controller. |
| EXPECT_FALSE(task_controllers_[0]->updated_id_); |
| } |
| |
| TEST_F(LearningSessionImplTest, ChangeDefaultTargetToNoValue) { |
| session_->RegisterTask(task_0_); |
| |
| std::unique_ptr<LearningTaskController> controller = |
| session_->GetController(task_0_.name); |
| task_environment_.RunUntilIdle(); |
| |
| // Start an observation with a default, then remove it. |
| base::UnguessableToken id = base::UnguessableToken::Create(); |
| TargetValue default_target(123); |
| controller->BeginObservation(id, FeatureVector(), default_target); |
| controller->UpdateDefaultTarget(id, absl::nullopt); |
| task_environment_.RunUntilIdle(); |
| EXPECT_EQ(task_controllers_[0]->id_, id); |
| |
| // Should cancel the observation. |
| controller.reset(); |
| task_environment_.RunUntilIdle(); |
| EXPECT_EQ(task_controllers_[0]->cancelled_id_, id); |
| |
| // Shouldn't notify the underlying controller. |
| EXPECT_FALSE(task_controllers_[0]->updated_id_); |
| } |
| |
| TEST_F(LearningSessionImplTest, PredictDistribution) { |
| session_->RegisterTask(task_0_); |
| |
| std::unique_ptr<LearningTaskController> controller = |
| session_->GetController(task_0_.name); |
| task_environment_.RunUntilIdle(); |
| |
| FeatureVector features = {FeatureValue(123), FeatureValue(456)}; |
| TargetHistogram observed_prediction; |
| controller->PredictDistribution( |
| features, base::BindOnce( |
| [](TargetHistogram* test_storage, |
| const absl::optional<TargetHistogram>& predicted) { |
| *test_storage = *predicted; |
| }, |
| &observed_prediction)); |
| task_environment_.RunUntilIdle(); |
| EXPECT_EQ(features, task_controllers_[0]->predict_features_); |
| EXPECT_FALSE(task_controllers_[0]->predict_cb_.is_null()); |
| |
| TargetHistogram expected_prediction; |
| expected_prediction[TargetValue(1)] = 1.0; |
| expected_prediction[TargetValue(2)] = 2.0; |
| expected_prediction[TargetValue(3)] = 3.0; |
| std::move(task_controllers_[0]->predict_cb_).Run(expected_prediction); |
| task_environment_.RunUntilIdle(); |
| EXPECT_EQ(expected_prediction, observed_prediction); |
| } |
| |
| } // namespace learning |
| } // namespace media |