// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "media/cdm/win/media_foundation_cdm.h"

#include <wchar.h>

#include "base/bind.h"
#include "base/strings/utf_string_conversions.h"
#include "base/test/mock_callback.h"
#include "base/test/task_environment.h"
#include "media/base/mock_filters.h"
#include "media/base/test_helpers.h"
#include "media/base/win/media_foundation_cdm_proxy.h"
#include "media/base/win/mf_helpers.h"
#include "media/base/win/mf_mocks.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

using ::testing::_;
using ::testing::DoAll;
using ::testing::InSequence;
using ::testing::IsEmpty;
using ::testing::NotNull;
using ::testing::Return;
using ::testing::SetArgPointee;
using ::testing::SetArgReferee;
using ::testing::StrEq;
using ::testing::StrictMock;
using ::testing::WithoutArgs;

namespace media {

namespace {

const char kSessionId[] = "session_id";
const double kExpirationMs = 123456789.0;
const auto kExpirationTime = base::Time::FromJsTime(kExpirationMs);

std::vector<uint8_t> StringToVector(const std::string& str) {
  return std::vector<uint8_t>(str.begin(), str.end());
}

// testing::InvokeArgument<N> does not work with base::OnceCallback. Use this
// gmock action template to invoke base::OnceCallback. `k` is the k-th argument
// and `T` is the callback's type.
ACTION_TEMPLATE(InvokeCallbackArgument,
                HAS_2_TEMPLATE_PARAMS(int, k, typename, T),
                AND_1_VALUE_PARAMS(p0)) {
  std::move(const_cast<T&>(std::get<k>(args))).Run(p0);
}

}  // namespace

using Microsoft::WRL::ComPtr;

class MediaFoundationCdmTest : public testing::Test {
 public:
  MediaFoundationCdmTest()
      : mf_cdm_(MakeComPtr<MockMFCdm>()),
        mf_cdm_session_(MakeComPtr<MockMFCdmSession>()),
        cdm_(base::MakeRefCounted<MediaFoundationCdm>(
            base::BindRepeating(&MediaFoundationCdmTest::CreateMFCdm,
                                base::Unretained(this)),
            is_type_supported_cb_.Get(),
            store_client_token_cb_.Get(),
            base::BindRepeating(&MockCdmClient::OnSessionMessage,
                                base::Unretained(&cdm_client_)),
            base::BindRepeating(&MockCdmClient::OnSessionClosed,
                                base::Unretained(&cdm_client_)),
            base::BindRepeating(&MockCdmClient::OnSessionKeysChange,
                                base::Unretained(&cdm_client_)),
            base::BindRepeating(&MockCdmClient::OnSessionExpirationUpdate,
                                base::Unretained(&cdm_client_)))) {}

  ~MediaFoundationCdmTest() override = default;

  void CreateMFCdm(HRESULT& hresult,
                   Microsoft::WRL::ComPtr<IMFContentDecryptionModule>& mf_cdm) {
    if (can_initialize_) {
      hresult = S_OK;
      mf_cdm = mf_cdm_;
    } else {
      hresult = E_FAIL;
      mf_cdm.Reset();
    }
  }

  void Initialize() { ASSERT_SUCCESS(cdm_->Initialize()); }

  void InitializeAndExpectFailure() {
    can_initialize_ = false;
    ASSERT_FAILED(cdm_->Initialize());
  }

  void SetGenerateRequestExpectations(
      ComPtr<MockMFCdmSession> mf_cdm_session,
      const char* session_id,
      IMFContentDecryptionModuleSessionCallbacks** mf_cdm_session_callbacks,
      bool expect_message = true) {
    std::vector<uint8_t> license_request = StringToVector("request");

    // Session ID to return. Will be released by |mf_cdm_session_|.
    std::wstring wide_session_id = base::UTF8ToWide(session_id);
    LPWSTR mf_session_id = nullptr;
    ASSERT_SUCCESS(
        CopyCoTaskMemWideString(wide_session_id.data(), &mf_session_id));

    COM_EXPECT_CALL(mf_cdm_session,
                    GenerateRequest(StrEq(L"webm"), NotNull(), _))
        .WillOnce(WithoutArgs([=] {  // Capture local variables by value.
          (*mf_cdm_session_callbacks)
              ->KeyMessage(MF_MEDIAKEYSESSION_MESSAGETYPE_LICENSE_REQUEST,
                           license_request.data(), license_request.size(),
                           nullptr);
          return S_OK;
        }));

    COM_EXPECT_CALL(mf_cdm_session, GetSessionId(_))
        .WillOnce(DoAll(SetArgPointee<0>(mf_session_id), Return(S_OK)));

    if (expect_message) {
      EXPECT_CALL(cdm_client_,
                  OnSessionMessage(session_id, CdmMessageType::LICENSE_REQUEST,
                                   license_request));
    }
  }

  void CreateSessionAndGenerateRequest() {
    std::vector<uint8_t> init_data = StringToVector("init_data");

    COM_EXPECT_CALL(mf_cdm_,
                    CreateSession(MF_MEDIAKEYSESSION_TYPE_TEMPORARY, _, _))
        .WillOnce(DoAll(SaveComPtr<1>(&mf_cdm_session_callbacks_),
                        SetComPointee<2>(mf_cdm_session_.Get()), Return(S_OK)));

    SetGenerateRequestExpectations(mf_cdm_session_, kSessionId,
                                   &mf_cdm_session_callbacks_);

    cdm_->CreateSessionAndGenerateRequest(
        CdmSessionType::kTemporary, EmeInitDataType::WEBM, init_data,
        std::make_unique<MockCdmSessionPromise>(/*expect_success=*/true,
                                                &session_id_));

    task_environment_.RunUntilIdle();
    EXPECT_EQ(session_id_, kSessionId);
  }

  void OnCdmProxyReceived(scoped_refptr<MediaFoundationCdmProxy> mf_cdm_proxy) {
    mf_cdm_proxy_ = std::move(mf_cdm_proxy);
  }

 protected:
  base::test::TaskEnvironment task_environment_;

  StrictMock<MockCdmClient> cdm_client_;
  StrictMock<base::MockCallback<MediaFoundationCdm::IsTypeSupportedCB>>
      is_type_supported_cb_;
  base::MockCallback<MediaFoundationCdm::StoreClientTokenCB>
      store_client_token_cb_;
  ComPtr<MockMFCdm> mf_cdm_;
  ComPtr<MockMFCdmSession> mf_cdm_session_;
  ComPtr<IMFContentDecryptionModuleSessionCallbacks> mf_cdm_session_callbacks_;
  scoped_refptr<MediaFoundationCdm> cdm_;
  bool can_initialize_ = true;
  std::string session_id_;
  scoped_refptr<MediaFoundationCdmProxy> mf_cdm_proxy_;
};

TEST_F(MediaFoundationCdmTest, SetServerCertificate) {
  Initialize();

  std::vector<uint8_t> certificate = StringToVector("certificate");
  COM_EXPECT_CALL(mf_cdm_,
                  SetServerCertificate(certificate.data(), certificate.size()))
      .WillOnce(Return(S_OK));

  cdm_->SetServerCertificate(
      certificate, std::make_unique<MockCdmPromise>(/*expect_success=*/true));
}

TEST_F(MediaFoundationCdmTest, SetServerCertificate_Failure) {
  Initialize();

  std::vector<uint8_t> certificate = StringToVector("certificate");
  COM_EXPECT_CALL(mf_cdm_,
                  SetServerCertificate(certificate.data(), certificate.size()))
      .WillOnce(Return(E_FAIL));

  cdm_->SetServerCertificate(
      certificate, std::make_unique<MockCdmPromise>(/*expect_success=*/false));
}

TEST_F(MediaFoundationCdmTest, GetStatusForPolicy_HdcpNone_KeyStatusUsable) {
  Initialize();
  CdmKeyInformation::KeyStatus key_status;
  cdm_->GetStatusForPolicy(HdcpVersion::kHdcpVersionNone,
                           std::make_unique<MockCdmKeyStatusPromise>(
                               /*expect_success=*/true, &key_status));
  EXPECT_EQ(CdmKeyInformation::KeyStatus::USABLE, key_status);
}

TEST_F(MediaFoundationCdmTest, GetStatusForPolicy_HdcpV1_1_KeyStatusUsable) {
  Initialize();
  EXPECT_CALL(is_type_supported_cb_,
              Run("video/mp4;codecs=\"avc1\";features=\"hdcp=1\"", _))
      .WillOnce(
          InvokeCallbackArgument<1,
                                 MediaFoundationCdm::IsTypeSupportedResultCB>(
              /*is_supported=*/true));

  CdmKeyInformation::KeyStatus key_status;
  cdm_->GetStatusForPolicy(HdcpVersion::kHdcpVersion1_1,
                           std::make_unique<MockCdmKeyStatusPromise>(
                               /*expect_success=*/true, &key_status));
  EXPECT_EQ(CdmKeyInformation::KeyStatus::USABLE, key_status);
}

TEST_F(MediaFoundationCdmTest,
       GetStatusForPolicy_HdcpV2_3_KeyStatusOutputRestricted) {
  Initialize();
  EXPECT_CALL(is_type_supported_cb_,
              Run("video/mp4;codecs=\"avc1\";features=\"hdcp=2\"", _))
      .WillOnce(
          InvokeCallbackArgument<1,
                                 MediaFoundationCdm::IsTypeSupportedResultCB>(
              /*is_supported=*/false));

  CdmKeyInformation::KeyStatus key_status;
  cdm_->GetStatusForPolicy(HdcpVersion::kHdcpVersion2_3,
                           std::make_unique<MockCdmKeyStatusPromise>(
                               /*expect_success=*/true, &key_status));
  EXPECT_EQ(CdmKeyInformation::KeyStatus::OUTPUT_RESTRICTED, key_status);
}

TEST_F(MediaFoundationCdmTest, CreateSessionAndGenerateRequest) {
  Initialize();
  CreateSessionAndGenerateRequest();
}

// Tests the case where two sessions are being created in parallel.
TEST_F(MediaFoundationCdmTest, CreateSessionAndGenerateRequest_Parallel) {
  Initialize();

  std::vector<uint8_t> init_data = StringToVector("init_data");
  const char kSessionId1[] = "session_id_1";
  const char kSessionId2[] = "session_id_2";

  auto mf_cdm_session_1 = MakeComPtr<MockMFCdmSession>();
  auto mf_cdm_session_2 = MakeComPtr<MockMFCdmSession>();
  ComPtr<IMFContentDecryptionModuleSessionCallbacks> mf_cdm_session_callbacks_1;
  ComPtr<IMFContentDecryptionModuleSessionCallbacks> mf_cdm_session_callbacks_2;

  COM_EXPECT_CALL(mf_cdm_,
                  CreateSession(MF_MEDIAKEYSESSION_TYPE_TEMPORARY, _, _))
      .WillOnce(DoAll(SaveComPtr<1>(&mf_cdm_session_callbacks_1),
                      SetComPointee<2>(mf_cdm_session_1.Get()), Return(S_OK)))
      .WillOnce(DoAll(SaveComPtr<1>(&mf_cdm_session_callbacks_2),
                      SetComPointee<2>(mf_cdm_session_2.Get()), Return(S_OK)));

  SetGenerateRequestExpectations(mf_cdm_session_1, kSessionId1,
                                 &mf_cdm_session_callbacks_1);
  SetGenerateRequestExpectations(mf_cdm_session_2, kSessionId2,
                                 &mf_cdm_session_callbacks_2);

  std::string session_id_1;
  std::string session_id_2;
  cdm_->CreateSessionAndGenerateRequest(
      CdmSessionType::kTemporary, EmeInitDataType::WEBM, init_data,
      std::make_unique<MockCdmSessionPromise>(/*expect_success=*/true,
                                              &session_id_1));
  cdm_->CreateSessionAndGenerateRequest(
      CdmSessionType::kTemporary, EmeInitDataType::WEBM, init_data,
      std::make_unique<MockCdmSessionPromise>(/*expect_success=*/true,
                                              &session_id_2));

  task_environment_.RunUntilIdle();
  EXPECT_EQ(session_id_1, kSessionId1);
  EXPECT_EQ(session_id_2, kSessionId2);
}

TEST_F(MediaFoundationCdmTest, InitializeFailure) {
  InitializeAndExpectFailure();

  std::vector<uint8_t> init_data = StringToVector("init_data");
  cdm_->CreateSessionAndGenerateRequest(
      CdmSessionType::kTemporary, EmeInitDataType::WEBM, init_data,
      std::make_unique<MockCdmSessionPromise>(/*expect_success=*/false,
                                              &session_id_));

  task_environment_.RunUntilIdle();
  EXPECT_TRUE(session_id_.empty());
}

TEST_F(MediaFoundationCdmTest,
       CreateSessionAndGenerateRequest_CreateSessionFailure) {
  Initialize();

  COM_EXPECT_CALL(mf_cdm_,
                  CreateSession(MF_MEDIAKEYSESSION_TYPE_TEMPORARY, _, _))
      .WillOnce(Return(E_FAIL));

  std::vector<uint8_t> init_data = StringToVector("init_data");
  cdm_->CreateSessionAndGenerateRequest(
      CdmSessionType::kTemporary, EmeInitDataType::WEBM, init_data,
      std::make_unique<MockCdmSessionPromise>(/*expect_success=*/false,
                                              &session_id_));

  task_environment_.RunUntilIdle();
  EXPECT_TRUE(session_id_.empty());
}

TEST_F(MediaFoundationCdmTest,
       CreateSessionAndGenerateRequest_GenerateRequestFailure) {
  Initialize();

  COM_EXPECT_CALL(mf_cdm_,
                  CreateSession(MF_MEDIAKEYSESSION_TYPE_TEMPORARY, _, _))
      .WillOnce(DoAll(SaveComPtr<1>(&mf_cdm_session_callbacks_),
                      SetComPointee<2>(mf_cdm_session_.Get()), Return(S_OK)));

  std::vector<uint8_t> init_data = StringToVector("init_data");
  COM_EXPECT_CALL(mf_cdm_session_,
                  GenerateRequest(StrEq(L"webm"), NotNull(), init_data.size()))
      .WillOnce(Return(E_FAIL));

  cdm_->CreateSessionAndGenerateRequest(
      CdmSessionType::kTemporary, EmeInitDataType::WEBM, init_data,
      std::make_unique<MockCdmSessionPromise>(/*expect_success=*/false,
                                              &session_id_));

  task_environment_.RunUntilIdle();
  EXPECT_TRUE(session_id_.empty());
}

// Duplicate session IDs cause session creation failure.
TEST_F(MediaFoundationCdmTest,
       CreateSessionAndGenerateRequest_DuplicateSessionId) {
  Initialize();

  auto mf_cdm_session_1 = MakeComPtr<MockMFCdmSession>();
  auto mf_cdm_session_2 = MakeComPtr<MockMFCdmSession>();
  ComPtr<IMFContentDecryptionModuleSessionCallbacks> mf_cdm_session_callbacks_1;
  ComPtr<IMFContentDecryptionModuleSessionCallbacks> mf_cdm_session_callbacks_2;

  COM_EXPECT_CALL(mf_cdm_,
                  CreateSession(MF_MEDIAKEYSESSION_TYPE_TEMPORARY, _, _))
      .WillOnce(DoAll(SaveComPtr<1>(&mf_cdm_session_callbacks_1),
                      SetComPointee<2>(mf_cdm_session_1.Get()), Return(S_OK)))
      .WillOnce(DoAll(SaveComPtr<1>(&mf_cdm_session_callbacks_2),
                      SetComPointee<2>(mf_cdm_session_2.Get()), Return(S_OK)));

  // In both sessions we return kSessionId. Session 1 succeeds. Session 2 fails
  // because of duplicate session ID.
  SetGenerateRequestExpectations(mf_cdm_session_1, kSessionId,
                                 &mf_cdm_session_callbacks_1);
  SetGenerateRequestExpectations(mf_cdm_session_2, kSessionId,
                                 &mf_cdm_session_callbacks_2,
                                 /*expect_message=*/false);
  std::string session_id_1;
  std::string session_id_2;
  std::vector<uint8_t> init_data = StringToVector("init_data");
  cdm_->CreateSessionAndGenerateRequest(
      CdmSessionType::kTemporary, EmeInitDataType::WEBM, init_data,
      std::make_unique<MockCdmSessionPromise>(/*expect_success=*/true,
                                              &session_id_1));
  cdm_->CreateSessionAndGenerateRequest(
      CdmSessionType::kTemporary, EmeInitDataType::WEBM, init_data,
      std::make_unique<MockCdmSessionPromise>(/*expect_success=*/false,
                                              &session_id_2));

  task_environment_.RunUntilIdle();
  EXPECT_EQ(session_id_1, kSessionId);
  EXPECT_TRUE(session_id_2.empty());
}

// LoadSession() is not implemented.
TEST_F(MediaFoundationCdmTest, LoadSession) {
  Initialize();

  cdm_->LoadSession(CdmSessionType::kPersistentLicense, kSessionId,
                    std::make_unique<MockCdmSessionPromise>(
                        /*expect_success=*/false, &session_id_));

  task_environment_.RunUntilIdle();
  EXPECT_TRUE(session_id_.empty());
}

TEST_F(MediaFoundationCdmTest, UpdateSession) {
  Initialize();
  CreateSessionAndGenerateRequest();

  std::vector<uint8_t> response = StringToVector("response");
  COM_EXPECT_CALL(mf_cdm_session_, Update(NotNull(), response.size()))
      .WillOnce(DoAll([&] { mf_cdm_session_callbacks_->KeyStatusChanged(); },
                      Return(S_OK)));
  COM_EXPECT_CALL(mf_cdm_session_, GetKeyStatuses(_, _)).WillOnce(Return(S_OK));
  COM_EXPECT_CALL(mf_cdm_session_, GetExpiration(_))
      .WillOnce(DoAll(SetArgPointee<0>(kExpirationMs), Return(S_OK)));
  EXPECT_CALL(cdm_client_, OnSessionKeysChangeCalled(_, true));
  EXPECT_CALL(cdm_client_, OnSessionExpirationUpdate(_, kExpirationTime));

  cdm_->UpdateSession(
      session_id_, response,
      std::make_unique<MockCdmPromise>(/*expect_success=*/true));
  task_environment_.RunUntilIdle();
}

TEST_F(MediaFoundationCdmTest, UpdateSession_InvalidSessionId) {
  Initialize();
  CreateSessionAndGenerateRequest();

  std::vector<uint8_t> response = StringToVector("response");
  cdm_->UpdateSession(
      "invalid_session_id", response,
      std::make_unique<MockCdmPromise>(/*expect_success=*/false));
  task_environment_.RunUntilIdle();
}

TEST_F(MediaFoundationCdmTest, UpdateSession_Failure) {
  Initialize();
  CreateSessionAndGenerateRequest();

  std::vector<uint8_t> response = StringToVector("response");
  COM_EXPECT_CALL(mf_cdm_session_, Update(NotNull(), response.size()))
      .WillOnce(Return(E_FAIL));

  cdm_->UpdateSession(
      session_id_, response,
      std::make_unique<MockCdmPromise>(/*expect_success=*/false));
  task_environment_.RunUntilIdle();
}

TEST_F(MediaFoundationCdmTest, CloseSession) {
  Initialize();
  CreateSessionAndGenerateRequest();

  COM_EXPECT_CALL(mf_cdm_session_, Close()).WillOnce(Return(S_OK));
  EXPECT_CALL(cdm_client_,
              OnSessionClosed(kSessionId, CdmSessionClosedReason::kClose));

  cdm_->CloseSession(session_id_,
                     std::make_unique<MockCdmPromise>(/*expect_success=*/true));
  task_environment_.RunUntilIdle();
}

TEST_F(MediaFoundationCdmTest, CloseSession_Failure) {
  Initialize();
  CreateSessionAndGenerateRequest();

  COM_EXPECT_CALL(mf_cdm_session_, Close()).WillOnce(Return(E_FAIL));

  cdm_->CloseSession(
      session_id_, std::make_unique<MockCdmPromise>(/*expect_success=*/false));
  task_environment_.RunUntilIdle();
}

TEST_F(MediaFoundationCdmTest, RemoveSession) {
  Initialize();
  CreateSessionAndGenerateRequest();

  COM_EXPECT_CALL(mf_cdm_session_, Remove()).WillOnce(Return(S_OK));
  COM_EXPECT_CALL(mf_cdm_session_, GetExpiration(_))
      .WillOnce(DoAll(SetArgPointee<0>(kExpirationMs), Return(S_OK)));
  EXPECT_CALL(cdm_client_, OnSessionExpirationUpdate(_, kExpirationTime));

  cdm_->RemoveSession(
      session_id_, std::make_unique<MockCdmPromise>(/*expect_success=*/true));
  task_environment_.RunUntilIdle();
}

TEST_F(MediaFoundationCdmTest, RemoveSession_Failure) {
  Initialize();
  CreateSessionAndGenerateRequest();

  COM_EXPECT_CALL(mf_cdm_session_, Remove()).WillOnce(Return(E_FAIL));

  cdm_->RemoveSession(
      session_id_, std::make_unique<MockCdmPromise>(/*expect_success=*/false));
  task_environment_.RunUntilIdle();
}

TEST_F(MediaFoundationCdmTest, HardwareContextReset) {
  Initialize();
  CreateSessionAndGenerateRequest();

  CdmContext* cdm_context = cdm_->GetCdmContext();
  cdm_context->GetMediaFoundationCdmProxy(base::BindOnce(
      &MediaFoundationCdmTest::OnCdmProxyReceived, base::Unretained(this)));
  task_environment_.RunUntilIdle();
  ASSERT_TRUE(mf_cdm_proxy_);

  COM_EXPECT_CALL(mf_cdm_session_, Close()).WillOnce(Return(S_OK));
  EXPECT_CALL(cdm_client_,
              OnSessionClosed(kSessionId,
                              CdmSessionClosedReason::kHardwareContextReset));
  mf_cdm_proxy_->OnHardwareContextReset();

  // Create a new session and expect success.
  CreateSessionAndGenerateRequest();
}

TEST_F(MediaFoundationCdmTest, HardwareContextReset_InitializeFailure) {
  Initialize();
  CreateSessionAndGenerateRequest();

  CdmContext* cdm_context = cdm_->GetCdmContext();
  cdm_context->GetMediaFoundationCdmProxy(base::BindOnce(
      &MediaFoundationCdmTest::OnCdmProxyReceived, base::Unretained(this)));
  task_environment_.RunUntilIdle();
  ASSERT_TRUE(mf_cdm_proxy_);

  // Make the next `Initialize()` fail.
  can_initialize_ = false;

  COM_EXPECT_CALL(mf_cdm_session_, Close()).WillOnce(Return(S_OK));
  EXPECT_CALL(cdm_client_,
              OnSessionClosed(kSessionId,
                              CdmSessionClosedReason::kHardwareContextReset));
  mf_cdm_proxy_->OnHardwareContextReset();

  std::vector<uint8_t> init_data = StringToVector("init_data");
  cdm_->CreateSessionAndGenerateRequest(
      CdmSessionType::kTemporary, EmeInitDataType::WEBM, init_data,
      std::make_unique<MockCdmSessionPromise>(/*expect_success=*/false,
                                              &session_id_));
  task_environment_.RunUntilIdle();
}

}  // namespace media
