// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package net

import (
	"context"
	"internal/bytealg"
	"os"
	"syscall"
)

// Probe probes IPv4, IPv6 and IPv4-mapped IPv6 communication
// capabilities.
//
// Plan 9 uses IPv6 natively, see ip(3).
func (p *ipStackCapabilities) probe() {
	p.ipv4Enabled = probe(netdir+"/iproute", "4i")
	p.ipv6Enabled = probe(netdir+"/iproute", "6i")
	if p.ipv4Enabled && p.ipv6Enabled {
		p.ipv4MappedIPv6Enabled = true
	}
}

func probe(filename, query string) bool {
	var file *file
	var err error
	if file, err = open(filename); err != nil {
		return false
	}
	defer file.close()

	r := false
	for line, ok := file.readLine(); ok && !r; line, ok = file.readLine() {
		f := getFields(line)
		if len(f) < 3 {
			continue
		}
		for i := 0; i < len(f); i++ {
			if query == f[i] {
				r = true
				break
			}
		}
	}
	return r
}

// parsePlan9Addr parses address of the form [ip!]port (e.g. 127.0.0.1!80).
func parsePlan9Addr(s string) (ip IP, iport int, err error) {
	addr := IPv4zero // address contains port only
	i := bytealg.IndexByteString(s, '!')
	if i >= 0 {
		addr = ParseIP(s[:i])
		if addr == nil {
			return nil, 0, &ParseError{Type: "IP address", Text: s}
		}
	}
	p, _, ok := dtoi(s[i+1:])
	if !ok {
		return nil, 0, &ParseError{Type: "port", Text: s}
	}
	if p < 0 || p > 0xFFFF {
		return nil, 0, &AddrError{Err: "invalid port", Addr: string(p)}
	}
	return addr, p, nil
}

func readPlan9Addr(proto, filename string) (addr Addr, err error) {
	var buf [128]byte

	f, err := os.Open(filename)
	if err != nil {
		return
	}
	defer f.Close()
	n, err := f.Read(buf[:])
	if err != nil {
		return
	}
	ip, port, err := parsePlan9Addr(string(buf[:n]))
	if err != nil {
		return
	}
	switch proto {
	case "tcp":
		addr = &TCPAddr{IP: ip, Port: port}
	case "udp":
		addr = &UDPAddr{IP: ip, Port: port}
	default:
		return nil, UnknownNetworkError(proto)
	}
	return addr, nil
}

func startPlan9(ctx context.Context, net string, addr Addr) (ctl *os.File, dest, proto, name string, err error) {
	var (
		ip   IP
		port int
	)
	switch a := addr.(type) {
	case *TCPAddr:
		proto = "tcp"
		ip = a.IP
		port = a.Port
	case *UDPAddr:
		proto = "udp"
		ip = a.IP
		port = a.Port
	default:
		err = UnknownNetworkError(net)
		return
	}

	if port > 65535 {
		err = InvalidAddrError("port should be < 65536")
		return
	}

	clone, dest, err := queryCS1(ctx, proto, ip, port)
	if err != nil {
		return
	}
	f, err := os.OpenFile(clone, os.O_RDWR, 0)
	if err != nil {
		return
	}
	var buf [16]byte
	n, err := f.Read(buf[:])
	if err != nil {
		f.Close()
		return
	}
	return f, dest, proto, string(buf[:n]), nil
}

func fixErr(err error) {
	oe, ok := err.(*OpError)
	if !ok {
		return
	}
	nonNilInterface := func(a Addr) bool {
		switch a := a.(type) {
		case *TCPAddr:
			return a == nil
		case *UDPAddr:
			return a == nil
		case *IPAddr:
			return a == nil
		default:
			return false
		}
	}
	if nonNilInterface(oe.Source) {
		oe.Source = nil
	}
	if nonNilInterface(oe.Addr) {
		oe.Addr = nil
	}
	if pe, ok := oe.Err.(*os.PathError); ok {
		if _, ok = pe.Err.(syscall.ErrorString); ok {
			oe.Err = pe.Err
		}
	}
}

func dialPlan9(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, err error) {
	defer func() { fixErr(err) }()
	type res struct {
		fd  *netFD
		err error
	}
	resc := make(chan res)
	go func() {
		testHookDialChannel()
		fd, err := dialPlan9Blocking(ctx, net, laddr, raddr)
		select {
		case resc <- res{fd, err}:
		case <-ctx.Done():
			if fd != nil {
				fd.Close()
			}
		}
	}()
	select {
	case res := <-resc:
		return res.fd, res.err
	case <-ctx.Done():
		return nil, mapErr(ctx.Err())
	}
}

func dialPlan9Blocking(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, err error) {
	if isWildcard(raddr) {
		raddr = toLocal(raddr, net)
	}
	f, dest, proto, name, err := startPlan9(ctx, net, raddr)
	if err != nil {
		return nil, err
	}
	_, err = f.WriteString("connect " + dest)
	if err != nil {
		f.Close()
		return nil, err
	}
	data, err := os.OpenFile(netdir+"/"+proto+"/"+name+"/data", os.O_RDWR, 0)
	if err != nil {
		f.Close()
		return nil, err
	}
	laddr, err = readPlan9Addr(proto, netdir+"/"+proto+"/"+name+"/local")
	if err != nil {
		data.Close()
		f.Close()
		return nil, err
	}
	return newFD(proto, name, nil, f, data, laddr, raddr)
}

func listenPlan9(ctx context.Context, net string, laddr Addr) (fd *netFD, err error) {
	defer func() { fixErr(err) }()
	f, dest, proto, name, err := startPlan9(ctx, net, laddr)
	if err != nil {
		return nil, err
	}
	_, err = f.WriteString("announce " + dest)
	if err != nil {
		f.Close()
		return nil, err
	}
	laddr, err = readPlan9Addr(proto, netdir+"/"+proto+"/"+name+"/local")
	if err != nil {
		f.Close()
		return nil, err
	}
	return newFD(proto, name, nil, f, nil, laddr, nil)
}

func (fd *netFD) netFD() (*netFD, error) {
	return newFD(fd.net, fd.n, fd.listen, fd.ctl, fd.data, fd.laddr, fd.raddr)
}

func (fd *netFD) acceptPlan9() (nfd *netFD, err error) {
	defer func() { fixErr(err) }()
	if err := fd.pfd.ReadLock(); err != nil {
		return nil, err
	}
	defer fd.pfd.ReadUnlock()
	listen, err := os.Open(fd.dir + "/listen")
	if err != nil {
		return nil, err
	}
	var buf [16]byte
	n, err := listen.Read(buf[:])
	if err != nil {
		listen.Close()
		return nil, err
	}
	name := string(buf[:n])
	ctl, err := os.OpenFile(netdir+"/"+fd.net+"/"+name+"/ctl", os.O_RDWR, 0)
	if err != nil {
		listen.Close()
		return nil, err
	}
	data, err := os.OpenFile(netdir+"/"+fd.net+"/"+name+"/data", os.O_RDWR, 0)
	if err != nil {
		listen.Close()
		ctl.Close()
		return nil, err
	}
	raddr, err := readPlan9Addr(fd.net, netdir+"/"+fd.net+"/"+name+"/remote")
	if err != nil {
		listen.Close()
		ctl.Close()
		data.Close()
		return nil, err
	}
	return newFD(fd.net, name, listen, ctl, data, fd.laddr, raddr)
}

func isWildcard(a Addr) bool {
	var wildcard bool
	switch a := a.(type) {
	case *TCPAddr:
		wildcard = a.isWildcard()
	case *UDPAddr:
		wildcard = a.isWildcard()
	case *IPAddr:
		wildcard = a.isWildcard()
	}
	return wildcard
}

func toLocal(a Addr, net string) Addr {
	switch a := a.(type) {
	case *TCPAddr:
		a.IP = loopbackIP(net)
	case *UDPAddr:
		a.IP = loopbackIP(net)
	case *IPAddr:
		a.IP = loopbackIP(net)
	}
	return a
}