#!/usr/bin/python # Copyright 2017 The Chromium OS Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. import Queue import array import collections import os import shutil import tempfile import threading import unittest from contextlib import contextmanager from multiprocessing import connection import common from autotest_lib.site_utils import lxc from autotest_lib.site_utils.lxc import unittest_setup from autotest_lib.site_utils.lxc.container_pool import message from autotest_lib.site_utils.lxc.container_pool import service from autotest_lib.site_utils.lxc.container_pool import unittest_client FakeHostDir = collections.namedtuple('FakeHostDir', ['path']) class ServiceTests(unittest.TestCase): """Unit tests for the Service class.""" @classmethod def setUpClass(cls): """Creates a directory for running the unit tests. """ # Explicitly use /tmp as the tmpdir. Board specific TMPDIRs inside of # the chroot are set to a path that causes the socket address to exceed # the maximum allowable length. cls.test_dir = tempfile.mkdtemp(prefix='service_unittest_', dir='/tmp') @classmethod def tearDownClass(cls): """Deletes the test directory. """ shutil.rmtree(cls.test_dir) def setUp(self): """Per-test setup.""" # Put each test in its own test dir, so it's hermetic. self.test_dir = tempfile.mkdtemp(dir=ServiceTests.test_dir) self.host_dir = FakeHostDir(self.test_dir) self.address = os.path.join(self.test_dir, lxc.DEFAULT_CONTAINER_POOL_SOCKET) def testConnection(self): """Tests a simple connection to the pool service.""" with self.run_service(): self.assertTrue(self._pool_is_healthy()) def testAbortedConnection(self): """Tests that a closed connection doesn't crash the service.""" with self.run_service(): client = connection.Client(self.address) client.close() self.assertTrue(self._pool_is_healthy()) def testCorruptedMessage(self): """Tests that corrupted messages don't crash the service.""" with self.run_service(), self.create_client() as client: # Send a raw array of bytes. This will cause an unpickling error. client.send_bytes(array.array('i', range(1, 10))) # Verify that the container pool closed the connection. with self.assertRaises(EOFError): client.recv() # Verify that the main container pool service is still alive. self.assertTrue(self._pool_is_healthy()) def testInvalidMessageClass(self): """Tests that bad messages don't crash the service.""" with self.run_service(), self.create_client() as client: # Send a valid object but not of the right Message class. client.send('foo') # Verify that the container pool closed the connection. with self.assertRaises(EOFError): client.recv() # Verify that the main container pool service is still alive. self.assertTrue(self._pool_is_healthy()) def testInvalidMessageType(self): """Tests that messages with a bad type don't crash the service.""" with self.run_service(), self.create_client() as client: # Send a valid object but not of the right Message class. client.send(message.Message('foo', None)) # Verify that the container pool closed the connection. with self.assertRaises(EOFError): client.recv() # Verify that the main container pool service is still alive. self.assertTrue(self._pool_is_healthy()) def testStop(self): """Tests stopping the service.""" with self.run_service() as svc, self.create_client() as client: self.assertTrue(svc.is_running()) client.send(message.shutdown()) client.recv() # wait for ack self.assertFalse(svc.is_running()) def testStatus(self): """Tests querying service status.""" pool = MockPool() with self.run_service(pool) as svc, self.create_client() as client: client.send(message.status()) status = client.recv() self.assertTrue(status['running']) self.assertEqual(self.address, status['socket_path']) self.assertEqual(pool.capacity, status['pool capacity']) self.assertEqual(pool.size, status['pool size']) self.assertEqual(pool.worker_count, status['pool worker count']) self.assertEqual(pool.errors.qsize(), status['pool errors']) # Change some values, ensure the changes are reflected. pool.capacity = 42 pool.size = 19 pool.worker_count = 3 error_count = 8 for e in range(error_count): pool.errors.put(e) client.send(message.status()) status = client.recv() self.assertTrue(status['running']) self.assertEqual(self.address, status['socket_path']) self.assertEqual(pool.capacity, status['pool capacity']) self.assertEqual(pool.size, status['pool size']) self.assertEqual(pool.worker_count, status['pool worker count']) self.assertEqual(pool.errors.qsize(), status['pool errors']) def testGet(self): """Tests getting a container from the pool.""" test_pool = MockPool() fake_container = MockContainer() test_id = lxc.ContainerId.create(42) test_pool.containers.put(fake_container) with self.run_service(test_pool): with self.create_client() as client: client.send(message.get(test_id)) test_container = client.recv() self.assertEqual(test_id, test_container.id) def testGet_timeoutImmediate(self): """Tests getting a container with timeouts.""" test_id = lxc.ContainerId.create(42) with self.run_service(): with self.create_client() as client: client.send(message.get(test_id)) test_container = client.recv() self.assertIsNone(test_container) def testGet_timeoutDelayed(self): """Tests getting a container with timeouts.""" test_id = lxc.ContainerId.create(42) with self.run_service(): with self.create_client() as client: client.send(message.get(test_id, timeout=1)) test_container = client.recv() self.assertIsNone(test_container) def testMultipleClients(self): """Tests multiple simultaneous connections.""" with self.run_service(): with self.create_client() as client0: with self.create_client() as client1: msg0 = 'two driven jocks help fax my big quiz' msg1 = 'how quickly daft jumping zebras vex' client0.send(message.echo(msg0)) client1.send(message.echo(msg1)) echo0 = client0.recv() echo1 = client1.recv() self.assertEqual(msg0, echo0) self.assertEqual(msg1, echo1) def _pool_is_healthy(self): """Verifies that the pool service is still functioning. Sends an echo message and tests for a response. This is a stronger signal of aliveness than checking Service.is_running, but a False return value does not necessarily indicate that the pool service shut down cleanly. Use Service.is_running to check that. """ with self.create_client() as client: msg = 'foobar' client.send(message.echo(msg)) return client.recv() == msg @contextmanager def run_service(self, pool=None): """Creates and cleans up a Service instance.""" if pool is None: pool = MockPool() svc = service.Service(self.host_dir, pool) thread = threading.Thread(name='service', target=svc.start) thread.start() try: yield svc finally: svc.stop() thread.join(1) @contextmanager def create_client(self): """Creates and cleans up a client connection.""" client = unittest_client.connect(self.address) try: yield client finally: client.close() class MockPool(object): """A mock pool class for testing the service.""" def __init__(self): """Initializes a mock empty pool.""" self.capacity = 0 self.size = 0 self.worker_count = 0 self.errors = Queue.Queue() self.containers = Queue.Queue() def cleanup(self): """Required by pool interface. Does nothing.""" pass def get(self, timeout=0): """Required by pool interface. @return: A pool from the containers queue. """ try: return self.containers.get(block=(timeout > 0), timeout=timeout) except Queue.Empty: return None class MockContainer(object): """A mock container class for testing the service.""" def __init__(self): """Initializes a mock container.""" self.id = None self.name = 'test_container' if __name__ == '__main__': unittest_setup.setup(require_sudo=False) unittest.main()