diff options
Diffstat (limited to 'src/amd/compiler/aco_instruction_selection.cpp')
-rw-r--r-- | src/amd/compiler/aco_instruction_selection.cpp | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 0dad702e419..1c47c45eacb 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -1406,6 +1406,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_max_i16_e64, dst); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_i16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_max_i16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_i32, dst, true); } else if (dst.regClass() == s1) { @@ -1420,6 +1422,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_max_u16_e64, dst); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_u16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_max_u16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_u32, dst, true); } else if (dst.regClass() == s1) { @@ -1434,6 +1438,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_min_i16_e64, dst); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_i16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_min_i16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_i32, dst, true); } else if (dst.regClass() == s1) { @@ -1448,6 +1454,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_min_u16_e64, dst); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_u16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_min_u16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_u32, dst, true); } else if (dst.regClass() == s1) { @@ -1510,6 +1518,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_lshrrev_b16_e64, dst, false, 2, true); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_lshrrev_b16, dst, false, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_lshrrev_b16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_lshrrev_b32, dst, false, true); } else if (dst.regClass() == v2 && ctx->program->chip_class >= GFX8) { @@ -1531,6 +1541,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_lshlrev_b16_e64, dst, false, 2, true); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_lshlrev_b16, dst, false, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_lshlrev_b16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_lshlrev_b32, dst, false, true, false, false, 1); } else if (dst.regClass() == v2 && ctx->program->chip_class >= GFX8) { @@ -1552,6 +1564,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_ashrrev_i16_e64, dst, false, 2, true); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_ashrrev_i16, dst, false, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_ashrrev_i16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_ashrrev_i32, dst, false, true); } else if (dst.regClass() == v2 && ctx->program->chip_class >= GFX8) { @@ -1628,6 +1642,9 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } else if (dst.bytes() <= 2 && ctx->program->chip_class >= GFX8) { emit_vop2_instruction(ctx, instr, aco_opcode::v_add_u16, dst, true); break; + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_add_u16, dst); + break; } Temp src0 = get_alu_src(ctx, instr->src[0]); @@ -1737,6 +1754,9 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) if (dst.regClass() == s1) { emit_sop2_instruction(ctx, instr, aco_opcode::s_sub_i32, dst, true); break; + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_sub_u16, dst); + break; } Temp src0 = get_alu_src(ctx, instr->src[0]); @@ -1815,6 +1835,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_lo_u16_e64, dst); } else if (dst.bytes() <= 2 && ctx->program->chip_class >= GFX8) { emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_lo_u16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_mul_lo_u16, dst); } else if (dst.type() == RegType::vgpr) { Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = get_alu_src(ctx, instr->src[1]); @@ -1896,6 +1918,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_fmul: { if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_mul_f16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f32, dst, true); } else if (dst.regClass() == v2) { @@ -1908,6 +1932,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_fadd: { if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_add_f16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f32, dst, true); } else if (dst.regClass() == v2) { @@ -1918,6 +1944,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fsub: { + if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + Instruction* add = emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_add_f16, dst); + VOP3P_instruction* sub = static_cast<VOP3P_instruction*>(add); + sub->neg_lo[1] = true; + sub->neg_hi[1] = true; + break; + } + Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = get_alu_src(ctx, instr->src[1]); if (dst.regClass() == v2b) { @@ -1944,6 +1978,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) if (dst.regClass() == v2b) { // TODO: check fp_mode.must_flush_denorms16_64 emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_max_f16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32); } else if (dst.regClass() == v2) { @@ -1957,6 +1993,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) if (dst.regClass() == v2b) { // TODO: check fp_mode.must_flush_denorms16_64 emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_min_f16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32); } else if (dst.regClass() == v2) { @@ -2011,6 +2049,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fneg: { + if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + Temp src = get_alu_src_vop3p(ctx, instr->src[0]); + bld.vop3p(aco_opcode::v_pk_mul_f16, Definition(dst), src, Operand(uint16_t(0xBC00)), + instr->src[0].swizzle[0] & 1, instr->src[0].swizzle[1] & 1); + emit_split_vector(ctx, dst, 2); + break; + } Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { if (ctx->block->fp_mode.must_flush_denorms16_64) @@ -2055,6 +2100,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fsat: { + if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + Temp src = get_alu_src_vop3p(ctx, instr->src[0]); + Instruction* vop3p = bld.vop3p(aco_opcode::v_pk_mul_f16, Definition(dst), src, Operand(uint16_t(0x3C00)), + instr->src[0].swizzle[0] & 1, instr->src[0].swizzle[1] & 1); + static_cast<VOP3P_instruction*>(vop3p)->clamp = true; + emit_split_vector(ctx, dst, 2); + break; + } Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { bld.vop3(aco_opcode::v_med3_f16, Definition(dst), Operand((uint16_t)0u), Operand((uint16_t)0x3c00), src); |