// Copyright 2017 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.

package ast

import (
	"bytes"
	"fmt"
	"io"
)

func Format(desc *Description) []byte {
	buf := new(bytes.Buffer)
	FormatWriter(buf, desc)
	return buf.Bytes()
}

func FormatWriter(w io.Writer, desc *Description) {
	for _, n := range desc.Nodes {
		s, ok := n.(serializer)
		if !ok {
			panic(fmt.Sprintf("unknown top level decl: %#v", n))
		}
		s.serialize(w)
	}
}

func SerializeNode(n Node) string {
	s, ok := n.(serializer)
	if !ok {
		panic(fmt.Sprintf("unknown node: %#v", n))
	}
	buf := new(bytes.Buffer)
	s.serialize(buf)
	return buf.String()
}

func FormatInt(v uint64, format IntFmt) string {
	switch format {
	case IntFmtDec:
		return fmt.Sprint(v)
	case IntFmtNeg:
		return fmt.Sprint(int64(v))
	case IntFmtHex:
		return fmt.Sprintf("0x%x", v)
	case IntFmtChar:
		return fmt.Sprintf("'%c'", v)
	default:
		panic(fmt.Sprintf("unknown int format %v", format))
	}
}

type serializer interface {
	serialize(w io.Writer)
}

func (nl *NewLine) serialize(w io.Writer) {
	fmt.Fprintf(w, "\n")
}

func (com *Comment) serialize(w io.Writer) {
	fmt.Fprintf(w, "#%v\n", com.Text)
}

func (incl *Include) serialize(w io.Writer) {
	fmt.Fprintf(w, "include <%v>\n", incl.File.Value)
}

func (inc *Incdir) serialize(w io.Writer) {
	fmt.Fprintf(w, "incdir <%v>\n", inc.Dir.Value)
}

func (def *Define) serialize(w io.Writer) {
	fmt.Fprintf(w, "define %v\t%v\n", def.Name.Name, fmtInt(def.Value))
}

func (res *Resource) serialize(w io.Writer) {
	fmt.Fprintf(w, "resource %v[%v]", res.Name.Name, fmtType(res.Base))
	for i, v := range res.Values {
		fmt.Fprintf(w, "%v%v", comma(i, ": "), fmtInt(v))
	}
	fmt.Fprintf(w, "\n")
}

func (typedef *TypeDef) serialize(w io.Writer) {
	fmt.Fprintf(w, "type %v%v", typedef.Name.Name, fmtIdentList(typedef.Args))
	if typedef.Type != nil {
		fmt.Fprintf(w, " %v\n", fmtType(typedef.Type))
	}
	if typedef.Struct != nil {
		typedef.Struct.serialize(w)
	}
}

func (c *Call) serialize(w io.Writer) {
	fmt.Fprintf(w, "%v(", c.Name.Name)
	for i, a := range c.Args {
		fmt.Fprintf(w, "%v%v", comma(i, ""), fmtField(a))
	}
	fmt.Fprintf(w, ")")
	if c.Ret != nil {
		fmt.Fprintf(w, " %v", fmtType(c.Ret))
	}
	fmt.Fprintf(w, "\n")
}

func (str *Struct) serialize(w io.Writer) {
	opening, closing := '{', '}'
	if str.IsUnion {
		opening, closing = '[', ']'
	}
	fmt.Fprintf(w, "%v %c\n", str.Name.Name, opening)
	// Align all field types to the same column.
	const tabWidth = 8
	maxTabs := 0
	for _, f := range str.Fields {
		tabs := (len(f.Name.Name) + tabWidth) / tabWidth
		if maxTabs < tabs {
			maxTabs = tabs
		}
	}
	for _, f := range str.Fields {
		if f.NewBlock {
			fmt.Fprintf(w, "\n")
		}
		for _, com := range f.Comments {
			fmt.Fprintf(w, "#%v\n", com.Text)
		}
		fmt.Fprintf(w, "\t%v\t", f.Name.Name)
		for tabs := len(f.Name.Name)/tabWidth + 1; tabs < maxTabs; tabs++ {
			fmt.Fprintf(w, "\t")
		}
		fmt.Fprintf(w, "%v\n", fmtType(f.Type))
	}
	for _, com := range str.Comments {
		fmt.Fprintf(w, "#%v\n", com.Text)
	}
	fmt.Fprintf(w, "%c", closing)
	if attrs := fmtTypeList(str.Attrs); attrs != "" {
		fmt.Fprintf(w, " %v", attrs)
	}
	fmt.Fprintf(w, "\n")
}

func (flags *IntFlags) serialize(w io.Writer) {
	fmt.Fprintf(w, "%v = ", flags.Name.Name)
	for i, v := range flags.Values {
		fmt.Fprintf(w, "%v%v", comma(i, ""), fmtInt(v))
	}
	fmt.Fprintf(w, "\n")
}

func (flags *StrFlags) serialize(w io.Writer) {
	fmt.Fprintf(w, "%v = ", flags.Name.Name)
	for i, v := range flags.Values {
		fmt.Fprintf(w, "%v\"%v\"", comma(i, ""), v.Value)
	}
	fmt.Fprintf(w, "\n")
}

func fmtField(f *Field) string {
	return fmt.Sprintf("%v %v", f.Name.Name, fmtType(f.Type))
}

func (n *Type) serialize(w io.Writer) {
	w.Write([]byte(fmtType(n)))
}

func fmtType(t *Type) string {
	v := ""
	switch {
	case t.Ident != "":
		v = t.Ident
	case t.HasString:
		v = fmt.Sprintf("\"%v\"", t.String)
	default:
		v = FormatInt(t.Value, t.ValueFmt)
	}
	if t.HasColon {
		switch {
		case t.Ident2 != "":
			v += fmt.Sprintf(":%v", t.Ident2)
		default:
			v += fmt.Sprintf(":%v", FormatInt(t.Value2, t.Value2Fmt))
		}
	}
	v += fmtTypeList(t.Args)
	return v
}

func fmtTypeList(args []*Type) string {
	if len(args) == 0 {
		return ""
	}
	w := new(bytes.Buffer)
	fmt.Fprintf(w, "[")
	for i, t := range args {
		fmt.Fprintf(w, "%v%v", comma(i, ""), fmtType(t))
	}
	fmt.Fprintf(w, "]")
	return w.String()
}

func fmtIdentList(args []*Ident) string {
	if len(args) == 0 {
		return ""
	}
	w := new(bytes.Buffer)
	fmt.Fprintf(w, "[")
	for i, arg := range args {
		fmt.Fprintf(w, "%v%v", comma(i, ""), arg.Name)
	}
	fmt.Fprintf(w, "]")
	return w.String()
}

func fmtInt(i *Int) string {
	switch {
	case i.Ident != "":
		return i.Ident
	case i.CExpr != "":
		return fmt.Sprintf("%v", i.CExpr)
	default:
		return FormatInt(i.Value, i.ValueFmt)
	}
}

func comma(i int, or string) string {
	if i == 0 {
		return or
	}
	return ", "
}