Skip to content

Commit

Permalink
Implement pairwise add and extmul for RISCV
Browse files Browse the repository at this point in the history
Signed-off-by: Zoltan Herczeg [email protected]
  • Loading branch information
zherczeg authored and clover2123 committed Jan 26, 2025
1 parent 87cf500 commit 3a6517e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 37 deletions.
1 change: 1 addition & 0 deletions src/jit/ByteCodeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)
#define OTOp3DotAddV128 OTOp3V128

#elif (defined SLJIT_CONFIG_RISCV && SLJIT_CONFIG_RISCV)

#define OPERAND_TYPE_LIST_SIMD_ARCH \
OL2(OTOp1V128CB, /* SD */ V128 | NOTMP, V128 | NOTMP) \
OL3(OTOp2V128, /* SSD */ V128 | TMP, V128 | TMP, V128 | TMP | S0 | S1) \
Expand Down
116 changes: 79 additions & 37 deletions src/jit/SimdRiscvInl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ enum TypeOpcode : uint32_t {
vand_vv = InstructionType::opivv | OPCODE(0x9),
vcompress_vm = InstructionType::opmvv | OPCODE(0x17),
#if defined(__riscv_zvbb)
vcpop_v = InstructionType::opmvv | OPCODE(0x12) | (0xE << 15),
vcpop_v = InstructionType::opmvv | OPCODE(0x12) | (0xe << 15),
#endif
vfadd_vf = InstructionType::opfvf | OPCODE(0x0),
vfadd_vv = InstructionType::opfvv | OPCODE(0x0),
Expand All @@ -56,7 +56,7 @@ enum TypeOpcode : uint32_t {
vfmul_vv = InstructionType::opfvv | OPCODE(0x24),
vfsgnj_vv = InstructionType::opfvv | OPCODE(0x8),
vfsgnjn_vv = InstructionType::opfvv | OPCODE(0x9),
vfsgnjx_vv = InstructionType::opfvv | OPCODE(0xA),
vfsgnjx_vv = InstructionType::opfvv | OPCODE(0xa),
vfsqrt_v = InstructionType::opfvv | OPCODE(0x13),
vfsub_vv = InstructionType::opfvv | OPCODE(0x2),
vmax_vv = InstructionType::opivv | OPCODE(0x7),
Expand All @@ -65,16 +65,16 @@ enum TypeOpcode : uint32_t {
vmerge_vv = (InstructionType::opivv ^ InstructionType::vm) | OPCODE(0x17),
vmfeq_vv = InstructionType::opfvv | OPCODE(0x18),
vmfle_vv = InstructionType::opfvv | OPCODE(0x19),
vmflt_vv = InstructionType::opfvv | OPCODE(0x1B),
vmfne_vv = InstructionType::opfvv | OPCODE(0x1C),
vmflt_vv = InstructionType::opfvv | OPCODE(0x1b),
vmfne_vv = InstructionType::opfvv | OPCODE(0x1c),
vmin_vv = InstructionType::opivv | OPCODE(0x5),
vminu_vv = InstructionType::opivv | OPCODE(0x4),
vmseq_vv = InstructionType::opivv | OPCODE(0x18),
vmsle_vv = InstructionType::opivv | OPCODE(0x1D),
vmsleu_vv = InstructionType::opivv | OPCODE(0x1C),
vmslt_vv = InstructionType::opivv | OPCODE(0x1B),
vmslt_vx = InstructionType::opivx | OPCODE(0x1B),
vmsltu_vv = InstructionType::opivv | OPCODE(0x1A),
vmsle_vv = InstructionType::opivv | OPCODE(0x1d),
vmsleu_vv = InstructionType::opivv | OPCODE(0x1c),
vmslt_vv = InstructionType::opivv | OPCODE(0x1b),
vmslt_vx = InstructionType::opivx | OPCODE(0x1b),
vmsltu_vv = InstructionType::opivv | OPCODE(0x1a),
vmsne_vi = InstructionType::opivi | OPCODE(0x19),
vmsne_vv = InstructionType::opivv | OPCODE(0x19),
vmul_vv = InstructionType::opmvv | OPCODE(0x25),
Expand All @@ -83,16 +83,16 @@ enum TypeOpcode : uint32_t {
vmv_vv = InstructionType::opivv | OPCODE(0x17),
vmv_vx = InstructionType::opivx | OPCODE(0x17),
vmv_xs = InstructionType::opmvv | OPCODE(0x10),
vor_vv = InstructionType::opivv | OPCODE(0xA),
vor_vv = InstructionType::opivv | OPCODE(0xa),
vredmaxu_vs = InstructionType::opmvv | OPCODE(0x6),
vredminu_vs = InstructionType::opmvv | OPCODE(0x4),
vredsum_vs = InstructionType::opmvv | OPCODE(0x0),
vrgather_vv = InstructionType::opivv | OPCODE(0xC),
vrgather_vv = InstructionType::opivv | OPCODE(0xc),
vrsub_vi = InstructionType::opivi | OPCODE(0x3),
vsadd_vv = InstructionType::opivv | OPCODE(0x21),
vsaddu_vv = InstructionType::opivv | OPCODE(0x20),
vsext_vf2 = InstructionType::opmvv | OPCODE(0x12) | (0x7 << 15),
vslidedown_vi = InstructionType::opivi | OPCODE(0xF),
vslidedown_vi = InstructionType::opivi | OPCODE(0xf),
vsll_vi = InstructionType::opivi | OPCODE(0x25),
vsll_vx = InstructionType::opivx | OPCODE(0x25),
vsra_vi = InstructionType::opivi | OPCODE(0x29),
Expand All @@ -102,9 +102,10 @@ enum TypeOpcode : uint32_t {
vssub_vv = InstructionType::opivv | OPCODE(0x23),
vssubu_vv = InstructionType::opivv | OPCODE(0x22),
vsub_vv = InstructionType::opivv | OPCODE(0x2),
vwmul_vv = InstructionType::opmvv | OPCODE(0x3B),
vxor_vi = InstructionType::opivi | OPCODE(0xB),
vxor_vv = InstructionType::opivv | OPCODE(0xB),
vwmul_vv = InstructionType::opmvv | OPCODE(0x3b),
vwmulu_vv = InstructionType::opmvv | OPCODE(0x38),
vxor_vi = InstructionType::opivi | OPCODE(0xb),
vxor_vv = InstructionType::opivv | OPCODE(0xb),
vzext_vf2 = InstructionType::opmvv | OPCODE(0x12) | (0x6 << 15),
};

Expand All @@ -115,9 +116,20 @@ enum OperandTypes : uint32_t {
rmIsGpr = 1 << 4,
rdIsGpr = 1 << 5
};

enum VectorLengthMultiplyTypes : uint32_t {
vlMul1 = 0,
vlMul2 = 1,
vlMul4 = 2,
vlMul8 = 3,
vlMulF2 = 7,
vlMulF4 = 6,
vlMulF8 = 5,
};

}; // namespace SimdOp

static void simdEmitVsetivli(struct sljit_compiler* compiler, sljit_s32 type, sljit_ins vlmul)
static void simdEmitVsetivli(struct sljit_compiler* compiler, sljit_s32 type, uint32_t vlmul)
{
uint32_t elem_size = (uint32_t)(((type) >> 18) & 0x3f);
uint32_t avl = (uint32_t)1 << (4 - elem_size);
Expand Down Expand Up @@ -151,7 +163,7 @@ static void simdEmitOp(sljit_compiler* compiler, uint32_t opcode, sljit_s32 rd,
sljit_emit_op_custom(compiler, &opcode, sizeof(uint32_t));
}

static void simdEmitTypedOp(sljit_compiler* compiler, sljit_s32 type, uint32_t opcode, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm, uint32_t optype = 0, sljit_ins vlmul = 0)
static void simdEmitTypedOp(sljit_compiler* compiler, sljit_s32 type, uint32_t opcode, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm, uint32_t optype = 0, uint32_t vlmul = 0)
{
simdEmitVsetivli(compiler, type, vlmul);
simdEmitOp(compiler, opcode, rd, rn, rm, optype);
Expand Down Expand Up @@ -186,6 +198,17 @@ static void simdEmitAbs(sljit_compiler* compiler, sljit_s32 type, sljit_s32 rd,
}
}

static void simdEmitPairwiseAdd(sljit_compiler* compiler, sljit_s32 type, bool isSigned, sljit_s32 rd, sljit_s32 rn)
{
sljit_s32 tmp = SLJIT_TMP_DEST_VREG;
sljit_s32 shift = (type == SLJIT_SIMD_ELEM_16) ? 8 : 16;

simdEmitTypedOp(compiler, type, isSigned ? SimdOp::vsra_vi : SimdOp::vsrl_vi, tmp, rn, shift, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vsll_vi, rd, rn, shift, SimdOp::rmIsImm);
simdEmitOp(compiler, isSigned ? SimdOp::vsra_vi : SimdOp::vsrl_vi, rd, rd, shift, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vadd_vv, rd, rd, tmp);
}

static void simdEmitAllTrue(sljit_compiler* compiler, sljit_s32 type, sljit_s32 rd, sljit_s32 rn)
{
sljit_s32 tmp = SLJIT_TMP_DEST_VREG;
Expand Down Expand Up @@ -530,8 +553,12 @@ static void emitUnarySIMD(sljit_compiler* compiler, Instruction* instr)
simdEmitPopcnt(compiler, srcType, dst, args[0].arg, instr->requiredReg(1));
break;
case ByteCode::I16X8ExtaddPairwiseI8X16SOpcode:
case ByteCode::I32X4ExtaddPairwiseI16X8SOpcode:
simdEmitPairwiseAdd(compiler, dstType, true, dst, args[0].arg);
break;
case ByteCode::I16X8ExtaddPairwiseI8X16UOpcode:
case ByteCode::I32X4ExtaddPairwiseI16X8UOpcode:
simdEmitPairwiseAdd(compiler, dstType, false, dst, args[0].arg);
break;
case ByteCode::I16X8ExtendLowI8X16SOpcode:
case ByteCode::I32X4ExtendLowI16X8SOpcode:
Expand All @@ -553,10 +580,6 @@ static void emitUnarySIMD(sljit_compiler* compiler, Instruction* instr)
case ByteCode::I64X2ExtendHighI32X4UOpcode:
simdEmitExtend(compiler, srcType, false, false, dst, args[0].arg);
break;
case ByteCode::I32X4ExtaddPairwiseI16X8SOpcode:
break;
case ByteCode::I32X4ExtaddPairwiseI16X8UOpcode:
break;
case ByteCode::I32X4TruncSatF32X4SOpcode:
simdEmitTruncSat(compiler, srcType, SimdOp::vfcvt_rtz_x_f_v, dst, args[0].arg);
break;
Expand Down Expand Up @@ -666,6 +689,29 @@ static bool emitUnaryCondSIMD(sljit_compiler* compiler, Instruction* instr)
return false;
}

static void simdEmitExtmul(sljit_compiler* compiler, sljit_s32 type, uint32_t opcode, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm)
{
sljit_s32 tmp = SLJIT_TMP_DEST_VREG;
bool useTmp = (rd == rn || rd == rm);

simdEmitTypedOp(compiler, type, opcode, useTmp ? tmp : rd, rn, rm, 0, SimdOp::vlMulF2);

if (useTmp) {
simdEmitTypedOp(compiler, type, SimdOp::vmv_vv, rd, 0, tmp, SimdOp::rnIsImm);
}
}

static void simdEmitExtmulHigh(sljit_compiler* compiler, sljit_s32 type, uint32_t opcode, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm)
{
sljit_s32 tmp1 = SLJIT_TMP_DEST_VREG;
sljit_s32 tmp2 = SLJIT_VR0;

simdEmitTypedOp(compiler, SLJIT_SIMD_ELEM_8, SimdOp::vslidedown_vi, tmp1, rn, 8, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vslidedown_vi, tmp2, rm, 8, SimdOp::rmIsImm);

simdEmitTypedOp(compiler, type, opcode, rd, tmp1, tmp2, 0, SimdOp::vlMulF2);
}

static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
{
Operand* operands = instr->operands();
Expand Down Expand Up @@ -959,27 +1005,31 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
simdEmitTypedOp(compiler, srcType, SimdOp::vmul_vv, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I16X8ExtmulLowI8X16SOpcode:
case ByteCode::I32X4ExtmulLowI16X8SOpcode:
case ByteCode::I64X2ExtmulLowI32X4SOpcode:
simdEmitExtmul(compiler, srcType, SimdOp::vwmul_vv, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I16X8ExtmulHighI8X16SOpcode:
case ByteCode::I32X4ExtmulHighI16X8SOpcode:
case ByteCode::I64X2ExtmulHighI32X4SOpcode:
simdEmitExtmulHigh(compiler, srcType, SimdOp::vwmul_vv, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I16X8ExtmulLowI8X16UOpcode:
case ByteCode::I32X4ExtmulLowI16X8UOpcode:
case ByteCode::I64X2ExtmulLowI32X4UOpcode:
simdEmitExtmul(compiler, srcType, SimdOp::vwmulu_vv, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I16X8ExtmulHighI8X16UOpcode:
case ByteCode::I32X4ExtmulHighI16X8UOpcode:
case ByteCode::I64X2ExtmulHighI32X4UOpcode:
simdEmitExtmulHigh(compiler, srcType, SimdOp::vwmulu_vv, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I16X8NarrowI32X4SOpcode:
break;
case ByteCode::I16X8NarrowI32X4UOpcode:
break;
case ByteCode::I16X8Q15mulrSatSOpcode:
break;
case ByteCode::I32X4ExtmulLowI16X8SOpcode:
break;
case ByteCode::I32X4ExtmulHighI16X8SOpcode:
break;
case ByteCode::I32X4ExtmulLowI16X8UOpcode:
break;
case ByteCode::I32X4ExtmulHighI16X8UOpcode:
break;
case ByteCode::I32X4DotI16X8SOpcode:
simdEmitI32x4DotI16x8(compiler, srcType, dst, args[0].arg, args[1].arg);
break;
Expand Down Expand Up @@ -1039,14 +1089,6 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
case ByteCode::F64X2GeOpcode:
simdEmitCompare(compiler, srcType, SimdOp::vmfle_vv, dst, args[1].arg, args[0].arg);
break;
case ByteCode::I64X2ExtmulLowI32X4SOpcode:
break;
case ByteCode::I64X2ExtmulHighI32X4SOpcode:
break;
case ByteCode::I64X2ExtmulLowI32X4UOpcode:
break;
case ByteCode::I64X2ExtmulHighI32X4UOpcode:
break;
case ByteCode::V128AndOpcode:
simdEmitTypedOp(compiler, srcType, SimdOp::vand_vv, dst, args[0].arg, args[1].arg);
break;
Expand Down

0 comments on commit 3a6517e

Please sign in to comment.