普通文本  |  232行  |  7.38 KB

# This test generates all variants of wmma intrinsics and verifies that LLVM
# generates correct instructions for them.

# RUN: python %s > %t.ll
# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll

from itertools import product
from string import Template

def make_wmma_slice_ty(abcd, itype):
  elt_ty = "<2 x half>" if itype == "f16" else "float"
  num_elts = 4 if abcd in "cd" and itype == "f16" else 8;
  return [elt_ty] * num_elts

def make_wmma_ld_ret_ty(abc, itype):
  return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))

# returns address space
def get_aspace(space):
  space_map = {
      ".global" : 1,
      ".shared" : 3,
      ".const"  : 4,
      ".local"  : 5,
      ".param"  : 101,
      ""        : 0,
      ".generic": 0
  }
  return space_map[space];

def get_pspace(space):
  return "p%di8" % get_aspace(space);

# Convenient test patterns.
check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)

known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"]

def gen_wmma_load_tests():
  load_template = """
declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});

; CHECK-LABEL: .func {{.*}}test_${function}(
define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
; CHECK: ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
  ret ${ret_ty} %v0;
}

; CHECK-LABEL: .func{{.*}}test_${function}_o(
define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
; CHECK: ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
  ret ${ret_ty} %v0;
}
"""
  intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
  instruction_template = "wmma.load.${abc}.sync.${layout}.${geom}${space}.${itype}"

  for geom, abc, layout, space, stride, itype in product(
      known_geoms,
      "abc",
      ["row","col"],
      ["",".shared",".global"],
      ["", ".stride"],
      ["f16", "f32"]):

    params = {
        "abc" : abc,
        "layout" : layout,
        "space" : space,
        "stride" : stride,
        "itype" : itype,
        "pspace" : get_pspace(space),
        "as"     : "addrspace(%d)" % get_aspace(space),
        "geom"   : geom,
    }

    if itype == "f32" and abc != "c":
      continue

    test_params = params
    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
    test_params["function"] = test_params["intrinsic"].replace(".","_")
    test_params["instruction"] = Template(instruction_template).substitute(params)
    test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
    if abc == "c" :
      test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
    else:
      test_params["check_result"] = check_f16_8

    if stride:
      test_params["extra_args"] = ", i32 %stride";
      test_params["stride_pattern"] = ", %r{{[0-9]+}}"
    else:
      test_params["extra_args"] = ""
      test_params["stride_pattern"] = ""

    print(Template(load_template).substitute(test_params))

def make_wmma_slice_args(itype, abcd, prefix="v"):
  return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t
                  in enumerate(make_wmma_slice_ty(abcd, itype))])

def gen_wmma_store_tests():
  store_template = """
declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});

; CHECK-LABEL: .func {{.*}}test_${function}(
define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
; CHECK: {${check_args}}
; CHECK: ${stride_pattern}
  call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
  ret void
}

; CHECK-LABEL: .func{{.*}}test_${function}_o(
define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
; CHECK: ${check_args}
; CHECK: ${stride_pattern}
  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
  call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
  ret void
}
"""
  intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
  instruction_template = "wmma.store.${abc}.sync.${layout}.${geom}${space}.${itype}"

  for geom, abc, layout, space, stride, itype in product(
      known_geoms,
      "d",
      ["row","col"],
      ["",".shared",".global"],
      ["", ".stride"],
      ["f16", "f32"]):

    params = {
        "abc" : abc,
        "layout" : layout,
        "space" : space,
        "stride" : stride,
        "itype" : itype,
        "pspace" : get_pspace(space),
        "as"     : "addrspace(%d)" % get_aspace(space),
        "geom"   : geom,
    }

    test_params = params
    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
    test_params["function"] = test_params["intrinsic"].replace(".","_")
    test_params["instruction"] = Template(instruction_template).substitute(params)
    test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
    test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
    if stride:
      test_params["extra_args"] = ", i32 %stride";
      test_params["stride_pattern"] = ", %r{{[0-9]+}};"
    else:
      test_params["extra_args"] = ""
      test_params["stride_pattern"] = ";"
    test_params["args"] = make_wmma_slice_args(itype, "d");

    print(Template(store_template).substitute(test_params))

def gen_wmma_mma_tests():
  mma_template = """
declare ${ret_ty} @${intrinsic}(
        ${args});

; CHECK-LABEL: .func {{.*}}test_${function}(
define ${ret_ty} @test_${function}(
        ${args}) {
; CHECK: ${instruction}
; CHECK-NEXT: ${check_d}
; CHECK-NEXT: ${check_ab}
; CHECK-NEXT: ${check_ab}
; CHECK-NEXT: ${check_c}
  %r = call ${ret_ty} @${intrinsic}(
        ${args});
  ret ${ret_ty} %r;
}
"""
  intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
  instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"

  for geom, alayout, blayout, ctype, dtype, satf in product(
      known_geoms,
      ["row","col"],
      ["row","col"],
      ["f16", "f32"],
      ["f16", "f32"],
      [".satfinite", ""]):

    params = {
        "alayout" : alayout,
        "blayout" : blayout,
        "ctype" : ctype,
        "dtype" : dtype,
        "satf"  : satf,
        "geom"  : geom,
    }

    test_params = params
    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
    test_params["function"] = test_params["intrinsic"].replace(".", "_")
    test_params["instruction"] = Template(instruction_template).substitute(params)
    test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
    test_params["check_ab"] = check_f16_8
    test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8
    test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8
    args = ",\n        ".join(make_wmma_slice_args(t, abcd, prefix=abcd)
                              for abcd, t in (("a", "f16"),
                                              ("b", "f16"),
                                              ("c", ctype)))
    test_params["args"] = args
    print(Template(mma_template).substitute(test_params))

def main():
  gen_wmma_load_tests()
  gen_wmma_store_tests()
  gen_wmma_mma_tests()

main()