#!/usr/bin/env python # # This script generates a BPF program with structure inspired by trace.py. The # generated program operates on PID-indexed stacks. Generally speaking, # bookkeeping is done at every intermediate function kprobe/kretprobe to enforce # the goal of "fail iff this call chain and these predicates". # # Top level functions(the ones at the end of the call chain) are responsible for # creating the pid_struct and deleting it from the map in kprobe and kretprobe # respectively. # # Intermediate functions(between should_fail_whatever and the top level # functions) are responsible for updating the stack to indicate "I have been # called and one of my predicate(s) passed" in their entry probes. In their exit # probes, they do the opposite, popping their stack to maintain correctness. # This implementation aims to ensure correctness in edge cases like recursive # calls, so there's some additional information stored in pid_struct for that. # # At the bottom level function(should_fail_whatever), we do a simple check to # ensure all necessary calls/predicates have passed before error injection. # # Note: presently there are a few hacks to get around various rewriter/verifier # issues. # # Note: this tool requires: # - CONFIG_BPF_KPROBE_OVERRIDE # # USAGE: inject [-h] [-I header] [-P probability] [-v] mode spec # # Copyright (c) 2018 Facebook, Inc. # Licensed under the Apache License, Version 2.0 (the "License") # # 16-Mar-2018 Howard McLauchlan Created this. import argparse import re from bcc import BPF class Probe: errno_mapping = { "kmalloc": "-ENOMEM", "bio": "-EIO", } @classmethod def configure(cls, mode, probability): cls.mode = mode cls.probability = probability def __init__(self, func, preds, length, entry): # length of call chain self.length = length self.func = func self.preds = preds self.is_entry = entry def _bail(self, err): raise ValueError("error in probe '%s': %s" % (self.spec, err)) def _get_err(self): return Probe.errno_mapping[Probe.mode] def _get_if_top(self): # ordering guarantees that if this function is top, the last tup is top chk = self.preds[0][1] == 0 if not chk: return "" if Probe.probability == 1: early_pred = "false" else: early_pred = "bpf_get_prandom_u32() > %s" % str(int((1<<32)*Probe.probability)) # init the map # dont do an early exit here so the singular case works automatically # have an early exit for probability option enter = """ /* * Early exit for probability case */ if (%s) return 0; /* * Top level function init map */ struct pid_struct p_struct = {0, 0}; m.insert(&pid, &p_struct); """ % early_pred # kill the entry exit = """ /* * Top level function clean up map */ m.delete(&pid); """ return enter if self.is_entry else exit def _get_heading(self): # we need to insert identifier and ctx into self.func # gonna make a lot of formatting assumptions to make this work left = self.func.find("(") right = self.func.rfind(")") # self.event and self.func_name need to be accessible self.event = self.func[0:left] self.func_name = self.event + ("_entry" if self.is_entry else "_exit") func_sig = "struct pt_regs *ctx" # assume theres something in there, no guarantee its well formed if right > left + 1 and self.is_entry: func_sig += ", " + self.func[left + 1:right] return "int %s(%s)" % (self.func_name, func_sig) def _get_entry_logic(self): # there is at least one tup(pred, place) for this function text = """ if (p->conds_met >= %s) return 0; if (p->conds_met == %s && %s) { p->stack[%s] = p->curr_call; p->conds_met++; }""" text = text % (self.length, self.preds[0][1], self.preds[0][0], self.preds[0][1]) # for each additional pred for tup in self.preds[1:]: text += """ else if (p->conds_met == %s && %s) { p->stack[%s] = p->curr_call; p->conds_met++; } """ % (tup[1], tup[0], tup[1]) return text def _generate_entry(self): prog = self._get_heading() + """ { u32 pid = bpf_get_current_pid_tgid(); %s struct pid_struct *p = m.lookup(&pid); if (!p) return 0; /* * preparation for predicate, if necessary */ %s /* * Generate entry logic */ %s p->curr_call++; return 0; }""" prog = prog % (self._get_if_top(), self.prep, self._get_entry_logic()) return prog # only need to check top of stack def _get_exit_logic(self): text = """ if (p->conds_met < 1 || p->conds_met >= %s) return 0; if (p->stack[p->conds_met - 1] == p->curr_call) p->conds_met--; """ return text % str(self.length + 1) def _generate_exit(self): prog = self._get_heading() + """ { u32 pid = bpf_get_current_pid_tgid(); struct pid_struct *p = m.lookup(&pid); if (!p) return 0; p->curr_call--; /* * Generate exit logic */ %s %s return 0; }""" prog = prog % (self._get_exit_logic(), self._get_if_top()) return prog # Special case for should_fail_whatever def _generate_bottom(self): pred = self.preds[0][0] text = self._get_heading() + """ { /* * preparation for predicate, if necessary */ %s /* * If this is the only call in the chain and predicate passes */ if (%s == 1 && %s) { bpf_override_return(ctx, %s); return 0; } u32 pid = bpf_get_current_pid_tgid(); struct pid_struct *p = m.lookup(&pid); if (!p) return 0; /* * If all conds have been met and predicate passes */ if (p->conds_met == %s && %s) bpf_override_return(ctx, %s); return 0; }""" return text % (self.prep, self.length, pred, self._get_err(), self.length - 1, pred, self._get_err()) # presently parses and replaces STRCMP # STRCMP exists because string comparison is inconvenient and somewhat buggy # https://github.com/iovisor/bcc/issues/1617 def _prepare_pred(self): self.prep = "" for i in range(len(self.preds)): new_pred = "" pred = self.preds[i][0] place = self.preds[i][1] start, ind = 0, 0 while start < len(pred): ind = pred.find("STRCMP(", start) if ind == -1: break new_pred += pred[start:ind] # 7 is len("STRCMP(") start = pred.find(")", start + 7) + 1 # then ind ... start is STRCMP(...) ptr, literal = pred[ind + 7:start - 1].split(",") literal = literal.strip() # x->y->z, some string literal # we make unique id with place_ind uuid = "%s_%s" % (place, ind) unique_bool = "is_true_%s" % uuid self.prep += """ char *str_%s = %s; bool %s = true;\n""" % (uuid, ptr.strip(), unique_bool) check = "\t%s &= *(str_%s++) == '%%s';\n" % (unique_bool, uuid) for ch in literal: self.prep += check % ch self.prep += check % r'\0' new_pred += unique_bool new_pred += pred[start:] self.preds[i] = (new_pred, place) def generate_program(self): # generate code to work around various rewriter issues self._prepare_pred() # special case for bottom if self.preds[-1][1] == self.length - 1: return self._generate_bottom() return self._generate_entry() if self.is_entry else self._generate_exit() def attach(self, bpf): if self.is_entry: bpf.attach_kprobe(event=self.event, fn_name=self.func_name) else: bpf.attach_kretprobe(event=self.event, fn_name=self.func_name) class Tool: examples =""" EXAMPLES: # ./inject.py kmalloc -v 'SyS_mount()' Fails all calls to syscall mount # ./inject.py kmalloc -v '(true) => SyS_mount()(true)' Explicit rewriting of above # ./inject.py kmalloc -v 'mount_subtree() => btrfs_mount()' Fails btrfs mounts only # ./inject.py kmalloc -v 'd_alloc_parallel(struct dentry *parent, const struct \\ qstr *name)(STRCMP(name->name, 'bananas'))' Fails dentry allocations of files named 'bananas' # ./inject.py kmalloc -v -P 0.01 'SyS_mount()' Fails calls to syscall mount with 1% probability """ # add cases as necessary error_injection_mapping = { "kmalloc": "should_failslab(struct kmem_cache *s, gfp_t gfpflags)", "bio": "should_fail_bio(struct bio *bio)", } def __init__(self): parser = argparse.ArgumentParser(description="Fail specified kernel" + " functionality when call chain and predicates are met", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=Tool.examples) parser.add_argument(dest="mode", choices=['kmalloc','bio'], help="indicate which base kernel function to fail") parser.add_argument(metavar="spec", dest="spec", help="specify call chain") parser.add_argument("-I", "--include", action="append", metavar="header", help="additional header files to include in the BPF program") parser.add_argument("-P", "--probability", default=1, metavar="probability", type=float, help="probability that this call chain will fail") parser.add_argument("-v", "--verbose", action="store_true", help="print BPF program") self.args = parser.parse_args() self.program = "" self.spec = self.args.spec self.map = {} self.probes = [] self.key = Tool.error_injection_mapping[self.args.mode] # create_probes and associated stuff def _create_probes(self): self._parse_spec() Probe.configure(self.args.mode, self.args.probability) # self, func, preds, total, entry # create all the pair probes for fx, preds in self.map.items(): # do the enter self.probes.append(Probe(fx, preds, self.length, True)) if self.key == fx: continue # do the exit self.probes.append(Probe(fx, preds, self.length, False)) def _parse_frames(self): # sentinel data = self.spec + '\0' start, count = 0, 0 frames = [] cur_frame = [] i = 0 last_frame_added = 0 while i < len(data): # improper input if count < 0: raise Exception("Check your parentheses") c = data[i] count += c == '(' count -= c == ')' if not count: if c == '\0' or (c == '=' and data[i + 1] == '>'): # This block is closing a chunk. This means cur_frame must # have something in it. if not cur_frame: raise Exception("Cannot parse spec, missing parens") if len(cur_frame) == 2: frame = tuple(cur_frame) elif cur_frame[0][0] == '(': frame = self.key, cur_frame[0] else: frame = cur_frame[0], '(true)' frames.append(frame) del cur_frame[:] i += 1 start = i + 1 elif c == ')': cur_frame.append(data[start:i + 1].strip()) start = i + 1 last_frame_added = start i += 1 # We only permit spaces after the last frame if self.spec[last_frame_added:].strip(): raise Exception("Invalid characters found after last frame"); # improper input if count: raise Exception("Check your parentheses") return frames def _parse_spec(self): frames = self._parse_frames() frames.reverse() absolute_order = 0 for f in frames: # default case func, pred = f[0], f[1] if not self._validate_predicate(pred): raise Exception("Invalid predicate") if not self._validate_identifier(func): raise Exception("Invalid function identifier") tup = (pred, absolute_order) if func not in self.map: self.map[func] = [tup] else: self.map[func].append(tup) absolute_order += 1 if self.key not in self.map: self.map[self.key] = [('(true)', absolute_order)] absolute_order += 1 self.length = absolute_order def _validate_identifier(self, func): # We've already established paren balancing. We will only look for # identifier validity here. paren_index = func.find("(") potential_id = func[:paren_index] pattern = '[_a-zA-z][_a-zA-Z0-9]*$' if re.match(pattern, potential_id): return True return False def _validate_predicate(self, pred): if len(pred) > 0 and pred[0] == "(": open = 1 for i in range(1, len(pred)): if pred[i] == "(": open += 1 elif pred[i] == ")": open -= 1 if open != 0: # not well formed, break return False return True def _def_pid_struct(self): text = """ struct pid_struct { u64 curr_call; /* book keeping to handle recursion */ u64 conds_met; /* stack pointer */ u64 stack[%s]; }; """ % self.length return text def _attach_probes(self): self.bpf = BPF(text=self.program) for p in self.probes: p.attach(self.bpf) def _generate_program(self): # leave out auto includes for now self.program += '#include <linux/mm.h>\n' for include in (self.args.include or []): self.program += "#include <%s>\n" % include self.program += self._def_pid_struct() self.program += "BPF_HASH(m, u32, struct pid_struct);\n" for p in self.probes: self.program += p.generate_program() + "\n" if self.args.verbose: print(self.program) def _main_loop(self): while True: self.bpf.perf_buffer_poll() def run(self): self._create_probes() self._generate_program() self._attach_probes() self._main_loop() if __name__ == "__main__": Tool().run()