blob: f2d3b62994b87141d4c722bc8cdcb3a97dce40df [file] [log] [blame]
#!/usr/bin/python
#
# Copyright 2022 The Cobalt Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides utilities for reading the sharding configuration from JSON and parsing
it into the corresponding allocation of tests to shards.
"""
import os
import json
import logging
_SHARDING_CONFIG_FILE = os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'sharding_configuration.json')
class ShardingTestConfig(object):
"""
Class that encapsulates the sharding configuration for a given platform.
"""
RUN_FULL_TEST = 'run_full_test'
RUN_PARTIAL_TEST = 'run_partial_test'
SKIP_TEST = 'skip_test'
def __init__(self, platform, test_targets, platform_sharding_config=None):
try:
self.platform_sharding_config = platform_sharding_config
if not platform_sharding_config:
self.platform_sharding_config = self._read_platform_sharding_config(
platform)
# Create the default shard, and add it to the config.
default_shard = self._create_default_shard(self.platform_sharding_config,
test_targets)
self.platform_sharding_config.insert(0, default_shard)
logging.debug(
'%s',
json.dumps(self.platform_sharding_config, sort_keys=True, indent=2))
except FileNotFoundError as e:
raise RuntimeError('No sharding configuration file found.') from e
def _read_platform_sharding_config(self,
platform,
filename=_SHARDING_CONFIG_FILE):
with open(filename, 'r', encoding='utf-8') as file:
sharding_json = json.load(file)
for platform_key in sharding_json:
if platform in platform_key:
return sharding_json[platform_key]
return sharding_json['default']
def _find_unmapped_test_chunks(self, platform_sharding_config):
"""
This function determines which test chunks have not been assigned by
enumerating all possible test chunks, given the existing incomplete sharding
configuration.
E.g. if the sharding configuration consists of:
config = [
{
"foo_test": [2, 3],
},
{
"foo_test": [1, 3],
}
]
Then we can assume that the [3, 3] chunk is unmapped.
"""
# Iterate over all chunks:
# - Create a mapping between test_name and chunk_total
# - Record the unmapped chunk_index values for each test_name
chunk_totals_by_test = {}
unmapped_chunk_index_by_test = {}
for shard_chunk in platform_sharding_config:
for test_name, chunk_info in shard_chunk.items():
# The entire test is represented by '*' in the |chunk_info| field.
if not isinstance(chunk_info, list):
unmapped_chunk_index_by_test[test_name] = set()
continue
# The test chunk is represented by [index, count], where:
# count >= index >= 1
chunk_index, chunk_total = chunk_info
if test_name not in chunk_totals_by_test:
chunk_totals_by_test[test_name] = chunk_total
unmapped_chunk_index_by_test[test_name] = set(
i for i in range(1, chunk_total + 1))
unmapped_chunk_index_by_test[test_name].remove(chunk_index)
logging.debug('Printing all tests with unmapped chunks:')
for k, v in unmapped_chunk_index_by_test.items():
logging.debug('%s: %s', k, v)
logging.debug('Printing all tests with corresponding chunk sizes:')
logging.debug(json.dumps(chunk_totals_by_test, sort_keys=True, indent=2))
return unmapped_chunk_index_by_test, chunk_totals_by_test
def _create_default_shard(self, platform_sharding_config, test_targets):
"""
The default shard consists of all unmapped test chunks, and unmapped tests.
This function generates the default shard using the list of all test targets
and the currently configured set of test chunks.
A test chunk is defined as a subset of test cases from the test suite, which
corresponds to the tests run by GTEST when provided the following arguments:
$ ./foo_test --gtest_shard_index=INDEX --gtest_total_shards=COUNT
Where COUNT > INDEX > 1.
"""
unmapped_chunk_index_by_test, chunk_totals_by_test = (
self._find_unmapped_test_chunks(platform_sharding_config))
# Add the entire test to the default shard, if the test is not found in the
# list of unmapped chunks.
unlisted_tests = set(test_targets) - set(
unmapped_chunk_index_by_test.keys())
default_shard = {name: '*' for name in unlisted_tests}
# Add the unmapped test chunks.
for test_name, unmapped_chunks in unmapped_chunk_index_by_test.items():
if len(unmapped_chunks) == 0:
# Test has no unmapped chunks.
continue
# Shard cannot contain multiple chunks from the same test.
if len(unmapped_chunks) > 1:
raise ValueError('Invalid Sharding Configuration: default shard must '
'not contain multiple chunks from the same test ('
f'test:{test_name} unmapped_chunks:'
f'{unmapped_chunks}).')
# Add the unmapped test chunk to the default shard.
chunk_index = list(unmapped_chunks)[0]
chunk_total = chunk_totals_by_test[test_name]
default_shard[test_name] = [chunk_index, chunk_total]
return default_shard
def get_test_run_config(self, test_target, shard_index):
"""
Returns whether the input test is run in the input shard (specified by its
index). If the test is run as part of the shard, then it also provides the
additional parameters to subdivide the test (if needed). Otherwise the test
is assumed to be fully run in the shard.
"""
shard_test_targets = self.platform_sharding_config[shard_index]
if test_target in shard_test_targets:
test_run_config = shard_test_targets[test_target]
if test_run_config == '*':
return (self.RUN_FULL_TEST, 0, 0)
else:
sub_shard_index = test_run_config[0] - 1
sub_shard_count = test_run_config[1]
return (self.RUN_PARTIAL_TEST, sub_shard_index, sub_shard_count)
else:
return (self.SKIP_TEST, 0, 0)