// Copyright 2017 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.

package prog

import (
	"fmt"
)

type CsumChunkKind int

const (
	CsumChunkArg CsumChunkKind = iota
	CsumChunkConst
)

type CsumInfo struct {
	Kind   CsumKind
	Chunks []CsumChunk
}

type CsumChunk struct {
	Kind  CsumChunkKind
	Arg   Arg    // for CsumChunkArg
	Value uint64 // for CsumChunkConst
	Size  uint64 // for CsumChunkConst
}

func calcChecksumsCall(c *Call) (map[Arg]CsumInfo, map[Arg]struct{}) {
	var inetCsumFields, pseudoCsumFields []Arg

	// Find all csum fields.
	ForeachArg(c, func(arg Arg, _ *ArgCtx) {
		if typ, ok := arg.Type().(*CsumType); ok {
			switch typ.Kind {
			case CsumInet:
				inetCsumFields = append(inetCsumFields, arg)
			case CsumPseudo:
				pseudoCsumFields = append(pseudoCsumFields, arg)
			default:
				panic(fmt.Sprintf("unknown csum kind %v", typ.Kind))
			}
		}
	})

	if len(inetCsumFields) == 0 && len(pseudoCsumFields) == 0 {
		return nil, nil
	}

	// Build map of each field to its parent struct.
	parentsMap := make(map[Arg]Arg)
	ForeachArg(c, func(arg Arg, _ *ArgCtx) {
		if _, ok := arg.Type().(*StructType); ok {
			for _, field := range arg.(*GroupArg).Inner {
				parentsMap[InnerArg(field)] = arg
			}
		}
	})

	csumMap := make(map[Arg]CsumInfo)
	csumUses := make(map[Arg]struct{})

	// Calculate generic inet checksums.
	for _, arg := range inetCsumFields {
		typ, _ := arg.Type().(*CsumType)
		csummedArg := findCsummedArg(arg, typ, parentsMap)
		csumUses[csummedArg] = struct{}{}
		chunk := CsumChunk{CsumChunkArg, csummedArg, 0, 0}
		csumMap[arg] = CsumInfo{Kind: CsumInet, Chunks: []CsumChunk{chunk}}
	}

	// No need to continue if there are no pseudo csum fields.
	if len(pseudoCsumFields) == 0 {
		return csumMap, csumUses
	}

	// Extract ipv4 or ipv6 source and destination addresses.
	var ipSrcAddr, ipDstAddr Arg
	ForeachArg(c, func(arg Arg, _ *ArgCtx) {
		groupArg, ok := arg.(*GroupArg)
		if !ok {
			return
		}
		// syz_csum_* structs are used in tests
		switch groupArg.Type().Name() {
		case "ipv4_header", "syz_csum_ipv4_header":
			ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 4)
		case "ipv6_packet", "syz_csum_ipv6_header":
			ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 16)
		}
	})
	if ipSrcAddr == nil || ipDstAddr == nil {
		panic("no ipv4 nor ipv6 header found")
	}

	// Calculate pseudo checksums.
	for _, arg := range pseudoCsumFields {
		typ, _ := arg.Type().(*CsumType)
		csummedArg := findCsummedArg(arg, typ, parentsMap)
		protocol := uint8(typ.Protocol)
		var info CsumInfo
		if ipSrcAddr.Size() == 4 {
			info = composePseudoCsumIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol)
		} else {
			info = composePseudoCsumIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol)
		}
		csumMap[arg] = info
	}

	return csumMap, csumUses
}

func findCsummedArg(arg Arg, typ *CsumType, parentsMap map[Arg]Arg) Arg {
	if typ.Buf == "parent" {
		if csummedArg, ok := parentsMap[arg]; ok {
			return csummedArg
		}
		panic(fmt.Sprintf("parent for %v is not in parents map", typ.Name()))
	} else {
		for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] {
			if typ.Buf == parent.Type().Name() {
				return parent
			}
		}
	}
	panic(fmt.Sprintf("csum field '%v' references non existent field '%v'", typ.FieldName(), typ.Buf))
}

func composePseudoCsumIPv4(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) CsumInfo {
	info := CsumInfo{Kind: CsumInet}
	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0})
	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0})
	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap16(uint16(protocol))), 2})
	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap16(uint16(tcpPacket.Size()))), 2})
	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0})
	return info
}

func composePseudoCsumIPv6(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) CsumInfo {
	info := CsumInfo{Kind: CsumInet}
	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0})
	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0})
	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap32(uint32(tcpPacket.Size()))), 4})
	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap32(uint32(protocol))), 4})
	info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0})
	return info
}

func extractHeaderParams(arg *GroupArg, size uint64) (Arg, Arg) {
	srcAddr := getFieldByName(arg, "src_ip")
	dstAddr := getFieldByName(arg, "dst_ip")
	if srcAddr.Size() != size || dstAddr.Size() != size {
		panic(fmt.Sprintf("src/dst_ip fields in %v must be %v bytes", arg.Type().Name(), size))
	}
	return srcAddr, dstAddr
}

func getFieldByName(arg *GroupArg, name string) Arg {
	for _, field := range arg.Inner {
		if field.Type().FieldName() == name {
			return field
		}
	}
	panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type().Name()))
}