[qfcc] Use an implementation function for products

Another step towards proper handling of nested scales.
This commit is contained in:
Bill Currie 2023-09-29 13:29:19 +09:00
parent 355d3d76b4
commit 55ee8562df

View file

@ -640,8 +640,10 @@ component_sum (int op, const expr_t **c, const expr_t **a, const expr_t **b,
}
static const expr_t *
distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
bool (*reject) (const expr_t *a, const expr_t *b))
distribute_product (type_t *type, const expr_t *a, const expr_t *b,
const expr_t *(*product) (type_t *type,
const expr_t *a, const expr_t *b),
bool anti_com)
{
bool neg = false;
if (is_neg (a)) {
@ -652,7 +654,7 @@ distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
neg = !neg;
b = neg_expr (b);
}
if (op_anti_com (op) && neg) {
if (anti_com && neg) {
auto t = a;
a = b;
b = t;
@ -683,8 +685,8 @@ distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
for (auto i = a_adds; *i; i++) {
for (auto j = b_adds; *j; j++) {
if (!reject || !reject(*i, *j)) {
auto p = typed_binary_expr (type, op, *i, *j);
auto p = product (type, *i, *j);
if (p) {
p = fold_constants (p);
p = edag_add_expr (p);
a = sum_expr (type, a, p);
@ -693,8 +695,8 @@ distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
}
for (auto i = a_subs; *i; i++) {
for (auto j = b_subs; *j; j++) {
if (!reject || !reject(*i, *j)) {
auto p = typed_binary_expr (type, op, *i, *j);
auto p = product (type, *i, *j);
if (p) {
p = fold_constants (p);
p = edag_add_expr (p);
a = sum_expr (type, a, p);
@ -703,8 +705,8 @@ distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
}
for (auto i = a_adds; *i; i++) {
for (auto j = b_subs; *j; j++) {
if (!reject || !reject(*i, *j)) {
auto p = typed_binary_expr (type, op, *i, *j);
auto p = product (type, *i, *j);
if (p) {
p = fold_constants (p);
p = edag_add_expr (p);
b = sum_expr (type, b, p);
@ -713,8 +715,8 @@ distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
}
for (auto i = a_subs; *i; i++) {
for (auto j = b_adds; *j; j++) {
if (!reject || !reject(*i, *j)) {
auto p = typed_binary_expr (type, op, *i, *j);
auto p = product (type, *i, *j);
if (p) {
p = fold_constants (p);
p = edag_add_expr (p);
b = sum_expr (type, b, p);
@ -755,6 +757,18 @@ apply_scale (type_t *type, const expr_t *expr, const expr_t *prod)
return expr;
}
static const expr_t *
do_mult (type_t *type, const expr_t *a, const expr_t *b)
{
return typed_binary_expr (type, '*', a, b);
}
static const expr_t *
do_scale (type_t *type, const expr_t *a, const expr_t *b)
{
return typed_binary_expr (type, SCALE, a, b);
}
static const expr_t *
scale_expr (type_t *type, const expr_t *a, const expr_t *b)
{
@ -774,8 +788,8 @@ scale_expr (type_t *type, const expr_t *a, const expr_t *b)
b = prod;
}
int op = is_scalar (get_type (a)) ? '*' : SCALE;
auto scale = distribute_product (type, op, a, b, 0);
auto op = is_scalar (get_type (a)) ? do_mult : do_scale;
auto scale = distribute_product (type, a, b, op, false);
if (!scale) {
return 0;
}
@ -804,6 +818,15 @@ reject_dot (const expr_t *a, const expr_t *b)
return false;
}
static const expr_t *
do_dot (type_t *type, const expr_t *a, const expr_t *b)
{
if (reject_dot (a, b)) {
return 0;
}
return typed_binary_expr (type, DOT, a, b);
}
static const expr_t *
dot_expr (type_t *type, const expr_t *a, const expr_t *b)
{
@ -816,7 +839,7 @@ dot_expr (type_t *type, const expr_t *a, const expr_t *b)
prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod);
auto dot = distribute_product (type, DOT, a, b, reject_dot);
auto dot = distribute_product (type, a, b, do_dot, false);
dot = apply_scale (type, dot, prod);
return dot;
}
@ -827,6 +850,15 @@ reject_cross (const expr_t *a, const expr_t *b)
return traverse_scale (a) == traverse_scale (b);
}
static const expr_t *
do_cross (type_t *type, const expr_t *a, const expr_t *b)
{
if (reject_cross (a, b)) {
return 0;
}
return typed_binary_expr (type, CROSS, a, b);
}
static const expr_t *
cross_expr (type_t *type, const expr_t *a, const expr_t *b)
{
@ -839,7 +871,7 @@ cross_expr (type_t *type, const expr_t *a, const expr_t *b)
prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod);
auto cross = distribute_product (type, CROSS, a, b, reject_cross);
auto cross = distribute_product (type, a, b, do_cross, true);
cross = apply_scale (type, cross, prod);
return cross;
}
@ -850,6 +882,15 @@ reject_wedge (const expr_t *a, const expr_t *b)
return traverse_scale (a) == traverse_scale (b);
}
static const expr_t *
do_wedge (type_t *type, const expr_t *a, const expr_t *b)
{
if (reject_wedge (a, b)) {
return 0;
}
return typed_binary_expr (type, WEDGE, a, b);
}
static const expr_t *
wedge_expr (type_t *type, const expr_t *a, const expr_t *b)
{
@ -862,7 +903,7 @@ wedge_expr (type_t *type, const expr_t *a, const expr_t *b)
prod = extract_scale (&a, prod);
prod = extract_scale (&b, prod);
auto wedge = distribute_product (type, WEDGE, a, b, reject_wedge);
auto wedge = distribute_product (type, a, b, do_wedge, true);
wedge = apply_scale (type, wedge, prod);
return wedge;
}