"""Multiply primitive optimized for the gemv operation."""
import neon_emitter
class Error(Exception):
"""Module level error."""
class ConfigurationError(Error):
"""Unsupported configuration."""
def GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators,
count, lhs, rhs_1, rhs_2):
"""Emit inner loop for 1 row x M cols multiplication."""
emitter.EmitComment('General 1xM lanes loop.')
emitter.EmitNumericalLabel(1)
emitter.EmitNewline()
emitter.EmitComment('Subtract counter.')
emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
emitter.EmitNewline()
right_load = [registers.DoubleRegister() for unused_i in range(4)]
left_load = registers.DoubleRegister()
emitter.EmitVLoad('1.8', left_load, emitter.DereferenceIncrement(lhs, 64))
emitter.EmitVLoadA('1.8', right_load, emitter.DereferenceIncrement(rhs_1, 64))
emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
emitter.EmitPldOffset(rhs_1, emitter.ImmediateConstant(128))
multiply_results = [registers.QuadRegister() for unused_i in range(4)]
for i in range(4):
emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load)
emitter.EmitVLoadA('1.8', right_load[:lanes_count],
emitter.DereferenceIncrement(rhs_2, 64))
emitter.EmitPldOffset(rhs_2, emitter.ImmediateConstant(lanes_count * 32))
for i in range(4):
emitter.EmitVPadal('u16', aggregators[i], multiply_results[i])
for i in range(lanes_count):
emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load)
for i in range(lanes_count):
emitter.EmitVPadal('u16', aggregators[i + 4], multiply_results[i])
emitter.EmitNewline()
emitter.EmitComment('Loop break.')
emitter.EmitBneBack(1)
emitter.EmitNewline()
registers.FreeRegister(left_load)
registers.FreeRegisters(right_load)
registers.FreeRegisters(multiply_results)
def ReadLeft(emitter, registers, lhs):
register = registers.QuadRegister()
emitter.EmitVLoadA('1.32', [emitter.AllLanes(registers.Low(register)),
emitter.AllLanes(registers.High(register))],
emitter.Dereference(lhs, None))
return register
def ReadRight(emitter, registers, rhs, count):
if count == 1 or count == 2:
register = registers.DoubleRegister()
elif count == 3 or count == 4:
register = registers.QuadRegister()
else:
raise ConfigurationError('Unsupported elements no: %d' % count)
emitter.EmitVLoad('1.32', register, emitter.Dereference(rhs, 64))
return register
def DuplicateGeneralRegister(emitter, registers, general_register,
min_register):
duplicated = registers.QuadRegister(min_register)
emitter.EmitVDup('32', duplicated, general_register)
return duplicated
def GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators,
result_type, lhs_add, rhs_add, lhs, rhs_1,
rhs_2, results):
"""Generates assembly responsible for reducing the 4 way aggregators."""
if lhs_add:
left_offset = ReadLeft(emitter, registers, lhs)
else:
left_offset = None
if rhs_add:
right_offset_1 = ReadRight(emitter, registers, rhs_1, 4)
right_offset_2 = ReadRight(emitter, registers, rhs_2, lanes_count)
else:
right_offset_1 = None
right_offset_2 = None
if result_type is 'float':
result_scale = DuplicateGeneralRegister(
emitter, registers, registers.MapParameter('result_scale'), 4)
else:
result_scale = None
emitter.EmitNewline()
emitter.EmitComment('Horizontal reduce aggregators.')
for aggregator in aggregators:
emitter.EmitVPadd('u32', registers.Low(aggregator),
registers.Low(aggregator), registers.High(aggregator))
temp = aggregators[0]
emitter.EmitVPadd('u32', registers.Low(temp), registers.Low(aggregators[0]),
registers.Low(aggregators[1]))
emitter.EmitVPadd('u32', registers.High(temp), registers.Low(aggregators[2]),
registers.Low(aggregators[3]))
if lanes_count == 1:
temp_2 = registers.Low(aggregators[1])
emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]),
registers.Low(aggregators[4]))
elif lanes_count == 2:
temp_2 = registers.Low(aggregators[1])
emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]),
registers.Low(aggregators[5]))
elif lanes_count == 3:
temp_2 = aggregators[1]
emitter.EmitVPadd('u32', registers.Low(temp_2),
registers.Low(aggregators[4]),
registers.Low(aggregators[5]))
emitter.EmitVPadd('u32', registers.High(temp_2),
registers.Low(aggregators[6]),
registers.Low(aggregators[6]))
elif lanes_count == 4:
temp_2 = aggregators[1]
emitter.EmitVPadd('u32', registers.Low(temp_2),
registers.Low(aggregators[4]),
registers.Low(aggregators[5]))
emitter.EmitVPadd('u32', registers.High(temp_2),
registers.Low(aggregators[6]),
registers.Low(aggregators[7]))
else:
temp_2 = None
if lhs_add:
emitter.EmitNewline()
emitter.EmitComment('Add lhs offsets to aggregated rows.')
emitter.EmitVAdd('s32', temp, temp, left_offset)
if lanes_count == 1 or lanes_count == 2:
emitter.EmitVAdd('s32', temp_2, temp_2, registers.Low(left_offset))
elif lanes_count == 3 or lanes_count == 4:
emitter.EmitVAdd('s32', temp_2, temp_2, left_offset)
if rhs_add:
emitter.EmitNewline()
emitter.EmitComment('Add rhs offset to aggregated rows.')
emitter.EmitVAdd('s32', temp, temp, right_offset_1)
emitter.EmitVAdd('s32', temp_2, temp_2, right_offset_2)
if result_type is 'float':
emitter.EmitNewline()
emitter.EmitComment('Convert to float and scale.')
emitter.EmitVCvt('f32', 's32', temp, temp)
emitter.EmitVCvt('f32', 's32', temp_2, temp_2)
emitter.EmitVMul('f32', temp, temp, result_scale)
if lanes_count == 1 or lanes_count == 2:
emitter.EmitVMul('f32', temp_2, temp_2, registers.Low(result_scale))
elif lanes_count == 3 or lanes_count == 4:
emitter.EmitVMul('f32', temp_2, temp_2, result_scale)
emitter.EmitNewline()
emitter.EmitComment('Store results.')
if lanes_count == 1:
emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp)],
emitter.DereferenceIncrement(results, None))
emitter.EmitVStore('1.32', emitter.Lane(temp_2, 0),
emitter.Dereference(results, None))
elif lanes_count == 2:
emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp),
temp_2], emitter.Dereference(results, None))
elif lanes_count == 3:
emitter.EmitVStoreA(
'1.32',
[registers.Low(temp), registers.High(temp), registers.Low(temp_2)],
emitter.DereferenceIncrement(results, None))
emitter.EmitVStore('1.32', emitter.Lane(
registers.High(temp_2), 0), emitter.Dereference(results, None))
elif lanes_count == 4:
emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp),
registers.Low(temp_2), registers.High(temp_2)],
emitter.Dereference(results, None))
def BuildName(result_type, lhs_add, rhs_add, lanes):
name = 'mul_1x8_%dx8_%s' % (lanes, result_type)
if lhs_add:
name += '_lhsadd'
if rhs_add:
name += '_rhsadd'
return name
def CppResultType(result_type):
if result_type is 'int32':
return 'std::int32_t*'
elif result_type is 'float':
return 'float*'
else:
raise ConfigurationError('Unsupported result type: %s' % result_type)
def GetParameters(result_type):
params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs_1'],
['const std::uint8_t*', 'rhs_2'], ['std::int32_t', 'count'],
[CppResultType(result_type), 'result']]
if result_type is 'float':
params.append(['float', 'result_scale'])
return params
def GenerateAndClearAggregators(emitter, registers, aggregator_count):
"""Prepare aggregators and emit aggregator clear code."""
emitter.EmitNewline()
emitter.EmitComment('Clear aggregators.')
aggregators = []
for i in range(aggregator_count):
aggregator = registers.QuadRegister()
aggregators.append(aggregator)
if i < 3:
emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0))
else:
emitter.EmitVMov('i32', aggregator, aggregators[i - 3])
emitter.EmitNewline()
return aggregators
def GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes_count):
"""Generates the 1xN multiplication primitive."""
if lanes_count < 1 or lanes_count > 4:
raise ConfigurationError('Lanes should be: 1, 2, 3 or 4.')
emitter.EmitFunctionBeginA(
BuildName(result_type, lhs_add, rhs_add, lanes_count + 4),
GetParameters(result_type), 'inline void')
emitter.EmitAssert('count % 8 == 0')
emitter.EmitAssert('count >= 8')
emitter.EmitAsmBegin()
registers = neon_emitter.NeonRegisters()
count = registers.MapParameter('count')
lhs = registers.MapParameter('lhs')
rhs_1 = registers.MapParameter('rhs_1')
rhs_2 = registers.MapParameter('rhs_2')
emitter.EmitPld(lhs)
emitter.EmitPld(rhs_1)
emitter.EmitPld(rhs_2)
aggregators = GenerateAndClearAggregators(emitter, registers, lanes_count + 4)
GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators,
count, lhs, rhs_1, rhs_2)
GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators,
result_type, lhs_add, rhs_add, lhs, rhs_1,
rhs_2, registers.MapParameter('result'))
emitter.EmitAsmEnd(registers.MappedParameters(), [],
registers.Clobbers() + ['cc', 'memory'])
emitter.EmitFunctionEnd()
def GenerateFunctions(emitter, result_type, lhs_add, rhs_add):
for lanes in range(1, 5):
GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes)
emitter.EmitNewline()
if __name__ == '__main__':
GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True, True)