// Copyright 2017 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.

package main

import (
	"bytes"
	"fmt"
	"os"
	"regexp"
	"strconv"
	"strings"
	"text/template"

	"github.com/google/syzkaller/pkg/compiler"
	"github.com/google/syzkaller/pkg/osutil"
)

func extract(info *compiler.ConstInfo, cc string, args []string, addSource string, declarePrintf bool) (
	map[string]uint64, map[string]bool, error) {
	data := &CompileData{
		AddSource:     addSource,
		Defines:       info.Defines,
		Includes:      info.Includes,
		Values:        info.Consts,
		DeclarePrintf: declarePrintf,
	}
	undeclared := make(map[string]bool)
	bin, out, err := compile(cc, args, data)
	if err != nil {
		// Some consts and syscall numbers are not defined on some archs.
		// Figure out from compiler output undefined consts,
		// and try to compile again without them.
		valMap := make(map[string]bool)
		for _, val := range info.Consts {
			valMap[val] = true
		}
		for _, errMsg := range []string{
			"error: ‘([a-zA-Z0-9_]+)’ undeclared",
			"error: '([a-zA-Z0-9_]+)' undeclared",
			"note: in expansion of macro ‘([a-zA-Z0-9_]+)’",
			"error: use of undeclared identifier '([a-zA-Z0-9_]+)'",
		} {
			re := regexp.MustCompile(errMsg)
			matches := re.FindAllSubmatch(out, -1)
			for _, match := range matches {
				val := string(match[1])
				if valMap[val] {
					undeclared[val] = true
				}
			}
		}
		data.Values = nil
		for _, v := range info.Consts {
			if undeclared[v] {
				continue
			}
			data.Values = append(data.Values, v)
		}
		bin, out, err = compile(cc, args, data)
		if err != nil {
			return nil, nil, fmt.Errorf("failed to run compiler: %v\n%v", err, string(out))
		}
	}
	defer os.Remove(bin)

	out, err = osutil.Command(bin).CombinedOutput()
	if err != nil {
		return nil, nil, fmt.Errorf("failed to run flags binary: %v\n%v", err, string(out))
	}
	flagVals := strings.Split(string(out), " ")
	if len(out) == 0 {
		flagVals = nil
	}
	if len(flagVals) != len(data.Values) {
		return nil, nil, fmt.Errorf("fetched wrong number of values %v, want != %v",
			len(flagVals), len(data.Values))
	}
	res := make(map[string]uint64)
	for i, name := range data.Values {
		val := flagVals[i]
		n, err := strconv.ParseUint(val, 10, 64)
		if err != nil {
			return nil, nil, fmt.Errorf("failed to parse value: %v (%v)", err, val)
		}
		res[name] = n
	}
	return res, undeclared, nil
}

type CompileData struct {
	AddSource     string
	Defines       map[string]string
	Includes      []string
	Values        []string
	DeclarePrintf bool
}

func compile(cc string, args []string, data *CompileData) (bin string, out []byte, err error) {
	src := new(bytes.Buffer)
	if err := srcTemplate.Execute(src, data); err != nil {
		return "", nil, fmt.Errorf("failed to generate source: %v", err)
	}
	binFile, err := osutil.TempFile("syz-extract-bin")
	if err != nil {
		return "", nil, err
	}
	args = append(args, []string{
		"-x", "c", "-",
		"-o", binFile,
		"-w",
	}...)
	cmd := osutil.Command(cc, args...)
	cmd.Stdin = src
	if out, err := cmd.CombinedOutput(); err != nil {
		os.Remove(binFile)
		return "", out, err
	}
	return binFile, nil, nil
}

var srcTemplate = template.Must(template.New("").Parse(`
#define __asm__(...)

{{range $incl := $.Includes}}
#include <{{$incl}}>
{{end}}

{{range $name, $val := $.Defines}}
#ifndef {{$name}}
#	define {{$name}} {{$val}}
#endif
{{end}}

{{.AddSource}}

{{if .DeclarePrintf}}
int printf(const char *format, ...);
{{end}}

int main() {
	int i;
	unsigned long long vals[] = {
		{{range $val := $.Values}}(unsigned long long){{$val}},
		{{end}}
	};
	for (i = 0; i < sizeof(vals)/sizeof(vals[0]); i++) {
		if (i != 0)
			printf(" ");
		printf("%llu", vals[i]);
	}
	return 0;
}
`))