普通文本  |  193行  |  6.41 KB

#!/usr/bin/python
# Copyright 2016 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import mock
import unittest

import common

from autotest_lib.server import utils
from autotest_lib.server.hosts import host_info
from autotest_lib.server.hosts import base_label

 # pylint: disable=missing-docstring


class TestBaseLabel(base_label.BaseLabel):
    """TestBaseLabel is used for testing/validating BaseLabel methods."""

    _NAME = 'base_label'

    def exists(self, host):
        return host.exists


class TestBaseLabels(base_label.BaseLabel):
    """
    TestBaseLabels is used for testing/validating BaseLabel methods.

    This is a variation of BaseLabel with multiple labels for _NAME
    to ensure we handle a label that contains a list of labels for
    its _NAME attribute.
    """

    _NAME = ['base_label_1' , 'base_label_2']


class TestStringPrefixLabel(base_label.StringPrefixLabel):
    """
    TestBaseLabels is used for testing/validating StringPrefixLabel methods.

    This test class is to check that we properly construct the prefix labels
    with the label passed in during construction.
    """

    _NAME = 'prefix'

    def __init__(self, label='postfix'):
        self.label_to_return = label


    def generate_labels(self, _):
        return [self.label_to_return]


class MockAFEHost(utils.EmptyAFEHost):

    def __init__(self, labels=None, attributes=None):
        self.labels = labels or []
        self.attributes = attributes or {}


class MockHost(object):

    def __init__(self, exists=True, store=None):
        self.hostname = 'hostname'
        self.exists = exists
        self.host_info_store = store


class BaseLabelUnittests(unittest.TestCase):
    """Unittest for testing base_label.BaseLabel."""

    def setUp(self):
        self.test_base_label = TestBaseLabel()
        self.test_base_labels = TestBaseLabels()


    def test_generate_labels(self):
        """Let's make sure generate_labels() returns the labels expected."""
        self.assertEqual(self.test_base_label.generate_labels(None),
                         [self.test_base_label._NAME])


    def test_get(self):
        """Let's make sure the logic in get() works as expected."""
        # We should get labels here.
        self.assertEqual(self.test_base_label.get(MockHost(exists=True)),
                         [self.test_base_label._NAME])
        # We should get nothing here.
        self.assertEqual(self.test_base_label.get(MockHost(exists=False)),
                         [])


    def test_get_all_labels(self):
        """Check that we get the expected labels for get_all_labels()."""
        prefix_tbl, full_tbl = self.test_base_label.get_all_labels()
        prefix_tbls, full_tbls = self.test_base_labels.get_all_labels()

        # We want to check that we always get a list of labels regardless if
        # the label class attribute _NAME is a list or a string.
        self.assertEqual(full_tbl, set([self.test_base_label._NAME]))
        self.assertEqual(full_tbls, set(self.test_base_labels._NAME))

        # We want to make sure we get nothing on the prefix_* side of things
        # since BaseLabel shouldn't be a prefix for any label.
        self.assertEqual(prefix_tbl, set())
        self.assertEqual(prefix_tbls, set())


class StringPrefixLabelUnittests(unittest.TestCase):
    """Unittest for testing base_label.StringPrefixLabel."""

    def setUp(self):
        self.postfix_label = 'postfix_label'
        self.test_label = TestStringPrefixLabel(label=self.postfix_label)


    def test_get(self):
        """Let's make sure that the labels we get are prefixed."""
        self.assertEqual(self.test_label.get(None),
                         ['%s:%s' % (self.test_label._NAME,
                                     self.postfix_label)])


    def test_get_all_labels(self):
        """Check that we only get prefix labels and no full labels."""
        prefix_labels, postfix_labels = self.test_label.get_all_labels()
        self.assertEqual(prefix_labels, set(['%s:' % self.test_label._NAME]))
        self.assertEqual(postfix_labels, set())


class LabelRetrieverUnittests(unittest.TestCase):
    """Unittest for testing base_label.LabelRetriever."""

    def setUp(self):
        label_list = [TestStringPrefixLabel(), TestBaseLabel()]
        self.retriever = base_label.LabelRetriever(label_list)
        self.retriever._populate_known_labels(label_list)


    def test_populate_known_labels(self):
        """Check that _populate_known_labels() works as expected."""
        full_names = set([TestBaseLabel._NAME])
        prefix_names = set(['%s:' % TestStringPrefixLabel._NAME])
        # Check on a normal retriever.
        self.assertEqual(self.retriever.label_full_names, full_names)
        self.assertEqual(self.retriever.label_prefix_names, prefix_names)


    def test_is_known_label(self):
        """Check _is_known_label() detects/skips the right labels."""
        # This will be a list of tuples of label and expected return bool.
        # Make sure Full matches match correctly
        labels_to_check = [(TestBaseLabel._NAME, True),
                           ('%s:' % TestStringPrefixLabel._NAME, True),
                           # Make sure partial matches fail.
                           (TestBaseLabel._NAME[:2], False),
                           ('%s:' % TestStringPrefixLabel._NAME[:2], False),
                           ('no_label_match', False)]

        for label, expected_known in labels_to_check:
            self.assertEqual(self.retriever._is_known_label(label),
                             expected_known)


    def test_update_labels(self):
        """Check that we add/remove the expected labels in update_labels()."""
        label_to_add = 'label_to_add'
        label_to_remove = 'prefix:label_to_remove'
        store = host_info.InMemoryHostInfoStore(
                info=host_info.HostInfo(
                        labels=[label_to_remove, TestBaseLabel._NAME],
                ),
        )
        mockhost = MockHost(store=store)

        retriever = base_label.LabelRetriever(
                [TestStringPrefixLabel(label=label_to_add),
                 TestBaseLabel()])
        retriever.update_labels(mockhost)
        self.assertEqual(
                set(store.get().labels),
                {'%s:%s' % (TestStringPrefixLabel._NAME, label_to_add),
                 TestBaseLabel._NAME},
        )


if __name__ == '__main__':
    unittest.main()