Golang程序  |  127行  |  2.91 KB

// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// +build ignore

// This program is run via "go generate" (via a directive in sort.go)
// to generate zfuncversion.go.
//
// It copies sort.go to zfuncversion.go, only retaining funcs which
// take a "data Interface" parameter, and renaming each to have a
// "_func" suffix and taking a "data lessSwap" instead. It then rewrites
// each internal function call to the appropriate _func variants.

package main

import (
	"bytes"
	"go/ast"
	"go/format"
	"go/parser"
	"go/token"
	"io/ioutil"
	"log"
	"regexp"
)

var fset = token.NewFileSet()

func main() {
	af, err := parser.ParseFile(fset, "sort.go", nil, 0)
	if err != nil {
		log.Fatal(err)
	}
	af.Doc = nil
	af.Imports = nil
	af.Comments = nil

	var newDecl []ast.Decl
	for _, d := range af.Decls {
		fd, ok := d.(*ast.FuncDecl)
		if !ok {
			continue
		}
		if fd.Recv != nil || fd.Name.IsExported() {
			continue
		}
		typ := fd.Type
		if len(typ.Params.List) < 1 {
			continue
		}
		arg0 := typ.Params.List[0]
		arg0Name := arg0.Names[0].Name
		arg0Type := arg0.Type.(*ast.Ident)
		if arg0Name != "data" || arg0Type.Name != "Interface" {
			continue
		}
		arg0Type.Name = "lessSwap"

		newDecl = append(newDecl, fd)
	}
	af.Decls = newDecl
	ast.Walk(visitFunc(rewriteCalls), af)

	var out bytes.Buffer
	if err := format.Node(&out, fset, af); err != nil {
		log.Fatalf("format.Node: %v", err)
	}

	// Get rid of blank lines after removal of comments.
	src := regexp.MustCompile(`\n{2,}`).ReplaceAll(out.Bytes(), []byte("\n"))

	// Add comments to each func, for the lost reader.
	// This is so much easier than adding comments via the AST
	// and trying to get position info correct.
	src = regexp.MustCompile(`(?m)^func (\w+)`).ReplaceAll(src, []byte("\n// Auto-generated variant of sort.go:$1\nfunc ${1}_func"))

	// Final gofmt.
	src, err = format.Source(src)
	if err != nil {
		log.Fatalf("format.Source: %v on\n%s", err, src)
	}

	out.Reset()
	out.WriteString(`// Code generated from sort.go using genzfunc.go; DO NOT EDIT.

// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

`)
	out.Write(src)

	const target = "zfuncversion.go"
	if err := ioutil.WriteFile(target, out.Bytes(), 0644); err != nil {
		log.Fatal(err)
	}
}

type visitFunc func(ast.Node) ast.Visitor

func (f visitFunc) Visit(n ast.Node) ast.Visitor { return f(n) }

func rewriteCalls(n ast.Node) ast.Visitor {
	ce, ok := n.(*ast.CallExpr)
	if ok {
		rewriteCall(ce)
	}
	return visitFunc(rewriteCalls)
}

func rewriteCall(ce *ast.CallExpr) {
	ident, ok := ce.Fun.(*ast.Ident)
	if !ok {
		// e.g. skip SelectorExpr (data.Less(..) calls)
		return
	}
	// skip casts
	if ident.Name == "int" || ident.Name == "uint" {
		return
	}
	if len(ce.Args) < 1 {
		return
	}
	ident.Name += "_func"
}