普通文本  |  613行  |  23.15 KB

#
# Copyright (C) 2016 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import json
import logging
import os
import socket
import time
import types

from vts.proto import AndroidSystemControlMessage_pb2 as SysMsg_pb2
from vts.proto import ComponentSpecificationMessage_pb2 as CompSpecMsg_pb2
from vts.runners.host import const
from vts.runners.host import errors
from vts.utils.python.mirror import mirror_object

from google.protobuf import text_format

TARGET_IP = os.environ.get("TARGET_IP", None)
TARGET_PORT = os.environ.get("TARGET_PORT", None)
_DEFAULT_SOCKET_TIMEOUT_SECS = 1800
_SOCKET_CONN_TIMEOUT_SECS = 60
_SOCKET_CONN_RETRY_NUMBER = 5
COMMAND_TYPE_NAME = {
    1: "LIST_HALS",
    2: "SET_HOST_INFO",
    101: "CHECK_DRIVER_SERVICE",
    102: "LAUNCH_DRIVER_SERVICE",
    103: "VTS_AGENT_COMMAND_READ_SPECIFICATION",
    201: "LIST_APIS",
    202: "CALL_API",
    203: "VTS_AGENT_COMMAND_GET_ATTRIBUTE",
    301: "VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND"
}


class VtsTcpClient(object):
    """VTS TCP Client class.

    Attribute:
        connection: a TCP socket instance.
        channel: a file to write and read data.
        _mode: the connection mode (adb_forwarding or ssh_tunnel)
        timeout: tcp connection timeout.
    """

    def __init__(self,
                 mode="adb_forwarding",
                 timeout=_DEFAULT_SOCKET_TIMEOUT_SECS):
        self.connection = None
        self.channel = None
        self._mode = mode
        self.timeout = timeout

    @property
    def timeout(self):
        """Get TCP connection timeout.

        This function assumes timeout property setter is in __init__before
        any getter calls.

        Returns:
            int, timeout
        """
        return self._timeout

    @timeout.setter
    def timeout(self, timeout):
        """Set TCP connection timeout.

        Args:
            timeout: int, TCP connection timeout in seconds.
        """
        self._timeout = timeout

    def Connect(self,
                ip=TARGET_IP,
                command_port=TARGET_PORT,
                callback_port=None,
                retry=_SOCKET_CONN_RETRY_NUMBER,
                timeout=None):
        """Connects to a target device.

        Args:
            ip: string, the IP address of a target device.
            command_port: int, the TCP port which can be used to connect to
                          a target device.
            callback_port: int, the TCP port number of a host-side callback
                           server.
            retry: int, the number of times to retry connecting before giving
                   up.
            timeout: tcp connection timeout.

        Returns:
            True if success, False otherwise

        Raises:
            socket.error when the connection fails.
        """
        if not command_port:
            logging.error("ip %s, command_port %s, callback_port %s invalid",
                          ip, command_port, callback_port)
            return False

        for i in xrange(retry):
            connection_timeout = self._timeout if timeout is None else timeout
            try:
                self.connection = socket.create_connection(
                    (ip, command_port), timeout=connection_timeout)
                break
            except socket.error as e:
                # Wait a bit and retry.
                logging.exception("Connect failed %s", e)
                time.sleep(1)
                if i + 1 == retry:
                    raise errors.VtsTcpClientCreationError(
                        "Couldn't connect to %s:%s" % (ip, command_port))
        self.channel = self.connection.makefile(mode="brw")

        if callback_port is not None:
            self.SendCommand(
                SysMsg_pb2.SET_HOST_INFO, callback_port=callback_port)
            resp = self.RecvResponse()
            if (resp.response_code != SysMsg_pb2.SUCCESS):
                return False
        return True

    def Disconnect(self):
        """Disconnects from the target device.

        TODO(yim): Send a msg to the target side to teardown handler session
        and release memory before closing the socket.
        """
        if self.connection is not None:
            self.channel = None
            self.connection.close()
            self.connection = None

    def ListHals(self, base_paths):
        """RPC to LIST_HALS."""
        self.SendCommand(SysMsg_pb2.LIST_HALS, paths=base_paths)
        resp = self.RecvResponse()
        if (resp.response_code == SysMsg_pb2.SUCCESS):
            return resp.file_names
        return None

    def CheckDriverService(self, service_name):
        """RPC to CHECK_DRIVER_SERVICE."""
        self.SendCommand(
            SysMsg_pb2.CHECK_DRIVER_SERVICE, service_name=service_name)
        resp = self.RecvResponse()
        return (resp.response_code == SysMsg_pb2.SUCCESS)

    def LaunchDriverService(self,
                            driver_type,
                            service_name,
                            bits,
                            file_path=None,
                            target_class=None,
                            target_type=None,
                            target_version=None,
                            target_package=None,
                            target_component_name=None,
                            hw_binder_service_name=None):
        """RPC to LAUNCH_DRIVER_SERVICE."""
        logging.info("service_name: %s", service_name)
        logging.info("file_path: %s", file_path)
        logging.info("bits: %s", bits)
        logging.info("driver_type: %s", driver_type)
        self.SendCommand(
            SysMsg_pb2.LAUNCH_DRIVER_SERVICE,
            driver_type=driver_type,
            service_name=service_name,
            bits=bits,
            file_path=file_path,
            target_class=target_class,
            target_type=target_type,
            target_version=target_version,
            target_package=target_package,
            target_component_name=target_component_name,
            hw_binder_service_name=hw_binder_service_name)
        resp = self.RecvResponse()
        logging.info("resp for LAUNCH_DRIVER_SERVICE: %s", resp)
        if driver_type == SysMsg_pb2.VTS_DRIVER_TYPE_HAL_HIDL \
                or driver_type == SysMsg_pb2.VTS_DRIVER_TYPE_HAL_CONVENTIONAL \
                or driver_type == SysMsg_pb2.VTS_DRIVER_TYPE_HAL_LEGACY:
            if resp.response_code == SysMsg_pb2.SUCCESS:
                return int(resp.result)
            else:
                return -1
        else:
            return (resp.response_code == SysMsg_pb2.SUCCESS)

    def ListApis(self):
        """RPC to LIST_APIS."""
        self.SendCommand(SysMsg_pb2.LIST_APIS)
        resp = self.RecvResponse()
        logging.info("resp for LIST_APIS: %s", resp)
        if (resp.response_code == SysMsg_pb2.SUCCESS):
            return resp.spec
        return None

    def GetPythonDataOfVariableSpecMsg(self, var_spec_msg):
        """Returns the python native data structure for a given message.

        Args:
            var_spec_msg: VariableSpecificationMessage

        Returns:
            python native data structure (e.g., string, integer, list).

        Raises:
            VtsUnsupportedTypeError if unsupported type is specified.
            VtsMalformedProtoStringError if StringDataValueMessage is
                not populated.
        """
        if var_spec_msg.type == CompSpecMsg_pb2.TYPE_SCALAR:
            scalar_type = getattr(var_spec_msg, "scalar_type", "")
            if scalar_type:
                return getattr(var_spec_msg.scalar_value, scalar_type)
        elif var_spec_msg.type == CompSpecMsg_pb2.TYPE_ENUM:
            scalar_type = getattr(var_spec_msg, "scalar_type", "")
            if scalar_type:
                return getattr(var_spec_msg.scalar_value, scalar_type)
            else:
                return var_spec_msg.scalar_value.int32_t
        elif var_spec_msg.type == CompSpecMsg_pb2.TYPE_STRING:
            if hasattr(var_spec_msg, "string_value"):
                return getattr(var_spec_msg.string_value, "message", "")
            raise errors.VtsMalformedProtoStringError()
        elif var_spec_msg.type == CompSpecMsg_pb2.TYPE_STRUCT:
            result = {}
            index = 1
            for struct_value in var_spec_msg.struct_value:
                if len(struct_value.name) > 0:
                    result[struct_value.
                           name] = self.GetPythonDataOfVariableSpecMsg(
                               struct_value)
                else:
                    result["attribute%d" %
                           index] = self.GetPythonDataOfVariableSpecMsg(
                               struct_value)
                index += 1
            return result
        elif var_spec_msg.type == CompSpecMsg_pb2.TYPE_UNION:
            result = VtsReturnValueObject()
            index = 1
            for union_value in var_spec_msg.union_value:
                if len(union_value.name) > 0:
                    result[union_value.
                           name] = self.GetPythonDataOfVariableSpecMsg(
                               union_value)
                else:
                    result["attribute%d" %
                           index] = self.GetPythonDataOfVariableSpecMsg(
                               union_value)
                index += 1
            return result
        elif (var_spec_msg.type == CompSpecMsg_pb2.TYPE_VECTOR or
              var_spec_msg.type == CompSpecMsg_pb2.TYPE_ARRAY):
            result = []
            for vector_value in var_spec_msg.vector_value:
                result.append(
                    self.GetPythonDataOfVariableSpecMsg(vector_value))
            return result
        elif (var_spec_msg.type == CompSpecMsg_pb2.TYPE_HIDL_INTERFACE):
            logging.debug("var_spec_msg: %s", var_spec_msg)
            return var_spec_msg

        raise errors.VtsUnsupportedTypeError("unsupported type %s" %
                                             var_spec_msg.type)

    def CallApi(self, arg, caller_uid=None):
        """RPC to CALL_API."""
        self.SendCommand(SysMsg_pb2.CALL_API, arg=arg, caller_uid=caller_uid)
        resp = self.RecvResponse()
        resp_code = resp.response_code
        if (resp_code == SysMsg_pb2.SUCCESS):
            result = CompSpecMsg_pb2.FunctionSpecificationMessage()
            if resp.result == "error":
                raise errors.VtsTcpCommunicationError(
                    "API call error by the VTS driver.")
            try:
                text_format.Merge(resp.result, result)
            except text_format.ParseError as e:
                logging.exception(e)
                logging.error("Paring error\n%s", resp.result)
            if result.return_type.type == CompSpecMsg_pb2.TYPE_SUBMODULE:
                logging.info("returned a submodule spec")
                logging.info("spec: %s", result.return_type_submodule_spec)
                return mirror_object.MirrorObject(
                    self, result.return_type_submodule_spec, None)

            logging.info("result: %s", result.return_type_hidl)
            if len(result.return_type_hidl) == 1:
                result_value = self.GetPythonDataOfVariableSpecMsg(
                    result.return_type_hidl[0])
            elif len(result.return_type_hidl) > 1:
                result_value = []
                for return_type_hidl in result.return_type_hidl:
                    result_value.append(
                        self.GetPythonDataOfVariableSpecMsg(return_type_hidl))
            else:  # For non-HIDL return value
                if hasattr(result, "return_type"):
                    result_value = result
                else:
                    result_value = None

            if hasattr(result, "raw_coverage_data"):
                return result_value, {"coverage": result.raw_coverage_data}
            else:
                return result_value

        logging.error("NOTICE - Likely a crash discovery!")
        logging.error("SysMsg_pb2.SUCCESS is %s", SysMsg_pb2.SUCCESS)
        raise errors.VtsTcpCommunicationError(
            "RPC Error, response code for %s is %s" % (arg, resp_code))

    def GetAttribute(self, arg):
        """RPC to VTS_AGENT_COMMAND_GET_ATTRIBUTE."""
        self.SendCommand(SysMsg_pb2.VTS_AGENT_COMMAND_GET_ATTRIBUTE, arg=arg)
        resp = self.RecvResponse()
        resp_code = resp.response_code
        if (resp_code == SysMsg_pb2.SUCCESS):
            result = CompSpecMsg_pb2.FunctionSpecificationMessage()
            if resp.result == "error":
                raise errors.VtsTcpCommunicationError(
                    "Get attribute request failed on target.")
            try:
                text_format.Merge(resp.result, result)
            except text_format.ParseError as e:
                logging.exception(e)
                logging.error("Paring error\n%s", resp.result)
            if result.return_type.type == CompSpecMsg_pb2.TYPE_SUBMODULE:
                logging.info("returned a submodule spec")
                logging.info("spec: %s", result.return_type_submodule_spec)
                return mirror_object.MirrorObject(
                    self, result.return_type_submodule_spec, None)
            elif result.return_type.type == CompSpecMsg_pb2.TYPE_SCALAR:
                return getattr(result.return_type.scalar_value,
                               result.return_type.scalar_type)
            return result
        logging.error("NOTICE - Likely a crash discovery!")
        logging.error("SysMsg_pb2.SUCCESS is %s", SysMsg_pb2.SUCCESS)
        raise errors.VtsTcpCommunicationError(
            "RPC Error, response code for %s is %s" % (arg, resp_code))

    def ExecuteShellCommand(self, command, no_except=False):
        """RPC to VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND.

        Args:
            command: string or list of string, command to execute on device
            no_except: bool, whether to throw exceptions. If set to True,
                       when exception happens, return code will be -1 and
                       str(err) will be in stderr. Result will maintain the
                       same length as with input command.

        Returns:
            dictionary of list, command results that contains stdout,
            stderr, and exit_code.
        """
        if not no_except:
            return self.__ExecuteShellCommand(command)

        try:
            return self.__ExecuteShellCommand(command)
        except Exception as e:
            logging.exception(e)
            return {
                const.STDOUT: [""] * len(command),
                const.STDERR: [str(e)] * len(command),
                const.EXIT_CODE: [-1] * len(command)
            }

    def __ExecuteShellCommand(self, command):
        """RPC to VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND.

        Args:
            command: string or list of string, command to execute on device

        Returns:
            dictionary of list, command results that contains stdout,
            stderr, and exit_code.
        """
        self.SendCommand(
            SysMsg_pb2.VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND,
            shell_command=command)
        resp = self.RecvResponse(retries=2)
        logging.debug("resp for VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND: %s",
                      resp)

        stdout = []
        stderr = []
        exit_code = -1

        if not resp:
            msg = "Framework error: TCP client did not receive response from device."
            logging.error(msg)
            stderr = [msg]
        elif resp.response_code != SysMsg_pb2.SUCCESS:
            msg = "Framework error: TCP client received unsuccessful response code."
            logging.error(msg)
            stderr = [msg]
        else:
            stdout = resp.stdout
            stderr = resp.stderr
            exit_code = resp.exit_code

        return {
            const.STDOUT: stdout,
            const.STDERR: stderr,
            const.EXIT_CODE: exit_code
        }

    def Ping(self):
        """RPC to send a PING request.

        Returns:
            True if the agent is alive, False otherwise.
        """
        self.SendCommand(SysMsg_pb2.PING)
        resp = self.RecvResponse()
        logging.info("resp for PING: %s", resp)
        if resp is not None and resp.response_code == SysMsg_pb2.SUCCESS:
            return True
        return False

    def ReadSpecification(self,
                          interface_name,
                          target_class,
                          target_type,
                          target_version,
                          target_package,
                          recursive=False):
        """RPC to VTS_AGENT_COMMAND_READ_SPECIFICATION.

        Args:
            other args: see SendCommand
            recursive: boolean, set to recursively read the imported
                       specification(s) and return the merged one.
        """
        self.SendCommand(
            SysMsg_pb2.VTS_AGENT_COMMAND_READ_SPECIFICATION,
            service_name=interface_name,
            target_class=target_class,
            target_type=target_type,
            target_version=target_version,
            target_package=target_package)
        resp = self.RecvResponse(retries=2)
        logging.info("resp for VTS_AGENT_COMMAND_EXECUTE_READ_INTERFACE: %s",
                     resp)
        logging.info("proto: %s", resp.result)
        result = CompSpecMsg_pb2.ComponentSpecificationMessage()
        if resp.result == "error":
            raise errors.VtsTcpCommunicationError(
                "API call error by the VTS driver.")
        try:
            text_format.Merge(resp.result, result)
        except text_format.ParseError as e:
            logging.exception(e)
            logging.error("Paring error\n%s", resp.result)

        if recursive and hasattr(result, "import"):
            for imported_interface in getattr(result, "import"):
                if imported_interface == "android.hidl.base@1.0::types":
                    logging.warn("import android.hidl.base@1.0::types skipped")
                    continue
                imported_result = self.ReadSpecification(
                    imported_interface.split("::")[1],
                    # TODO(yim): derive target_class and
                    # target_type from package path or remove them
                    msg.component_class
                    if target_class is None else target_class,
                    msg.component_type if target_type is None else target_type,
                    float(imported_interface.split("@")[1].split("::")[0]),
                    imported_interface.split("@")[0])
                result.MergeFrom(imported_result)

        return result

    def SendCommand(self,
                    command_type,
                    paths=None,
                    file_path=None,
                    bits=None,
                    target_class=None,
                    target_type=None,
                    target_version=None,
                    target_package=None,
                    target_component_name=None,
                    hw_binder_service_name=None,
                    module_name=None,
                    service_name=None,
                    callback_port=None,
                    driver_type=None,
                    shell_command=None,
                    caller_uid=None,
                    arg=None):
        """Sends a command.

        Args:
            command_type: integer, the command type.
            each of the other args are to fill in a field in
            AndroidSystemControlCommandMessage.
        """
        if not self.channel:
            raise errors.VtsTcpCommunicationError(
                "channel is None, unable to send command.")

        command_msg = SysMsg_pb2.AndroidSystemControlCommandMessage()
        command_msg.command_type = command_type
        logging.info("sending a command (type %s)",
                     COMMAND_TYPE_NAME[command_type])
        if command_type == 202:
            logging.info("target API: %s", arg)

        if target_class is not None:
            command_msg.target_class = target_class

        if target_type is not None:
            command_msg.target_type = target_type

        if target_version is not None:
            command_msg.target_version = int(target_version * 100)

        if target_package is not None:
            command_msg.target_package = target_package

        if target_component_name is not None:
            command_msg.target_component_name = target_component_name

        if hw_binder_service_name is not None:
            command_msg.hw_binder_service_name = hw_binder_service_name

        if module_name is not None:
            command_msg.module_name = module_name

        if service_name is not None:
            command_msg.service_name = service_name

        if driver_type is not None:
            command_msg.driver_type = driver_type

        if paths is not None:
            command_msg.paths.extend(paths)

        if file_path is not None:
            command_msg.file_path = file_path

        if bits is not None:
            command_msg.bits = bits

        if callback_port is not None:
            command_msg.callback_port = callback_port

        if caller_uid is not None:
            command_msg.driver_caller_uid = caller_uid

        if arg is not None:
            command_msg.arg = arg

        if shell_command is not None:
            if isinstance(shell_command, types.ListType):
                command_msg.shell_command.extend(shell_command)
            else:
                command_msg.shell_command.append(shell_command)

        logging.info("command %s" % command_msg)
        message = command_msg.SerializeToString()
        message_len = len(message)
        logging.debug("sending %d bytes", message_len)
        self.channel.write(str(message_len) + b'\n')
        self.channel.write(message)
        self.channel.flush()

    def RecvResponse(self, retries=0):
        """Receives and parses the response, and returns the relevant ResponseMessage.

        Args:
            retries: an integer indicating the max number of retries in case of
                     session timeout error.
        """
        for index in xrange(1 + retries):
            try:
                if index != 0:
                    logging.info("retrying...")
                header = self.channel.readline().strip("\n")
                length = int(header) if header else 0
                logging.info("resp %d bytes", length)
                data = self.channel.read(length)
                response_msg = SysMsg_pb2.AndroidSystemControlResponseMessage()
                response_msg.ParseFromString(data)
                logging.debug("Response %s", "success" if
                              response_msg.response_code == SysMsg_pb2.SUCCESS
                              else "fail")
                return response_msg
            except socket.timeout as e:
                logging.exception(e)
        return None