blob: 4d2068fbf02fc00b0c6f4dde322f7b23c0df6c18 [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/common/target_histogram.h"
#include <sstream>
namespace media {
namespace learning {
TargetHistogram::TargetHistogram() = default;
TargetHistogram::TargetHistogram(const TargetHistogram& rhs) = default;
TargetHistogram::TargetHistogram(TargetHistogram&& rhs) = default;
TargetHistogram::~TargetHistogram() = default;
TargetHistogram& TargetHistogram::operator=(const TargetHistogram& rhs) =
default;
TargetHistogram& TargetHistogram::operator=(TargetHistogram&& rhs) = default;
bool TargetHistogram::operator==(const TargetHistogram& rhs) const {
return rhs.total_counts() == total_counts() && rhs.counts_ == counts_;
}
TargetHistogram& TargetHistogram::operator+=(const TargetHistogram& rhs) {
for (auto& rhs_pair : rhs.counts())
counts_[rhs_pair.first] += rhs_pair.second;
return *this;
}
TargetHistogram& TargetHistogram::operator+=(const TargetValue& rhs) {
counts_[rhs]++;
return *this;
}
TargetHistogram& TargetHistogram::operator+=(const LabelledExample& example) {
counts_[example.target_value] += example.weight;
return *this;
}
double TargetHistogram::operator[](const TargetValue& value) const {
auto iter = counts_.find(value);
if (iter == counts_.end())
return 0;
return iter->second;
}
double& TargetHistogram::operator[](const TargetValue& value) {
return counts_[value];
}
bool TargetHistogram::FindSingularMax(TargetValue* value_out,
double* counts_out) const {
if (!counts_.size())
return false;
double unused_counts;
if (!counts_out)
counts_out = &unused_counts;
auto iter = counts_.begin();
*value_out = iter->first;
*counts_out = iter->second;
bool singular_max = true;
for (iter++; iter != counts_.end(); iter++) {
if (iter->second > *counts_out) {
*value_out = iter->first;
*counts_out = iter->second;
singular_max = true;
} else if (iter->second == *counts_out) {
// If this turns out to be the max, then it's not singular.
singular_max = false;
}
}
return singular_max;
}
double TargetHistogram::Average() const {
double total_value = 0.;
double total_counts = 0;
for (auto& iter : counts_) {
total_value += iter.first.value() * iter.second;
total_counts += iter.second;
}
if (!total_counts)
return 0.;
return total_value / total_counts;
}
void TargetHistogram::Normalize() {
double total = total_counts();
for (auto& iter : counts_)
iter.second /= total;
}
std::string TargetHistogram::ToString() const {
std::ostringstream ss;
ss << "[";
for (auto& entry : counts_)
ss << " " << entry.first << ":" << entry.second;
ss << " ]";
return ss.str();
}
std::ostream& operator<<(std::ostream& out,
const media::learning::TargetHistogram& dist) {
return out << dist.ToString();
}
} // namespace learning
} // namespace media