# # Copyright (C) 2016 The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import socket import socketserver import threading from vts.runners.host import errors from vts.proto import AndroidSystemControlMessage_pb2 as SysMsg from vts.proto import ComponentSpecificationMessage_pb2 as CompSpecMsg from vts.utils.python.mirror import pb2py _functions = dict() # Dictionary to hold function pointers class CallbackServerError(errors.VtsError): """Raised when an error occurs in VTS TCP server.""" class CallbackRequestHandler(socketserver.StreamRequestHandler): """The request handler class for our server.""" def handle(self): """Receives requests from clients. When a callback happens on the target side, a request message is posted to the host side and is handled here. The message is parsed and the appropriate callback function on the host side is called. """ header = self.rfile.readline().strip() try: len = int(header) except ValueError: if header: logging.exception("Unable to convert '%s' into an integer, which " "is required for reading the next message." % header) raise else: logging.error('CallbackRequestHandler received empty message header. Skipping...') return # Read the request message. received_data = self.rfile.read(len) logging.debug("Received callback message: %s", received_data) request_message = SysMsg.AndroidSystemCallbackRequestMessage() request_message.ParseFromString(received_data) logging.debug('Handling callback ID: %s', request_message.id) response_message = SysMsg.AndroidSystemCallbackResponseMessage() # Call the appropriate callback function and construct the response # message. if request_message.id in _functions: callback_args = [] for arg in request_message.arg: callback_args.append(pb2py.Convert(arg)) args = tuple(callback_args) _functions[request_message.id](*args) response_message.response_code = SysMsg.SUCCESS else: logging.error("Callback function ID %s is not registered!", request_message.id) response_message.response_code = SysMsg.FAIL # send the response back to client message = response_message.SerializeToString() # self.request is the TCP socket connected to the client self.request.sendall(message) class CallbackServer(object): """This class creates TCPServer in separate thread. Attributes: _server: an instance of socketserver.TCPServer. _port: this variable maintains the port number used in creating the server connection. _ip: variable to hold the IP Address of the host. _hostname: IP Address to which initial connection is made. """ def __init__(self): self._server = None self._port = 0 # Port 0 means to select an arbitrary unused port self._ip = "" # Used to store the IP address for the server self._hostname = "localhost" # IP address to which initial connection is made def RegisterCallback(self, callback_func): """Registers a callback function. Args: callback_func: The function to register. Returns: string, Id of the registered callback function. Raises: CallbackServerError is raised if the func_id is already registered. """ if self.GetCallbackId(callback_func): raise CallbackServerError("Function is already registered") id = 0 if _functions: id = int(max(_functions, key=int)) + 1 _functions[str(id)] = callback_func return str(id) def UnregisterCallback(self, func_id): """Removes a callback function from the registry. Args: func_id: The ID of the callback function to remove. Raises: CallbackServerError is raised if the func_id is not registered. """ try: _functions.pop(func_id) except KeyError: raise CallbackServerError( "Can't remove function ID '%s', which is not registered." % func_id) def GetCallbackId(self, callback_func): """Get ID of the callback function. Registers a callback function. Args: callback_func: The function to register. Returns: string, Id of the callback function if found, None otherwise. """ return _functions.get(callback_func, None) def Start(self, port=0): """Starts the server. Args: port: integer, number of the port on which the server listens. Default is 0, which means auto-select a port available. Returns: IP Address, port number Raises: CallbackServerError is raised if the server fails to start. """ try: self._server = socketserver.TCPServer( (self._hostname, port), CallbackRequestHandler) self._ip, self._port = self._server.server_address # Start a thread with the server. # Each request will be handled in a child thread. server_thread = threading.Thread(target=self._server.serve_forever) server_thread.daemon = True server_thread.start() logging.info('TcpServer %s started (%s:%s)', server_thread.name, self._ip, self._port) return self._ip, self._port except (RuntimeError, IOError, socket.error) as e: logging.exception(e) raise CallbackServerError( 'Failed to start CallbackServer on (%s:%s).' % (self._hostname, port)) def Stop(self): """Stops the server. """ self._server.shutdown() self._server.server_close() @property def ip(self): return self._ip @property def port(self): return self._port