Golang程序  |  154行  |  4.21 KB

// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// This file contains the check for http.Response values being used before
// checking for errors.

package main

import (
	"go/ast"
	"go/types"
)

var (
	httpResponseType types.Type
	httpClientType   types.Type
)

func init() {
	if typ := importType("net/http", "Response"); typ != nil {
		httpResponseType = typ
	}
	if typ := importType("net/http", "Client"); typ != nil {
		httpClientType = typ
	}
	// if http.Response or http.Client are not defined don't register this check.
	if httpResponseType == nil || httpClientType == nil {
		return
	}

	register("httpresponse",
		"check errors are checked before using an http Response",
		checkHTTPResponse, callExpr)
}

func checkHTTPResponse(f *File, node ast.Node) {
	call := node.(*ast.CallExpr)
	if !isHTTPFuncOrMethodOnClient(f, call) {
		return // the function call is not related to this check.
	}

	finder := &blockStmtFinder{node: call}
	ast.Walk(finder, f.file)
	stmts := finder.stmts()
	if len(stmts) < 2 {
		return // the call to the http function is the last statement of the block.
	}

	asg, ok := stmts[0].(*ast.AssignStmt)
	if !ok {
		return // the first statement is not assignment.
	}
	resp := rootIdent(asg.Lhs[0])
	if resp == nil {
		return // could not find the http.Response in the assignment.
	}

	def, ok := stmts[1].(*ast.DeferStmt)
	if !ok {
		return // the following statement is not a defer.
	}
	root := rootIdent(def.Call.Fun)
	if root == nil {
		return // could not find the receiver of the defer call.
	}

	if resp.Obj == root.Obj {
		f.Badf(root.Pos(), "using %s before checking for errors", resp.Name)
	}
}

// isHTTPFuncOrMethodOnClient checks whether the given call expression is on
// either a function of the net/http package or a method of http.Client that
// returns (*http.Response, error).
func isHTTPFuncOrMethodOnClient(f *File, expr *ast.CallExpr) bool {
	fun, _ := expr.Fun.(*ast.SelectorExpr)
	sig, _ := f.pkg.types[fun].Type.(*types.Signature)
	if sig == nil {
		return false // the call is not on of the form x.f()
	}

	res := sig.Results()
	if res.Len() != 2 {
		return false // the function called does not return two values.
	}
	if ptr, ok := res.At(0).Type().(*types.Pointer); !ok || !types.Identical(ptr.Elem(), httpResponseType) {
		return false // the first return type is not *http.Response.
	}
	if !types.Identical(res.At(1).Type().Underlying(), errorType) {
		return false // the second return type is not error
	}

	typ := f.pkg.types[fun.X].Type
	if typ == nil {
		id, ok := fun.X.(*ast.Ident)
		return ok && id.Name == "http" // function in net/http package.
	}

	if types.Identical(typ, httpClientType) {
		return true // method on http.Client.
	}
	ptr, ok := typ.(*types.Pointer)
	return ok && types.Identical(ptr.Elem(), httpClientType) // method on *http.Client.
}

// blockStmtFinder is an ast.Visitor that given any ast node can find the
// statement containing it and its succeeding statements in the same block.
type blockStmtFinder struct {
	node  ast.Node       // target of search
	stmt  ast.Stmt       // innermost statement enclosing argument to Visit
	block *ast.BlockStmt // innermost block enclosing argument to Visit.
}

// Visit finds f.node performing a search down the ast tree.
// It keeps the last block statement and statement seen for later use.
func (f *blockStmtFinder) Visit(node ast.Node) ast.Visitor {
	if node == nil || f.node.Pos() < node.Pos() || f.node.End() > node.End() {
		return nil // not here
	}
	switch n := node.(type) {
	case *ast.BlockStmt:
		f.block = n
	case ast.Stmt:
		f.stmt = n
	}
	if f.node.Pos() == node.Pos() && f.node.End() == node.End() {
		return nil // found
	}
	return f // keep looking
}

// stmts returns the statements of f.block starting from the one including f.node.
func (f *blockStmtFinder) stmts() []ast.Stmt {
	for i, v := range f.block.List {
		if f.stmt == v {
			return f.block.List[i:]
		}
	}
	return nil
}

// rootIdent finds the root identifier x in a chain of selections x.y.z, or nil if not found.
func rootIdent(n ast.Node) *ast.Ident {
	switch n := n.(type) {
	case *ast.SelectorExpr:
		return rootIdent(n.X)
	case *ast.Ident:
		return n
	default:
		return nil
	}
}