[qfcc] Use an implementation function for products

Another step towards proper handling of nested scales.
This commit is contained in:
Bill Currie 2023-09-29 13:29:19 +09:00
parent 355d3d76b4
commit 55ee8562df

View file

@ -640,8 +640,10 @@ component_sum (int op, const expr_t **c, const expr_t **a, const expr_t **b,
} }
static const expr_t * static const expr_t *
distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b, distribute_product (type_t *type, const expr_t *a, const expr_t *b,
bool (*reject) (const expr_t *a, const expr_t *b)) const expr_t *(*product) (type_t *type,
const expr_t *a, const expr_t *b),
bool anti_com)
{ {
bool neg = false; bool neg = false;
if (is_neg (a)) { if (is_neg (a)) {
@ -652,7 +654,7 @@ distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
neg = !neg; neg = !neg;
b = neg_expr (b); b = neg_expr (b);
} }
if (op_anti_com (op) && neg) { if (anti_com && neg) {
auto t = a; auto t = a;
a = b; a = b;
b = t; b = t;
@ -683,8 +685,8 @@ distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
for (auto i = a_adds; *i; i++) { for (auto i = a_adds; *i; i++) {
for (auto j = b_adds; *j; j++) { for (auto j = b_adds; *j; j++) {
if (!reject || !reject(*i, *j)) { auto p = product (type, *i, *j);
auto p = typed_binary_expr (type, op, *i, *j); if (p) {
p = fold_constants (p); p = fold_constants (p);
p = edag_add_expr (p); p = edag_add_expr (p);
a = sum_expr (type, a, p); a = sum_expr (type, a, p);
@ -693,8 +695,8 @@ distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
} }
for (auto i = a_subs; *i; i++) { for (auto i = a_subs; *i; i++) {
for (auto j = b_subs; *j; j++) { for (auto j = b_subs; *j; j++) {
if (!reject || !reject(*i, *j)) { auto p = product (type, *i, *j);
auto p = typed_binary_expr (type, op, *i, *j); if (p) {
p = fold_constants (p); p = fold_constants (p);
p = edag_add_expr (p); p = edag_add_expr (p);
a = sum_expr (type, a, p); a = sum_expr (type, a, p);
@ -703,8 +705,8 @@ distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
} }
for (auto i = a_adds; *i; i++) { for (auto i = a_adds; *i; i++) {
for (auto j = b_subs; *j; j++) { for (auto j = b_subs; *j; j++) {
if (!reject || !reject(*i, *j)) { auto p = product (type, *i, *j);
auto p = typed_binary_expr (type, op, *i, *j); if (p) {
p = fold_constants (p); p = fold_constants (p);
p = edag_add_expr (p); p = edag_add_expr (p);
b = sum_expr (type, b, p); b = sum_expr (type, b, p);
@ -713,8 +715,8 @@ distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
} }
for (auto i = a_subs; *i; i++) { for (auto i = a_subs; *i; i++) {
for (auto j = b_adds; *j; j++) { for (auto j = b_adds; *j; j++) {
if (!reject || !reject(*i, *j)) { auto p = product (type, *i, *j);
auto p = typed_binary_expr (type, op, *i, *j); if (p) {
p = fold_constants (p); p = fold_constants (p);
p = edag_add_expr (p); p = edag_add_expr (p);
b = sum_expr (type, b, p); b = sum_expr (type, b, p);
@ -755,6 +757,18 @@ apply_scale (type_t *type, const expr_t *expr, const expr_t *prod)
return expr; return expr;
} }
static const expr_t *
do_mult (type_t *type, const expr_t *a, const expr_t *b)
{
return typed_binary_expr (type, '*', a, b);
}
static const expr_t *
do_scale (type_t *type, const expr_t *a, const expr_t *b)
{
return typed_binary_expr (type, SCALE, a, b);
}
static const expr_t * static const expr_t *
scale_expr (type_t *type, const expr_t *a, const expr_t *b) scale_expr (type_t *type, const expr_t *a, const expr_t *b)
{ {
@ -774,8 +788,8 @@ scale_expr (type_t *type, const expr_t *a, const expr_t *b)
b = prod; b = prod;
} }
int op = is_scalar (get_type (a)) ? '*' : SCALE; auto op = is_scalar (get_type (a)) ? do_mult : do_scale;
auto scale = distribute_product (type, op, a, b, 0); auto scale = distribute_product (type, a, b, op, false);
if (!scale) { if (!scale) {
return 0; return 0;
} }
@ -804,6 +818,15 @@ reject_dot (const expr_t *a, const expr_t *b)
return false; return false;
} }
static const expr_t *
do_dot (type_t *type, const expr_t *a, const expr_t *b)
{
if (reject_dot (a, b)) {
return 0;
}
return typed_binary_expr (type, DOT, a, b);
}
static const expr_t * static const expr_t *
dot_expr (type_t *type, const expr_t *a, const expr_t *b) dot_expr (type_t *type, const expr_t *a, const expr_t *b)
{ {
@ -816,7 +839,7 @@ dot_expr (type_t *type, const expr_t *a, const expr_t *b)
prod = extract_scale (&a, prod); prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod); prod = extract_scale (&b, prod);
auto dot = distribute_product (type, DOT, a, b, reject_dot); auto dot = distribute_product (type, a, b, do_dot, false);
dot = apply_scale (type, dot, prod); dot = apply_scale (type, dot, prod);
return dot; return dot;
} }
@ -827,6 +850,15 @@ reject_cross (const expr_t *a, const expr_t *b)
return traverse_scale (a) == traverse_scale (b); return traverse_scale (a) == traverse_scale (b);
} }
static const expr_t *
do_cross (type_t *type, const expr_t *a, const expr_t *b)
{
if (reject_cross (a, b)) {
return 0;
}
return typed_binary_expr (type, CROSS, a, b);
}
static const expr_t * static const expr_t *
cross_expr (type_t *type, const expr_t *a, const expr_t *b) cross_expr (type_t *type, const expr_t *a, const expr_t *b)
{ {
@ -839,7 +871,7 @@ cross_expr (type_t *type, const expr_t *a, const expr_t *b)
prod = extract_scale (&a, prod); prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod); prod = extract_scale (&b, prod);
auto cross = distribute_product (type, CROSS, a, b, reject_cross); auto cross = distribute_product (type, a, b, do_cross, true);
cross = apply_scale (type, cross, prod); cross = apply_scale (type, cross, prod);
return cross; return cross;
} }
@ -850,6 +882,15 @@ reject_wedge (const expr_t *a, const expr_t *b)
return traverse_scale (a) == traverse_scale (b); return traverse_scale (a) == traverse_scale (b);
} }
static const expr_t *
do_wedge (type_t *type, const expr_t *a, const expr_t *b)
{
if (reject_wedge (a, b)) {
return 0;
}
return typed_binary_expr (type, WEDGE, a, b);
}
static const expr_t * static const expr_t *
wedge_expr (type_t *type, const expr_t *a, const expr_t *b) wedge_expr (type_t *type, const expr_t *a, const expr_t *b)
{ {
@ -862,7 +903,7 @@ wedge_expr (type_t *type, const expr_t *a, const expr_t *b)
prod = extract_scale (&a, prod); prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod); prod = extract_scale (&b, prod);
auto wedge = distribute_product (type, WEDGE, a, b, reject_wedge); auto wedge = distribute_product (type, a, b, do_wedge, true);
wedge = apply_scale (type, wedge, prod); wedge = apply_scale (type, wedge, prod);
return wedge; return wedge;
} }