diff options
author | Daniel Schürmann <daniel@schuermann.dev> | 2020-08-24 20:36:06 +0100 |
---|---|---|
committer | Marge Bot <eric+marge@anholt.net> | 2021-01-13 17:46:56 +0000 |
commit | 454bbf8f230e44e54b1dfc04e87dff353fa3fd1f (patch) | |
tree | 17db203f94bc2b3c886046d8aa0fcbb475ba0361 | |
parent | 5ad52ac90630e344650cf9a1b48820432af22680 (diff) |
aco: emit packed 16bit instructions
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6680>
-rw-r--r-- | src/amd/compiler/aco_instruction_selection.cpp | 53 | ||||
-rw-r--r-- | src/amd/compiler/aco_instruction_selection_setup.cpp | 17 |
2 files changed, 66 insertions, 4 deletions
diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 0dad702e419..1c47c45eacb 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -1406,6 +1406,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_max_i16_e64, dst); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_i16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_max_i16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_i32, dst, true); } else if (dst.regClass() == s1) { @@ -1420,6 +1422,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_max_u16_e64, dst); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_u16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_max_u16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_u32, dst, true); } else if (dst.regClass() == s1) { @@ -1434,6 +1438,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_min_i16_e64, dst); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_i16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_min_i16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_i32, dst, true); } else if (dst.regClass() == s1) { @@ -1448,6 +1454,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_min_u16_e64, dst); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_u16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_min_u16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_u32, dst, true); } else if (dst.regClass() == s1) { @@ -1510,6 +1518,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_lshrrev_b16_e64, dst, false, 2, true); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_lshrrev_b16, dst, false, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_lshrrev_b16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_lshrrev_b32, dst, false, true); } else if (dst.regClass() == v2 && ctx->program->chip_class >= GFX8) { @@ -1531,6 +1541,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_lshlrev_b16_e64, dst, false, 2, true); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_lshlrev_b16, dst, false, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_lshlrev_b16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_lshlrev_b32, dst, false, true, false, false, 1); } else if (dst.regClass() == v2 && ctx->program->chip_class >= GFX8) { @@ -1552,6 +1564,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_ashrrev_i16_e64, dst, false, 2, true); } else if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_ashrrev_i16, dst, false, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_ashrrev_i16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_ashrrev_i32, dst, false, true); } else if (dst.regClass() == v2 && ctx->program->chip_class >= GFX8) { @@ -1628,6 +1642,9 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } else if (dst.bytes() <= 2 && ctx->program->chip_class >= GFX8) { emit_vop2_instruction(ctx, instr, aco_opcode::v_add_u16, dst, true); break; + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_add_u16, dst); + break; } Temp src0 = get_alu_src(ctx, instr->src[0]); @@ -1737,6 +1754,9 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) if (dst.regClass() == s1) { emit_sop2_instruction(ctx, instr, aco_opcode::s_sub_i32, dst, true); break; + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_sub_u16, dst); + break; } Temp src0 = get_alu_src(ctx, instr->src[0]); @@ -1815,6 +1835,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_lo_u16_e64, dst); } else if (dst.bytes() <= 2 && ctx->program->chip_class >= GFX8) { emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_lo_u16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_mul_lo_u16, dst); } else if (dst.type() == RegType::vgpr) { Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = get_alu_src(ctx, instr->src[1]); @@ -1896,6 +1918,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_fmul: { if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_mul_f16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f32, dst, true); } else if (dst.regClass() == v2) { @@ -1908,6 +1932,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_fadd: { if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_add_f16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f32, dst, true); } else if (dst.regClass() == v2) { @@ -1918,6 +1944,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fsub: { + if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + Instruction* add = emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_add_f16, dst); + VOP3P_instruction* sub = static_cast<VOP3P_instruction*>(add); + sub->neg_lo[1] = true; + sub->neg_hi[1] = true; + break; + } + Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = get_alu_src(ctx, instr->src[1]); if (dst.regClass() == v2b) { @@ -1944,6 +1978,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) if (dst.regClass() == v2b) { // TODO: check fp_mode.must_flush_denorms16_64 emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_max_f16, dst); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32); } else if (dst.regClass() == v2) { @@ -1957,6 +1993,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) if (dst.regClass() == v2b) { // TODO: check fp_mode.must_flush_denorms16_64 emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f16, dst, true); + } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_min_f16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32); } else if (dst.regClass() == v2) { @@ -2011,6 +2049,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fneg: { + if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + Temp src = get_alu_src_vop3p(ctx, instr->src[0]); + bld.vop3p(aco_opcode::v_pk_mul_f16, Definition(dst), src, Operand(uint16_t(0xBC00)), + instr->src[0].swizzle[0] & 1, instr->src[0].swizzle[1] & 1); + emit_split_vector(ctx, dst, 2); + break; + } Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { if (ctx->block->fp_mode.must_flush_denorms16_64) @@ -2055,6 +2100,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fsat: { + if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) { + Temp src = get_alu_src_vop3p(ctx, instr->src[0]); + Instruction* vop3p = bld.vop3p(aco_opcode::v_pk_mul_f16, Definition(dst), src, Operand(uint16_t(0x3C00)), + instr->src[0].swizzle[0] & 1, instr->src[0].swizzle[1] & 1); + static_cast<VOP3P_instruction*>(vop3p)->clamp = true; + emit_split_vector(ctx, dst, 2); + break; + } Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { bld.vop3(aco_opcode::v_med3_f16, Definition(dst), Operand((uint16_t)0u), Operand((uint16_t)0x3c00), src); diff --git a/src/amd/compiler/aco_instruction_selection_setup.cpp b/src/amd/compiler/aco_instruction_selection_setup.cpp index e0c2b93d2b8..04aa7be7f3d 100644 --- a/src/amd/compiler/aco_instruction_selection_setup.cpp +++ b/src/amd/compiler/aco_instruction_selection_setup.cpp @@ -680,7 +680,7 @@ void init_context(isel_context *ctx, nir_shader *shader) switch(instr->type) { case nir_instr_type_alu: { nir_alu_instr *alu_instr = nir_instr_as_alu(instr); - RegType type = RegType::sgpr; + RegType type = nir_dest_is_divergent(alu_instr->dest.dest) ? RegType::vgpr : RegType::sgpr; switch(alu_instr->op) { case nir_op_fmul: case nir_op_fadd: @@ -745,10 +745,19 @@ void init_context(isel_context *ctx, nir_shader *shader) case nir_op_b2f16: case nir_op_b2f32: case nir_op_mov: - type = nir_dest_is_divergent(alu_instr->dest.dest) ? RegType::vgpr : RegType::sgpr; break; - case nir_op_bcsel: - type = nir_dest_is_divergent(alu_instr->dest.dest) ? RegType::vgpr : RegType::sgpr; + case nir_op_iadd: + case nir_op_isub: + case nir_op_imul: + case nir_op_imin: + case nir_op_imax: + case nir_op_umin: + case nir_op_umax: + case nir_op_ishl: + case nir_op_ishr: + case nir_op_ushr: + /* packed 16bit instructions have to be VGPR */ + type = alu_instr->dest.dest.ssa.num_components == 2 ? RegType::vgpr : type; FALLTHROUGH; default: for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) { |