blob: bb3b17155b4272b975ae7d39e3bc302702bcdd15 [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <utility>
#include <vector>
#include "base/bind.h"
#include "base/test/task_environment.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "media/learning/common/learning_task_controller.h"
#include "media/learning/impl/learning_session_impl.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class LearningSessionImplTest : public testing::Test {
public:
class FakeLearningTaskController;
using ControllerVector = std::vector<FakeLearningTaskController*>;
using TaskRunnerVector = std::vector<base::SequencedTaskRunner*>;
class FakeLearningTaskController : public LearningTaskController {
public:
// Send ControllerVector* as void*, else it complains that args can't be
// forwarded. Adding base::Unretained() doesn't help.
FakeLearningTaskController(void* controllers,
const LearningTask& task,
SequenceBoundFeatureProvider feature_provider)
: feature_provider_(std::move(feature_provider)) {
static_cast<ControllerVector*>(controllers)->push_back(this);
// As a complete hack, call the only public method on fp so that
// we can verify that it was given to us by the session.
if (!feature_provider_.is_null()) {
feature_provider_.AsyncCall(&FeatureProvider::AddFeatures)
.WithArgs(FeatureVector(), FeatureProvider::FeatureVectorCB());
}
}
void BeginObservation(
base::UnguessableToken id,
const FeatureVector& features,
const absl::optional<TargetValue>& default_target,
const absl::optional<ukm::SourceId>& source_id) override {
id_ = id;
observation_features_ = features;
default_target_ = default_target;
source_id_ = source_id;
}
void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override {
EXPECT_EQ(id_, id);
example_.features = std::move(observation_features_);
example_.target_value = completion.target_value;
example_.weight = completion.weight;
}
void CancelObservation(base::UnguessableToken id) override {
cancelled_id_ = id;
}
void UpdateDefaultTarget(
base::UnguessableToken id,
const absl::optional<TargetValue>& default_target) override {
// Should not be called, since LearningTaskControllerImpl doesn't support
// default values.
updated_id_ = id;
}
const LearningTask& GetLearningTask() override {
NOTREACHED();
return LearningTask::Empty();
}
void PredictDistribution(const FeatureVector& features,
PredictionCB callback) override {
predict_features_ = features;
predict_cb_ = std::move(callback);
}
SequenceBoundFeatureProvider feature_provider_;
base::UnguessableToken id_;
FeatureVector observation_features_;
FeatureVector predict_features_;
PredictionCB predict_cb_;
absl::optional<TargetValue> default_target_;
absl::optional<ukm::SourceId> source_id_;
LabelledExample example_;
// Most recently cancelled id.
base::UnguessableToken cancelled_id_;
// Id of most recently changed default target value.
absl::optional<base::UnguessableToken> updated_id_;
};
class FakeFeatureProvider : public FeatureProvider {
public:
FakeFeatureProvider(bool* flag_ptr) : flag_ptr_(flag_ptr) {}
// Do nothing, except note that we were called.
void AddFeatures(FeatureVector features,
FeatureProvider::FeatureVectorCB cb) override {
*flag_ptr_ = true;
}
bool* flag_ptr_ = nullptr;
};
LearningSessionImplTest() {
task_runner_ = base::SequencedTaskRunnerHandle::Get();
session_ = std::make_unique<LearningSessionImpl>(task_runner_);
session_->SetTaskControllerFactoryCBForTesting(base::BindRepeating(
[](ControllerVector* controllers, TaskRunnerVector* task_runners,
scoped_refptr<base::SequencedTaskRunner> task_runner,
const LearningTask& task,
SequenceBoundFeatureProvider feature_provider)
-> base::SequenceBound<LearningTaskController> {
task_runners->push_back(task_runner.get());
return base::SequenceBound<FakeLearningTaskController>(
task_runner, static_cast<void*>(controllers), task,
std::move(feature_provider));
},
&task_controllers_, &task_runners_));
task_0_.name = "task_0";
task_1_.name = "task_1";
}
~LearningSessionImplTest() override {
// To prevent a memory leak, reset the session. This will post destruction
// of other objects, so RunUntilIdle().
session_.reset();
task_environment_.RunUntilIdle();
}
base::test::TaskEnvironment task_environment_;
scoped_refptr<base::SequencedTaskRunner> task_runner_;
std::unique_ptr<LearningSessionImpl> session_;
LearningTask task_0_;
LearningTask task_1_;
ControllerVector task_controllers_;
TaskRunnerVector task_runners_;
};
TEST_F(LearningSessionImplTest, RegisteringTasksCreatesControllers) {
EXPECT_EQ(task_controllers_.size(), 0u);
EXPECT_EQ(task_runners_.size(), 0u);
session_->RegisterTask(task_0_);
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_.size(), 1u);
EXPECT_EQ(task_runners_.size(), 1u);
EXPECT_EQ(task_runners_[0], task_runner_.get());
session_->RegisterTask(task_1_);
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_.size(), 2u);
EXPECT_EQ(task_runners_.size(), 2u);
EXPECT_EQ(task_runners_[1], task_runner_.get());
// Make sure controllers are being returned for the right tasks.
// Note: this test passes because LearningSessionController::GetController()
// returns a wrapper around a FakeLTC, instead of the FakeLTC itself. The
// wrapper internally built by LearningSessionImpl has a proper implementation
// of GetLearningTask(), whereas the FakeLTC does not.
std::unique_ptr<LearningTaskController> ltc_0 =
session_->GetController(task_0_.name);
EXPECT_EQ(ltc_0->GetLearningTask().name, task_0_.name);
std::unique_ptr<LearningTaskController> ltc_1 =
session_->GetController(task_1_.name);
EXPECT_EQ(ltc_1->GetLearningTask().name, task_1_.name);
}
TEST_F(LearningSessionImplTest, ExamplesAreForwardedToCorrectTask) {
session_->RegisterTask(task_0_);
session_->RegisterTask(task_1_);
base::UnguessableToken id = base::UnguessableToken::Create();
LabelledExample example_0({FeatureValue(123), FeatureValue(456)},
TargetValue(1234));
std::unique_ptr<LearningTaskController> ltc_0 =
session_->GetController(task_0_.name);
ukm::SourceId source_id(123);
ltc_0->BeginObservation(id, example_0.features, absl::nullopt, source_id);
ltc_0->CompleteObservation(
id, ObservationCompletion(example_0.target_value, example_0.weight));
LabelledExample example_1({FeatureValue(321), FeatureValue(654)},
TargetValue(4321));
std::unique_ptr<LearningTaskController> ltc_1 =
session_->GetController(task_1_.name);
ltc_1->BeginObservation(id, example_1.features);
ltc_1->CompleteObservation(
id, ObservationCompletion(example_1.target_value, example_1.weight));
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->example_, example_0);
EXPECT_EQ(task_controllers_[0]->source_id_, source_id);
EXPECT_EQ(task_controllers_[1]->example_, example_1);
}
TEST_F(LearningSessionImplTest, ControllerLifetimeScopedToSession) {
session_->RegisterTask(task_0_);
std::unique_ptr<LearningTaskController> controller =
session_->GetController(task_0_.name);
// Destroy the session. |controller| should still be usable, though it won't
// forward requests anymore.
session_.reset();
task_environment_.RunUntilIdle();
// Should not crash.
controller->BeginObservation(base::UnguessableToken::Create(),
FeatureVector());
}
TEST_F(LearningSessionImplTest, FeatureProviderIsForwarded) {
// Verify that a FeatureProvider actually gets forwarded to the LTC.
bool flag = false;
session_->RegisterTask(
task_0_, base::SequenceBound<FakeFeatureProvider>(task_runner_, &flag));
task_environment_.RunUntilIdle();
// Registering the task should create a FakeLearningTaskController, which will
// call AddFeatures on the fake FeatureProvider.
EXPECT_TRUE(flag);
}
TEST_F(LearningSessionImplTest, DestroyingControllerCancelsObservations) {
session_->RegisterTask(task_0_);
std::unique_ptr<LearningTaskController> controller =
session_->GetController(task_0_.name);
task_environment_.RunUntilIdle();
// Start an observation and verify that it starts.
base::UnguessableToken id = base::UnguessableToken::Create();
controller->BeginObservation(id, FeatureVector());
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->id_, id);
EXPECT_NE(task_controllers_[0]->cancelled_id_, id);
// Should result in cancelling the observation.
controller.reset();
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->cancelled_id_, id);
}
TEST_F(LearningSessionImplTest,
DestroyingControllerCompletesObservationsWithDefaultValues) {
// Also verifies that we don't send the default to the underlying controller,
// because LearningTaskControllerImpl doesn't support it.
session_->RegisterTask(task_0_);
std::unique_ptr<LearningTaskController> controller =
session_->GetController(task_0_.name);
task_environment_.RunUntilIdle();
// Start an observation and verify that it doesn't forward the default target.
base::UnguessableToken id = base::UnguessableToken::Create();
TargetValue default_target(123);
controller->BeginObservation(id, FeatureVector(), default_target);
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->id_, id);
EXPECT_FALSE(task_controllers_[0]->default_target_);
// Should complete the observation.
controller.reset();
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->example_.target_value, default_target);
}
TEST_F(LearningSessionImplTest, ChangeDefaultTargetToValue) {
session_->RegisterTask(task_0_);
std::unique_ptr<LearningTaskController> controller =
session_->GetController(task_0_.name);
task_environment_.RunUntilIdle();
// Start an observation without a default, then add one.
base::UnguessableToken id = base::UnguessableToken::Create();
controller->BeginObservation(id, FeatureVector(), absl::nullopt);
TargetValue default_target(123);
controller->UpdateDefaultTarget(id, default_target);
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->id_, id);
// Should complete the observation.
controller.reset();
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->example_.target_value, default_target);
// Shouldn't notify the underlying controller.
EXPECT_FALSE(task_controllers_[0]->updated_id_);
}
TEST_F(LearningSessionImplTest, ChangeDefaultTargetToNoValue) {
session_->RegisterTask(task_0_);
std::unique_ptr<LearningTaskController> controller =
session_->GetController(task_0_.name);
task_environment_.RunUntilIdle();
// Start an observation with a default, then remove it.
base::UnguessableToken id = base::UnguessableToken::Create();
TargetValue default_target(123);
controller->BeginObservation(id, FeatureVector(), default_target);
controller->UpdateDefaultTarget(id, absl::nullopt);
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->id_, id);
// Should cancel the observation.
controller.reset();
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->cancelled_id_, id);
// Shouldn't notify the underlying controller.
EXPECT_FALSE(task_controllers_[0]->updated_id_);
}
TEST_F(LearningSessionImplTest, PredictDistribution) {
session_->RegisterTask(task_0_);
std::unique_ptr<LearningTaskController> controller =
session_->GetController(task_0_.name);
task_environment_.RunUntilIdle();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
TargetHistogram observed_prediction;
controller->PredictDistribution(
features, base::BindOnce(
[](TargetHistogram* test_storage,
const absl::optional<TargetHistogram>& predicted) {
*test_storage = *predicted;
},
&observed_prediction));
task_environment_.RunUntilIdle();
EXPECT_EQ(features, task_controllers_[0]->predict_features_);
EXPECT_FALSE(task_controllers_[0]->predict_cb_.is_null());
TargetHistogram expected_prediction;
expected_prediction[TargetValue(1)] = 1.0;
expected_prediction[TargetValue(2)] = 2.0;
expected_prediction[TargetValue(3)] = 3.0;
std::move(task_controllers_[0]->predict_cb_).Run(expected_prediction);
task_environment_.RunUntilIdle();
EXPECT_EQ(expected_prediction, observed_prediction);
}
} // namespace learning
} // namespace media