[qfcc] Move product negation into distribute_product

Now all the products are handled consistently, and distribute_product
takes advantage of anti-commutativity.
This commit is contained in:
Bill Currie 2023-09-29 11:00:44 +09:00
parent 14d1148523
commit 26dec9ee21

View file

@ -86,17 +86,25 @@ anti_com (const expr_t *e)
return e && e->type == ex_expr && e->expr.anticommute;
}
static bool __attribute__((const))
op_commute (int op)
{
return (op == '+' || op == '*' || op == HADAMARD || op == DOT);
}
static bool __attribute__((const))
op_anti_com (int op)
{
return (op == '-' || op == CROSS || op == WEDGE);
}
static const expr_t *
typed_binary_expr (type_t *type, int op, const expr_t *e1, const expr_t *e2)
{
auto e = new_binary_expr (op, e1, e2);
e->expr.type = type;
if (op == '+' || op == '*' || op == HADAMARD || op == DOT) {
e->expr.commutative = true;
}
if (op == '-' || op == CROSS || op == WEDGE) {
e->expr.anticommute = true;
}
e->expr.commutative = op_commute (op);
e->expr.anticommute = op_anti_com (op);
return e;
}
@ -634,6 +642,22 @@ static const expr_t *
distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
bool (*reject) (const expr_t *a, const expr_t *b))
{
bool neg = false;
if (is_neg (a)) {
neg = !neg;
a = neg_expr (a);
}
if (is_neg (b)) {
neg = !neg;
b = neg_expr (b);
}
if (op_anti_com (op) && neg) {
auto t = a;
a = b;
b = t;
neg = false;
}
int a_terms = count_terms (a);
int b_terms = count_terms (b);
@ -696,8 +720,15 @@ distribute_product (type_t *type, int op, const expr_t *a, const expr_t *b,
}
}
}
auto sum = sum_expr_low (type, '-', a, b);
return sum;
if (neg) {
// note order ----------------------V--V
auto sum = sum_expr_low (type, '-', b, a);
return sum;
} else {
// note order ----------------------V--V
auto sum = sum_expr_low (type, '-', a, b);
return sum;
}
}
static const expr_t *
@ -715,16 +746,6 @@ scale_expr (type_t *type, const expr_t *a, const expr_t *b)
}
int op = is_scalar (get_type (a)) ? '*' : SCALE;
bool neg = false;
if (is_neg (a)) {
neg = !neg;
a = neg_expr (a);
}
if (is_neg (b)) {
neg = !neg;
b = neg_expr (b);
}
if (is_scale (a)) {
// covert scale (scale (X, y), z) to scale (X, y*z)
b = scale_expr (get_type (b), b, a->expr.e2);
@ -737,11 +758,6 @@ scale_expr (type_t *type, const expr_t *a, const expr_t *b)
}
scale = fold_constants (scale);
scale = edag_add_expr (scale);
if (neg) {
scale = neg_expr (scale);
scale = fold_constants (scale);
scale = edag_add_expr (scale);
}
scale = cast_expr (type, scale);
scale = edag_add_expr (scale);
return scale;
@ -809,15 +825,6 @@ dot_expr (type_t *type, const expr_t *a, const expr_t *b)
// propagated zero
return 0;
}
bool neg = false;
if (is_neg (a)) {
neg = !neg;
a = neg_expr (a);
}
if (is_neg (b)) {
neg = !neg;
b = neg_expr (b);
}
int a_terms = count_terms (a);
int b_terms = count_terms (b);
@ -850,10 +857,6 @@ dot_expr (type_t *type, const expr_t *a, const expr_t *b)
if (prod) {
dot = scale_expr (type, dot, prod);
}
if (neg) {
dot = neg_expr (dot);
dot = edag_add_expr (dot);
}
return dot;
}
@ -895,20 +898,6 @@ cross_expr (type_t *type, const expr_t *a, const expr_t *b)
// propagated zero
return 0;
}
bool neg = false;
if (is_neg (a)) {
neg = !neg;
a = neg_expr (a);
}
if (is_neg (b)) {
neg = !neg;
b = neg_expr (b);
}
if (neg) {
auto t = a;
a = b;
b = t;
}
if (traverse_scale (a) == traverse_scale (b)) {
return 0;
@ -944,20 +933,6 @@ wedge_expr (type_t *type, const expr_t *a, const expr_t *b)
// propagated zero
return 0;
}
bool neg = false;
if (is_neg (a)) {
neg = !neg;
a = neg_expr (a);
}
if (is_neg (b)) {
neg = !neg;
b = neg_expr (b);
}
if (neg) {
auto t = a;
a = b;
b = t;
}
auto wedge = distribute_product (type, WEDGE, a, b, reject_wedge);
if (!wedge) {
return 0;