// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains the check for http.Response values being used before
// checking for errors.
package main
import (
"go/ast"
"go/types"
)
var (
httpResponseType types.Type
httpClientType types.Type
)
func init() {
if typ := importType("net/http", "Response"); typ != nil {
httpResponseType = typ
}
if typ := importType("net/http", "Client"); typ != nil {
httpClientType = typ
}
// if http.Response or http.Client are not defined don't register this check.
if httpResponseType == nil || httpClientType == nil {
return
}
register("httpresponse",
"check errors are checked before using an http Response",
checkHTTPResponse, callExpr)
}
func checkHTTPResponse(f *File, node ast.Node) {
call := node.(*ast.CallExpr)
if !isHTTPFuncOrMethodOnClient(f, call) {
return // the function call is not related to this check.
}
finder := &blockStmtFinder{node: call}
ast.Walk(finder, f.file)
stmts := finder.stmts()
if len(stmts) < 2 {
return // the call to the http function is the last statement of the block.
}
asg, ok := stmts[0].(*ast.AssignStmt)
if !ok {
return // the first statement is not assignment.
}
resp := rootIdent(asg.Lhs[0])
if resp == nil {
return // could not find the http.Response in the assignment.
}
def, ok := stmts[1].(*ast.DeferStmt)
if !ok {
return // the following statement is not a defer.
}
root := rootIdent(def.Call.Fun)
if root == nil {
return // could not find the receiver of the defer call.
}
if resp.Obj == root.Obj {
f.Badf(root.Pos(), "using %s before checking for errors", resp.Name)
}
}
// isHTTPFuncOrMethodOnClient checks whether the given call expression is on
// either a function of the net/http package or a method of http.Client that
// returns (*http.Response, error).
func isHTTPFuncOrMethodOnClient(f *File, expr *ast.CallExpr) bool {
fun, _ := expr.Fun.(*ast.SelectorExpr)
sig, _ := f.pkg.types[fun].Type.(*types.Signature)
if sig == nil {
return false // the call is not on of the form x.f()
}
res := sig.Results()
if res.Len() != 2 {
return false // the function called does not return two values.
}
if ptr, ok := res.At(0).Type().(*types.Pointer); !ok || !types.Identical(ptr.Elem(), httpResponseType) {
return false // the first return type is not *http.Response.
}
if !types.Identical(res.At(1).Type().Underlying(), errorType) {
return false // the second return type is not error
}
typ := f.pkg.types[fun.X].Type
if typ == nil {
id, ok := fun.X.(*ast.Ident)
return ok && id.Name == "http" // function in net/http package.
}
if types.Identical(typ, httpClientType) {
return true // method on http.Client.
}
ptr, ok := typ.(*types.Pointer)
return ok && types.Identical(ptr.Elem(), httpClientType) // method on *http.Client.
}
// blockStmtFinder is an ast.Visitor that given any ast node can find the
// statement containing it and its succeeding statements in the same block.
type blockStmtFinder struct {
node ast.Node // target of search
stmt ast.Stmt // innermost statement enclosing argument to Visit
block *ast.BlockStmt // innermost block enclosing argument to Visit.
}
// Visit finds f.node performing a search down the ast tree.
// It keeps the last block statement and statement seen for later use.
func (f *blockStmtFinder) Visit(node ast.Node) ast.Visitor {
if node == nil || f.node.Pos() < node.Pos() || f.node.End() > node.End() {
return nil // not here
}
switch n := node.(type) {
case *ast.BlockStmt:
f.block = n
case ast.Stmt:
f.stmt = n
}
if f.node.Pos() == node.Pos() && f.node.End() == node.End() {
return nil // found
}
return f // keep looking
}
// stmts returns the statements of f.block starting from the one including f.node.
func (f *blockStmtFinder) stmts() []ast.Stmt {
for i, v := range f.block.List {
if f.stmt == v {
return f.block.List[i:]
}
}
return nil
}
// rootIdent finds the root identifier x in a chain of selections x.y.z, or nil if not found.
func rootIdent(n ast.Node) *ast.Ident {
switch n := n.(type) {
case *ast.SelectorExpr:
return rootIdent(n.X)
case *ast.Ident:
return n
default:
return nil
}
}