diff options
-rw-r--r-- | src/amd/common/ac_nir_to_llvm.c | 33 |
1 files changed, 21 insertions, 12 deletions
diff --git a/src/amd/common/ac_nir_to_llvm.c b/src/amd/common/ac_nir_to_llvm.c index 82def9ab20f..37ef50fabaf 100644 --- a/src/amd/common/ac_nir_to_llvm.c +++ b/src/amd/common/ac_nir_to_llvm.c @@ -671,22 +671,31 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr) result = LLVMBuildXor(ctx->ac.builder, src[0], src[1], ""); break; case nir_op_ishl: - result = LLVMBuildShl(ctx->ac.builder, src[0], - LLVMBuildZExt(ctx->ac.builder, src[1], - LLVMTypeOf(src[0]), ""), - ""); + if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) < ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) + src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], + LLVMTypeOf(src[0]), ""); + else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) > ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) + src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], + LLVMTypeOf(src[0]), ""); + result = LLVMBuildShl(ctx->ac.builder, src[0], src[1], ""); break; case nir_op_ishr: - result = LLVMBuildAShr(ctx->ac.builder, src[0], - LLVMBuildZExt(ctx->ac.builder, src[1], - LLVMTypeOf(src[0]), ""), - ""); + if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) < ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) + src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], + LLVMTypeOf(src[0]), ""); + else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) > ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) + src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], + LLVMTypeOf(src[0]), ""); + result = LLVMBuildAShr(ctx->ac.builder, src[0], src[1], ""); break; case nir_op_ushr: - result = LLVMBuildLShr(ctx->ac.builder, src[0], - LLVMBuildZExt(ctx->ac.builder, src[1], - LLVMTypeOf(src[0]), ""), - ""); + if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) < ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) + src[1] = LLVMBuildZExt(ctx->ac.builder, src[1], + LLVMTypeOf(src[0]), ""); + else if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[1])) > ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0]))) + src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], + LLVMTypeOf(src[0]), ""); + result = LLVMBuildLShr(ctx->ac.builder, src[0], src[1], ""); break; case nir_op_ilt32: result = emit_int_cmp(&ctx->ac, LLVMIntSLT, src[0], src[1]); |