#!/usr/bin/python # # Copyright 2015 The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import from errno import * # pylint: disable=wildcard-import import os import random import re from socket import * # pylint: disable=wildcard-import import threading import time import unittest import multinetwork_base import net_test import packets import sock_diag import tcp_test NUM_SOCKETS = 30 NO_BYTECODE = "" class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest): @staticmethod def _CreateLotsOfSockets(): # Dict mapping (addr, sport, dport) tuples to socketpairs. socketpairs = {} for _ in xrange(NUM_SOCKETS): family, addr = random.choice([ (AF_INET, "127.0.0.1"), (AF_INET6, "::1"), (AF_INET6, "::ffff:127.0.0.1")]) socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr) sport, dport = (socketpair[0].getsockname()[1], socketpair[1].getsockname()[1]) socketpairs[(addr, sport, dport)] = socketpair return socketpairs def assertSocketClosed(self, sock): self.assertRaisesErrno(ENOTCONN, sock.getpeername) def assertSocketConnected(self, sock): sock.getpeername() # No errors? Socket is alive and connected. def assertSocketsClosed(self, socketpair): for sock in socketpair: self.assertSocketClosed(sock) def setUp(self): super(SockDiagBaseTest, self).setUp() self.sock_diag = sock_diag.SockDiag() self.socketpairs = {} def tearDown(self): for socketpair in self.socketpairs.values(): for s in socketpair: s.close() super(SockDiagBaseTest, self).tearDown() class SockDiagTest(SockDiagBaseTest): def assertSockDiagMatchesSocket(self, s, diag_msg): family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) self.assertEqual(diag_msg.family, family) src, sport = s.getsockname()[0:2] self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src)) self.assertEqual(diag_msg.id.sport, sport) if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]: dst, dport = s.getpeername()[0:2] self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst)) self.assertEqual(diag_msg.id.dport, dport) else: self.assertRaisesErrno(ENOTCONN, s.getpeername) def testFindsMappedSockets(self): """Tests that inet_diag_find_one_icsk can find mapped sockets. Relevant kernel commits: android-3.10: f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() """ socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::ffff:127.0.0.1") for sock in socketpair: diag_msg = self.sock_diag.FindSockDiagFromFd(sock) diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) self.sock_diag.GetSockDiag(diag_req) # No errors? Good. def testFindsAllMySockets(self): """Tests that basic socket dumping works. Relevant commits: android-3.4: ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields android-3.10 3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields """ self.socketpairs = self._CreateLotsOfSockets() sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE) self.assertGreaterEqual(len(sockets), NUM_SOCKETS) # Find the cookies for all of our sockets. cookies = {} for diag_msg, unused_attrs in sockets: addr = self.sock_diag.GetSourceAddress(diag_msg) sport = diag_msg.id.sport dport = diag_msg.id.dport if (addr, sport, dport) in self.socketpairs: cookies[(addr, sport, dport)] = diag_msg.id.cookie elif (addr, dport, sport) in self.socketpairs: cookies[(addr, sport, dport)] = diag_msg.id.cookie # Did we find all the cookies? self.assertEquals(2 * NUM_SOCKETS, len(cookies)) socketpairs = self.socketpairs.values() random.shuffle(socketpairs) for socketpair in socketpairs: for sock in socketpair: # Check that we can find a diag_msg by scanning a dump. self.assertSockDiagMatchesSocket( sock, self.sock_diag.FindSockDiagFromFd(sock)) cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie # Check that we can find a diag_msg once we know the cookie. req = self.sock_diag.DiagReqFromSocket(sock) req.id.cookie = cookie diag_msg = self.sock_diag.GetSockDiag(req) req.states = 1 << diag_msg.state self.assertSockDiagMatchesSocket(sock, diag_msg) def testBytecodeCompilation(self): # pylint: disable=bad-whitespace instructions = [ (sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0 (sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16 (sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44 (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)), # 48 (sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64 (sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72 # 76 acc # 80 rej ] # pylint: enable=bad-whitespace bytecode = self.sock_diag.PackBytecode(instructions) expected = ( "0208500000000000" "050848000000ffff" "071c20000a800000ffffffff00000000000000000000000000000001" "01041c00" "0718200002200000ffffffff7f000001" "0508100000006566" "00040400" ) self.assertMultiLineEqual(expected, bytecode.encode("hex")) self.assertEquals(76, len(bytecode)) self.socketpairs = self._CreateLotsOfSockets() filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode) allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE) self.assertItemsEqual(allsockets, filteredsockets) # Pick a few sockets in hash table order, and check that the bytecode we # compiled selects them properly. for socketpair in self.socketpairs.values()[:20]: for s in socketpair: diag_msg = self.sock_diag.FindSockDiagFromFd(s) instructions = [ (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport), (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport), (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport), (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport), ] bytecode = self.sock_diag.PackBytecode(instructions) self.assertEquals(32, len(bytecode)) sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode) self.assertEquals(1, len(sockets)) # TODO: why doesn't comparing the cstructs work? self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack()) def testCrossFamilyBytecode(self): """Checks for a cross-family bug in inet_diag_hostcond matching. Relevant kernel commits: android-3.4: f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run() """ # TODO: this is only here because the test fails if there are any open # sockets other than the ones it creates itself. Make the bytecode more # specific and remove it. self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, "")) unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1") unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1") bytecode4 = self.sock_diag.PackBytecode([ (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))]) bytecode6 = self.sock_diag.PackBytecode([ (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))]) # IPv4/v6 filters must never match IPv6/IPv4 sockets... v4sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4) self.assertTrue(v4sockets) self.assertTrue(all(d.family == AF_INET for d, _ in v4sockets)) v6sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6) self.assertTrue(v6sockets) self.assertTrue(all(d.family == AF_INET6 for d, _ in v6sockets)) # Except for mapped addresses, which match both IPv4 and IPv6. pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::ffff:127.0.0.1") diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5] v4sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4)] v6sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6)] self.assertTrue(all(d in v4sockets for d in diag_msgs)) self.assertTrue(all(d in v6sockets for d in diag_msgs)) def testPortComparisonValidation(self): """Checks for a bug in validating port comparison bytecode. Relevant kernel commits: android-3.4: 5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads """ bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8)) self.assertRaisesErrno( EINVAL, self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack()) def testNonSockDiagCommand(self): def DiagDump(code): sock_id = self.sock_diag._EmptyInetDiagSockId() req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff, sock_id)) self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "") op = sock_diag.SOCK_DIAG_BY_FAMILY DiagDump(op) # No errors? Good. self.assertRaisesErrno(EINVAL, DiagDump, op + 17) class SockDestroyTest(SockDiagBaseTest): """Tests that SOCK_DESTROY works correctly. Relevant kernel commits: net-next: b613f56 net: diag: split inet_diag_dump_one_icsk into two 64be0ae net: diag: Add the ability to destroy a socket. 6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets. c1e64e2 net: diag: Support destroying TCP sockets. 2010b93 net: tcp: deal with listen sockets properly in tcp_abort. android-3.4: d48ec88 net: diag: split inet_diag_dump_one_icsk into two 2438189 net: diag: Add the ability to destroy a socket. 7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets. 44047b2 net: diag: Support destroying TCP sockets. 200dae7 net: tcp: deal with listen sockets properly in tcp_abort. android-3.10: 9eaff90 net: diag: split inet_diag_dump_one_icsk into two d60326c net: diag: Add the ability to destroy a socket. 3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets. 529dfc6 net: diag: Support destroying TCP sockets. 9c712fe net: tcp: deal with listen sockets properly in tcp_abort. android-3.18: 100263d net: diag: split inet_diag_dump_one_icsk into two 194c5f3 net: diag: Add the ability to destroy a socket. 8387ea2 net: diag: Support SOCK_DESTROY for inet sockets. b80585a net: diag: Support destroying TCP sockets. 476c6ce net: tcp: deal with listen sockets properly in tcp_abort. """ def testClosesSockets(self): self.socketpairs = self._CreateLotsOfSockets() for _, socketpair in self.socketpairs.iteritems(): # Close one of the sockets. # This will send a RST that will close the other side as well. s = random.choice(socketpair) if random.randrange(0, 2) == 1: self.sock_diag.CloseSocketFromFd(s) else: diag_msg = self.sock_diag.FindSockDiagFromFd(s) # Get the cookie wrong and ensure that we get an error and the socket # is not closed. real_cookie = diag_msg.id.cookie diag_msg.id.cookie = os.urandom(len(real_cookie)) req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req) self.assertSocketConnected(s) # Now close it with the correct cookie. req.id.cookie = real_cookie self.sock_diag.CloseSocket(req) # Check that both sockets in the pair are closed. self.assertSocketsClosed(socketpair) def testNonTcpSockets(self): s = socket(AF_INET6, SOCK_DGRAM, 0) s.connect(("::1", 53)) self.sock_diag.FindSockDiagFromFd(s) # No exceptions? Good. self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s) # TODO: # Test that killing unix sockets returns EOPNOTSUPP. class SocketExceptionThread(threading.Thread): def __init__(self, sock, operation): self.exception = None super(SocketExceptionThread, self).__init__() self.daemon = True self.sock = sock self.operation = operation def run(self): try: self.operation(self.sock) except IOError, e: self.exception = e class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): def testIpv4MappedSynRecvSocket(self): """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets. Relevant kernel commits: android-3.4: 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state """ netid = random.choice(self.tuns.keys()) self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid) sock_id = self.sock_diag._EmptyInetDiagSockId() sock_id.sport = self.port states = 1 << tcp_test.TCP_SYN_RECV req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id)) children = self.sock_diag.Dump(req, NO_BYTECODE) self.assertTrue(children) for child, unused_args in children: self.assertEqual(tcp_test.TCP_SYN_RECV, child.state) self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr), child.id.dst) self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr), child.id.src) class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): def setUp(self): super(SockDestroyTcpTest, self).setUp() self.netid = random.choice(self.tuns.keys()) def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True): """Closes the socket and checks whether a RST is sent or not.""" if sock is not None: self.assertIsNone(req, "Must specify sock or req, not both") self.sock_diag.CloseSocketFromFd(sock) self.assertRaisesErrno(EINVAL, sock.accept) else: self.assertIsNone(sock, "Must specify sock or req, not both") self.sock_diag.CloseSocket(req) if expect_reset: desc, rst = self.RstPacket() msg = "%s: expecting %s: " % (msg, desc) self.ExpectPacketOn(self.netid, msg, rst) else: msg = "%s: " % msg self.ExpectNoPacketsOn(self.netid, msg) if sock is not None and do_close: sock.close() def CheckTcpReset(self, state, statename): for version in [4, 5, 6]: msg = "Closing incoming IPv%d %s socket" % (version, statename) self.IncomingConnection(version, state, self.netid) self.CheckRstOnClose(self.s, None, False, msg) if state != tcp_test.TCP_LISTEN: msg = "Closing accepted IPv%d %s socket" % (version, statename) self.CheckRstOnClose(self.accepted, None, True, msg) def testTcpResets(self): """Checks that closing sockets in appropriate states sends a RST.""" self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN") self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED") self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT") def FindChildSockets(self, s): """Finds the SYN_RECV child sockets of a given listening socket.""" d = self.sock_diag.FindSockDiagFromFd(self.s) req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED req.id.cookie = "\x00" * 8 children = self.sock_diag.Dump(req, NO_BYTECODE) return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) for d, _ in children] def CheckChildSocket(self, version, statename, parent_first): state = getattr(tcp_test, statename) self.IncomingConnection(version, state, self.netid) d = self.sock_diag.FindSockDiagFromFd(self.s) parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) children = self.FindChildSockets(self.s) self.assertEquals(1, len(children)) is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED) # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the # regular TCP hash tables, and inet_diag_find_one_icsk can find them. # Before 4.4, we can see those sockets in dumps, but we can't fetch # or close them. can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4) for child in children: if can_close_children: self.sock_diag.GetSockDiag(child) # No errors? Good, child found. else: self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child) def CloseParent(expect_reset): msg = "Closing parent IPv%d %s socket %s child" % ( version, statename, "before" if parent_first else "after") self.CheckRstOnClose(self.s, None, expect_reset, msg) self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, parent) def CheckChildrenClosed(): for child in children: self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child) def CloseChildren(): for child in children: msg = "Closing child IPv%d %s socket %s parent" % ( version, statename, "after" if parent_first else "before") self.sock_diag.GetSockDiag(child) self.CheckRstOnClose(None, child, is_established, msg) self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child) CheckChildrenClosed() if parent_first: # Closing the parent will close child sockets, which will send a RST, # iff they are already established. CloseParent(is_established) if is_established: CheckChildrenClosed() elif can_close_children: CloseChildren() CheckChildrenClosed() self.s.close() else: if can_close_children: CloseChildren() CloseParent(False) self.s.close() def testChildSockets(self): for version in [4, 5, 6]: self.CheckChildSocket(version, "TCP_SYN_RECV", False) self.CheckChildSocket(version, "TCP_SYN_RECV", True) self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False) self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True) def CloseDuringBlockingCall(self, sock, call, expected_errno): thread = SocketExceptionThread(sock, call) thread.start() time.sleep(0.1) self.sock_diag.CloseSocketFromFd(sock) thread.join(1) self.assertFalse(thread.is_alive()) self.assertIsNotNone(thread.exception) self.assertTrue(isinstance(thread.exception, IOError), "Expected IOError, got %s" % thread.exception) self.assertEqual(expected_errno, thread.exception.errno) self.assertSocketClosed(sock) def testAcceptInterrupted(self): """Tests that accept() is interrupted by SOCK_DESTROY.""" for version in [4, 5, 6]: self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid) self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL) self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo") self.assertRaisesErrno(EINVAL, self.s.accept) def testReadInterrupted(self): """Tests that read() is interrupted by SOCK_DESTROY.""" for version in [4, 5, 6]: self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096), ECONNABORTED) self.assertRaisesErrno(EPIPE, self.accepted.send, "foo") def testConnectInterrupted(self): """Tests that connect() is interrupted by SOCK_DESTROY.""" for version in [4, 5, 6]: family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) self.SelectInterface(s, self.netid, "mark") if version == 5: remoteaddr = "::ffff:" + self.GetRemoteAddress(4) version = 4 else: remoteaddr = self.GetRemoteAddress(version) s.bind(("", 0)) _, sport = s.getsockname()[:2] self.CloseDuringBlockingCall( s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED) desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid), remoteaddr, sport=sport, seq=None) self.ExpectPacketOn(self.netid, desc, syn) msg = "SOCK_DESTROY of socket in connect, expected no RST" self.ExpectNoPacketsOn(self.netid, msg) if __name__ == "__main__": unittest.main()