summaryrefslogtreecommitdiff
path: root/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c')
-rw-r--r--src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c33
1 files changed, 26 insertions, 7 deletions
diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
index 75d75bab541..b55e026a950 100644
--- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
+++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
@@ -220,7 +220,7 @@ get_atomic_op(nir_intrinsic_op op)
static SpvId
emit_float_const(struct ntv_context *ctx, int bit_size, double value)
{
- assert(bit_size == 32 || bit_size == 64);
+ assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
return spirv_builder_const_float(&ctx->builder, bit_size, value);
}
@@ -241,7 +241,7 @@ emit_int_const(struct ntv_context *ctx, int bit_size, int64_t value)
static SpvId
get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
{
- assert(bit_size == 32 || bit_size == 64);
+ assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
if (num_components > 1)
@@ -312,6 +312,9 @@ get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
case GLSL_TYPE_BOOL:
return spirv_builder_type_bool(&ctx->builder);
+ case GLSL_TYPE_FLOAT16:
+ return spirv_builder_type_float(&ctx->builder, 16);
+
case GLSL_TYPE_FLOAT:
return spirv_builder_type_float(&ctx->builder, 32);
@@ -1389,7 +1392,7 @@ static SpvId
get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
unsigned num_components, double value)
{
- assert(bit_size == 32 || bit_size == 64);
+ assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
SpvId result = emit_float_const(ctx, bit_size, value);
if (num_components == 1)
@@ -1578,12 +1581,15 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
UNOP(nir_op_f2u16, SpvOpConvertFToU)
UNOP(nir_op_f2i32, SpvOpConvertFToS)
UNOP(nir_op_f2u32, SpvOpConvertFToU)
+ UNOP(nir_op_i2f16, SpvOpConvertSToF)
UNOP(nir_op_i2f32, SpvOpConvertSToF)
+ UNOP(nir_op_u2f16, SpvOpConvertUToF)
UNOP(nir_op_u2f32, SpvOpConvertUToF)
UNOP(nir_op_i2i16, SpvOpSConvert)
UNOP(nir_op_i2i32, SpvOpSConvert)
UNOP(nir_op_u2u16, SpvOpUConvert)
UNOP(nir_op_u2u32, SpvOpUConvert)
+ UNOP(nir_op_f2f16, SpvOpFConvert)
UNOP(nir_op_f2f32, SpvOpFConvert)
UNOP(nir_op_f2i64, SpvOpConvertFToS)
UNOP(nir_op_f2u64, SpvOpConvertFToU)
@@ -1612,6 +1618,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
get_ivec_constant(ctx, bit_size, num_components, 0));
break;
+ case nir_op_b2f16:
case nir_op_b2f32:
case nir_op_b2f64:
assert(nir_op_infos[alu->op].num_inputs == 1);
@@ -1809,10 +1816,19 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
load_const->value[i].b);
} else {
- for (int i = 0; i < num_components; i++)
- components[i] = emit_uint_const(ctx, bit_size,
- bit_size == 64 ? load_const->value[i].u64 : load_const->value[i].u32);
-
+ for (int i = 0; i < num_components; i++) {
+ if (bit_size == 16)
+ components[i] = emit_uint_const(ctx, bit_size,
+ load_const->value[i].u16);
+ else if (bit_size == 32)
+ components[i] = emit_uint_const(ctx, bit_size,
+ load_const->value[i].u32);
+ else if (bit_size == 64)
+ components[i] = emit_uint_const(ctx, bit_size,
+ load_const->value[i].u64);
+ else
+ unreachable("unhandled constant bit size!");
+ }
}
constant = spirv_builder_const_composite(&ctx->builder, type,
components, num_components);
@@ -3587,6 +3603,9 @@ nir_to_spirv(struct nir_shader *s, const struct zink_so_info *so_info, bool spir
spirv_builder_emit_cap(&ctx.builder, SpvCapabilityInt16);
if (s->info.bit_sizes_int & 64)
spirv_builder_emit_cap(&ctx.builder, SpvCapabilityInt64);
+
+ if (s->info.bit_sizes_float & 16)
+ spirv_builder_emit_cap(&ctx.builder, SpvCapabilityFloat16);
if (s->info.bit_sizes_float & 64)
spirv_builder_emit_cap(&ctx.builder, SpvCapabilityFloat64);