# Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import dpkt
import os
import select
import struct
import sys
import threading
import time
import traceback
class SimulatorError(Exception):
"A Simulator generic error."
class NullContext(object):
"""A context manager without any functionality."""
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False # raises the exception if passed.
class Simulator(object):
"""A TUN/TAP network interface simulator class.
This class allows several implementations of different fake hosts to
coexists on the same TUN/TAP interface. It will dispatch the same packet
to each one of the registered hosts, providing some basic filtering
to simplify these implementations.
"""
def __init__(self, iface):
"""Initialize the instance.
@param tuntap.TunTap iface: the interface over which this interface
runs. Should not be shared with other modules.
"""
self._iface = iface
self._rules = []
# _events holds a lists of events that need to be fired for each
# timestamp stored on the key. The event list is a list of callback
# functions that will be called if the simulation reaches that
# timestamp. This is used to fire time-based events.
self._events = {}
self._write_queue = []
# A pipe used to wake up the run() method from a diffent thread calling
# stop(). See the stop() method for details.
self._pipe_rd, self._pipe_wr = os.pipe()
self._running = False
# Lock object used for _events if multithreading is required.
self._lock = NullContext()
def __del__(self):
os.close(self._pipe_rd)
os.close(self._pipe_wr)
def add_match(self, rule, callback):
"""Add a new match rule to the outbound traffic.
This function adds a new rule that will be matched against each packet
that the host sends through the interface and will call a callback if
it matches. The rule can be specified in the following ways:
* A python function that takes a packet as a single argument and
returns True when the packet matches.
* A dictionary of key=value pairs that all of them need to be matched.
A pair matches when the packet has the provided chain of attributes
and its value is equal to the provided value. For example, this will
match any DNS traffic sent to the host 192.168.0.1:
{"ip.dst": socket.inet_aton("192.168.0.1"),
"ip.upd.dport": 53}
@param rule: The rule description.
@param callback: A callback function that receives the dpkt packet as
the only argument.
"""
if not callable(callback):
raise SimulatorError("|callback| must be a callable object.")
if callable(rule):
self._rules.append((rule, callback))
if isinstance(rule, dict):
rule = dict(rule) # Makes a copy of the dict, but not the contents.
self._rules.append((lambda p: self._dict_rule(rule, p), callback))
else:
raise SimulatorError("Unknown rule format: %r" % rule)
def add_timeout(self, timeout, callback):
"""Add a new callback function to be called after a timeout.
This method schedules the given |callback| to be called after |timeout|
seconds. The callback will be called at most once while the simulator
is running (see the run() method). To have a repetitive event call again
add_timeout() from the callback.
@param timeout: The rule description.
@param callback: A callback function that doesn't receive any argument.
"""
if not callable(callback):
raise SimulatorError("|callback| must be a callable object.")
timestamp = time.time() + timeout
with self._lock:
if timestamp not in self._events:
self._events[timestamp] = [callback]
else:
self._events[timestamp].append(callback)
def remove_timeout(self, callback):
"""Removes the every scheduled timeout call to the passed callback.
When a callable object is passed to add_timeout() it is scheduled to be
called once the timeout is reached. This method removes all the
scheduled calls to that object.
@param callback: The callable object passed to add_timeout().
@return: Wether the callback was found and removed at least once.
"""
removed = False
for _ts, ev_list in self._events.iteritems():
try:
while True:
ev_list.remove(callback)
removed = True
except ValueError:
pass
return removed
def _dict_rule(self, rules, pkt):
"""Returns wether a given packet matches a set of rules.
The maching rules passed in |rules| need to be a dict() as described
on the add_match() method. The packet |pkt| is any dpkt packet.
"""
for key, value in rules.iteritems():
p = pkt
for member in key.split('.'):
if not hasattr(p, member):
return False
p = getattr(p, member)
if p != value:
return False
return True
def write(self, pkt):
"""Writes a packet to the network interface.
@param pkt: The dpkt.Packet to be received on the network interface.
"""
# Converts the dpkt packet to: flags, proto, buffer.
self._write_queue.append(struct.pack("!HH", 0, pkt.type) + str(pkt))
def run(self, timeout=None, until=None):
"""Runs the Simulator.
This method blocks the caller thread until the timeout is reached (if
a timeout is passed), until stop() is called or until the function
passed in until returns a True value (if a function is passed);
whichever occurs first. stop() can be called from any other thread or
from a callback called from this thread.
@param timeout: The timeout in seconds. Can be a float value, or None
for no timeout.
@param until: A callable object called during the loop returning True
when the loop should stop.
"""
if not self._iface.is_up():
raise SimulatorError("Interface is down.")
stop_callback = None
if timeout != None:
# We use a newly created callable object to avoid remove another
# scheduled call to self.stop.
stop_callback = lambda: self.stop()
self.add_timeout(timeout, stop_callback)
self._running = True
iface_fd = self._iface.fileno()
# Check the until function.
while not (until and until()):
# The main purpose of this loop is to wait (block) until the next
# event is required to be fired. There are four kinds of events:
# * a packet is received.
# * a packet waiting to be sent can now be sent.
# * a time-based event needs to be fired.
# * the simulator was stopped from a different thread.
# To achieve this we use select.select() to wait simultaneously on
# all those event sources.
# Fires all the time-based events that need to be fired and computes
# the timeout for the next event if there's one.
timeout = None
cur_time = time.time()
with self._lock:
if self._events:
# Check events that should be fired.
while self._events and min(self._events) <= cur_time:
key = min(self._events)
lst = self._events[key]
del self._events[key]
for callback in lst:
callback()
cur_time = time.time()
# Check if there is an event to attend. Here we know that
# min(self._events) > cur_time because the previous while
# finished.
if self._events:
timeout = min(self._events) - cur_time # in seconds
# Pool the until() function at least once a second.
if timeout is None or timeout > 1.0:
timeout = 1.0
# Compute the list of file descriptors that select.select() needs to
# monitor to attend the required events. select() will return when
# any of the following occurs:
# * rlist: is possible to read from the interface or another
# thread want's to wake up the simulator loop.
# * wlist: is possible to write to network, if there's a packet
# pending.
# * xlist: an error on the network fd occured. Likely the TAP
# interface was closed.
# * timeout: The previously computed timeout was reached.
rlist = iface_fd, self._pipe_rd
wlist = tuple()
if self._write_queue:
wlist = iface_fd,
xlist = iface_fd,
rlist, wlist, xlist = select.select(rlist, wlist, xlist, timeout)
if self._pipe_rd in rlist:
msg = os.read(self._pipe_rd, 1)
# stop() breaks the loop sending a '*'.
if '*' in msg:
break
# Other messages are ignored.
if xlist:
break
if iface_fd in wlist:
self._iface.write(self._write_queue.pop(0))
# Attempt to send all the scheduled packets before reading more
continue
# Process the given packet:
if iface_fd in rlist:
raw = self._iface.read()
flag, proto = struct.unpack("!HH", raw[:4])
pkt = dpkt.ethernet.Ethernet(raw[4:])
for rule, callback in self._rules:
if rule(pkt):
# Parse again the packet to allow callbacks to modify
# it.
callback(dpkt.ethernet.Ethernet(raw[4:]))
if stop_callback:
self.remove_timeout(stop_callback)
self._running = False
def stop(self):
"""Stops the run() method if it is running."""
os.write(self._pipe_wr, '*')
class SimulatorThread(threading.Thread, Simulator):
"""A threaded version of the Simulator.
This class exposses a similar interface as the Simulator class with the
difference that it runs on its own thread. This exposes an extra method
start() that should be called instead of Simulator.run(). start() will make
the process run continuosly until stop() is called, after which the
simulator can't be restarted.
The methods used to add new matches can be called from any thread *before*
the method start() is caller. After that point, only the callbacks, running
from this thread, are allowed to create new matches and timeouts.
Example:
simu = SimulatorThread(tap_interface)
simu.add_match({"ip.tcp.dport": 80}, some_callback)
simu.start()
time.sleep(100)
simu.stop()
simu.join() # Optional
"""
def __init__(self, iface, timeout=None):
threading.Thread.__init__(self)
Simulator.__init__(self, iface)
self._timeout = timeout
# We allow the same thread to acquire the lock more than once. This is
# useful if a callback want's to add itself.
self._lock = threading.RLock()
self.error = None
def run_on_simulator(self, callback):
"""Runs the given callback on the SimulatorThread thread.
Before calling start() on the SimulatorThread, all the calls seting up
the simulator are allowed, but once the thread is running, concurrency
problems should be considered. This method runs the provided callback
on the simulator.
@param callback: A callback function without arguments.
"""
self.add_timeout(0, callback)
# Wake up the main loop with an ignored message.
os.write(self._pipe_wr, ' ')
def wait_for_condition(self, condition, timeout=None):
"""Blocks until the condition is met or timeout is exceeded.
This method should be called from a different thread while the simulator
thread is running as it blocks the calling thread's execution until a
condition is met. The condition function is evaluated in a callback
running on the simulator thread and thus can safely access objects owned
by the simulator.
@param condition: A function called on the simulator thread that returns
a value indicating if the condition is met.
@param timeout: The timeout in seconds. None for no timeout.
@return: The value returned by condition the last time it was called.
This means that in the event of a timeout, this function will return a
value that evaluates to False since the condition wasn't met the last
time it was checked.
"""
# Lock and Condition used to wait until the passed condition is met.
lock_cond = threading.Lock()
cond_var = threading.Condition(lock_cond)
# We use a mutable object like the [] to pass the reference by value
# to the simulator's callback and let it modify the contents.
ret = [None]
# Create the actual callback that will be running on the simulator
# thread and pass a reference to it to keep including it
callback = lambda: self._condition_poller(
callback, ret, cond_var, condition)
# Let the simulator keep calling our function, it will keep calling
# itself until the condition is met (or we remove it).
self.run_on_simulator(callback)
# Condition variable waiting loop.
cur_time = time.time()
start_time = cur_time
with cond_var:
while not ret[0]:
if timeout is None:
cond_var.wait()
else:
cur_timeout = timeout - (cur_time - start_time)
if cur_timeout < 0:
break
cond_var.wait(cur_timeout)
cur_time = time.time()
self.remove_timeout(callback)
return ret[0]
def _condition_poller(self, callback, ref_value, cond_var, func):
"""Callback function used to poll for a condition.
This method keeps scheduling itself in the simulator until the passed
condition evaluates to a True value. This effectivelly implements a
polling mechanism. See wait_for_condition() for details.
"""
with cond_var:
ref_value[0] = func()
if ref_value[0]:
cond_var.notify()
else:
self.add_timeout(1., callback)
def run(self):
"""Runs the simulation on the thread, called by start().
This method wraps the Simulator.run() to pass the timeout value passed
during construction.
"""
try:
Simulator.run(self, self._timeout)
except Exception, e:
self.error = e
exc_type, exc_value, exc_traceback = sys.exc_info()
self.traceback = ''.join(traceback.format_exception(
exc_type, exc_value, exc_traceback))