[qfcc] Move nested scale handling to distribute_product

Or really, the implementers. This gets my fancy zero test down to just
unrecognized permutations of commutative multiplies and dot products
(with the multiplies above the dot products).
This commit is contained in:
Bill Currie 2023-09-29 15:06:07 +09:00
parent 55ee8562df
commit f1f87527aa

View file

@ -766,6 +766,10 @@ do_mult (type_t *type, const expr_t *a, const expr_t *b)
static const expr_t *
do_scale (type_t *type, const expr_t *a, const expr_t *b)
{
const expr_t *prod = extract_scale (&a, b);
if (prod) {
b = prod;
}
return typed_binary_expr (type, SCALE, a, b);
}
@ -783,11 +787,6 @@ scale_expr (type_t *type, const expr_t *a, const expr_t *b)
internal_error (b, "not a real scalar type");
}
const expr_t *prod = extract_scale (&a, b);
if (prod) {
b = prod;
}
auto op = is_scalar (get_type (a)) ? do_mult : do_scale;
auto scale = distribute_product (type, a, b, op, false);
if (!scale) {
@ -824,7 +823,14 @@ do_dot (type_t *type, const expr_t *a, const expr_t *b)
if (reject_dot (a, b)) {
return 0;
}
return typed_binary_expr (type, DOT, a, b);
const expr_t *prod = 0;
prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod);
auto dot = typed_binary_expr (type, DOT, a, b);
dot = apply_scale (type, dot, prod);
return dot;
}
static const expr_t *
@ -835,12 +841,7 @@ dot_expr (type_t *type, const expr_t *a, const expr_t *b)
return 0;
}
const expr_t *prod = 0;
prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod);
auto dot = distribute_product (type, a, b, do_dot, false);
dot = apply_scale (type, dot, prod);
return dot;
}
@ -856,7 +857,14 @@ do_cross (type_t *type, const expr_t *a, const expr_t *b)
if (reject_cross (a, b)) {
return 0;
}
return typed_binary_expr (type, CROSS, a, b);
const expr_t *prod = 0;
prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod);
auto cross = typed_binary_expr (type, CROSS, a, b);
cross = apply_scale (type, cross, prod);
return cross;
}
static const expr_t *
@ -867,12 +875,7 @@ cross_expr (type_t *type, const expr_t *a, const expr_t *b)
return 0;
}
const expr_t *prod = 0;
prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod);
auto cross = distribute_product (type, a, b, do_cross, true);
cross = apply_scale (type, cross, prod);
return cross;
}
@ -888,7 +891,14 @@ do_wedge (type_t *type, const expr_t *a, const expr_t *b)
if (reject_wedge (a, b)) {
return 0;
}
return typed_binary_expr (type, WEDGE, a, b);
const expr_t *prod = 0;
prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod);
auto wedge = typed_binary_expr (type, WEDGE, a, b);
wedge = apply_scale (type, wedge, prod);
return wedge;
}
static const expr_t *
@ -898,13 +908,7 @@ wedge_expr (type_t *type, const expr_t *a, const expr_t *b)
// propagated zero
return 0;
}
const expr_t *prod = 0;
prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod);
auto wedge = distribute_product (type, a, b, do_wedge, true);
wedge = apply_scale (type, wedge, prod);
return wedge;
}