From 4b7025fa0b9e529a2ad5d6438de34c5c4f405082 Mon Sep 17 00:00:00 2001 From: Bill Currie Date: Thu, 5 Dec 2024 02:00:35 +0900 Subject: [PATCH] [qfcc] Implement matrix ops in spirv And they pass spirv-val. --- tools/qfcc/source/target_spirv.c | 86 ++++++++++++++++++++++++++------ 1 file changed, 72 insertions(+), 14 deletions(-) diff --git a/tools/qfcc/source/target_spirv.c b/tools/qfcc/source/target_spirv.c index 4c9271e44..6ecbe5ae1 100644 --- a/tools/qfcc/source/target_spirv.c +++ b/tools/qfcc/source/target_spirv.c @@ -801,6 +801,8 @@ typedef struct { SpvOp op; unsigned types1; unsigned types2; + bool mat1; + bool mat2; spirv_expr_f generate; extinst_t *extinst; } spvop_t; @@ -829,6 +831,39 @@ spirv_generate_vqmul (const expr_t *e, spirvctx_t *ctx) return 0; } +static unsigned +spirv_generate_matrix (const expr_t *e, spirvctx_t *ctx) +{ + auto mat_type = get_type (e); + int count = type_cols (mat_type); + scoped_src_loc (e); + unsigned columns[count]; + auto col_type = column_type (mat_type); + auto e1 = e->expr.e1; + auto e2 = e->expr.e2; + for (int i = 0; i < count; i++) { + auto ind = new_int_expr (i, false); + auto a = new_array_expr (e1, ind); + auto b = new_array_expr (e2, ind); + a->array.type = col_type; + b->array.type = col_type; + auto c = typed_binary_expr (col_type, e->expr.op, a, b); + + columns[i] = spirv_emit_expr (c, ctx); + } + + int tid = type_id (mat_type, ctx); + int id = spirv_id (ctx); + auto insn = spirv_new_insn (SpvOpCompositeConstruct, 3 + count, + ctx->code_space); + INSN (insn, 1) = tid; + INSN (insn, 2) = id; + for (int i = 0; i < count; i++) { + INSN (insn, 3 + i) = columns[i]; + } + return id; +} + #define SPV_type(t) (1<<(t)) #define SPV_SINT (SPV_type(ev_int)|SPV_type(ev_long)) #define SPV_UINT (SPV_type(ev_uint)|SPV_type(ev_ulong)) @@ -868,10 +903,22 @@ static spvop_t spv_ops[] = { {"add", SpvOpIAdd, SPV_INT, SPV_INT }, {"add", SpvOpFAdd, SPV_FLOAT, SPV_FLOAT }, + {"add", .types1 = SPV_FLOAT, .types2 = SPV_FLOAT, + .mat1 = true, .mat2 = true, + .generate = spirv_generate_matrix }, {"sub", SpvOpISub, SPV_INT, SPV_INT }, {"sub", SpvOpFSub, SPV_FLOAT, SPV_FLOAT }, + {"add", .types1 = SPV_FLOAT, .types2 = SPV_FLOAT, + .mat1 = true, .mat2 = true, + .generate = spirv_generate_matrix }, {"mul", SpvOpIMul, SPV_INT, SPV_INT }, {"mul", SpvOpFMul, SPV_FLOAT, SPV_FLOAT }, + {"mul", SpvOpMatrixTimesVector, SPV_FLOAT, SPV_FLOAT, + .mat1 = true, .mat2 = false }, + {"mul", SpvOpVectorTimesMatrix, SPV_FLOAT, SPV_FLOAT, + .mat1 = false, .mat2 = true }, + {"mul", SpvOpMatrixTimesMatrix, SPV_FLOAT, SPV_FLOAT, + .mat1 = true, .mat2 = true }, {"div", SpvOpUDiv, SPV_UINT, SPV_UINT }, {"div", SpvOpSDiv, SPV_SINT, SPV_SINT }, {"div", SpvOpFDiv, SPV_FLOAT, SPV_FLOAT }, @@ -894,6 +941,8 @@ static spvop_t spv_ops[] = { {"dot", SpvOpDot, SPV_FLOAT, SPV_FLOAT }, {"scale", SpvOpVectorTimesScalar, SPV_FLOAT, SPV_FLOAT }, + {"scale", SpvOpMatrixTimesScalar, SPV_FLOAT, SPV_FLOAT, + .mat1 = true, .mat2 = false }, {"cross", GLSLstd450Cross, SPV_FLOAT, SPV_FLOAT, .extinst = &glsl_450 }, @@ -908,15 +957,21 @@ static spvop_t spv_ops[] = { }; static const spvop_t * -spirv_find_op (const char *op_name, etype_t type1, etype_t type2) +spirv_find_op (const char *op_name, const type_t *type1, const type_t *type2) { constexpr int num_ops = sizeof (spv_ops) / sizeof (spv_ops[0]); + etype_t t1 = type1->type; + etype_t t2 = type2 ? type2->type : ev_void; + bool mat1 = is_matrix (type1); + bool mat2 = is_matrix (type2); + for (int i = 0; i < num_ops; i++) { if (strcmp (spv_ops[i].op_name, op_name) == 0 - && spv_ops[i].types1 & SPV_type(type1) - && ((!spv_ops[i].types2 && type2 == ev_void) + && spv_ops[i].types1 & SPV_type(t1) + && ((!spv_ops[i].types2 && t2 == ev_void) || (spv_ops[i].types2 - && (spv_ops[i].types2 & SPV_type(type2))))) { + && (spv_ops[i].types2 & SPV_type(t2)))) + && spv_ops[i].mat1 == mat1 && spv_ops[i].mat2 == mat2) { return &spv_ops[i]; } } @@ -983,14 +1038,14 @@ spirv_uexpr (const expr_t *e, spirvctx_t *ctx) } } auto t = get_type (e->expr.e1); - auto spv_op = spirv_find_op (op_name, t->type, 0); + auto spv_op = spirv_find_op (op_name, t, nullptr); if (!spv_op) { internal_error (e, "unexpected unary op_name: %s %s\n", op_name, - pr_type_name[t->type]); + get_type_string(t)); } if (!spv_op->op) { internal_error (e, "unimplemented op: %s %s\n", op_name, - pr_type_name[t->type]); + get_type_string(t)); } unsigned uid = spirv_emit_expr (e->expr.e1, ctx); @@ -1034,17 +1089,20 @@ spirv_expr (const expr_t *e, spirvctx_t *ctx) internal_error (e, "unexpected binary op: %d\n", e->expr.op); } auto t1 = get_type (e->expr.e1); - auto t2 = get_type (e->expr.e1); - auto spv_op = spirv_find_op (op_name, t1->type, t2->type); + auto t2 = get_type (e->expr.e2); + auto spv_op = spirv_find_op (op_name, t1, t2); if (!spv_op) { internal_error (e, "unexpected binary op_name: %s %s %s\n", op_name, - pr_type_name[t1->type], - pr_type_name[t2->type]); + get_type_string(t1), + get_type_string(t2)); } - if (!spv_op->op) { + if (!spv_op->op && !spv_op->generate) { internal_error (e, "unimplemented op: %s %s %s\n", op_name, - pr_type_name[t1->type], - pr_type_name[t2->type]); + get_type_string(t1), + get_type_string(t2)); + } + if (spv_op->generate) { + return spv_op->generate (e, ctx); } unsigned bid1 = spirv_emit_expr (e->expr.e1, ctx); unsigned bid2 = spirv_emit_expr (e->expr.e2, ctx);