[qfcc] Clean up nested scale handling

This is incomplete in that the handling needs to be moved into
distribute_product, but the infrastructure is there without breaking
anything.
This commit is contained in:
Bill Currie 2023-09-29 13:13:35 +09:00
parent a21c857579
commit 355d3d76b4

View file

@ -732,6 +732,29 @@ distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
} }
} }
static const expr_t *scale_expr (type_t *type,
const expr_t *a, const expr_t *b);
static const expr_t *
extract_scale (const expr_t **expr, const expr_t *prod)
{
if (is_scale (*expr)) {
auto s = (*expr)->expr.e2;
prod = prod ? scale_expr (get_type (prod), prod, s) : s;
*expr = (*expr)->expr.e1;
}
return prod;
}
static const expr_t *
apply_scale (type_t *type, const expr_t *expr, const expr_t *prod)
{
if (expr && prod) {
expr = scale_expr (type, expr, prod);
}
return expr;
}
static const expr_t * static const expr_t *
scale_expr (type_t *type, const expr_t *a, const expr_t *b) scale_expr (type_t *type, const expr_t *a, const expr_t *b)
{ {
@ -745,14 +768,13 @@ scale_expr (type_t *type, const expr_t *a, const expr_t *b)
if (!is_real (get_type (b))) { if (!is_real (get_type (b))) {
internal_error (b, "not a real scalar type"); internal_error (b, "not a real scalar type");
} }
int op = is_scalar (get_type (a)) ? '*' : SCALE;
if (is_scale (a)) { const expr_t *prod = extract_scale (&a, b);
// covert scale (scale (X, y), z) to scale (X, y*z) if (prod) {
b = scale_expr (get_type (b), b, a->expr.e2); b = prod;
a = a->expr.e1;
} }
int op = is_scalar (get_type (a)) ? '*' : SCALE;
auto scale = distribute_product (type, op, a, b, 0); auto scale = distribute_product (type, op, a, b, 0);
if (!scale) { if (!scale) {
return 0; return 0;
@ -791,23 +813,11 @@ dot_expr (type_t *type, const expr_t *a, const expr_t *b)
} }
const expr_t *prod = 0; const expr_t *prod = 0;
if (is_scale (a)) { prod = extract_scale (&a, prod);
prod = a->expr.e2; prod = extract_scale (&b, prod);
a = a->expr.e1;
}
if (is_scale (b)) {
auto s = b->expr.e2;
prod = prod ? scale_expr (get_type (prod), prod, s) : s;
b = b->expr.e1;
}
auto dot = distribute_product (type, DOT, a, b, reject_dot); auto dot = distribute_product (type, DOT, a, b, reject_dot);
if (!dot) { dot = apply_scale (type, dot, prod);
return 0;
}
if (prod) {
dot = scale_expr (type, dot, prod);
}
return dot; return dot;
} }
@ -825,7 +835,12 @@ cross_expr (type_t *type, const expr_t *a, const expr_t *b)
return 0; return 0;
} }
const expr_t *prod = 0;
prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod);
auto cross = distribute_product (type, CROSS, a, b, reject_cross); auto cross = distribute_product (type, CROSS, a, b, reject_cross);
cross = apply_scale (type, cross, prod);
return cross; return cross;
} }
@ -842,7 +857,13 @@ wedge_expr (type_t *type, const expr_t *a, const expr_t *b)
// propagated zero // propagated zero
return 0; return 0;
} }
const expr_t *prod = 0;
prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod);
auto wedge = distribute_product (type, WEDGE, a, b, reject_wedge); auto wedge = distribute_product (type, WEDGE, a, b, reject_wedge);
wedge = apply_scale (type, wedge, prod);
return wedge; return wedge;
} }