"""Qnt primitive used by the GEMM function.
"""
import neon_emitter
class Error(Exception):
"""Module level error."""
class ConfigurationError(Error):
"""Unsupported configuration."""
class QntLane(object):
def __init__(self, source, output, offset, load_1, load_2):
self.source = source
self.output = output
self.offset = offset
self.load_1 = load_1
self.load_2 = load_2
def BuildName(lanes, leftovers, aligned):
name = 'qnt_%dx8' % lanes
if leftovers:
name += '_%d' % leftovers
if aligned:
name += '_aligned'
return name
def LoadAndDuplicateOffsets(emitter, registers, lanes, offsets):
if lanes == 1 or lanes == 2 or lanes == 3:
offset_registers = []
for unused_i in range(0, lanes):
register = registers.QuadRegister()
emitter.EmitVLoadA('1.32', [emitter.AllLanes(registers.Low(register)),
emitter.AllLanes(registers.High(register))],
emitter.DereferenceIncrement(offsets, 32))
offset_registers.append(register)
return offset_registers
else:
raise ConfigurationError('Unsupported number of lanes: %d' % lanes)
def GenerateQntLanes(emitter, registers, qnt_lanes, source, stride, destination,
destination_stride, offsets):
"""Prepare lanes for reading unquantized multiplication results."""
offset_registers = LoadAndDuplicateOffsets(emitter, registers, qnt_lanes,
offsets)
lanes = []
last_input_register = source
last_output_register = destination
for i in range(0, qnt_lanes):
if not i:
lanes.append(QntLane(source,
destination,
offset_registers[i],
registers.QuadRegister(), # load 1
registers.QuadRegister())) # load 2
else:
input_register = registers.GeneralRegister()
output_register = registers.GeneralRegister()
lanes.append(QntLane(input_register,
output_register,
offset_registers[i],
registers.QuadRegister(), # load 1
registers.QuadRegister())) # load 2
emitter.EmitAdd(input_register, last_input_register, stride)
emitter.EmitAdd(output_register, last_output_register, destination_stride)
last_input_register = input_register
last_output_register = output_register
return lanes
def DuplicateRegister(emitter, registers, value):
register = registers.QuadRegister()
emitter.EmitVDup('32', register, value)
return register
def GenerateQuantize(emitter, registers, lanes, lane_temps,
multiplicative_offset, rounding_offset, shift):
"""Inner loop for quantization: add offsets, multiply, round, shift."""
for lane in lanes:
emitter.EmitVAdd('i32', lane[0], lane[0], lane[1])
for lane in lanes:
emitter.EmitVMul('i32', lane[0], lane[0], multiplicative_offset)
for lane in lanes:
emitter.EmitVAdd('i32', lane[0], lane[0], rounding_offset)
for lane in lanes:
emitter.EmitVShl('s32', lane[0], lane[0], shift)
for lane in lanes:
emitter.EmitVQmovn('s32', lane[2], lane[0])
for lane_temp in lane_temps:
emitter.EmitVQmovun('s16', registers.Low(lane_temp), lane_temp)
def GenerateLoadQuantizeStore(emitter, registers, lanes, multiplicative_offset,
rounding_offset, shift, alignment):
"""Load unquantized data from lanes, quantize, store final result."""
lane_temps = []
for lane in lanes:
lane_temps.append(registers.QuadRegister())
for lane in lanes:
emitter.EmitVLoadA(
'1.32', [registers.Low(lane.load_1), registers.High(lane.load_1),
registers.Low(lane.load_2), registers.High(lane.load_2)],
emitter.DereferenceIncrement(lane.source, 64))
for lane in lanes:
emitter.EmitPld(lane.source)
quantize_setup = []
for (lane_temp, lane) in zip(lane_temps, lanes):
quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)])
quantize_setup.append([lane.load_2, lane.offset, registers.High(lane_temp)])
GenerateQuantize(emitter, registers, quantize_setup, lane_temps,
multiplicative_offset, rounding_offset, shift)
for (lane_temp, lane) in zip(lane_temps, lanes):
emitter.EmitVStore('1.8', registers.Low(lane_temp),
emitter.DereferenceIncrement(lane.output, alignment))
for lane_temp in lane_temps:
registers.FreeRegister(lane_temp)
def GenerateLoadLeftovers(emitter, registers, leftovers, lanes):
"""Handle non multiply of 8 leftover loading."""
if leftovers == 1:
for lane in lanes:
emitter.EmitVLoad('1.32', emitter.Lane(
registers.Low(lane.load_1), 0),
emitter.Dereference(lane.source, None))
elif leftovers == 2:
for lane in lanes:
emitter.EmitVLoad('1.32', registers.Low(lane.load_1),
emitter.Dereference(lane.source, 64))
elif leftovers == 3:
for lane in lanes:
emitter.EmitVLoad('1.32', registers.Low(lane.load_1),
emitter.DereferenceIncrement(lane.source, 64))
for lane in lanes:
emitter.EmitVLoad('1.32', emitter.Lane(
registers.High(lane.load_1), 0),
emitter.Dereference(lane.source, None))
elif leftovers == 4:
for lane in lanes:
emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1),
registers.High(lane.load_1)],
emitter.Dereference(lane.source, 64))
elif leftovers == 5:
for lane in lanes:
emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1),
registers.High(lane.load_1)],
emitter.DereferenceIncrement(lane.source, 64))
for lane in lanes:
emitter.EmitVLoad('1.32', emitter.Lane(
registers.Low(lane.load_2), 0),
emitter.Dereference(lane.source, None))
elif leftovers == 6:
for lane in lanes:
emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1),
registers.High(lane.load_1),
registers.Low(lane.load_2)],
emitter.Dereference(lane.source, 64))
elif leftovers == 7:
for lane in lanes:
emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1),
registers.High(lane.load_1),
registers.Low(lane.load_2)],
emitter.DereferenceIncrement(lane.source, 64))
for lane in lanes:
emitter.EmitVLoad('1.32', emitter.Lane(
registers.High(lane.load_2), 0),
emitter.Dereference(lane.source, None))
else:
raise ConfigurationError('Unsuported leftover count: %d' % leftovers)
def GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes):
"""Handle non multiply of 8 leftover storing."""
setup = []
for (temp, lane) in zip(lane_temps, lanes):
setup.append([registers.Low(temp), lane.output])
if leftovers == 1:
for lane in setup:
emitter.EmitVStore('1.8', emitter.Lane(lane[0], 0),
emitter.Dereference(lane[1], None))
elif leftovers == 2:
for lane in setup:
emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0),
emitter.Dereference(lane[1], None))
elif leftovers == 3:
for lane in setup:
emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0),
emitter.DereferenceIncrement(lane[1], None))
for lane in setup:
emitter.EmitVStore('1.8', emitter.Lane(lane[0], 2),
emitter.Dereference(lane[1], None))
elif leftovers == 4:
for lane in setup:
emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
emitter.Dereference(lane[1], None))
elif leftovers == 5:
for lane in setup:
emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
emitter.DereferenceIncrement(lane[1], None))
for lane in setup:
emitter.EmitVStore('1.8', emitter.Lane(lane[0], 4),
emitter.Dereference(lane[1], None))
elif leftovers == 6:
for lane in setup:
emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
emitter.DereferenceIncrement(lane[1], None))
for lane in setup:
emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2),
emitter.Dereference(lane[1], None))
elif leftovers == 7:
for lane in setup:
emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
emitter.DereferenceIncrement(lane[1], None))
for lane in setup:
emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2),
emitter.DereferenceIncrement(lane[1], None))
for lane in setup:
emitter.EmitVStore('1.8', emitter.Lane(lane[0], 6),
emitter.DereferenceIncrement(lane[1], None))
else:
raise ConfigurationError('Unsupported leftovers count: %d' % leftovers)
def GenerateLeftoverLoadQuantizeStore(emitter, registers, leftovers, lanes,
multiplicative_offset, rounding_offset,
shift):
"""Handle leftovers if row size not a multiply of 8."""
lane_temps = []
for lane in lanes:
lane_temps.append(registers.QuadRegister())
GenerateLoadLeftovers(emitter, registers, leftovers, lanes)
quantize_setup = []
for (lane_temp, lane) in zip(lane_temps, lanes):
quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)])
if leftovers > 4:
quantize_setup.append([lane.load_2, lane.offset, registers.High(lane_temp)
])
GenerateQuantize(emitter, registers, quantize_setup, lane_temps,
multiplicative_offset, rounding_offset, shift)
GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes)
def GenerateQntNx8(emitter, qnt_lanes, leftovers, aligned):
"""Emits optimized quantization code for given lanes and row size."""
if leftovers < 0 or leftovers > 7:
raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.')
if qnt_lanes < 1 or qnt_lanes > 3:
raise ConfigurationError('Qnt_lanes should should be 1, 2 or 3.')
name = BuildName(qnt_lanes, leftovers, aligned)
emitter.EmitFunctionBeginA(
name,
[['const std::int32_t*', 'source'], ['std::int32_t', 'count'],
['std::int32_t', 'stride'], ['const std::int32_t*', 'offsets'],
['std::uint8_t*', 'destination'], ['std::int32_t', 'destination_stride'],
['std::int32_t', 'multiplicative_offset'],
['std::int32_t', 'rounding_offset'], ['std::int32_t', 'shift']], 'void')
emitter.EmitAssert('count %% 8 == %d' % leftovers)
emitter.EmitAssert('count >= 8')
emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0')
if aligned:
emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0')
if qnt_lanes > 1:
emitter.EmitAssert('destination_stride % 8 == 0')
emitter.EmitAsmBegin()
registers = neon_emitter.NeonRegisters()
count = registers.MapParameter('count')
multiplicative_offset = DuplicateRegister(
emitter, registers, registers.MapParameter('multiplicative_offset'))
rounding_offset = DuplicateRegister(emitter, registers,
registers.MapParameter('rounding_offset'))
shift = DuplicateRegister(emitter, registers, registers.MapParameter('shift'))
lanes = GenerateQntLanes(
emitter, registers, qnt_lanes, registers.MapParameter('source'),
registers.MapParameter('stride'), registers.MapParameter('destination'),
registers.MapParameter('destination_stride'),
registers.MapParameter('offsets'))
if leftovers:
emitter.EmitSubs(count, count, emitter.ImmediateConstant(leftovers))
emitter.EmitBeqFront(2)
emitter.EmitNewline()
emitter.EmitNumericalLabel(1)
emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
GenerateLoadQuantizeStore(emitter, registers, lanes, multiplicative_offset,
rounding_offset, shift, 64 if aligned else None)
emitter.EmitNewline()
emitter.EmitBneBack(1)
if leftovers:
emitter.EmitNumericalLabel(2)
GenerateLeftoverLoadQuantizeStore(emitter, registers, leftovers, lanes,
multiplicative_offset, rounding_offset,
shift)
emitter.EmitAsmEnd(registers.MappedParameters(), [],
registers.Clobbers() + ['cc', 'memory'])
emitter.EmitFunctionEnd()
def BuildMultiQuantizeName(aligned, rows):
name = 'multi_qnt_%dx8' % rows
if aligned:
name = '%s_aligned' % name
return name
def GenerateMultiQuantize(emitter, aligned, rows):
"""Emit main quantization code that switches between optimized versions."""
name = BuildMultiQuantizeName(aligned, rows)
emitter.EmitFunctionBeginA(
name,
[['const std::int32_t*', 'source'], ['std::int32_t', 'count'],
['std::int32_t', 'stride'], ['const std::int32_t*', 'offsets'],
['std::uint8_t*', 'destination'], ['std::int32_t', 'destination_stride'],
['std::int32_t', 'multiplicative_offset'],
['std::int32_t', 'rounding_offset'], ['std::int32_t', 'shift']], 'void')
emitter.EmitSwitch('count % 8')
for leftovers in range(0, 8):
emitter.EmitCase(leftovers)
emitter.PushIndent()
emitter.EmitCall(
BuildName(rows, leftovers, aligned),
['source', 'count', 'stride', 'offsets', 'destination',
'destination_stride', 'multiplicative_offset', 'rounding_offset',
'shift'])
emitter.EmitBreak()
emitter.PopIndent()
emitter.EmitSwitchEnd()
emitter.EmitFunctionEnd()
def GenerateFunctions(neon, cc):
for aligned in [True, False]:
for lanes in range(1, 4):
for leftovers in range(0, 8):
GenerateQntNx8(neon, lanes, leftovers, aligned)
neon.EmitNewline()
for aligned in [True, False]:
for rows in range(1, 4):
GenerateMultiQuantize(cc, aligned, rows)
cc.EmitNewline()