# Copyright (c) 2014 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import datetime
import mox
import unittest
import common
from autotest_lib.frontend import setup_django_environment
from autotest_lib.frontend.afe import frontend_test_utils
from autotest_lib.frontend.afe import models
from autotest_lib.client.common_lib import error
from autotest_lib.client.common_lib import global_config
from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
from autotest_lib.scheduler.shard import shard_client
class ShardClientTest(mox.MoxTestBase,
frontend_test_utils.FrontendTestMixin):
"""Unit tests for functions in shard_client.py"""
GLOBAL_AFE_HOSTNAME = 'foo_autotest'
def setUp(self):
super(ShardClientTest, self).setUp()
global_config.global_config.override_config_value(
'SHARD', 'global_afe_hostname', self.GLOBAL_AFE_HOSTNAME)
self._frontend_common_setup(fill_data=False)
def setup_mocks(self):
self.mox.StubOutClassWithMocks(frontend_wrappers, 'RetryingAFE')
self.afe = frontend_wrappers.RetryingAFE(server=mox.IgnoreArg(),
delay_sec=5,
timeout_min=5)
def setup_global_config(self):
global_config.global_config.override_config_value(
'SHARD', 'is_slave_shard', 'True')
global_config.global_config.override_config_value(
'SHARD', 'shard_hostname', 'host1')
def expect_heartbeat(self, shard_hostname='host1',
known_job_ids=[], known_host_ids=[],
known_host_statuses=[], hqes=[], jobs=[],
side_effect=None, return_hosts=[], return_jobs=[],
return_suite_keyvals=[]):
call = self.afe.run(
'shard_heartbeat', shard_hostname=shard_hostname,
hqes=hqes, jobs=jobs,
known_job_ids=known_job_ids, known_host_ids=known_host_ids,
known_host_statuses=known_host_statuses,
)
if side_effect:
call = call.WithSideEffects(side_effect)
call.AndReturn({
'hosts': return_hosts,
'jobs': return_jobs,
'suite_keyvals': return_suite_keyvals,
})
def tearDown(self):
self._frontend_common_teardown()
# Without this global_config will keep state over test cases
global_config.global_config.reset_config_values()
def _get_sample_serialized_host(self):
return {'aclgroup_set': [],
'dirty': True,
'hostattribute_set': [],
'hostname': u'host1',
u'id': 2,
'invalid': False,
'labels': [],
'leased': True,
'lock_time': None,
'locked': False,
'protection': 0,
'shard': None,
'status': u'Ready'}
def _get_sample_serialized_job(self):
return {'control_file': u'foo',
'control_type': 2,
'created_on': datetime.datetime(2014, 9, 23, 15, 56, 10, 0),
'dependency_labels': [{u'id': 1,
'invalid': False,
'kernel_config': u'',
'name': u'board:lumpy',
'only_if_needed': False,
'platform': False}],
'email_list': u'',
'hostqueueentry_set': [{'aborted': False,
'active': False,
'complete': False,
'deleted': False,
'execution_subdir': u'',
'finished_on': None,
u'id': 1,
'meta_host': {u'id': 1,
'invalid': False,
'kernel_config': u'',
'name': u'board:lumpy',
'only_if_needed': False,
'platform': False},
'started_on': None,
'status': u'Queued'}],
u'id': 1,
'jobkeyval_set': [],
'max_runtime_hrs': 72,
'max_runtime_mins': 1440,
'name': u'dummy',
'owner': u'autotest_system',
'parse_failed_repair': True,
'priority': 40,
'parent_job_id': 0,
'reboot_after': 0,
'reboot_before': 1,
'run_reset': True,
'run_verify': False,
'shard': {'hostname': u'shard1', u'id': 1},
'synch_count': 0,
'test_retry': 0,
'timeout': 24,
'timeout_mins': 1440}
def _get_sample_serialized_suite_keyvals(self):
return {'id': 1,
'job_id': 0,
'key': 'test_key',
'value': 'test_value'}
def testHeartbeat(self):
"""Trigger heartbeat, verify RPCs and persisting of the responses."""
self.setup_mocks()
global_config.global_config.override_config_value(
'SHARD', 'shard_hostname', 'host1')
self.expect_heartbeat(
return_hosts=[self._get_sample_serialized_host()],
return_jobs=[self._get_sample_serialized_job()],
return_suite_keyvals=[
self._get_sample_serialized_suite_keyvals()])
modified_sample_host = self._get_sample_serialized_host()
modified_sample_host['hostname'] = 'host2'
self.expect_heartbeat(
return_hosts=[modified_sample_host],
known_host_ids=[modified_sample_host['id']],
known_host_statuses=[modified_sample_host['status']],
known_job_ids=[1])
def verify_upload_jobs_and_hqes(name, shard_hostname, jobs, hqes,
known_host_ids, known_host_statuses,
known_job_ids):
self.assertEqual(len(jobs), 1)
self.assertEqual(len(hqes), 1)
job, hqe = jobs[0], hqes[0]
self.assertEqual(hqe['status'], 'Completed')
self.expect_heartbeat(
jobs=mox.IgnoreArg(), hqes=mox.IgnoreArg(),
known_host_ids=[modified_sample_host['id']],
known_host_statuses=[modified_sample_host['status']],
known_job_ids=[], side_effect=verify_upload_jobs_and_hqes)
self.mox.ReplayAll()
sut = shard_client.get_shard_client()
sut.do_heartbeat()
# Check if dummy object was saved to DB
host = models.Host.objects.get(id=2)
self.assertEqual(host.hostname, 'host1')
# Check if suite keyval was saved to DB
suite_keyval = models.JobKeyval.objects.filter(job_id=0)[0]
self.assertEqual(suite_keyval.key, 'test_key')
sut.do_heartbeat()
# Ensure it wasn't overwritten
host = models.Host.objects.get(id=2)
self.assertEqual(host.hostname, 'host1')
job = models.Job.objects.all()[0]
job.shard = None
job.save()
hqe = job.hostqueueentry_set.all()[0]
hqe.status = 'Completed'
hqe.save()
sut.do_heartbeat()
self.mox.VerifyAll()
def testFailAndRedownloadJobs(self):
self.setup_mocks()
self.setup_global_config()
job1_serialized = self._get_sample_serialized_job()
job2_serialized = self._get_sample_serialized_job()
job2_serialized['id'] = 2
job2_serialized['hostqueueentry_set'][0]['id'] = 2
self.expect_heartbeat(return_jobs=[job1_serialized])
self.expect_heartbeat(return_jobs=[job1_serialized, job2_serialized])
self.expect_heartbeat(known_job_ids=[job1_serialized['id'],
job2_serialized['id']])
self.expect_heartbeat(known_job_ids=[job2_serialized['id']])
self.mox.ReplayAll()
sut = shard_client.get_shard_client()
original_process_heartbeat_response = sut.process_heartbeat_response
def failing_process_heartbeat_response(*args, **kwargs):
raise RuntimeError
sut.process_heartbeat_response = failing_process_heartbeat_response
self.assertRaises(RuntimeError, sut.do_heartbeat)
sut.process_heartbeat_response = original_process_heartbeat_response
sut.do_heartbeat()
sut.do_heartbeat()
job2 = models.Job.objects.get(pk=job1_serialized['id'])
job2.hostqueueentry_set.all().update(complete=True)
sut.do_heartbeat()
self.mox.VerifyAll()
def testFailAndRedownloadHosts(self):
self.setup_mocks()
self.setup_global_config()
host1_serialized = self._get_sample_serialized_host()
host2_serialized = self._get_sample_serialized_host()
host2_serialized['id'] = 3
host2_serialized['hostname'] = 'host2'
self.expect_heartbeat(return_hosts=[host1_serialized])
self.expect_heartbeat(return_hosts=[host1_serialized, host2_serialized])
self.expect_heartbeat(known_host_ids=[host1_serialized['id'],
host2_serialized['id']],
known_host_statuses=[host1_serialized['status'],
host2_serialized['status']])
self.mox.ReplayAll()
sut = shard_client.get_shard_client()
original_process_heartbeat_response = sut.process_heartbeat_response
def failing_process_heartbeat_response(*args, **kwargs):
raise RuntimeError
sut.process_heartbeat_response = failing_process_heartbeat_response
self.assertRaises(RuntimeError, sut.do_heartbeat)
self.assertEqual(models.Host.objects.count(), 0)
sut.process_heartbeat_response = original_process_heartbeat_response
sut.do_heartbeat()
sut.do_heartbeat()
self.mox.VerifyAll()
def testHeartbeatNoShardMode(self):
"""Ensure an exception is thrown when run on a non-shard machine."""
self.mox.ReplayAll()
self.assertRaises(error.HeartbeatOnlyAllowedInShardModeException,
shard_client.get_shard_client)
self.mox.VerifyAll()
def testLoop(self):
"""Test looping over heartbeats and aborting that loop works."""
self.setup_mocks()
self.setup_global_config()
global_config.global_config.override_config_value(
'SHARD', 'heartbeat_pause_sec', '0.01')
self.expect_heartbeat()
sut = None
def shutdown_sut(*args, **kwargs):
sut.shutdown()
self.expect_heartbeat(side_effect=shutdown_sut)
self.mox.ReplayAll()
sut = shard_client.get_shard_client()
sut.loop()
self.mox.VerifyAll()
if __name__ == '__main__':
unittest.main()