Skip to content

Commit

Permalink
Better Arm64 input register loading and incrementation.
Browse files Browse the repository at this point in the history
Believe it or not, these small changes make kernels about 3% faster.

PiperOrigin-RevId: 707440293
  • Loading branch information
alankelly authored and xnnpack-bot committed Dec 18, 2024
1 parent 9c682e5 commit 2889ba7
Show file tree
Hide file tree
Showing 79 changed files with 1,485 additions and 1,694 deletions.
14 changes: 6 additions & 8 deletions gemm_compiler/aarch64_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def inner_loop(self, M, N):
w_step=self.register_bytes() * N_COUNT,
)
for l in self.weights_asm()['loop']:
if N_COUNT % 2 == 0:
if N_COUNT % 2 != 0:
asm_string += l.format(
W_ptr=self.w_ptr_register(),
W=self.w_registers()[nr],
Expand Down Expand Up @@ -276,17 +276,15 @@ def clamp_inputs_and_outputs(
def increment_ptr(self, ptr, step):
return f'add {ptr}, {ptr}, {step}\n'

def zero_gp_register(self, reg):
return f'eor {reg}, {reg}, {reg}\n'
def initialize_k_register(self, reg):
kc_register = self.kc_register()
return f'mov {reg}, {kc_register}\n'

def cmp_k_and_jump_if_less(self, label):
kc_register = self.kc_register()
k_register = self.k_register()
return """add {k_register}, {k_register}, 4
cmp {kc_register}, {k_register}
bne {label}\n""".format(
label=label, k_register=k_register, kc_register=kc_register
)
return f"""subs {k_register}, {k_register}, 4
bne {label}\n"""

def epilogue(self, M, N, isa):
restore_stack = """
Expand Down
4 changes: 2 additions & 2 deletions gemm_compiler/avx512f_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def store(
if pop_c:
asm_string += '\n' + '# Pop output pointers from the stack.\n'
c_reg_offset = 0
POP_C = 'mov {C_REG}, [rsp + {offset}]\n'
POP_C = 'mov {C_REG}, [rsp - {offset}]\n'
for mr in range(0, M):
sp_offset = 128 + (mr) * 16 + 8
asm_string += POP_C.format(C_REG=cm_registers[mr], offset=sp_offset)
Expand All @@ -208,7 +208,7 @@ def store(
)
if pop_c:
asm_string += '\n' + '# Write output pointers to the stack.\n'
POP_C = 'mov [rsp + {offset}], {C_REG}\n'
POP_C = 'mov [rsp - {offset}], {C_REG}\n'
for mr in range(0, M):
sp_offset = 128 + (mr) * 16 + 8
asm_string += POP_C.format(C_REG=cm_registers[mr], offset=sp_offset)
Expand Down
4 changes: 2 additions & 2 deletions gemm_compiler/base_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def increment_ptr(self, ptr, step):
raise NotImplementedError

@abstractmethod
def zero_gp_register(self, reg):
"""Zero the given general purpose register."""
def initialize_k_register(self, reg):
"""Initialized the given general purpose register for inner loop control."""
raise NotImplementedError

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions gemm_compiler/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def generate_gemm_microkernel(

# the outer loop label
asm_string += '\nouter_loop:\n'
asm_string += '# Zero k counter.\n'
asm_string += isa.zero_gp_register(k_register)
asm_string += '# Initialize k counter.\n'
asm_string += isa.initialize_k_register(k_register)

# Read a registers from the stack if required
asm_string += isa.read_a_registers(M=M)
Expand Down
20 changes: 0 additions & 20 deletions gemm_compiler/neondot_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,6 @@ def quantization_params(self):
def quantization_params_register(self):
return 'x24'

def input_asm(self):
in_asm = {
'loop': [
'ldr d{AM}, [{AM_ptr}, {a_offset}]\n',
]
}
return in_asm

def weights_asm(self):
w_asm = {
'loop': [
'ldr q{W}, [{W_ptr}, {offset}]\n',
],
'loop_2': [
'ldp q{W}, q{W_1}, [{W_ptr}, {offset}]\n',
],
'after': 'add {W}, {W}, {w_step}\n',
}
return w_asm

def compute_asm(self):
c_asm = {
'loop': ['sdot v{ACC}.4s, v{W}.16b, v{A}.4b[0]\n'],
Expand Down
30 changes: 11 additions & 19 deletions gemm_compiler/neonfma_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,19 @@ def w_registers(self):
def input_asm(self):
in_asm = {
'loop': [
'ldr d{AM}, [{AM_ptr}, {a_offset}]\n',
'ldr s{AM}, [{AM_ptr}], 4\n',
]
}
return in_asm

def weights_asm(self):
w_asm = {
'loop': [
'ldr q{W}, [{W_ptr}, {offset}]\n',
'ldr q{W}, [{W_ptr}], 16\n',
],
'loop_2': [
'ldp q{W}, q{W_1}, [{W_ptr}, {offset}]\n',
'ldp q{W}, q{W_1}, [{W_ptr}], 32\n',
],
'after': 'add {W}, {W}, {w_step}\n',
}
return w_asm

Expand Down Expand Up @@ -139,22 +138,21 @@ def store(
cmp {nc}, {n_step}
b.lo tail_{N_2}\n""".format(n_step=N, N_2=N // 2, nc=nc_reg)
for mr in range(0, M):
asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}]\n'.format(
asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}], 32\n'.format(
ACC=accumulators[mr],
ACC_1=accumulators[M + mr],
c_reg=cm_registers[mr],
)
for nr in range(2, N_COUNT, 2):
asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}, {offset}]\n'.format(
asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}], 32\n'.format(
ACC=accumulators[M * 2 + mr],
ACC_1=accumulators[M * 3 + mr],
c_reg=cm_registers[mr],
offset=self.register_bytes() * nr,
)
for mr in range(0, M):
asm_string += 'add {cm}, {cm}, {cn_stride}\n'.format(
cn_stride=N_COUNT * 16, cm=cm_registers[mr]
)
AM_PTR = self.am_registers()[mr]
kc_register = self.kc_register()
asm_string += f'sub {AM_PTR}, {AM_PTR}, {kc_register}\n'
CHECK = """
sub {nc}, {nc}, {n_step}
b.ne outer_loop
Expand All @@ -167,7 +165,7 @@ def store(
\ntail_8:
tbz {nc_lo}, 3, tail_4\n""".format(nc_lo=nc_lo)
for mr in range(0, M):
asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}]\n'.format(
asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}], 32\n'.format(
ACC=accumulators[mr],
ACC_1=accumulators[mr + M],
c_reg=cm_registers[mr],
Expand All @@ -179,30 +177,24 @@ def store(
asm_string += 'mov v{ACC0}.16b, v{ACC1}.16b\n'.format(
ACC0=accumulators[mr + M], ACC1=accumulators[mr + 3 * M]
)
for mr in range(0, M):
asm_string += 'add {cm}, {cm}, 32\n'.format(cm=cm_registers[mr])
asm_string += """
\ntail_4:
tbz {nc_lo}, 2, tail_2\n""".format(nc_lo=nc_lo)
for mr in range(0, M):
asm_string += 'str q{ACC}, [{c_reg}]\n'.format(
asm_string += 'str q{ACC}, [{c_reg}], 16\n'.format(
ACC=accumulators[mr], c_reg=cm_registers[mr]
)
for mr in range(0, M):
asm_string += 'mov v{ACC0}.16b, v{ACC1}.16b\n'.format(
ACC0=accumulators[mr], ACC1=accumulators[mr + M]
)
for mr in range(0, M):
asm_string += 'add {cm}, {cm}, 16\n'.format(cm=cm_registers[mr])
asm_string += """
\ntail_2:
tbz {nc_lo}, 1, tail_1\n""".format(nc_lo=nc_lo)
for mr in range(0, M):
asm_string += 'str d{ACC}, [{c_reg}]\n'.format(
asm_string += 'str d{ACC}, [{c_reg}], 8\n'.format(
ACC=accumulators[mr], c_reg=cm_registers[mr]
)
for mr in range(0, M):
asm_string += 'add {c_reg}, {c_reg}, 8\n'.format(c_reg=cm_registers[mr])
for mr in range(0, M):
asm_string += 'dup d{ACC}, v{ACC}.d[1]\n'.format(ACC=accumulators[mr])
asm_string += """
Expand Down
14 changes: 7 additions & 7 deletions gemm_compiler/x64_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def input_output_register_setup(self, M):
cmovle {aM}, {aM_1}
cmovle {cM}, {cM_1}\n"""
INPUT_OUTPUT_REGISTER_PUSH = """
mov [rsp + {a_rsp_offset}], {aM}
mov [rsp + {c_rsp_offset}], {cM}\n"""
mov [rsp - {a_rsp_offset}], {aM}
mov [rsp - {c_rsp_offset}], {cM}\n"""
ret = ''
if self.stack_size(M) != 0:
ret += """sub rsp, {stack_size}\n""".format(
Expand All @@ -208,11 +208,11 @@ def input_output_register_setup(self, M):
ret += (
'# Write rsi (a pointer) to the stack as we need the register.\n'
)
ret += 'mov [rsp + 128], rsi\n'
ret += 'mov [rsp - 128], rsi\n'
ret += (
'# Write r10 (c pointer) to the stack as we need the register.\n'
)
ret += 'mov [rsp + 136], r10\n'
ret += 'mov [rsp - 136], r10\n'
for mr in range(1, M):
# cycle size of 2 if required
if M > self.max_M_before_spilling():
Expand Down Expand Up @@ -262,7 +262,7 @@ def read_a_registers(self, M):
if M <= self.max_M_before_spilling():
return ''
ret = '# Read a pointers from stack into GP registers.\n'
POP_A = 'mov {aM}, [rsp + {a_rsp_offset}]\n'
POP_A = 'mov {aM}, [rsp - {a_rsp_offset}]\n'
for mr in range(0, M):
a_rsp_offset = 128 + mr * 16
ret += POP_A.format(aM=registers[mr], a_rsp_offset=a_rsp_offset)
Expand All @@ -272,7 +272,7 @@ def read_a_registers(self, M):
def increment_ptr(self, ptr, step):
return f'add {ptr}, {step}\n'

def zero_gp_register(self, reg):
def initialize_k_register(self, reg):
return f'mov {reg}, 0\n'

def cmp_k_and_jump_if_less(self, label):
Expand All @@ -287,7 +287,7 @@ def cmp_k_and_jump_if_less(self, label):

def load_from_stack(self, reg, offset):
"""Load 8 bytes from the given offset from the stack pointer to reg."""
return f'mov {reg}, [rsp + {offset}]\n'
return f'mov {reg}, [rsp - {offset}]\n'

def epilogue(self, M, N, isa):
restore_stack = '\nreturn:\n'
Expand Down
Loading

0 comments on commit 2889ba7

Please sign in to comment.