diff options
-rw-r--r-- | src/amd/vulkan/radv_nir_to_llvm.c | 8 | ||||
-rw-r--r-- | src/amd/vulkan/radv_private.h | 1 | ||||
-rw-r--r-- | src/amd/vulkan/radv_shader.c | 2 |
3 files changed, 8 insertions, 3 deletions
diff --git a/src/amd/vulkan/radv_nir_to_llvm.c b/src/amd/vulkan/radv_nir_to_llvm.c index 5a34a85d7b4..5201f46b3a8 100644 --- a/src/amd/vulkan/radv_nir_to_llvm.c +++ b/src/amd/vulkan/radv_nir_to_llvm.c @@ -3610,9 +3610,10 @@ ac_setup_rings(struct radv_shader_context *ctx) unsigned radv_nir_get_max_workgroup_size(enum chip_class chip_class, + gl_shader_stage stage, const struct nir_shader *nir) { - switch (nir->info.stage) { + switch (stage) { case MESA_SHADER_TESS_CTRL: return chip_class >= CIK ? 128 : 64; case MESA_SHADER_GEOMETRY: @@ -3623,6 +3624,8 @@ radv_nir_get_max_workgroup_size(enum chip_class chip_class, return 0; } + if (!nir) + return chip_class >= GFX9 ? 128 : 64; unsigned max_workgroup_size = nir->info.cs.local_size[0] * nir->info.cs.local_size[1] * nir->info.cs.local_size[2]; @@ -3689,7 +3692,8 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm, for (int i = 0; i < shader_count; ++i) { ctx.max_workgroup_size = MAX2(ctx.max_workgroup_size, radv_nir_get_max_workgroup_size(ctx.options->chip_class, - shaders[i])); + shaders[i]->info.stage, + shaders[i])); } create_function(&ctx, shaders[shader_count - 1]->info.stage, shader_count >= 2, diff --git a/src/amd/vulkan/radv_private.h b/src/amd/vulkan/radv_private.h index b0bcb57e7e4..31c829d345b 100644 --- a/src/amd/vulkan/radv_private.h +++ b/src/amd/vulkan/radv_private.h @@ -1994,6 +1994,7 @@ void radv_compile_nir_shader(struct ac_llvm_compiler *ac_llvm, const struct radv_nir_compiler_options *options); unsigned radv_nir_get_max_workgroup_size(enum chip_class chip_class, + gl_shader_stage stage, const struct nir_shader *nir); /* radv_shader_info.h */ diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index ea9f3d97ac9..1f9fa487688 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -765,7 +765,7 @@ generate_shader_stats(struct radv_device *device, lds_increment); } else if (stage == MESA_SHADER_COMPUTE) { unsigned max_workgroup_size = - radv_nir_get_max_workgroup_size(chip_class, variant->nir); + radv_nir_get_max_workgroup_size(chip_class, stage, variant->nir); lds_per_wave = (conf->lds_size * lds_increment) / DIV_ROUND_UP(max_workgroup_size, 64); } |