| // Copyright 2017 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "base/win/com_init_check_hook.h" |
| |
| #include <objbase.h> |
| #include <shlobj.h> |
| #include <wrl/client.h> |
| |
| #include "base/test/gtest_util.h" |
| #include "base/win/com_init_util.h" |
| #include "base/win/patch_util.h" |
| #include "base/win/scoped_com_initializer.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| namespace base { |
| namespace win { |
| |
| using Microsoft::WRL::ComPtr; |
| |
| TEST(ComInitCheckHook, AssertNotInitialized) { |
| ComInitCheckHook com_check_hook; |
| AssertComApartmentType(ComApartmentType::NONE); |
| ComPtr<IUnknown> shell_link; |
| #if defined(COM_INIT_CHECK_HOOK_ENABLED) |
| EXPECT_DCHECK_DEATH(::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL, |
| IID_PPV_ARGS(&shell_link))); |
| #else |
| EXPECT_EQ(CO_E_NOTINITIALIZED, |
| ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL, |
| IID_PPV_ARGS(&shell_link))); |
| #endif |
| } |
| |
| TEST(ComInitCheckHook, HookRemoval) { |
| AssertComApartmentType(ComApartmentType::NONE); |
| { ComInitCheckHook com_check_hook; } |
| ComPtr<IUnknown> shell_link; |
| EXPECT_EQ(CO_E_NOTINITIALIZED, |
| ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL, |
| IID_PPV_ARGS(&shell_link))); |
| } |
| |
| TEST(ComInitCheckHook, NoAssertComInitialized) { |
| ComInitCheckHook com_check_hook; |
| ScopedCOMInitializer com_initializer; |
| ComPtr<IUnknown> shell_link; |
| EXPECT_TRUE(SUCCEEDED(::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL, |
| IID_PPV_ARGS(&shell_link)))); |
| } |
| |
| TEST(ComInitCheckHook, MultipleHooks) { |
| ComInitCheckHook com_check_hook_1; |
| ComInitCheckHook com_check_hook_2; |
| AssertComApartmentType(ComApartmentType::NONE); |
| ComPtr<IUnknown> shell_link; |
| #if defined(COM_INIT_CHECK_HOOK_ENABLED) |
| EXPECT_DCHECK_DEATH(::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL, |
| IID_PPV_ARGS(&shell_link))); |
| #else |
| EXPECT_EQ(CO_E_NOTINITIALIZED, |
| ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL, |
| IID_PPV_ARGS(&shell_link))); |
| #endif |
| } |
| |
| TEST(ComInitCheckHook, UnexpectedHook) { |
| #if defined(COM_INIT_CHECK_HOOK_ENABLED) |
| HMODULE ole32_library = ::LoadLibrary(L"ole32.dll"); |
| ASSERT_TRUE(ole32_library); |
| |
| uint32_t co_create_instance_padded_address = |
| reinterpret_cast<uint32_t>( |
| GetProcAddress(ole32_library, "CoCreateInstance")) - |
| 5; |
| const unsigned char* co_create_instance_bytes = |
| reinterpret_cast<const unsigned char*>(co_create_instance_padded_address); |
| const unsigned char original_byte = co_create_instance_bytes[0]; |
| const unsigned char unexpected_byte = 0xdb; |
| ASSERT_EQ(static_cast<DWORD>(NO_ERROR), |
| internal::ModifyCode( |
| reinterpret_cast<void*>(co_create_instance_padded_address), |
| reinterpret_cast<const void*>(&unexpected_byte), |
| sizeof(unexpected_byte))); |
| |
| EXPECT_DCHECK_DEATH({ ComInitCheckHook com_check_hook; }); |
| |
| // If this call fails, really bad things are going to happen to other tests |
| // so CHECK here. |
| CHECK_EQ(static_cast<DWORD>(NO_ERROR), |
| internal::ModifyCode( |
| reinterpret_cast<void*>(co_create_instance_padded_address), |
| reinterpret_cast<const void*>(&original_byte), |
| sizeof(original_byte))); |
| |
| ::FreeLibrary(ole32_library); |
| ole32_library = nullptr; |
| #endif |
| } |
| |
| TEST(ComInitCheckHook, ExternallyHooked) { |
| #if defined(COM_INIT_CHECK_HOOK_ENABLED) |
| HMODULE ole32_library = ::LoadLibrary(L"ole32.dll"); |
| ASSERT_TRUE(ole32_library); |
| |
| uint32_t co_create_instance_address = reinterpret_cast<uint32_t>( |
| GetProcAddress(ole32_library, "CoCreateInstance")); |
| const unsigned char* co_create_instance_bytes = |
| reinterpret_cast<const unsigned char*>(co_create_instance_address); |
| const unsigned char original_byte = co_create_instance_bytes[0]; |
| const unsigned char jmp_byte = 0xe9; |
| ASSERT_EQ(static_cast<DWORD>(NO_ERROR), |
| internal::ModifyCode( |
| reinterpret_cast<void*>(co_create_instance_address), |
| reinterpret_cast<const void*>(&jmp_byte), sizeof(jmp_byte))); |
| |
| // Externally patched instances should crash so we catch these cases on bots. |
| EXPECT_DCHECK_DEATH({ ComInitCheckHook com_check_hook; }); |
| |
| // If this call fails, really bad things are going to happen to other tests |
| // so CHECK here. |
| CHECK_EQ( |
| static_cast<DWORD>(NO_ERROR), |
| internal::ModifyCode(reinterpret_cast<void*>(co_create_instance_address), |
| reinterpret_cast<const void*>(&original_byte), |
| sizeof(original_byte))); |
| |
| ::FreeLibrary(ole32_library); |
| ole32_library = nullptr; |
| #endif |
| } |
| |
| TEST(ComInitCheckHook, UnexpectedChangeDuringHook) { |
| #if defined(COM_INIT_CHECK_HOOK_ENABLED) |
| HMODULE ole32_library = ::LoadLibrary(L"ole32.dll"); |
| ASSERT_TRUE(ole32_library); |
| |
| uint32_t co_create_instance_padded_address = |
| reinterpret_cast<uint32_t>( |
| GetProcAddress(ole32_library, "CoCreateInstance")) - |
| 5; |
| const unsigned char* co_create_instance_bytes = |
| reinterpret_cast<const unsigned char*>(co_create_instance_padded_address); |
| const unsigned char original_byte = co_create_instance_bytes[0]; |
| const unsigned char unexpected_byte = 0xdb; |
| ASSERT_EQ(static_cast<DWORD>(NO_ERROR), |
| internal::ModifyCode( |
| reinterpret_cast<void*>(co_create_instance_padded_address), |
| reinterpret_cast<const void*>(&unexpected_byte), |
| sizeof(unexpected_byte))); |
| |
| EXPECT_DCHECK_DEATH({ |
| ComInitCheckHook com_check_hook; |
| |
| internal::ModifyCode( |
| reinterpret_cast<void*>(co_create_instance_padded_address), |
| reinterpret_cast<const void*>(&unexpected_byte), |
| sizeof(unexpected_byte)); |
| }); |
| |
| // If this call fails, really bad things are going to happen to other tests |
| // so CHECK here. |
| CHECK_EQ(static_cast<DWORD>(NO_ERROR), |
| internal::ModifyCode( |
| reinterpret_cast<void*>(co_create_instance_padded_address), |
| reinterpret_cast<const void*>(&original_byte), |
| sizeof(original_byte))); |
| |
| ::FreeLibrary(ole32_library); |
| ole32_library = nullptr; |
| #endif |
| } |
| |
| } // namespace win |
| } // namespace base |