[qfcc] Simplify algebraic sums

This doesn't fix the motor bug, but it doesn't make it worse. However,
it does simplify the trees quite a bit, so it should be easier to debug.
It seems the problem has something to do with fold_constants messing up
dagged aliases: in particular, const-folding multiplication by e0123 in
3d PGA as fold_constants sees it as 1.
This commit is contained in:
Bill Currie 2023-09-26 14:37:22 +09:00
parent 198cec6df8
commit 153a19937f
2 changed files with 86 additions and 92 deletions

View file

@ -80,15 +80,6 @@ is_neg (const expr_t *e)
return e->type == ex_uexpr && e->expr.op == '-';
}
static expr_t *
ext_expr (expr_t *src, type_t *type, int extend, bool reverse)
{
if (!src) {
return 0;
}
return edag_add_expr (new_extend_expr (src, type, extend, reverse));
}
static bool __attribute__((const))
anti_com (const expr_t *e)
{
@ -115,6 +106,24 @@ neg_expr (expr_t *e)
return edag_add_expr (fold_constants (e));
}
static expr_t *
ext_expr (expr_t *src, type_t *type, int extend, bool reverse)
{
if (!src) {
return 0;
}
bool neg = false;
if (is_neg (src)) {
neg = true;
src = neg_expr (src);
}
auto ext = edag_add_expr (new_extend_expr (src, type, extend, reverse));
if (neg) {
ext = neg_expr (ext);
}
return ext;
}
static expr_t *
alias_expr (type_t *type, expr_t *e, int offset)
{
@ -353,7 +362,7 @@ mvec_gather (expr_t **components, algebra_t *algebra)
}
return mvec;
}
#if 0
static bool __attribute__((const))
ext_compat (const expr_t *a, const expr_t *b)
{
@ -368,7 +377,7 @@ extract_extended_neg (const expr_t *expr)
auto e = expr->extend;
return neg_expr (ext_expr (neg_expr (e.src), e.type, e.extend, e.reverse));
}
#endif
static __attribute__((pure)) expr_t *
traverse_scale (expr_t *expr)
{
@ -424,102 +433,43 @@ distribute_terms_core (expr_t *sum,
distribute_terms_core (e2, adds, addind, subs, subind, subtract);
}
if (!is_sum (e1)) {
if (negative) {
subs[(*subind)++] = sum->expr.e1;
} else {
adds[(*addind)++] = sum->expr.e1;
auto e = sum->expr.e1;
auto arr = negative ^ is_neg (e) ? subs : adds;
auto ind = negative ^ is_neg (e) ? subind : addind;
if (is_neg (e)) {
e = neg_expr (e);
}
arr[(*ind)++] = e;
}
if (!is_sum (e2)) {
if (subtract) {
subs[(*subind)++] = sum->expr.e2;
} else {
adds[(*addind)++] = sum->expr.e2;
auto e = sum->expr.e2;
auto arr = subtract ^ is_neg (e) ? subs : adds;
auto ind = subtract ^ is_neg (e) ? subind : addind;
if (is_neg (e)) {
e = neg_expr (e);
}
arr[(*ind)++] = e;
}
}
static expr_t *
sum_expr (type_t *type, expr_t *a, expr_t *b)
sum_expr_low (type_t *type, int op, expr_t *a, expr_t *b)
{
if (!a) {
return b;
return op == '-' ? neg_expr (b) : b;
}
if (!b) {
return a;
}
if (a->type != ex_extend && b->type == ex_extend) {
// ensure always ext + something
return sum_expr (type, b, a);
}
if (a->type == ex_extend) {
if (b->type == ex_extend) {
if (ext_compat (a, b)) {
// push the sum below the two extends making for a single
// extend of a sum instead of a sum of two extends
auto ext = a->extend;
a = a->extend.src;
b = b->extend.src;
auto sum = sum_expr (get_type (a), a, b);
return ext_expr (sum, type, ext.extend, ext.reverse);
}
} else if (b->type == ex_expr && b->expr.op == '+') {
auto c = b->expr.e1;
auto d = b->expr.e2;
if (ext_compat (a, c)) {
// d should not be compatible with a because it should have
// already been merged
auto ext = a->extend;
a = a->extend.src;
c = c->extend.src;
auto sum = sum_expr (get_type (a), a, c);
sum = ext_expr (sum, type, ext.extend, ext.reverse);
return sum_expr (type, sum, d);
}
if (ext_compat (a, d)) {
// c should not be compatible with a because it should have
// already been merged
auto ext = a->extend;
a = a->extend.src;
d = d->extend.src;
auto sum = sum_expr (get_type (a), a, d);
sum = ext_expr (sum, type, ext.extend, ext.reverse);
return sum_expr (type, sum, c);
}
}
}
if (a->type == ex_extend && is_neg (a->extend.src)) {
a = extract_extended_neg (a);
}
if (b->type == ex_extend && is_neg (b->extend.src)) {
b = extract_extended_neg (b);
}
bool neg = false;
if ((is_neg (a) && (is_neg (b) || anti_com (b)))
|| (anti_com (a) && is_neg (b))) {
neg = true;
a = neg_expr (a);
b = neg_expr (b);
}
int op = '+';
if (is_neg (a) && b->type != ex_extend) {
auto t = a;
a = b;
b = t;
op = '-';
b = neg_expr (b);
} else if (a->type != ex_extend && is_neg (b)) {
op = '-';
b = neg_expr (b);
if (op == '-' && a == b) {
return 0;
}
auto sum = new_binary_expr (op, a, b);
sum->expr.type = type;
sum->expr.commutative = op == '+';
sum->expr.anticommute = op == '-';
sum = edag_add_expr (sum);
if (neg) {
sum = neg_expr (sum);
}
return sum;
}
@ -538,14 +488,56 @@ distribute_terms (expr_t *sum, expr_t **adds, expr_t **subs)
static expr_t *
collect_terms (type_t *type, expr_t **adds, expr_t **subs)
{
expr_t *a = 0;
expr_t *b = 0;
for (auto s = adds; *s; s++) {
b = sum_expr (type, b, *s);
a = sum_expr_low (type, '+', a, *s);
}
for (auto s = subs; *s; s++) {
b = sum_expr (type, b, neg_expr (*s));
b = sum_expr_low (type, '+', b, *s);
}
return b;
auto sum = sum_expr_low (type, '-', a, b);
return sum;
}
static expr_t *
sum_expr (type_t *type, expr_t *a, expr_t *b)
{
if (!a) {
return b;
}
if (!b) {
return a;
}
auto sum = new_binary_expr ('+', a, b);
int num_terms = count_terms (sum);
expr_t *adds[num_terms + 1] = {};
expr_t *subs[num_terms + 1] = {};
distribute_terms (sum, adds, subs);
expr_t **dstadd, **srcadd;
expr_t **dstsub, **srcsub;
for (dstadd = adds, srcadd = adds; *srcadd; srcadd++) {
for (dstsub = subs, srcsub = subs; *srcsub; srcsub++) {
if (*srcadd == *srcsub) {
// found a-a
break;
}
}
if (*srcsub++) {
// found a-a
while (*srcsub) {
*dstsub++ = *srcsub++;
}
*dstsub = 0;
continue;
}
*dstadd++ = *srcadd;
}
*dstadd = 0;
sum = collect_terms (type, adds, subs);
return sum;
}
static void
@ -564,7 +556,9 @@ component_sum (int op, expr_t **c, expr_t **a, expr_t **b,
} else {
c[i] = sum_expr (sum_type, a[i], neg_expr (b[i]));
}
c[i] = fold_constants (c[i]);
if (c[i]) {
c[i] = fold_constants (c[i]);
}
} else if (a[i]) {
c[i] = a[i];
} else if (b[i]) {

View file

@ -347,7 +347,7 @@ test_geom (void)
d = { .mvec = tvec * bvec.bvec };
if ((vec4)d.vec != '22 -8 4 9' || (vec4)d.tvec != '-30 -85 34 0') {
printf ("vec * bvec != '22 -8 4 9' '-30 -85 34 0': %q %q\n",
printf ("tvec * bvec != '22 -8 4 9' '-30 -85 34 0': %q %q\n",
d.vec, d.tvec);
return 1;
}