普通文本  |  532行  |  20.3 KB

# Copyright 2015 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for the Flask utilities"""

import datetime
import json
import logging

import flask
import httplib2
import mock
import six.moves.http_client as httplib
import six.moves.urllib.parse as urlparse
import unittest2

import oauth2client
from oauth2client import client
from oauth2client import clientsecrets
from oauth2client.contrib import flask_util


__author__ = 'jonwayne@google.com (Jon Wayne Parrott)'


class Http2Mock(object):
    """Mock httplib2.Http for code exchange / refresh"""

    def __init__(self, status=httplib.OK, **kwargs):
        self.status = status
        self.content = {
            'access_token': 'foo_access_token',
            'refresh_token': 'foo_refresh_token',
            'expires_in': 3600,
            'extra': 'value',
        }
        self.content.update(kwargs)

    def request(self, token_uri, method, body, headers, *args, **kwargs):
        self.body = body
        self.headers = headers
        return (self, json.dumps(self.content).encode('utf-8'))

    def __enter__(self):
        self.httplib2_orig = httplib2.Http
        httplib2.Http = self
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        httplib2.Http = self.httplib2_orig

    def __call__(self, *args, **kwargs):
        return self


class FlaskOAuth2Tests(unittest2.TestCase):

    def setUp(self):
        self.app = flask.Flask(__name__)
        self.app.testing = True
        self.app.config['SECRET_KEY'] = 'notasecert'
        self.app.logger.setLevel(logging.CRITICAL)
        self.oauth2 = flask_util.UserOAuth2(
            self.app,
            client_id='client_idz',
            client_secret='client_secretz')

    def _generate_credentials(self, scopes=None):
        return client.OAuth2Credentials(
            'access_tokenz',
            'client_idz',
            'client_secretz',
            'refresh_tokenz',
            datetime.datetime.utcnow() + datetime.timedelta(seconds=3600),
            oauth2client.GOOGLE_TOKEN_URI,
            'Test',
            id_token={
                'sub': '123',
                'email': 'user@example.com'
            },
            scopes=scopes)

    def test_explicit_configuration(self):
        oauth2 = flask_util.UserOAuth2(
            flask.Flask(__name__), client_id='id', client_secret='secret')

        self.assertEqual(oauth2.client_id, 'id')
        self.assertEqual(oauth2.client_secret, 'secret')

        return_val = (
            clientsecrets.TYPE_WEB,
            {'client_id': 'id', 'client_secret': 'secret'})

        with mock.patch('oauth2client.clientsecrets.loadfile',
                        return_value=return_val):

            oauth2 = flask_util.UserOAuth2(
                flask.Flask(__name__), client_secrets_file='file.json')

            self.assertEqual(oauth2.client_id, 'id')
            self.assertEqual(oauth2.client_secret, 'secret')

    def test_delayed_configuration(self):
        app = flask.Flask(__name__)
        oauth2 = flask_util.UserOAuth2()
        oauth2.init_app(app, client_id='id', client_secret='secret')
        self.assertEqual(oauth2.app, app)

    def test_explicit_storage(self):
        storage_mock = mock.Mock()
        oauth2 = flask_util.UserOAuth2(
            flask.Flask(__name__), storage=storage_mock, client_id='id',
            client_secret='secret')
        self.assertEqual(oauth2.storage, storage_mock)

    def test_explicit_scopes(self):
        oauth2 = flask_util.UserOAuth2(
            flask.Flask(__name__), scopes=['1', '2'], client_id='id',
            client_secret='secret')
        self.assertEqual(oauth2.scopes, ['1', '2'])

    def test_bad_client_secrets(self):
        return_val = (
            'other',
            {'client_id': 'id', 'client_secret': 'secret'})

        with mock.patch('oauth2client.clientsecrets.loadfile',
                        return_value=return_val):
            with self.assertRaises(ValueError):
                flask_util.UserOAuth2(flask.Flask(__name__),
                                      client_secrets_file='file.json')

    def test_app_configuration(self):
        app = flask.Flask(__name__)
        app.config['GOOGLE_OAUTH2_CLIENT_ID'] = 'id'
        app.config['GOOGLE_OAUTH2_CLIENT_SECRET'] = 'secret'

        oauth2 = flask_util.UserOAuth2(app)

        self.assertEqual(oauth2.client_id, 'id')
        self.assertEqual(oauth2.client_secret, 'secret')

        return_val = (
            clientsecrets.TYPE_WEB,
            {'client_id': 'id2', 'client_secret': 'secret2'})

        with mock.patch('oauth2client.clientsecrets.loadfile',
                        return_value=return_val):

            app = flask.Flask(__name__)
            app.config['GOOGLE_OAUTH2_CLIENT_SECRETS_FILE'] = 'file.json'
            oauth2 = flask_util.UserOAuth2(app)

            self.assertEqual(oauth2.client_id, 'id2')
            self.assertEqual(oauth2.client_secret, 'secret2')

    def test_no_configuration(self):
        with self.assertRaises(ValueError):
            flask_util.UserOAuth2(flask.Flask(__name__))

    def test_create_flow(self):
        with self.app.test_request_context():
            flow = self.oauth2._make_flow()
            state = json.loads(flow.params['state'])
            self.assertIn('google_oauth2_csrf_token', flask.session)
            self.assertEqual(
                flask.session['google_oauth2_csrf_token'], state['csrf_token'])
            self.assertEqual(flow.client_id, self.oauth2.client_id)
            self.assertEqual(flow.client_secret, self.oauth2.client_secret)
            self.assertIn('http', flow.redirect_uri)
            self.assertIn('oauth2callback', flow.redirect_uri)

            flow = self.oauth2._make_flow(return_url='/return_url')
            state = json.loads(flow.params['state'])
            self.assertEqual(state['return_url'], '/return_url')

            flow = self.oauth2._make_flow(extra_arg='test')
            self.assertEqual(flow.params['extra_arg'], 'test')

        # Test extra args specified in the constructor.
        app = flask.Flask(__name__)
        app.config['SECRET_KEY'] = 'notasecert'
        oauth2 = flask_util.UserOAuth2(
            app, client_id='client_id', client_secret='secret',
            extra_arg='test')

        with app.test_request_context():
            flow = oauth2._make_flow()
            self.assertEqual(flow.params['extra_arg'], 'test')

    def test_authorize_view(self):
        with self.app.test_client() as client:
            response = client.get('/oauth2authorize')
            location = response.headers['Location']
            q = urlparse.parse_qs(location.split('?', 1)[1])
            state = json.loads(q['state'][0])

            self.assertIn(oauth2client.GOOGLE_AUTH_URI, location)
            self.assertNotIn(self.oauth2.client_secret, location)
            self.assertIn(self.oauth2.client_id, q['client_id'])
            self.assertEqual(
                flask.session['google_oauth2_csrf_token'], state['csrf_token'])
            self.assertEqual(state['return_url'], '/')

        with self.app.test_client() as client:
            response = client.get('/oauth2authorize?return_url=/test')
            location = response.headers['Location']
            q = urlparse.parse_qs(location.split('?', 1)[1])
            state = json.loads(q['state'][0])
            self.assertEqual(state['return_url'], '/test')

        with self.app.test_client() as client:
            response = client.get('/oauth2authorize?extra_param=test')
            location = response.headers['Location']
            self.assertIn('extra_param=test', location)

    def _setup_callback_state(self, client, **kwargs):
        with self.app.test_request_context():
            # Flask doesn't create a request context with a session
            # transaction for some reason, so, set up the flow here,
            # then apply it to the session in the transaction.
            if not kwargs:
                self.oauth2._make_flow(return_url='/return_url')
            else:
                self.oauth2._make_flow(**kwargs)

            with client.session_transaction() as session:
                session.update(flask.session)
                csrf_token = session['google_oauth2_csrf_token']
                flow = flask_util._get_flow_for_token(csrf_token)
                state = flow.params['state']

        return state

    def test_callback_view(self):
        self.oauth2.storage = mock.Mock()
        with self.app.test_client() as client:
            with Http2Mock() as http:
                state = self._setup_callback_state(client)

                response = client.get(
                    '/oauth2callback?state={0}&code=codez'.format(state))

                self.assertEqual(response.status_code, httplib.FOUND)
                self.assertIn('/return_url', response.headers['Location'])
                self.assertIn(self.oauth2.client_secret, http.body)
                self.assertIn('codez', http.body)
                self.assertTrue(self.oauth2.storage.put.called)

    def test_authorize_callback(self):
        self.oauth2.authorize_callback = mock.Mock()
        self.test_callback_view()
        self.assertTrue(self.oauth2.authorize_callback.called)

    def test_callback_view_errors(self):
        # Error supplied to callback
        with self.app.test_client() as client:
            with client.session_transaction() as session:
                session['google_oauth2_csrf_token'] = 'tokenz'

            response = client.get('/oauth2callback?state={}&error=something')
            self.assertEqual(response.status_code, httplib.BAD_REQUEST)
            self.assertIn('something', response.data.decode('utf-8'))

        # CSRF mismatch
        with self.app.test_client() as client:
            with client.session_transaction() as session:
                session['google_oauth2_csrf_token'] = 'goodstate'

            state = json.dumps({
                'csrf_token': 'badstate',
                'return_url': '/return_url'
            })

            response = client.get(
                '/oauth2callback?state={0}&code=codez'.format(state))
            self.assertEqual(response.status_code, httplib.BAD_REQUEST)

        # KeyError, no CSRF state.
        with self.app.test_client() as client:
            response = client.get('/oauth2callback?state={}&code=codez')
            self.assertEqual(response.status_code, httplib.BAD_REQUEST)

        # Code exchange error
        with self.app.test_client() as client:
            state = self._setup_callback_state(client)

            with Http2Mock(status=httplib.INTERNAL_SERVER_ERROR):
                response = client.get(
                    '/oauth2callback?state={0}&code=codez'.format(state))
                self.assertEqual(response.status_code, httplib.BAD_REQUEST)

        # Invalid state json
        with self.app.test_client() as client:
            with client.session_transaction() as session:
                session['google_oauth2_csrf_token'] = 'tokenz'

            state = '[{'
            response = client.get(
                '/oauth2callback?state={0}&code=codez'.format(state))
            self.assertEqual(response.status_code, httplib.BAD_REQUEST)

        # Missing flow.
        with self.app.test_client() as client:
            with client.session_transaction() as session:
                session['google_oauth2_csrf_token'] = 'tokenz'

            state = json.dumps({
                'csrf_token': 'tokenz',
                'return_url': '/return_url'
            })

            response = client.get(
                '/oauth2callback?state={0}&code=codez'.format(state))
            self.assertEqual(response.status_code, httplib.BAD_REQUEST)

    def test_no_credentials(self):
        with self.app.test_request_context():
            self.assertFalse(self.oauth2.has_credentials())
            self.assertTrue(self.oauth2.credentials is None)
            self.assertTrue(self.oauth2.user_id is None)
            self.assertTrue(self.oauth2.email is None)
            with self.assertRaises(ValueError):
                self.oauth2.http()
            self.assertFalse(self.oauth2.storage.get())
            self.oauth2.storage.delete()

    def test_with_credentials(self):
        credentials = self._generate_credentials()
        with self.app.test_request_context():
            self.oauth2.storage.put(credentials)
            self.assertEqual(
                self.oauth2.credentials.access_token, credentials.access_token)
            self.assertEqual(
                self.oauth2.credentials.refresh_token,
                credentials.refresh_token)
            self.assertEqual(self.oauth2.user_id, '123')
            self.assertEqual(self.oauth2.email, 'user@example.com')
            self.assertTrue(self.oauth2.http())

    @mock.patch('oauth2client.client._UTCNOW')
    def test_with_expired_credentials(self, utcnow):
        utcnow.return_value = datetime.datetime(1990, 5, 29)

        credentials = self._generate_credentials()
        credentials.token_expiry = datetime.datetime(1990, 5, 28)

        # Has a refresh token, so this should be fine.
        with self.app.test_request_context():
            self.oauth2.storage.put(credentials)
            self.assertTrue(self.oauth2.has_credentials())

        # Without a refresh token this should return false.
        credentials.refresh_token = None
        with self.app.test_request_context():
            self.oauth2.storage.put(credentials)
            self.assertFalse(self.oauth2.has_credentials())

    def test_bad_id_token(self):
        credentials = self._generate_credentials()
        credentials.id_token = {}
        with self.app.test_request_context():
            self.oauth2.storage.put(credentials)
            self.assertTrue(self.oauth2.user_id is None)
            self.assertTrue(self.oauth2.email is None)

    def test_required(self):
        @self.app.route('/protected')
        @self.oauth2.required
        def index():
            return 'Hello'

        # No credentials, should redirect
        with self.app.test_client() as client:
            response = client.get('/protected')
            self.assertEqual(response.status_code, httplib.FOUND)
            self.assertIn('oauth2authorize', response.headers['Location'])
            self.assertIn('protected', response.headers['Location'])

        credentials = self._generate_credentials(scopes=self.oauth2.scopes)

        # With credentials, should allow
        with self.app.test_client() as client:
            with client.session_transaction() as session:
                session['google_oauth2_credentials'] = credentials.to_json()

            response = client.get('/protected')
            self.assertEqual(response.status_code, httplib.OK)
            self.assertIn('Hello', response.data.decode('utf-8'))

        # Expired credentials with refresh token, should allow.
        credentials.token_expiry = datetime.datetime(1990, 5, 28)
        with mock.patch('oauth2client.client._UTCNOW') as utcnow:
            utcnow.return_value = datetime.datetime(1990, 5, 29)

            with self.app.test_client() as client:
                with client.session_transaction() as session:
                    session['google_oauth2_credentials'] = (
                        credentials.to_json())

                response = client.get('/protected')
                self.assertEqual(response.status_code, httplib.OK)
                self.assertIn('Hello', response.data.decode('utf-8'))

        # Expired credentials without a refresh token, should redirect.
        credentials.refresh_token = None
        with mock.patch('oauth2client.client._UTCNOW') as utcnow:
            utcnow.return_value = datetime.datetime(1990, 5, 29)

            with self.app.test_client() as client:
                with client.session_transaction() as session:
                    session['google_oauth2_credentials'] = (
                        credentials.to_json())

                response = client.get('/protected')
            self.assertEqual(response.status_code, httplib.FOUND)
            self.assertIn('oauth2authorize', response.headers['Location'])
            self.assertIn('protected', response.headers['Location'])

    def _create_incremental_auth_app(self):
        self.app = flask.Flask(__name__)
        self.app.testing = True
        self.app.config['SECRET_KEY'] = 'notasecert'
        self.oauth2 = flask_util.UserOAuth2(
            self.app,
            client_id='client_idz',
            client_secret='client_secretz',
            include_granted_scopes=True)

        @self.app.route('/one')
        @self.oauth2.required(scopes=['one'])
        def one():
            return 'Hello'

        @self.app.route('/two')
        @self.oauth2.required(scopes=['two', 'three'])
        def two():
            return 'Hello'

    def test_incremental_auth(self):
        self._create_incremental_auth_app()

        # No credentials, should redirect
        with self.app.test_client() as client:
            response = client.get('/one')
            self.assertIn('one', response.headers['Location'])
            self.assertEqual(response.status_code, httplib.FOUND)

        # Credentials for one. /one should allow, /two should redirect.
        credentials = self._generate_credentials(scopes=['email', 'one'])

        with self.app.test_client() as client:
            with client.session_transaction() as session:
                session['google_oauth2_credentials'] = credentials.to_json()

            response = client.get('/one')
            self.assertEqual(response.status_code, httplib.OK)

            response = client.get('/two')
            self.assertIn('two', response.headers['Location'])
            self.assertEqual(response.status_code, httplib.FOUND)

            # Starting the authorization flow should include the
            # include_granted_scopes parameter as well as the scopes.
            response = client.get(response.headers['Location'][17:])
            q = urlparse.parse_qs(
                response.headers['Location'].split('?', 1)[1])
            self.assertIn('include_granted_scopes', q)
            self.assertEqual(
                set(q['scope'][0].split(' ')),
                set(['one', 'email', 'two', 'three']))

        # Actually call two() without a redirect.
        credentials2 = self._generate_credentials(
            scopes=['email', 'two', 'three'])

        with self.app.test_client() as client:
            with client.session_transaction() as session:
                session['google_oauth2_credentials'] = credentials2.to_json()

            response = client.get('/two')
            self.assertEqual(response.status_code, httplib.OK)

    def test_incremental_auth_exchange(self):
        self._create_incremental_auth_app()

        with Http2Mock():
            with self.app.test_client() as client:
                state = self._setup_callback_state(
                    client,
                    return_url='/return_url',
                    # Incremental auth scopes.
                    scopes=['one', 'two'])

                response = client.get(
                    '/oauth2callback?state={0}&code=codez'.format(state))
                self.assertEqual(response.status_code, httplib.FOUND)

                credentials = self.oauth2.credentials
                self.assertTrue(
                    credentials.has_scopes(['email', 'one', 'two']))

    def test_refresh(self):
        with self.app.test_request_context():
            with mock.patch('flask.session'):
                self.oauth2.storage.put(self._generate_credentials())

                self.oauth2.credentials.refresh(
                    Http2Mock(access_token='new_token'))

                self.assertEqual(
                    self.oauth2.storage.get().access_token, 'new_token')

    def test_delete(self):
        with self.app.test_request_context():

            self.oauth2.storage.put(self._generate_credentials())
            self.oauth2.storage.delete()

            self.assertNotIn('google_oauth2_credentials', flask.session)