// Copyright 2017 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
// Package compiler generates sys descriptions of syscalls, types and resources
// from textual descriptions.
package compiler
import (
"fmt"
"sort"
"strconv"
"strings"
"github.com/google/syzkaller/pkg/ast"
"github.com/google/syzkaller/prog"
"github.com/google/syzkaller/sys/targets"
)
// Overview of compilation process:
// 1. ast.Parse on text file does tokenization and builds AST.
// This step catches basic syntax errors. AST contains full debug info.
// 2. ExtractConsts as AST returns set of constant identifiers.
// This step also does verification of include/incdir/define AST nodes.
// 3. User translates constants to values.
// 4. Compile on AST and const values does the rest of the work and returns Prog
// containing generated prog objects.
// 4.1. assignSyscallNumbers: uses consts to assign syscall numbers.
// This step also detects unsupported syscalls and discards no longer
// needed AST nodes (inlcude, define, comments, etc).
// 4.2. patchConsts: patches Int nodes referring to consts with corresponding values.
// Also detects unsupported syscalls, structs, resources due to missing consts.
// 4.3. check: does extensive semantical checks of AST.
// 4.4. gen: generates prog objects from AST.
// Prog is description compilation result.
type Prog struct {
Resources []*prog.ResourceDesc
Syscalls []*prog.Syscall
StructDescs []*prog.KeyedStruct
// Set of unsupported syscalls/flags.
Unsupported map[string]bool
// Returned if consts was nil.
fileConsts map[string]*ConstInfo
}
// Compile compiles sys description.
func Compile(desc *ast.Description, consts map[string]uint64, target *targets.Target, eh ast.ErrorHandler) *Prog {
if eh == nil {
eh = ast.LoggingHandler
}
comp := &compiler{
desc: desc.Clone(),
target: target,
eh: eh,
ptrSize: target.PtrSize,
unsupported: make(map[string]bool),
resources: make(map[string]*ast.Resource),
typedefs: make(map[string]*ast.TypeDef),
structs: make(map[string]*ast.Struct),
intFlags: make(map[string]*ast.IntFlags),
strFlags: make(map[string]*ast.StrFlags),
used: make(map[string]bool),
usedTypedefs: make(map[string]bool),
structDescs: make(map[prog.StructKey]*prog.StructDesc),
structNodes: make(map[*prog.StructDesc]*ast.Struct),
structVarlen: make(map[string]bool),
}
for name, n := range builtinTypedefs {
comp.typedefs[name] = n
comp.usedTypedefs[name] = true
}
for name, n := range builtinStrFlags {
comp.strFlags[name] = n
}
comp.typecheck()
// The subsequent, more complex, checks expect basic validity of the tree,
// in particular corrent number of type arguments. If there were errors,
// don't proceed to avoid out-of-bounds references to type arguments.
if comp.errors != 0 {
return nil
}
if consts == nil {
fileConsts := comp.extractConsts()
if comp.errors != 0 {
return nil
}
return &Prog{fileConsts: fileConsts}
}
if comp.target.SyscallNumbers {
comp.assignSyscallNumbers(consts)
}
comp.patchConsts(consts)
comp.check()
if comp.errors != 0 {
return nil
}
for _, w := range comp.warnings {
eh(w.pos, w.msg)
}
syscalls := comp.genSyscalls()
prg := &Prog{
Resources: comp.genResources(),
Syscalls: syscalls,
StructDescs: comp.genStructDescs(syscalls),
Unsupported: comp.unsupported,
}
if comp.errors != 0 {
return nil
}
return prg
}
type compiler struct {
desc *ast.Description
target *targets.Target
eh ast.ErrorHandler
errors int
warnings []warn
ptrSize uint64
unsupported map[string]bool
resources map[string]*ast.Resource
typedefs map[string]*ast.TypeDef
structs map[string]*ast.Struct
intFlags map[string]*ast.IntFlags
strFlags map[string]*ast.StrFlags
used map[string]bool // contains used structs/resources
usedTypedefs map[string]bool
structDescs map[prog.StructKey]*prog.StructDesc
structNodes map[*prog.StructDesc]*ast.Struct
structVarlen map[string]bool
}
type warn struct {
pos ast.Pos
msg string
}
func (comp *compiler) error(pos ast.Pos, msg string, args ...interface{}) {
comp.errors++
comp.eh(pos, fmt.Sprintf(msg, args...))
}
func (comp *compiler) warning(pos ast.Pos, msg string, args ...interface{}) {
comp.warnings = append(comp.warnings, warn{pos, fmt.Sprintf(msg, args...)})
}
func (comp *compiler) structIsVarlen(name string) bool {
if varlen, ok := comp.structVarlen[name]; ok {
return varlen
}
s := comp.structs[name]
if s.IsUnion {
if varlen, _ := comp.parseUnionAttrs(s); varlen {
comp.structVarlen[name] = true
return true
}
}
comp.structVarlen[name] = false // to not hang on recursive types
varlen := false
for _, fld := range s.Fields {
if comp.isVarlen(fld.Type) {
varlen = true
break
}
}
comp.structVarlen[name] = varlen
return varlen
}
func (comp *compiler) parseUnionAttrs(n *ast.Struct) (varlen bool, size uint64) {
size = sizeUnassigned
for _, attr := range n.Attrs {
switch attr.Ident {
case "varlen":
if len(attr.Args) != 0 {
comp.error(attr.Pos, "%v attribute has args", attr.Ident)
}
varlen = true
case "size":
size = comp.parseSizeAttr(attr)
default:
comp.error(attr.Pos, "unknown union %v attribute %v",
n.Name.Name, attr.Ident)
}
}
return
}
func (comp *compiler) parseStructAttrs(n *ast.Struct) (packed bool, size, align uint64) {
size = sizeUnassigned
for _, attr := range n.Attrs {
switch {
case attr.Ident == "packed":
if len(attr.Args) != 0 {
comp.error(attr.Pos, "%v attribute has args", attr.Ident)
}
packed = true
case attr.Ident == "align_ptr":
if len(attr.Args) != 0 {
comp.error(attr.Pos, "%v attribute has args", attr.Ident)
}
align = comp.ptrSize
case strings.HasPrefix(attr.Ident, "align_"):
if len(attr.Args) != 0 {
comp.error(attr.Pos, "%v attribute has args", attr.Ident)
}
a, err := strconv.ParseUint(attr.Ident[6:], 10, 64)
if err != nil {
comp.error(attr.Pos, "bad struct %v alignment %v",
n.Name.Name, attr.Ident[6:])
continue
}
if a&(a-1) != 0 || a == 0 || a > 1<<30 {
comp.error(attr.Pos, "bad struct %v alignment %v (must be a sane power of 2)",
n.Name.Name, a)
}
align = a
case attr.Ident == "size":
size = comp.parseSizeAttr(attr)
default:
comp.error(attr.Pos, "unknown struct %v attribute %v",
n.Name.Name, attr.Ident)
}
}
return
}
func (comp *compiler) parseSizeAttr(attr *ast.Type) uint64 {
if len(attr.Args) != 1 {
comp.error(attr.Pos, "%v attribute is expected to have 1 argument", attr.Ident)
return sizeUnassigned
}
sz := attr.Args[0]
if unexpected, _, ok := checkTypeKind(sz, kindInt); !ok {
comp.error(sz.Pos, "unexpected %v, expect int", unexpected)
return sizeUnassigned
}
if sz.HasColon || len(sz.Args) != 0 {
comp.error(sz.Pos, "size attribute has colon or args")
return sizeUnassigned
}
return sz.Value
}
func (comp *compiler) getTypeDesc(t *ast.Type) *typeDesc {
if desc := builtinTypes[t.Ident]; desc != nil {
return desc
}
if comp.resources[t.Ident] != nil {
return typeResource
}
if comp.structs[t.Ident] != nil {
return typeStruct
}
if comp.typedefs[t.Ident] != nil {
return typeTypedef
}
return nil
}
func (comp *compiler) getArgsBase(t *ast.Type, field string, dir prog.Dir, isArg bool) (
*typeDesc, []*ast.Type, prog.IntTypeCommon) {
desc := comp.getTypeDesc(t)
if desc == nil {
panic(fmt.Sprintf("no type desc for %#v", *t))
}
args, opt := removeOpt(t)
com := genCommon(t.Ident, field, sizeUnassigned, dir, opt != nil)
base := genIntCommon(com, 0, false)
if desc.NeedBase {
base.TypeSize = comp.ptrSize
if !isArg {
baseType := args[len(args)-1]
args = args[:len(args)-1]
base = typeInt.Gen(comp, baseType, nil, base).(*prog.IntType).IntTypeCommon
}
}
return desc, args, base
}
func (comp *compiler) foreachType(n0 ast.Node,
cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) {
switch n := n0.(type) {
case *ast.Call:
for _, arg := range n.Args {
comp.foreachSubType(arg.Type, true, cb)
}
if n.Ret != nil {
comp.foreachSubType(n.Ret, true, cb)
}
case *ast.Resource:
comp.foreachSubType(n.Base, false, cb)
case *ast.Struct:
for _, f := range n.Fields {
comp.foreachSubType(f.Type, false, cb)
}
case *ast.TypeDef:
if len(n.Args) == 0 {
comp.foreachSubType(n.Type, false, cb)
}
default:
panic(fmt.Sprintf("unexpected node %#v", n0))
}
}
func (comp *compiler) foreachSubType(t *ast.Type, isArg bool,
cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) {
desc, args, base := comp.getArgsBase(t, "", prog.DirIn, isArg)
cb(t, desc, args, base)
for i, arg := range args {
if desc.Args[i].Type == typeArgType {
comp.foreachSubType(arg, desc.Args[i].IsArg, cb)
}
}
}
func removeOpt(t *ast.Type) ([]*ast.Type, *ast.Type) {
args := t.Args
if last := len(args) - 1; last >= 0 && args[last].Ident == "opt" {
return args[:last], args[last]
}
return args, nil
}
func (comp *compiler) parseIntType(name string) (size uint64, bigEndian bool) {
be := strings.HasSuffix(name, "be")
if be {
name = name[:len(name)-len("be")]
}
size = comp.ptrSize
if name != "intptr" {
size, _ = strconv.ParseUint(name[3:], 10, 64)
size /= 8
}
return size, be
}
func toArray(m map[string]bool) []string {
delete(m, "")
var res []string
for v := range m {
if v != "" {
res = append(res, v)
}
}
sort.Strings(res)
return res
}
func arrayContains(a []string, v string) bool {
for _, s := range a {
if s == v {
return true
}
}
return false
}