[qfcc] Use distributivity to cancel cross and dot products

For cross products: remove any a from a×(...+/-a...)
For dot products: remove any a×b from a•(...+/-a×b...) (or b×a)

This removed another 2 instructions :)
This commit is contained in:
Bill Currie 2023-09-25 19:40:29 +09:00
parent 7271d2d570
commit c01cbf4fc4

View file

@ -378,6 +378,67 @@ traverse_scale (expr_t *expr)
return expr;
}
static bool __attribute__((pure))
is_sum (const expr_t *expr)
{
return (expr && expr->type == ex_expr
&& (expr->expr.op == '+' || expr->expr.op == '-'));
}
static int __attribute__((pure))
count_terms (const expr_t *expr)
{
if (!is_sum (expr)) {
return 0;
}
auto e1 = expr->expr.e1;
auto e2 = expr->expr.e2;
int terms = !is_sum (e1) + !is_sum (e2);;
if (is_sum (e1)) {
terms += count_terms (expr->expr.e1);
}
if (is_sum (e2)) {
terms += count_terms (expr->expr.e2);
}
return terms;
}
static bool __attribute__((pure))
is_cross (const expr_t *expr)
{
return (expr && expr->type == ex_expr && (expr->expr.op == CROSS));
}
static void
distribute_terms_core (expr_t *sum,
expr_t **adds, int *addind,
expr_t **subs, int *subind, bool negative)
{
bool subtract = (sum->expr.op == '-') ^ negative;
auto e1 = sum->expr.e1;
auto e2 = sum->expr.e2;
if (is_sum (e1)) {
distribute_terms_core (e1, adds, addind, subs, subind, negative);
}
if (is_sum (e2)) {
distribute_terms_core (e2, adds, addind, subs, subind, subtract);
}
if (!is_sum (e1)) {
if (negative) {
subs[(*subind)++] = sum->expr.e1;
} else {
adds[(*addind)++] = sum->expr.e1;
}
}
if (!is_sum (e2)) {
if (subtract) {
subs[(*subind)++] = sum->expr.e2;
} else {
adds[(*addind)++] = sum->expr.e2;
}
}
}
static expr_t *
sum_expr (type_t *type, expr_t *a, expr_t *b)
{
@ -462,6 +523,31 @@ sum_expr (type_t *type, expr_t *a, expr_t *b)
return sum;
}
static void
distribute_terms (expr_t *sum, expr_t **adds, expr_t **subs)
{
int addind = 0;
int subind = 0;
if (!is_sum (sum)) {
internal_error (sum, "distribute_terms with no sum");
}
distribute_terms_core (sum, adds, &addind, subs, &subind, false);
}
static expr_t *
collect_terms (type_t *type, expr_t **adds, expr_t **subs)
{
expr_t *b = 0;
for (auto s = adds; *s; s++) {
b = sum_expr (type, b, *s);
}
for (auto s = subs; *s; s++) {
b = sum_expr (type, b, neg_expr (*s));
}
return b;
}
static void
component_sum (int op, expr_t **c, expr_t **a, expr_t **b,
algebra_t *algebra)
@ -533,6 +619,41 @@ scale_expr (type_t *type, expr_t *a, expr_t *b)
return scale;
}
static expr_t *
check_dot (expr_t *a, expr_t *b, int b_count)
{
expr_t *b_adds[b_count + 1] = {};
expr_t *b_subs[b_count + 1] = {};
a = traverse_scale (a);
distribute_terms (b, b_adds, b_subs);
expr_t **s, **d;
for (s = b_adds, d = s; *s; s++) {
auto c = traverse_scale (*s);
if (is_cross (c)) {
if (a == traverse_scale (c->expr.e1)
|| a == traverse_scale (c->expr.e2)) {
continue;
}
}
*d++ = *s;
}
*d = 0;
for (s = b_subs, d = s; *s; s++) {
auto c = traverse_scale (*s);
if (is_cross (c)) {
if (a == traverse_scale (c->expr.e1)
|| a == traverse_scale (c->expr.e2)) {
continue;
}
}
*d++ = *s;
}
*d = 0;
return collect_terms (get_type (b), b_adds, b_subs);
}
static expr_t *
dot_expr (type_t *type, expr_t *a, expr_t *b)
{
@ -550,6 +671,15 @@ dot_expr (type_t *type, expr_t *a, expr_t *b)
b = neg_expr (b);
}
int a_terms = count_terms (a);
int b_terms = count_terms (b);
if (a_terms && !b_terms) {
a = check_dot (b, a, a_terms);
} else if (!a_terms && b_terms) {
b = check_dot (a, b, b_terms);
}
auto dot = new_binary_expr (DOT, a, b);
dot->expr.type = type;
dot = edag_add_expr (dot);
@ -561,6 +691,31 @@ dot_expr (type_t *type, expr_t *a, expr_t *b)
return dot;
}
static expr_t *
check_cross (expr_t *a, expr_t *b, int b_count)
{
expr_t *b_adds[b_count + 1] = {};
expr_t *b_subs[b_count + 1] = {};
a = traverse_scale (a);
distribute_terms (b, b_adds, b_subs);
expr_t **s, **d;
for (s = b_adds, d = s; *s; s++) {
if (a != traverse_scale (*s)) {
*d++ = *s;
}
}
*d = 0;
for (s = b_subs, d = s; *s; s++) {
if (a != traverse_scale (*s)) {
*d++ = *s;
}
}
*d = 0;
return collect_terms (get_type (b), b_adds, b_subs);
}
static expr_t *
cross_expr (type_t *type, expr_t *a, expr_t *b)
{
@ -587,6 +742,15 @@ cross_expr (type_t *type, expr_t *a, expr_t *b)
return 0;
}
int a_terms = count_terms (a);
int b_terms = count_terms (b);
if (a_terms && !b_terms) {
a = check_cross (b, a, a_terms);
} else if (!a_terms && b_terms) {
b = check_cross (a, b, b_terms);
}
auto cross = new_binary_expr (CROSS, a, b);
cross->expr.type = type;
cross->expr.anticommute = true;