[qfcc] Implement matrix ops in spirv

And they pass spirv-val.
This commit is contained in:
Bill Currie 2024-12-05 02:00:35 +09:00
parent a9bff23a6a
commit 4b7025fa0b

View file

@ -801,6 +801,8 @@ typedef struct {
SpvOp op;
unsigned types1;
unsigned types2;
bool mat1;
bool mat2;
spirv_expr_f generate;
extinst_t *extinst;
} spvop_t;
@ -829,6 +831,39 @@ spirv_generate_vqmul (const expr_t *e, spirvctx_t *ctx)
return 0;
}
static unsigned
spirv_generate_matrix (const expr_t *e, spirvctx_t *ctx)
{
auto mat_type = get_type (e);
int count = type_cols (mat_type);
scoped_src_loc (e);
unsigned columns[count];
auto col_type = column_type (mat_type);
auto e1 = e->expr.e1;
auto e2 = e->expr.e2;
for (int i = 0; i < count; i++) {
auto ind = new_int_expr (i, false);
auto a = new_array_expr (e1, ind);
auto b = new_array_expr (e2, ind);
a->array.type = col_type;
b->array.type = col_type;
auto c = typed_binary_expr (col_type, e->expr.op, a, b);
columns[i] = spirv_emit_expr (c, ctx);
}
int tid = type_id (mat_type, ctx);
int id = spirv_id (ctx);
auto insn = spirv_new_insn (SpvOpCompositeConstruct, 3 + count,
ctx->code_space);
INSN (insn, 1) = tid;
INSN (insn, 2) = id;
for (int i = 0; i < count; i++) {
INSN (insn, 3 + i) = columns[i];
}
return id;
}
#define SPV_type(t) (1<<(t))
#define SPV_SINT (SPV_type(ev_int)|SPV_type(ev_long))
#define SPV_UINT (SPV_type(ev_uint)|SPV_type(ev_ulong))
@ -868,10 +903,22 @@ static spvop_t spv_ops[] = {
{"add", SpvOpIAdd, SPV_INT, SPV_INT },
{"add", SpvOpFAdd, SPV_FLOAT, SPV_FLOAT },
{"add", .types1 = SPV_FLOAT, .types2 = SPV_FLOAT,
.mat1 = true, .mat2 = true,
.generate = spirv_generate_matrix },
{"sub", SpvOpISub, SPV_INT, SPV_INT },
{"sub", SpvOpFSub, SPV_FLOAT, SPV_FLOAT },
{"add", .types1 = SPV_FLOAT, .types2 = SPV_FLOAT,
.mat1 = true, .mat2 = true,
.generate = spirv_generate_matrix },
{"mul", SpvOpIMul, SPV_INT, SPV_INT },
{"mul", SpvOpFMul, SPV_FLOAT, SPV_FLOAT },
{"mul", SpvOpMatrixTimesVector, SPV_FLOAT, SPV_FLOAT,
.mat1 = true, .mat2 = false },
{"mul", SpvOpVectorTimesMatrix, SPV_FLOAT, SPV_FLOAT,
.mat1 = false, .mat2 = true },
{"mul", SpvOpMatrixTimesMatrix, SPV_FLOAT, SPV_FLOAT,
.mat1 = true, .mat2 = true },
{"div", SpvOpUDiv, SPV_UINT, SPV_UINT },
{"div", SpvOpSDiv, SPV_SINT, SPV_SINT },
{"div", SpvOpFDiv, SPV_FLOAT, SPV_FLOAT },
@ -894,6 +941,8 @@ static spvop_t spv_ops[] = {
{"dot", SpvOpDot, SPV_FLOAT, SPV_FLOAT },
{"scale", SpvOpVectorTimesScalar, SPV_FLOAT, SPV_FLOAT },
{"scale", SpvOpMatrixTimesScalar, SPV_FLOAT, SPV_FLOAT,
.mat1 = true, .mat2 = false },
{"cross", GLSLstd450Cross, SPV_FLOAT, SPV_FLOAT,
.extinst = &glsl_450 },
@ -908,15 +957,21 @@ static spvop_t spv_ops[] = {
};
static const spvop_t *
spirv_find_op (const char *op_name, etype_t type1, etype_t type2)
spirv_find_op (const char *op_name, const type_t *type1, const type_t *type2)
{
constexpr int num_ops = sizeof (spv_ops) / sizeof (spv_ops[0]);
etype_t t1 = type1->type;
etype_t t2 = type2 ? type2->type : ev_void;
bool mat1 = is_matrix (type1);
bool mat2 = is_matrix (type2);
for (int i = 0; i < num_ops; i++) {
if (strcmp (spv_ops[i].op_name, op_name) == 0
&& spv_ops[i].types1 & SPV_type(type1)
&& ((!spv_ops[i].types2 && type2 == ev_void)
&& spv_ops[i].types1 & SPV_type(t1)
&& ((!spv_ops[i].types2 && t2 == ev_void)
|| (spv_ops[i].types2
&& (spv_ops[i].types2 & SPV_type(type2))))) {
&& (spv_ops[i].types2 & SPV_type(t2))))
&& spv_ops[i].mat1 == mat1 && spv_ops[i].mat2 == mat2) {
return &spv_ops[i];
}
}
@ -983,14 +1038,14 @@ spirv_uexpr (const expr_t *e, spirvctx_t *ctx)
}
}
auto t = get_type (e->expr.e1);
auto spv_op = spirv_find_op (op_name, t->type, 0);
auto spv_op = spirv_find_op (op_name, t, nullptr);
if (!spv_op) {
internal_error (e, "unexpected unary op_name: %s %s\n", op_name,
pr_type_name[t->type]);
get_type_string(t));
}
if (!spv_op->op) {
internal_error (e, "unimplemented op: %s %s\n", op_name,
pr_type_name[t->type]);
get_type_string(t));
}
unsigned uid = spirv_emit_expr (e->expr.e1, ctx);
@ -1034,17 +1089,20 @@ spirv_expr (const expr_t *e, spirvctx_t *ctx)
internal_error (e, "unexpected binary op: %d\n", e->expr.op);
}
auto t1 = get_type (e->expr.e1);
auto t2 = get_type (e->expr.e1);
auto spv_op = spirv_find_op (op_name, t1->type, t2->type);
auto t2 = get_type (e->expr.e2);
auto spv_op = spirv_find_op (op_name, t1, t2);
if (!spv_op) {
internal_error (e, "unexpected binary op_name: %s %s %s\n", op_name,
pr_type_name[t1->type],
pr_type_name[t2->type]);
get_type_string(t1),
get_type_string(t2));
}
if (!spv_op->op) {
if (!spv_op->op && !spv_op->generate) {
internal_error (e, "unimplemented op: %s %s %s\n", op_name,
pr_type_name[t1->type],
pr_type_name[t2->type]);
get_type_string(t1),
get_type_string(t2));
}
if (spv_op->generate) {
return spv_op->generate (e, ctx);
}
unsigned bid1 = spirv_emit_expr (e->expr.e1, ctx);
unsigned bid2 = spirv_emit_expr (e->expr.e2, ctx);