#!/usr/bin/env python # # Copyright 2016 - The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for acloud.internal.lib.utils.""" import errno import getpass import grp import os import shutil import subprocess import tempfile import time import unittest import mock from acloud import errors from acloud.internal.lib import driver_test_lib from acloud.internal.lib import utils # Tkinter may not be supported so mock it out. try: import Tkinter except ImportError: Tkinter = mock.Mock() class FakeTkinter(object): """Fake implementation of Tkinter.Tk()""" def __init__(self, width=None, height=None): self.width = width self.height = height # pylint: disable=invalid-name def winfo_screenheight(self): """Return the screen height.""" return self.height # pylint: disable=invalid-name def winfo_screenwidth(self): """Return the screen width.""" return self.width # pylint: disable=too-many-public-methods class UtilsTest(driver_test_lib.BaseDriverTest): """Test Utils.""" def TestTempDirSuccess(self): """Test create a temp dir.""" self.Patch(os, "chmod") self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") self.Patch(shutil, "rmtree") with utils.TempDir(): pass # Verify. tempfile.mkdtemp.assert_called_once() # pylint: disable=no-member shutil.rmtree.assert_called_with("/tmp/tempdir") # pylint: disable=no-member def TestTempDirExceptionRaised(self): """Test create a temp dir and exception is raised within with-clause.""" self.Patch(os, "chmod") self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") self.Patch(shutil, "rmtree") class ExpectedException(Exception): """Expected exception.""" pass def _Call(): with utils.TempDir(): raise ExpectedException("Expected exception.") # Verify. ExpectedException should be raised. self.assertRaises(ExpectedException, _Call) tempfile.mkdtemp.assert_called_once() # pylint: disable=no-member shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member def testTempDirWhenDeleteTempDirNoLongerExist(self): # pylint: disable=invalid-name """Test create a temp dir and dir no longer exists during deletion.""" self.Patch(os, "chmod") self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") expected_error = EnvironmentError() expected_error.errno = errno.ENOENT self.Patch(shutil, "rmtree", side_effect=expected_error) def _Call(): with utils.TempDir(): pass # Verify no exception should be raised when rmtree raises # EnvironmentError with errno.ENOENT, i.e. # directory no longer exists. _Call() tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member def testTempDirWhenDeleteEncounterError(self): """Test create a temp dir and encoutered error during deletion.""" self.Patch(os, "chmod") self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") expected_error = OSError("Expected OS Error") self.Patch(shutil, "rmtree", side_effect=expected_error) def _Call(): with utils.TempDir(): pass # Verify OSError should be raised. self.assertRaises(OSError, _Call) tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member def testTempDirOrininalErrorRaised(self): """Test original error is raised even if tmp dir deletion failed.""" self.Patch(os, "chmod") self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") expected_error = OSError("Expected OS Error") self.Patch(shutil, "rmtree", side_effect=expected_error) class ExpectedException(Exception): """Expected exception.""" pass def _Call(): with utils.TempDir(): raise ExpectedException("Expected Exception") # Verify. # ExpectedException should be raised, and OSError # should not be raised. self.assertRaises(ExpectedException, _Call) tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member def testCreateSshKeyPairKeyAlreadyExists(self): #pylint: disable=invalid-name """Test when the key pair already exists.""" public_key = "/fake/public_key" private_key = "/fake/private_key" self.Patch(os.path, "exists", side_effect=[True, True]) self.Patch(subprocess, "check_call") self.Patch(os, "makedirs", return_value=True) utils.CreateSshKeyPairIfNotExist(private_key, public_key) self.assertEqual(subprocess.check_call.call_count, 0) #pylint: disable=no-member def testCreateSshKeyPairKeyAreCreated(self): """Test when the key pair created.""" public_key = "/fake/public_key" private_key = "/fake/private_key" self.Patch(os.path, "exists", return_value=False) self.Patch(os, "makedirs", return_value=True) self.Patch(subprocess, "check_call") self.Patch(os, "rename") utils.CreateSshKeyPairIfNotExist(private_key, public_key) self.assertEqual(subprocess.check_call.call_count, 1) #pylint: disable=no-member subprocess.check_call.assert_called_with( #pylint: disable=no-member utils.SSH_KEYGEN_CMD + ["-C", getpass.getuser(), "-f", private_key], stdout=mock.ANY, stderr=mock.ANY) def testCreatePublicKeyAreCreated(self): """Test when the PublicKey created.""" public_key = "/fake/public_key" private_key = "/fake/private_key" self.Patch(os.path, "exists", side_effect=[False, True, True]) self.Patch(os, "makedirs", return_value=True) mock_open = mock.mock_open(read_data=public_key) self.Patch(subprocess, "check_output") self.Patch(os, "rename") with mock.patch("__builtin__.open", mock_open): utils.CreateSshKeyPairIfNotExist(private_key, public_key) self.assertEqual(subprocess.check_output.call_count, 1) #pylint: disable=no-member subprocess.check_output.assert_called_with( #pylint: disable=no-member utils.SSH_KEYGEN_PUB_CMD +["-f", private_key]) def TestRetryOnException(self): """Test Retry.""" def _IsValueError(exc): return isinstance(exc, ValueError) num_retry = 5 @utils.RetryOnException(_IsValueError, num_retry) def _RaiseAndRetry(sentinel): sentinel.alert() raise ValueError("Fake error.") sentinel = mock.MagicMock() self.assertRaises(ValueError, _RaiseAndRetry, sentinel) self.assertEqual(1 + num_retry, sentinel.alert.call_count) def testRetryExceptionType(self): """Test RetryExceptionType function.""" def _RaiseAndRetry(sentinel): sentinel.alert() raise ValueError("Fake error.") num_retry = 5 sentinel = mock.MagicMock() self.assertRaises( ValueError, utils.RetryExceptionType, (KeyError, ValueError), num_retry, _RaiseAndRetry, 0, # sleep_multiplier 1, # retry_backoff_factor sentinel=sentinel) self.assertEqual(1 + num_retry, sentinel.alert.call_count) def testRetry(self): """Test Retry.""" mock_sleep = self.Patch(time, "sleep") def _RaiseAndRetry(sentinel): sentinel.alert() raise ValueError("Fake error.") num_retry = 5 sentinel = mock.MagicMock() self.assertRaises( ValueError, utils.RetryExceptionType, (ValueError, KeyError), num_retry, _RaiseAndRetry, 1, # sleep_multiplier 2, # retry_backoff_factor sentinel=sentinel) self.assertEqual(1 + num_retry, sentinel.alert.call_count) mock_sleep.assert_has_calls( [ mock.call(1), mock.call(2), mock.call(4), mock.call(8), mock.call(16) ]) @mock.patch("__builtin__.raw_input") def testGetAnswerFromList(self, mock_raw_input): """Test GetAnswerFromList.""" answer_list = ["image1.zip", "image2.zip", "image3.zip"] mock_raw_input.return_value = 0 with self.assertRaises(SystemExit): utils.GetAnswerFromList(answer_list) mock_raw_input.side_effect = [1, 2, 3, 4] self.assertEqual(utils.GetAnswerFromList(answer_list), ["image1.zip"]) self.assertEqual(utils.GetAnswerFromList(answer_list), ["image2.zip"]) self.assertEqual(utils.GetAnswerFromList(answer_list), ["image3.zip"]) self.assertEqual(utils.GetAnswerFromList(answer_list, enable_choose_all=True), answer_list) @unittest.skipIf(isinstance(Tkinter, mock.Mock), "Tkinter mocked out, test case not needed.") @mock.patch.object(Tkinter, "Tk") def testCalculateVNCScreenRatio(self, mock_tk): """Test Calculating the scale ratio of VNC display.""" # Get scale-down ratio if screen height is smaller than AVD height. mock_tk.return_value = FakeTkinter(height=800, width=1200) avd_h = 1920 avd_w = 1080 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.4) # Get scale-down ratio if screen width is smaller than AVD width. mock_tk.return_value = FakeTkinter(height=800, width=1200) avd_h = 900 avd_w = 1920 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6) # Scale ratio = 1 if screen is larger than AVD. mock_tk.return_value = FakeTkinter(height=1080, width=1920) avd_h = 800 avd_w = 1280 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 1) # Get the scale if ratio of width is smaller than the # ratio of height. mock_tk.return_value = FakeTkinter(height=1200, width=800) avd_h = 1920 avd_w = 1080 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6) # pylint: disable=protected-access def testCheckUserInGroups(self): """Test CheckUserInGroups.""" self.Patch(os, "getgroups", return_value=[1, 2, 3]) gr1 = mock.MagicMock() gr1.gr_name = "fake_gr_1" gr2 = mock.MagicMock() gr2.gr_name = "fake_gr_2" gr3 = mock.MagicMock() gr3.gr_name = "fake_gr_3" self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3]) # User in all required groups should return true. self.assertTrue( utils.CheckUserInGroups( ["fake_gr_1", "fake_gr_2"])) # User not in all required groups should return False. self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3]) self.assertFalse( utils.CheckUserInGroups( ["fake_gr_1", "fake_gr_4"])) @mock.patch.object(utils, "CheckUserInGroups") def testAddUserGroupsToCmd(self, mock_user_group): """Test AddUserGroupsToCmd.""" command = "test_command" groups = ["group1", "group2"] # Don't add user group in command mock_user_group.return_value = True expected_value = "test_command" self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command, groups)) # Add user group in command mock_user_group.return_value = False expected_value = "sg group1 <<EOF\nsg group2\ntest_command\nEOF" self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command, groups)) @staticmethod def testScpPullFileSuccess(): """Test scp pull file successfully.""" subprocess.check_call = mock.MagicMock() utils.ScpPullFile("/tmp/test", "/tmp/test_1.log", "192.168.0.1") subprocess.check_call.assert_called_with(utils.SCP_CMD + [ "192.168.0.1:/tmp/test", "/tmp/test_1.log"]) @staticmethod def testScpPullFileWithUserNameSuccess(): """Test scp pull file successfully.""" subprocess.check_call = mock.MagicMock() utils.ScpPullFile("/tmp/test", "/tmp/test_1.log", "192.168.0.1", user_name="abc") subprocess.check_call.assert_called_with(utils.SCP_CMD + [ "abc@192.168.0.1:/tmp/test", "/tmp/test_1.log"]) # pylint: disable=invalid-name @staticmethod def testScpPullFileWithUserNameWithRsaKeySuccess(): """Test scp pull file successfully.""" subprocess.check_call = mock.MagicMock() utils.ScpPullFile("/tmp/test", "/tmp/test_1.log", "192.168.0.1", user_name="abc", rsa_key_file="/tmp/my_key") subprocess.check_call.assert_called_with(utils.SCP_CMD + [ "-i", "/tmp/my_key", "abc@192.168.0.1:/tmp/test", "/tmp/test_1.log"]) def testScpPullFileScpFailure(self): """Test scp pull file failure.""" subprocess.check_call = mock.MagicMock( side_effect=subprocess.CalledProcessError(123, "fake", "fake error")) self.assertRaises( errors.DeviceConnectionError, utils.ScpPullFile, "/tmp/test", "/tmp/test_1.log", "192.168.0.1") def testTimeoutException(self): """Test TimeoutException.""" @utils.TimeoutException(1, "should time out") def functionThatWillTimeOut(): """Test decorator of @utils.TimeoutException should timeout.""" time.sleep(5) self.assertRaises(errors.FunctionTimeoutError, functionThatWillTimeOut) def testTimeoutExceptionNoTimeout(self): """Test No TimeoutException.""" @utils.TimeoutException(5, "shouldn't time out") def functionThatShouldNotTimeout(): """Test decorator of @utils.TimeoutException shouldn't timeout.""" return None try: functionThatShouldNotTimeout() except errors.FunctionTimeoutError: self.fail("shouldn't timeout") def testAutoConnectCreateSSHTunnelFail(self): """test auto connect.""" fake_ip_addr = "1.1.1.1" fake_rsa_key_file = "/tmp/rsa_file" fake_target_vnc_port = 8888 target_adb_port = 9999 ssh_user = "fake_user" call_side_effect = subprocess.CalledProcessError(123, "fake", "fake error") result = utils.ForwardedPorts(vnc_port=None, adb_port=None) self.Patch(subprocess, "check_call", side_effect=call_side_effect) self.assertEqual(result, utils.AutoConnect(fake_ip_addr, fake_rsa_key_file, fake_target_vnc_port, target_adb_port, ssh_user)) if __name__ == "__main__": unittest.main()