[qfcc] Implement 2d PGA dot (inner) product

This has shown the need for more instructions, such as a 2d wedge
product and narrower swizzles. Also, making dot product produce a vector
instead of a scalar was a big mistake (works nicely in C, but not so
well in Ruamoko).
This commit is contained in:
Bill Currie 2023-08-24 00:19:58 +09:00
parent 469fdea0a1
commit 6d75d91de2
2 changed files with 167 additions and 15 deletions

View file

@ -225,18 +225,6 @@ scalar_product (expr_t *e1, expr_t *e2)
return mvec_gather (components, algebra);
}
static expr_t *
inner_product (expr_t *e1, expr_t *e2)
{
if (is_scalar (get_type (e1)) || is_scalar (get_type (e2))) {
auto scalar = is_scalar (get_type (e1)) ? e1 : e2;
notice (scalar,
"the inner product of a scalar with any other grade is 0");
return new_zero_expr (get_type (scalar));
}
internal_error (e1, "not implemented");
}
static void
component_sum (int op, expr_t **c, expr_t **a, expr_t **b,
algebra_t *algebra)
@ -297,7 +285,7 @@ dot_expr (type_t *type, expr_t *a, expr_t *b)
return dot;
}
typedef void (*pga3_wedge) (expr_t **c, expr_t *a, expr_t *b, algebra_t *alg);
typedef void (*pga_func) (expr_t **c, expr_t *a, expr_t *b, algebra_t *alg);
static void
scale_component (expr_t **c, expr_t *a, expr_t *b, algebra_t *alg)
{
@ -306,6 +294,168 @@ scale_component (expr_t **c, expr_t *a, expr_t *b, algebra_t *alg)
c[group] = scale;
}
static pga_func pga3_dot_funcs[6][6] = {
[2] = {
[2] = scale_component,
},
};
static void
pga2_yw_wx_xy_dot_yw_wx_xy (expr_t **c, expr_t *a, expr_t *b, algebra_t *alg)
{
auto stype = alg->type;
auto sa = new_offset_alias_expr (stype, a, 2);
auto sb = new_offset_alias_expr (stype, b, 2);
c[1] = unary_expr ('-', scale_expr (sa, sb, alg));
}
static void
pga2_yw_wx_xy_dot_x_y_w (expr_t **c, expr_t *a, expr_t *b, algebra_t *alg)
{
auto stype = alg->type;
auto dot_type = algebra_mvec_type (alg, 0x04);
auto va = new_offset_alias_expr (dot_type, new_swizzle_expr (a, "y-x0"), 0);
auto sb = new_offset_alias_expr (stype, b, 2);
auto tmp = new_binary_expr (CROSS, a, b);
tmp->e.expr.type = dot_type;
tmp = new_offset_alias_expr (dot_type, new_swizzle_expr (tmp, "00z"), 0);
c[2] = new_binary_expr ('+', scale_expr (va, sb, alg), tmp);
c[2]->e.expr.type = dot_type;
}
static void
pga2_x_y_w_dot_yw_wx_xy (expr_t **c, expr_t *a, expr_t *b, algebra_t *alg)
{
auto stype = alg->type;
auto dot_type = algebra_mvec_type (alg, 0x04);
auto va = new_offset_alias_expr (dot_type, new_swizzle_expr (a, "-yx0"), 0);
auto sb = new_offset_alias_expr (stype, b, 2);
auto tmp = new_binary_expr (CROSS, a, b);
tmp->e.expr.type = dot_type;
tmp = new_offset_alias_expr (dot_type, new_swizzle_expr (tmp, "00-z"), 0);
c[2] = new_binary_expr ('+', scale_expr (va, sb, alg), tmp);
c[2]->e.expr.type = dot_type;
}
static void
pga2_x_y_w_dot_x_y_w (expr_t **c, expr_t *a, expr_t *b, algebra_t *alg)
{
auto stype = alg->type;
auto vtype = vector_type (stype, 2);
auto va = new_offset_alias_expr (vtype, a, 0);
auto vb = new_offset_alias_expr (vtype, b, 0);
auto cs = dot_expr (stype, va, vb);
c[1] = cs;
}
static void
pga2_x_y_w_dot_wxy (expr_t **c, expr_t *a, expr_t *b, algebra_t *alg)
{
auto stype = alg->type;
auto vtype = vector_type (stype, 2);
auto dot_type = algebra_mvec_type (alg, 0x01);
auto va = new_offset_alias_expr (vtype, a, 0);
auto cv = scale_expr (va, b, alg);
auto tmp = new_temp_def_expr (dot_type);
auto vtmp = new_offset_alias_expr (vtype, tmp, 0);
auto block = new_block_expr ();
block->e.block.result = tmp;
append_expr (block, assign_expr (vtmp, cv));
c[0] = block;
}
static void
pga2_wxy_dot_x_y_w (expr_t **c, expr_t *a, expr_t *b, algebra_t *alg)
{
auto stype = alg->type;
auto vtype = vector_type (stype, 2);
auto dot_type = algebra_mvec_type (alg, 0x01);
auto vb = new_offset_alias_expr (vtype, b, 0);
auto cv = scale_expr (vb, a, alg);
auto tmp = new_temp_def_expr (dot_type);
auto vtmp = new_offset_alias_expr (vtype, tmp, 0);
auto block = new_block_expr ();
block->e.block.result = tmp;
append_expr (block, assign_expr (vtmp, unary_expr ('-', cv)));
c[0] = block;
}
static pga_func pga2_dot_funcs[4][4] = {
[0] = {
[0] = pga2_yw_wx_xy_dot_yw_wx_xy,
[2] = pga2_yw_wx_xy_dot_x_y_w,
},
[1] = {
[1] = scale_component,
},
[2] = {
[0] = pga2_x_y_w_dot_yw_wx_xy,
[2] = pga2_x_y_w_dot_x_y_w,
[3] = pga2_x_y_w_dot_wxy,
},
[3] = {
[2] = pga2_wxy_dot_x_y_w,
},
};
static void
component_dot (expr_t **c, expr_t *a, expr_t *b, algebra_t *algebra)
{
int p = algebra->plus;
int m = algebra->minus;
int z = algebra->zero;
if (p == 3 && m == 0 && z == 1) {
int ga = get_group (get_type (a), algebra);
int gb = get_group (get_type (b), algebra);
if (pga3_dot_funcs[ga][gb]) {
pga3_dot_funcs[ga][gb] (c, a, b, algebra);
}
} else if (p == 2 && m == 0 && z == 1) {
int ga = get_group (get_type (a), algebra);
int gb = get_group (get_type (b), algebra);
if (pga2_dot_funcs[ga][gb]) {
pga2_dot_funcs[ga][gb] (c, a, b, algebra);
}
} else {
}
}
static expr_t *
inner_product (expr_t *e1, expr_t *e2)
{
if (is_scalar (get_type (e1)) || is_scalar (get_type (e2))) {
auto scalar = is_scalar (get_type (e1)) ? e1 : e2;
notice (scalar,
"the inner product of a scalar with any other grade is 0");
return new_zero_expr (get_type (scalar));
}
auto t1 = get_type (e1);
auto t2 = get_type (e2);
auto algebra = is_algebra (t1) ? algebra_get (t1) : algebra_get (t2);
auto layout = &algebra->layout;
expr_t *a[layout->count] = {};
expr_t *b[layout->count] = {};
expr_t *c[layout->count] = {};
e1 = mvec_expr (e1, algebra);
e2 = mvec_expr (e2, algebra);
mvec_scatter (a, e1, algebra);
mvec_scatter (b, e2, algebra);
for (int i = 0; i < layout->count; i++) {
for (int j = 0; j < layout->count; j++) {
if (a[i] && b[j]) {
expr_t *w[layout->count] = {};
component_dot (w, a[i], b[j], algebra);
component_sum ('+', c, c, w, algebra);
}
}
}
return mvec_gather (c, algebra);
}
static void
pga3_x_y_z_w_wedge_x_y_z_w (expr_t **c, expr_t *a, expr_t *b, algebra_t *alg)
{
@ -387,7 +537,7 @@ pga3_wzy_wxz_wyx_xyz_wedge_x_y_z_w (expr_t **c, expr_t *a, expr_t *b,
c[4] = unary_expr ('-', dot_expr (algebra_mvec_type (alg, 0x10), a, b));
}
static void (*pga3_wedge_funcs[6][6])(expr_t**,expr_t*,expr_t*,algebra_t*) = {
static pga_func pga3_wedge_funcs[6][6] = {
[0] = {
[0] = pga3_x_y_z_w_wedge_x_y_z_w,
[1] = pga3_x_y_z_w_wedge_yz_zx_xy,
@ -438,7 +588,7 @@ pga2_x_y_w_wedge_x_y_w (expr_t **c, expr_t *a, expr_t *b, algebra_t *alg)
c[0]->e.expr.type = wedge_type;
}
static void (*pga2_wedge_funcs[4][4])(expr_t**,expr_t*,expr_t*,algebra_t*) = {
static pga_func pga2_wedge_funcs[4][4] = {
[0] = {
[1] = scale_component,
[2] = pga2_yw_wx_xy_wedge_x_y_w,

View file

@ -1,3 +1,4 @@
int foo[128];
@algebra(float) pgaf1;
//@algebra(double) pgad1;
//@algebra(float(3)) vgaf;
@ -42,6 +43,7 @@ main (void)
auto l2 = 3 * e1 - e2 + 10 * e0;
auto p = l1l2;
pga2 = p + (1 + p)l1;
pga2 = l1p;
}
return 0; // to survive and prevail :)
}