# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementations of interoperability test methods."""

import enum
import json
import os
import threading

from google import auth as google_auth
from google.auth import environment_vars as google_auth_environment_vars
from google.auth.transport import grpc as google_auth_transport_grpc
from google.auth.transport import requests as google_auth_transport_requests
import grpc
from grpc.beta import implementations

from src.proto.grpc.testing import empty_pb2
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc

_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"
_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"


def _maybe_echo_metadata(servicer_context):
    """Copies metadata from request to response if it is present."""
    invocation_metadata = dict(servicer_context.invocation_metadata())
    if _INITIAL_METADATA_KEY in invocation_metadata:
        initial_metadatum = (_INITIAL_METADATA_KEY,
                             invocation_metadata[_INITIAL_METADATA_KEY])
        servicer_context.send_initial_metadata((initial_metadatum,))
    if _TRAILING_METADATA_KEY in invocation_metadata:
        trailing_metadatum = (_TRAILING_METADATA_KEY,
                              invocation_metadata[_TRAILING_METADATA_KEY])
        servicer_context.set_trailing_metadata((trailing_metadatum,))


def _maybe_echo_status_and_message(request, servicer_context):
    """Sets the response context code and details if the request asks for them"""
    if request.HasField('response_status'):
        servicer_context.set_code(request.response_status.code)
        servicer_context.set_details(request.response_status.message)


class TestService(test_pb2_grpc.TestServiceServicer):

    def EmptyCall(self, request, context):
        _maybe_echo_metadata(context)
        return empty_pb2.Empty()

    def UnaryCall(self, request, context):
        _maybe_echo_metadata(context)
        _maybe_echo_status_and_message(request, context)
        return messages_pb2.SimpleResponse(
            payload=messages_pb2.Payload(
                type=messages_pb2.COMPRESSABLE,
                body=b'\x00' * request.response_size))

    def StreamingOutputCall(self, request, context):
        _maybe_echo_status_and_message(request, context)
        for response_parameters in request.response_parameters:
            yield messages_pb2.StreamingOutputCallResponse(
                payload=messages_pb2.Payload(
                    type=request.response_type,
                    body=b'\x00' * response_parameters.size))

    def StreamingInputCall(self, request_iterator, context):
        aggregate_size = 0
        for request in request_iterator:
            if request.payload is not None and request.payload.body:
                aggregate_size += len(request.payload.body)
        return messages_pb2.StreamingInputCallResponse(
            aggregated_payload_size=aggregate_size)

    def FullDuplexCall(self, request_iterator, context):
        _maybe_echo_metadata(context)
        for request in request_iterator:
            _maybe_echo_status_and_message(request, context)
            for response_parameters in request.response_parameters:
                yield messages_pb2.StreamingOutputCallResponse(
                    payload=messages_pb2.Payload(
                        type=request.payload.type,
                        body=b'\x00' * response_parameters.size))

    # NOTE(nathaniel): Apparently this is the same as the full-duplex call?
    # NOTE(atash): It isn't even called in the interop spec (Oct 22 2015)...
    def HalfDuplexCall(self, request_iterator, context):
        return self.FullDuplexCall(request_iterator, context)


def _expect_status_code(call, expected_code):
    if call.code() != expected_code:
        raise ValueError('expected code %s, got %s' % (expected_code,
                                                       call.code()))


def _expect_status_details(call, expected_details):
    if call.details() != expected_details:
        raise ValueError('expected message %s, got %s' % (expected_details,
                                                          call.details()))


def _validate_status_code_and_details(call, expected_code, expected_details):
    _expect_status_code(call, expected_code)
    _expect_status_details(call, expected_details)


def _validate_payload_type_and_length(response, expected_type, expected_length):
    if response.payload.type is not expected_type:
        raise ValueError('expected payload type %s, got %s' %
                         (expected_type, type(response.payload.type)))
    elif len(response.payload.body) != expected_length:
        raise ValueError('expected payload body size %d, got %d' %
                         (expected_length, len(response.payload.body)))


def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope,
                                 call_credentials):
    size = 314159
    request = messages_pb2.SimpleRequest(
        response_type=messages_pb2.COMPRESSABLE,
        response_size=size,
        payload=messages_pb2.Payload(body=b'\x00' * 271828),
        fill_username=fill_username,
        fill_oauth_scope=fill_oauth_scope)
    response_future = stub.UnaryCall.future(
        request, credentials=call_credentials)
    response = response_future.result()
    _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
    return response


def _empty_unary(stub):
    response = stub.EmptyCall(empty_pb2.Empty())
    if not isinstance(response, empty_pb2.Empty):
        raise TypeError(
            'response is of type "%s", not empty_pb2.Empty!' % type(response))


def _large_unary(stub):
    _large_unary_common_behavior(stub, False, False, None)


def _client_streaming(stub):
    payload_body_sizes = (
        27182,
        8,
        1828,
        45904,
    )
    payloads = (messages_pb2.Payload(body=b'\x00' * size)
                for size in payload_body_sizes)
    requests = (messages_pb2.StreamingInputCallRequest(payload=payload)
                for payload in payloads)
    response = stub.StreamingInputCall(requests)
    if response.aggregated_payload_size != 74922:
        raise ValueError(
            'incorrect size %d!' % response.aggregated_payload_size)


def _server_streaming(stub):
    sizes = (
        31415,
        9,
        2653,
        58979,
    )

    request = messages_pb2.StreamingOutputCallRequest(
        response_type=messages_pb2.COMPRESSABLE,
        response_parameters=(
            messages_pb2.ResponseParameters(size=sizes[0]),
            messages_pb2.ResponseParameters(size=sizes[1]),
            messages_pb2.ResponseParameters(size=sizes[2]),
            messages_pb2.ResponseParameters(size=sizes[3]),
        ))
    response_iterator = stub.StreamingOutputCall(request)
    for index, response in enumerate(response_iterator):
        _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
                                          sizes[index])


class _Pipe(object):

    def __init__(self):
        self._condition = threading.Condition()
        self._values = []
        self._open = True

    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    def next(self):
        with self._condition:
            while not self._values and self._open:
                self._condition.wait()
            if self._values:
                return self._values.pop(0)
            else:
                raise StopIteration()

    def add(self, value):
        with self._condition:
            self._values.append(value)
            self._condition.notify()

    def close(self):
        with self._condition:
            self._open = False
            self._condition.notify()

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()


def _ping_pong(stub):
    request_response_sizes = (
        31415,
        9,
        2653,
        58979,
    )
    request_payload_sizes = (
        27182,
        8,
        1828,
        45904,
    )

    with _Pipe() as pipe:
        response_iterator = stub.FullDuplexCall(pipe)
        for response_size, payload_size in zip(request_response_sizes,
                                               request_payload_sizes):
            request = messages_pb2.StreamingOutputCallRequest(
                response_type=messages_pb2.COMPRESSABLE,
                response_parameters=(
                    messages_pb2.ResponseParameters(size=response_size),),
                payload=messages_pb2.Payload(body=b'\x00' * payload_size))
            pipe.add(request)
            response = next(response_iterator)
            _validate_payload_type_and_length(
                response, messages_pb2.COMPRESSABLE, response_size)


def _cancel_after_begin(stub):
    with _Pipe() as pipe:
        response_future = stub.StreamingInputCall.future(pipe)
        response_future.cancel()
        if not response_future.cancelled():
            raise ValueError('expected cancelled method to return True')
        if response_future.code() is not grpc.StatusCode.CANCELLED:
            raise ValueError('expected status code CANCELLED')


def _cancel_after_first_response(stub):
    request_response_sizes = (
        31415,
        9,
        2653,
        58979,
    )
    request_payload_sizes = (
        27182,
        8,
        1828,
        45904,
    )
    with _Pipe() as pipe:
        response_iterator = stub.FullDuplexCall(pipe)

        response_size = request_response_sizes[0]
        payload_size = request_payload_sizes[0]
        request = messages_pb2.StreamingOutputCallRequest(
            response_type=messages_pb2.COMPRESSABLE,
            response_parameters=(
                messages_pb2.ResponseParameters(size=response_size),),
            payload=messages_pb2.Payload(body=b'\x00' * payload_size))
        pipe.add(request)
        response = next(response_iterator)
        # We test the contents of `response` in the Ping Pong test - don't check
        # them here.
        response_iterator.cancel()

        try:
            next(response_iterator)
        except grpc.RpcError as rpc_error:
            if rpc_error.code() is not grpc.StatusCode.CANCELLED:
                raise
        else:
            raise ValueError('expected call to be cancelled')


def _timeout_on_sleeping_server(stub):
    request_payload_size = 27182
    with _Pipe() as pipe:
        response_iterator = stub.FullDuplexCall(pipe, timeout=0.001)

        request = messages_pb2.StreamingOutputCallRequest(
            response_type=messages_pb2.COMPRESSABLE,
            payload=messages_pb2.Payload(body=b'\x00' * request_payload_size))
        pipe.add(request)
        try:
            next(response_iterator)
        except grpc.RpcError as rpc_error:
            if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED:
                raise
        else:
            raise ValueError('expected call to exceed deadline')


def _empty_stream(stub):
    with _Pipe() as pipe:
        response_iterator = stub.FullDuplexCall(pipe)
        pipe.close()
        try:
            next(response_iterator)
            raise ValueError('expected exactly 0 responses')
        except StopIteration:
            pass


def _status_code_and_message(stub):
    details = 'test status message'
    code = 2
    status = grpc.StatusCode.UNKNOWN  # code = 2

    # Test with a UnaryCall
    request = messages_pb2.SimpleRequest(
        response_type=messages_pb2.COMPRESSABLE,
        response_size=1,
        payload=messages_pb2.Payload(body=b'\x00'),
        response_status=messages_pb2.EchoStatus(code=code, message=details))
    response_future = stub.UnaryCall.future(request)
    _validate_status_code_and_details(response_future, status, details)

    # Test with a FullDuplexCall
    with _Pipe() as pipe:
        response_iterator = stub.FullDuplexCall(pipe)
        request = messages_pb2.StreamingOutputCallRequest(
            response_type=messages_pb2.COMPRESSABLE,
            response_parameters=(messages_pb2.ResponseParameters(size=1),),
            payload=messages_pb2.Payload(body=b'\x00'),
            response_status=messages_pb2.EchoStatus(code=code, message=details))
        pipe.add(request)  # sends the initial request.
    # Dropping out of with block closes the pipe
    _validate_status_code_and_details(response_iterator, status, details)


def _unimplemented_method(test_service_stub):
    response_future = (test_service_stub.UnimplementedCall.future(
        empty_pb2.Empty()))
    _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)


def _unimplemented_service(unimplemented_service_stub):
    response_future = (unimplemented_service_stub.UnimplementedCall.future(
        empty_pb2.Empty()))
    _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)


def _custom_metadata(stub):
    initial_metadata_value = "test_initial_metadata_value"
    trailing_metadata_value = "\x0a\x0b\x0a\x0b\x0a\x0b"
    metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value),
                (_TRAILING_METADATA_KEY, trailing_metadata_value))

    def _validate_metadata(response):
        initial_metadata = dict(response.initial_metadata())
        if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
            raise ValueError('expected initial metadata %s, got %s' %
                             (initial_metadata_value,
                              initial_metadata[_INITIAL_METADATA_KEY]))
        trailing_metadata = dict(response.trailing_metadata())
        if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
            raise ValueError('expected trailing metadata %s, got %s' %
                             (trailing_metadata_value,
                              initial_metadata[_TRAILING_METADATA_KEY]))

    # Testing with UnaryCall
    request = messages_pb2.SimpleRequest(
        response_type=messages_pb2.COMPRESSABLE,
        response_size=1,
        payload=messages_pb2.Payload(body=b'\x00'))
    response_future = stub.UnaryCall.future(request, metadata=metadata)
    _validate_metadata(response_future)

    # Testing with FullDuplexCall
    with _Pipe() as pipe:
        response_iterator = stub.FullDuplexCall(pipe, metadata=metadata)
        request = messages_pb2.StreamingOutputCallRequest(
            response_type=messages_pb2.COMPRESSABLE,
            response_parameters=(messages_pb2.ResponseParameters(size=1),))
        pipe.add(request)  # Sends the request
        next(response_iterator)  # Causes server to send trailing metadata
    # Dropping out of the with block closes the pipe
    _validate_metadata(response_iterator)


def _compute_engine_creds(stub, args):
    response = _large_unary_common_behavior(stub, True, True, None)
    if args.default_service_account != response.username:
        raise ValueError('expected username %s, got %s' %
                         (args.default_service_account, response.username))


def _oauth2_auth_token(stub, args):
    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
    wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
    response = _large_unary_common_behavior(stub, True, True, None)
    if wanted_email != response.username:
        raise ValueError('expected username %s, got %s' % (wanted_email,
                                                           response.username))
    if args.oauth_scope.find(response.oauth_scope) == -1:
        raise ValueError(
            'expected to find oauth scope "{}" in received "{}"'.format(
                response.oauth_scope, args.oauth_scope))


def _jwt_token_creds(stub, args):
    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
    wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
    response = _large_unary_common_behavior(stub, True, False, None)
    if wanted_email != response.username:
        raise ValueError('expected username %s, got %s' % (wanted_email,
                                                           response.username))


def _per_rpc_creds(stub, args):
    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
    wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
    google_credentials, unused_project_id = google_auth.default(
        scopes=[args.oauth_scope])
    call_credentials = grpc.metadata_call_credentials(
        google_auth_transport_grpc.AuthMetadataPlugin(
            credentials=google_credentials,
            request=google_auth_transport_requests.Request()))
    response = _large_unary_common_behavior(stub, True, False, call_credentials)
    if wanted_email != response.username:
        raise ValueError('expected username %s, got %s' % (wanted_email,
                                                           response.username))


@enum.unique
class TestCase(enum.Enum):
    EMPTY_UNARY = 'empty_unary'
    LARGE_UNARY = 'large_unary'
    SERVER_STREAMING = 'server_streaming'
    CLIENT_STREAMING = 'client_streaming'
    PING_PONG = 'ping_pong'
    CANCEL_AFTER_BEGIN = 'cancel_after_begin'
    CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response'
    EMPTY_STREAM = 'empty_stream'
    STATUS_CODE_AND_MESSAGE = 'status_code_and_message'
    UNIMPLEMENTED_METHOD = 'unimplemented_method'
    UNIMPLEMENTED_SERVICE = 'unimplemented_service'
    CUSTOM_METADATA = "custom_metadata"
    COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
    OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
    JWT_TOKEN_CREDS = 'jwt_token_creds'
    PER_RPC_CREDS = 'per_rpc_creds'
    TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'

    def test_interoperability(self, stub, args):
        if self is TestCase.EMPTY_UNARY:
            _empty_unary(stub)
        elif self is TestCase.LARGE_UNARY:
            _large_unary(stub)
        elif self is TestCase.SERVER_STREAMING:
            _server_streaming(stub)
        elif self is TestCase.CLIENT_STREAMING:
            _client_streaming(stub)
        elif self is TestCase.PING_PONG:
            _ping_pong(stub)
        elif self is TestCase.CANCEL_AFTER_BEGIN:
            _cancel_after_begin(stub)
        elif self is TestCase.CANCEL_AFTER_FIRST_RESPONSE:
            _cancel_after_first_response(stub)
        elif self is TestCase.TIMEOUT_ON_SLEEPING_SERVER:
            _timeout_on_sleeping_server(stub)
        elif self is TestCase.EMPTY_STREAM:
            _empty_stream(stub)
        elif self is TestCase.STATUS_CODE_AND_MESSAGE:
            _status_code_and_message(stub)
        elif self is TestCase.UNIMPLEMENTED_METHOD:
            _unimplemented_method(stub)
        elif self is TestCase.UNIMPLEMENTED_SERVICE:
            _unimplemented_service(stub)
        elif self is TestCase.CUSTOM_METADATA:
            _custom_metadata(stub)
        elif self is TestCase.COMPUTE_ENGINE_CREDS:
            _compute_engine_creds(stub, args)
        elif self is TestCase.OAUTH2_AUTH_TOKEN:
            _oauth2_auth_token(stub, args)
        elif self is TestCase.JWT_TOKEN_CREDS:
            _jwt_token_creds(stub, args)
        elif self is TestCase.PER_RPC_CREDS:
            _per_rpc_creds(stub, args)
        else:
            raise NotImplementedError(
                'Test case "%s" not implemented!' % self.name)