普通文本  |  422行  |  12.87 KB

#!/usr/bin/env python
# Copyright 2013 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Tests exercising the various classes in xmppserver.py."""

import unittest

import base64
import xmppserver

class XmlUtilsTest(unittest.TestCase):

  def testParseXml(self):
    xml_text = """<foo xmlns=""><bar xmlns=""><baz/></bar></foo>"""
    xml = xmppserver.ParseXml(xml_text)
    self.assertEqual(xml.toxml(), xml_text)

  def testCloneXml(self):
    xml = xmppserver.ParseXml('<foo/>')
    xml_clone = xmppserver.CloneXml(xml)
    xml_clone.setAttribute('bar', 'baz')
    self.assertEqual(xml, xml)
    self.assertEqual(xml_clone, xml_clone)
    self.assertNotEqual(xml, xml_clone)

  def testCloneXmlUnlink(self):
    xml_text = '<foo/>'
    xml = xmppserver.ParseXml(xml_text)
    xml_clone = xmppserver.CloneXml(xml)
    xml.unlink()
    self.assertEqual(xml.parentNode, None)
    self.assertNotEqual(xml_clone.parentNode, None)
    self.assertEqual(xml_clone.toxml(), xml_text)

class StanzaParserTest(unittest.TestCase):

  def setUp(self):
    self.stanzas = []

  def FeedStanza(self, stanza):
    # We can't append stanza directly because it is unlinked after
    # this callback.
    self.stanzas.append(stanza.toxml())

  def testBasic(self):
    parser = xmppserver.StanzaParser(self)
    parser.FeedString('<foo')
    self.assertEqual(len(self.stanzas), 0)
    parser.FeedString('/><bar></bar>')
    self.assertEqual(self.stanzas[0], '<foo/>')
    self.assertEqual(self.stanzas[1], '<bar/>')

  def testStream(self):
    parser = xmppserver.StanzaParser(self)
    parser.FeedString('<stream')
    self.assertEqual(len(self.stanzas), 0)
    parser.FeedString(':stream foo="bar" xmlns:stream="baz">')
    self.assertEqual(self.stanzas[0],
                     '<stream:stream foo="bar" xmlns:stream="baz"/>')

  def testNested(self):
    parser = xmppserver.StanzaParser(self)
    parser.FeedString('<foo')
    self.assertEqual(len(self.stanzas), 0)
    parser.FeedString(' bar="baz"')
    parser.FeedString('><baz/><blah>meh</blah></foo>')
    self.assertEqual(self.stanzas[0],
                     '<foo bar="baz"><baz/><blah>meh</blah></foo>')


class JidTest(unittest.TestCase):

  def testBasic(self):
    jid = xmppserver.Jid('foo', 'bar.com')
    self.assertEqual(str(jid), 'foo@bar.com')

  def testResource(self):
    jid = xmppserver.Jid('foo', 'bar.com', 'resource')
    self.assertEqual(str(jid), 'foo@bar.com/resource')

  def testGetBareJid(self):
    jid = xmppserver.Jid('foo', 'bar.com', 'resource')
    self.assertEqual(str(jid.GetBareJid()), 'foo@bar.com')


class IdGeneratorTest(unittest.TestCase):

  def testBasic(self):
    id_generator = xmppserver.IdGenerator('foo')
    for i in xrange(0, 100):
      self.assertEqual('foo.%d' % i, id_generator.GetNextId())


class HandshakeTaskTest(unittest.TestCase):

  def setUp(self):
    self.Reset()

  def Reset(self):
    self.data_received = 0
    self.handshake_done = False
    self.jid = None

  def SendData(self, _):
    self.data_received += 1

  def SendStanza(self, _, unused=True):
    self.data_received += 1

  def HandshakeDone(self, jid):
    self.handshake_done = True
    self.jid = jid

  def DoHandshake(self, resource_prefix, resource, username,
                  initial_stream_domain, auth_domain, auth_stream_domain):
    self.Reset()
    handshake_task = (
      xmppserver.HandshakeTask(self, resource_prefix, True))
    stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
    stream_xml.setAttribute('to', initial_stream_domain)
    self.assertEqual(self.data_received, 0)
    handshake_task.FeedStanza(stream_xml)
    self.assertEqual(self.data_received, 2)

    if auth_domain:
      username_domain = '%s@%s' % (username, auth_domain)
    else:
      username_domain = username
    auth_string = base64.b64encode('\0%s\0bar' % username_domain)
    auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string)
    handshake_task.FeedStanza(auth_xml)
    self.assertEqual(self.data_received, 3)

    stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
    stream_xml.setAttribute('to', auth_stream_domain)
    handshake_task.FeedStanza(stream_xml)
    self.assertEqual(self.data_received, 5)

    bind_xml = xmppserver.ParseXml(
      '<iq type="set"><bind><resource>%s</resource></bind></iq>' % resource)
    handshake_task.FeedStanza(bind_xml)
    self.assertEqual(self.data_received, 6)

    self.assertFalse(self.handshake_done)

    session_xml = xmppserver.ParseXml(
      '<iq type="set"><session></session></iq>')
    handshake_task.FeedStanza(session_xml)
    self.assertEqual(self.data_received, 7)

    self.assertTrue(self.handshake_done)

    self.assertEqual(self.jid.username, username)
    self.assertEqual(self.jid.domain,
                     auth_stream_domain or auth_domain or
                     initial_stream_domain)
    self.assertEqual(self.jid.resource,
                     '%s.%s' % (resource_prefix, resource))

    handshake_task.FeedStanza('<ignored/>')
    self.assertEqual(self.data_received, 7)

  def DoHandshakeUnauthenticated(self, resource_prefix, resource, username,
                                 initial_stream_domain):
    self.Reset()
    handshake_task = (
      xmppserver.HandshakeTask(self, resource_prefix, False))
    stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
    stream_xml.setAttribute('to', initial_stream_domain)
    self.assertEqual(self.data_received, 0)
    handshake_task.FeedStanza(stream_xml)
    self.assertEqual(self.data_received, 2)

    self.assertFalse(self.handshake_done)

    auth_string = base64.b64encode('\0%s\0bar' % username)
    auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string)
    handshake_task.FeedStanza(auth_xml)
    self.assertEqual(self.data_received, 3)

    self.assertTrue(self.handshake_done)

    self.assertEqual(self.jid, None)

    handshake_task.FeedStanza('<ignored/>')
    self.assertEqual(self.data_received, 3)

  def testBasic(self):
    self.DoHandshake('resource_prefix', 'resource',
                     'foo', 'bar.com', 'baz.com', 'quux.com')

  def testDomainBehavior(self):
    self.DoHandshake('resource_prefix', 'resource',
                     'foo', 'bar.com', 'baz.com', 'quux.com')
    self.DoHandshake('resource_prefix', 'resource',
                     'foo', 'bar.com', 'baz.com', '')
    self.DoHandshake('resource_prefix', 'resource',
                     'foo', 'bar.com', '', '')
    self.DoHandshake('resource_prefix', 'resource',
                     'foo', '', '', '')

  def testBasicUnauthenticated(self):
    self.DoHandshakeUnauthenticated('resource_prefix', 'resource',
                                    'foo', 'bar.com')


class FakeSocket(object):
  """A fake socket object used for testing.
  """

  def __init__(self):
    self._sent_data = []

  def GetSentData(self):
    return self._sent_data

  # socket-like methods.
  def fileno(self):
    return 0

  def setblocking(self, int):
    pass

  def getpeername(self):
    return ('', 0)

  def send(self, data):
    self._sent_data.append(data)
    pass

  def close(self):
    pass


class XmppConnectionTest(unittest.TestCase):

  def setUp(self):
    self.connections = set()
    self.fake_socket = FakeSocket()

  # XmppConnection delegate methods.
  def OnXmppHandshakeDone(self, xmpp_connection):
    self.connections.add(xmpp_connection)

  def OnXmppConnectionClosed(self, xmpp_connection):
    self.connections.discard(xmpp_connection)

  def ForwardNotification(self, unused_xmpp_connection, notification_stanza):
    for connection in self.connections:
      connection.ForwardNotification(notification_stanza)

  def testBasic(self):
    socket_map = {}
    xmpp_connection = xmppserver.XmppConnection(
      self.fake_socket, socket_map, self, ('', 0), True)
    self.assertEqual(len(socket_map), 1)
    self.assertEqual(len(self.connections), 0)
    xmpp_connection.HandshakeDone(xmppserver.Jid('foo', 'bar'))
    self.assertEqual(len(socket_map), 1)
    self.assertEqual(len(self.connections), 1)

    sent_data = self.fake_socket.GetSentData()

    # Test subscription request.
    self.assertEqual(len(sent_data), 0)
    xmpp_connection.collect_incoming_data(
      '<iq><subscribe xmlns="google:push"></subscribe></iq>')
    self.assertEqual(len(sent_data), 1)

    # Test acks.
    xmpp_connection.collect_incoming_data('<iq type="result"/>')
    self.assertEqual(len(sent_data), 1)

    # Test notification.
    xmpp_connection.collect_incoming_data(
      '<message><push xmlns="google:push"/></message>')
    self.assertEqual(len(sent_data), 2)

    # Test unexpected stanza.
    def SendUnexpectedStanza():
      xmpp_connection.collect_incoming_data('<foo/>')
    self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza)

    # Test unexpected notifier command.
    def SendUnexpectedNotifierCommand():
      xmpp_connection.collect_incoming_data(
        '<iq><foo xmlns="google:notifier"/></iq>')
    self.assertRaises(xmppserver.UnexpectedXml,
                      SendUnexpectedNotifierCommand)

    # Test close.
    xmpp_connection.close()
    self.assertEqual(len(socket_map), 0)
    self.assertEqual(len(self.connections), 0)

  def testBasicUnauthenticated(self):
    socket_map = {}
    xmpp_connection = xmppserver.XmppConnection(
      self.fake_socket, socket_map, self, ('', 0), False)
    self.assertEqual(len(socket_map), 1)
    self.assertEqual(len(self.connections), 0)
    xmpp_connection.HandshakeDone(None)
    self.assertEqual(len(socket_map), 0)
    self.assertEqual(len(self.connections), 0)

    # Test unexpected stanza.
    def SendUnexpectedStanza():
      xmpp_connection.collect_incoming_data('<foo/>')
    self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza)

    # Test redundant close.
    xmpp_connection.close()
    self.assertEqual(len(socket_map), 0)
    self.assertEqual(len(self.connections), 0)


class FakeXmppServer(xmppserver.XmppServer):
  """A fake XMPP server object used for testing.
  """

  def __init__(self):
    self._socket_map = {}
    self._fake_sockets = set()
    self._next_jid_suffix = 1
    xmppserver.XmppServer.__init__(self, self._socket_map, ('', 0))

  def GetSocketMap(self):
    return self._socket_map

  def GetFakeSockets(self):
    return self._fake_sockets

  def AddHandshakeCompletedConnection(self):
    """Creates a new XMPP connection and completes its handshake.
    """
    xmpp_connection = self.handle_accept()
    jid = xmppserver.Jid('user%s' % self._next_jid_suffix, 'domain.com')
    self._next_jid_suffix += 1
    xmpp_connection.HandshakeDone(jid)

  # XmppServer overrides.
  def accept(self):
    fake_socket = FakeSocket()
    self._fake_sockets.add(fake_socket)
    return (fake_socket, ('', 0))

  def close(self):
    self._fake_sockets.clear()
    xmppserver.XmppServer.close(self)


class XmppServerTest(unittest.TestCase):

  def setUp(self):
    self.xmpp_server = FakeXmppServer()

  def AssertSentDataLength(self, expected_length):
    for fake_socket in self.xmpp_server.GetFakeSockets():
      self.assertEqual(len(fake_socket.GetSentData()), expected_length)

  def testBasic(self):
    socket_map = self.xmpp_server.GetSocketMap()
    self.assertEqual(len(socket_map), 1)
    self.xmpp_server.AddHandshakeCompletedConnection()
    self.assertEqual(len(socket_map), 2)
    self.xmpp_server.close()
    self.assertEqual(len(socket_map), 0)

  def testMakeNotification(self):
    notification = self.xmpp_server.MakeNotification('channel', 'data')
    expected_xml = (
      '<message>'
      '  <push channel="channel" xmlns="google:push">'
      '    <data>%s</data>'
      '  </push>'
      '</message>' % base64.b64encode('data'))
    self.assertEqual(notification.toxml(), expected_xml)

  def testSendNotification(self):
    # Add a few connections.
    for _ in xrange(0, 7):
      self.xmpp_server.AddHandshakeCompletedConnection()

    self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 7)

    self.AssertSentDataLength(0)
    self.xmpp_server.SendNotification('channel', 'data')
    self.AssertSentDataLength(1)

  def testEnableDisableNotifications(self):
    # Add a few connections.
    for _ in xrange(0, 5):
      self.xmpp_server.AddHandshakeCompletedConnection()

    self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 5)

    self.AssertSentDataLength(0)
    self.xmpp_server.SendNotification('channel', 'data')
    self.AssertSentDataLength(1)

    self.xmpp_server.EnableNotifications()
    self.xmpp_server.SendNotification('channel', 'data')
    self.AssertSentDataLength(2)

    self.xmpp_server.DisableNotifications()
    self.xmpp_server.SendNotification('channel', 'data')
    self.AssertSentDataLength(2)

    self.xmpp_server.DisableNotifications()
    self.xmpp_server.SendNotification('channel', 'data')
    self.AssertSentDataLength(2)

    self.xmpp_server.EnableNotifications()
    self.xmpp_server.SendNotification('channel', 'data')
    self.AssertSentDataLength(3)


if __name__ == '__main__':
  unittest.main()