#
# Copyright (C) 2016 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import socket
import unittest
import logging
import errno
from socket import error as socket_error

from vts.runners.host import errors
from vts.proto import AndroidSystemControlMessage_pb2 as SysMsg_pb2
from vts.runners.host.tcp_server import callback_server

HOST, PORT = "localhost", 0
ERROR_PORT = 380  # port at which we test the error case.


class TestMethods(unittest.TestCase):
    """This class defines unit test methods.

    The common scenarios are when we wish to test the whether we are able to
    receive the expected data from the server; and whether we receive the
    correct error when we try to connect to server from a wrong port.

    Attributes:
        _callback_server: an instance of CallbackServer that is used to
                         start and stop the TCP server.
        _counter: This is used to keep track of number of calls made to the
                  callback function.
    """
    _callback_server = None
    _counter = 0

    def setUp(self):
        """This function initiates starting the server in CallbackServer."""
        self._callback_server = callback_server.CallbackServer()
        self._callback_server.Start()

    def tearDown(self):
        """To initiate shutdown of the server.

        This function calls the callback_server.CallbackServer.Stop which
        shutdowns the server.
        """
        self._callback_server.Stop()

    def DoErrorCase(self):
        """Unit test for Error case.

        This function tests the cases that throw exception.
        e.g sending requests to port 25.

        Raises:
            ConnectionRefusedError: ConnectionRefusedError occurred in
            test_ErrorCase().
        """
        host = self._callback_server.ip

        # Create a socket (SOCK_STREAM means a TCP socket)
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

        try:
            # Connect to server; this should result in Connection refused error
            sock.connect((host, ERROR_PORT))
        except socket_error as e:
            # We are comparing the error number of the error we expect and
            # the error that we get.
            # Test fails if ConnectionRefusedError is not raised at this step.
            if e.errno == errno.ECONNREFUSED:
                raise errors.ConnectionRefusedError  # Test is a success here
            else:
                raise e  # Test fails, since ConnectionRefusedError was expected
        finally:
            sock.close()

    def ConnectToServer(self, func_id):
        """This function creates a connection to TCP server and sends/receives
            message.

        Args:
            func_id: This is the unique key corresponding to a function and
                also the id field of the request_message that we send to the
                server.

        Returns:
            response_message: The object that the TCP host returns.

        Raises:
            TcpServerConnectionError: Exception occurred while stopping server.
        """
        # This object is sent to the TCP host
        request_message = SysMsg_pb2.AndroidSystemCallbackRequestMessage()
        request_message.id = func_id

        #  The response in string format that we receive from host
        received_message = ""

        # The final object that this function returns
        response_message = SysMsg_pb2.AndroidSystemCallbackResponseMessage()

        # Create a socket (SOCK_STREAM means a TCP socket)
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        host = self._callback_server.ip
        port = self._callback_server.port
        logging.info('Sending Request to host %s using port %s', host, port)

        try:
            # Connect to server and send request_message
            sock.connect((host, port))

            message = request_message.SerializeToString()
            sock.sendall(str(len(message)) + "\n" + message)
            logging.info("Sent: %s", message)

            # Receive request_message from the server and shut down
            received_message = sock.recv(1024)
            response_message.ParseFromString(received_message)
            logging.info('Received: %s', received_message)
        except socket_error as e:
            logging.error(e)
            raise errors.TcpServerConnectionError('Exception occurred.')
        finally:
            sock.close()

        return response_message

    def testDoErrorCase(self):
        """Unit test for error cases."""
        with self.assertRaises(errors.ConnectionRefusedError):
            self.DoErrorCase()

    def testCallback(self):
        """Tests two callback use cases."""
        self.TestNormalCase()
        self.TestDoRegisterCallback()

    def TestNormalCase(self):
        """Tests the normal request to TCPServer.

        This function sends the request to the Tcp server where the request
        should be a success.

        This function also checks the register callback feature by ensuring that
        callback_func() is called and the value of the global counter is
        increased by one.
        """
        def callback_func():
            self._counter += 1

        # Function should be registered with RegisterCallback
        func_id = self._callback_server.RegisterCallback(callback_func)
        self.assertEqual(func_id, '0')

        # Capture the previous value of global counter
        prev_value = self._counter

        # Connect to server
        response_message = self.ConnectToServer(func_id)

        # Confirm whether the callback_func() was called thereby increasing
        # value of global counter by 1
        self.assertEqual(self._counter, prev_value + 1)

        # Also confirm if query resulted in a success
        self.assertEqual(response_message.response_code, SysMsg_pb2.SUCCESS)

    def TestDoRegisterCallback(self):
        """Checks the register callback functionality of the Server.

        This function checks whether the value of global counter remains same
        if function is not registered. It also checks whether it's incremented
        by 1 when the function is registered.
        """
        def callback_func():
            self._counter += 1

        # Capture the previous value of global counter
        prev_value = self._counter

        # Function should be registered with RegisterCallback
        func_id = self._callback_server.RegisterCallback(callback_func)

        found_func_id = self._callback_server.GetCallbackId(callback_func)
        self.assertEqual(func_id, found_func_id)

        # Connect to server
        response_message = self.ConnectToServer(func_id)

        # Confirm whether the callback_func() was not called.
        self.assertEqual(self._counter, prev_value + 1)

        # also confirm the error message
        self.assertEqual(response_message.response_code, SysMsg_pb2.SUCCESS)

        # Now unregister the function and check again
        # Function should be unregistered with UnegisterCallback
        # and the key should also be present
        self._callback_server.UnregisterCallback(func_id)

        # Capture the previous value of global counter
        prev_value = self._counter

        # Connect to server
        response_message = self.ConnectToServer(func_id)

        # Confirm whether the callback_func() was not called.
        self.assertEqual(self._counter, prev_value)

        # also confirm the error message
        self.assertEqual(response_message.response_code, SysMsg_pb2.FAIL)

if __name__ == '__main__':
    unittest.main()