diff --git a/tools/qfcc/source/expr_algebra.c b/tools/qfcc/source/expr_algebra.c index 02e071200..f179ddc3d 100644 --- a/tools/qfcc/source/expr_algebra.c +++ b/tools/qfcc/source/expr_algebra.c @@ -378,6 +378,67 @@ traverse_scale (expr_t *expr) return expr; } +static bool __attribute__((pure)) +is_sum (const expr_t *expr) +{ + return (expr && expr->type == ex_expr + && (expr->expr.op == '+' || expr->expr.op == '-')); +} + +static int __attribute__((pure)) +count_terms (const expr_t *expr) +{ + if (!is_sum (expr)) { + return 0; + } + auto e1 = expr->expr.e1; + auto e2 = expr->expr.e2; + int terms = !is_sum (e1) + !is_sum (e2);; + if (is_sum (e1)) { + terms += count_terms (expr->expr.e1); + } + if (is_sum (e2)) { + terms += count_terms (expr->expr.e2); + } + return terms; +} + +static bool __attribute__((pure)) +is_cross (const expr_t *expr) +{ + return (expr && expr->type == ex_expr && (expr->expr.op == CROSS)); +} + +static void +distribute_terms_core (expr_t *sum, + expr_t **adds, int *addind, + expr_t **subs, int *subind, bool negative) +{ + bool subtract = (sum->expr.op == '-') ^ negative; + auto e1 = sum->expr.e1; + auto e2 = sum->expr.e2; + if (is_sum (e1)) { + distribute_terms_core (e1, adds, addind, subs, subind, negative); + } + if (is_sum (e2)) { + distribute_terms_core (e2, adds, addind, subs, subind, subtract); + } + if (!is_sum (e1)) { + if (negative) { + subs[(*subind)++] = sum->expr.e1; + } else { + adds[(*addind)++] = sum->expr.e1; + } + } + if (!is_sum (e2)) { + if (subtract) { + subs[(*subind)++] = sum->expr.e2; + } else { + adds[(*addind)++] = sum->expr.e2; + } + } +} + static expr_t * sum_expr (type_t *type, expr_t *a, expr_t *b) { @@ -462,6 +523,31 @@ sum_expr (type_t *type, expr_t *a, expr_t *b) return sum; } +static void +distribute_terms (expr_t *sum, expr_t **adds, expr_t **subs) +{ + int addind = 0; + int subind = 0; + + if (!is_sum (sum)) { + internal_error (sum, "distribute_terms with no sum"); + } + distribute_terms_core (sum, adds, &addind, subs, &subind, false); +} + +static expr_t * +collect_terms (type_t *type, expr_t **adds, expr_t **subs) +{ + expr_t *b = 0; + for (auto s = adds; *s; s++) { + b = sum_expr (type, b, *s); + } + for (auto s = subs; *s; s++) { + b = sum_expr (type, b, neg_expr (*s)); + } + return b; +} + static void component_sum (int op, expr_t **c, expr_t **a, expr_t **b, algebra_t *algebra) @@ -533,6 +619,41 @@ scale_expr (type_t *type, expr_t *a, expr_t *b) return scale; } +static expr_t * +check_dot (expr_t *a, expr_t *b, int b_count) +{ + expr_t *b_adds[b_count + 1] = {}; + expr_t *b_subs[b_count + 1] = {}; + + a = traverse_scale (a); + + distribute_terms (b, b_adds, b_subs); + expr_t **s, **d; + for (s = b_adds, d = s; *s; s++) { + auto c = traverse_scale (*s); + if (is_cross (c)) { + if (a == traverse_scale (c->expr.e1) + || a == traverse_scale (c->expr.e2)) { + continue; + } + } + *d++ = *s; + } + *d = 0; + for (s = b_subs, d = s; *s; s++) { + auto c = traverse_scale (*s); + if (is_cross (c)) { + if (a == traverse_scale (c->expr.e1) + || a == traverse_scale (c->expr.e2)) { + continue; + } + } + *d++ = *s; + } + *d = 0; + return collect_terms (get_type (b), b_adds, b_subs); +} + static expr_t * dot_expr (type_t *type, expr_t *a, expr_t *b) { @@ -550,6 +671,15 @@ dot_expr (type_t *type, expr_t *a, expr_t *b) b = neg_expr (b); } + int a_terms = count_terms (a); + int b_terms = count_terms (b); + + if (a_terms && !b_terms) { + a = check_dot (b, a, a_terms); + } else if (!a_terms && b_terms) { + b = check_dot (a, b, b_terms); + } + auto dot = new_binary_expr (DOT, a, b); dot->expr.type = type; dot = edag_add_expr (dot); @@ -561,6 +691,31 @@ dot_expr (type_t *type, expr_t *a, expr_t *b) return dot; } +static expr_t * +check_cross (expr_t *a, expr_t *b, int b_count) +{ + expr_t *b_adds[b_count + 1] = {}; + expr_t *b_subs[b_count + 1] = {}; + + a = traverse_scale (a); + + distribute_terms (b, b_adds, b_subs); + expr_t **s, **d; + for (s = b_adds, d = s; *s; s++) { + if (a != traverse_scale (*s)) { + *d++ = *s; + } + } + *d = 0; + for (s = b_subs, d = s; *s; s++) { + if (a != traverse_scale (*s)) { + *d++ = *s; + } + } + *d = 0; + return collect_terms (get_type (b), b_adds, b_subs); +} + static expr_t * cross_expr (type_t *type, expr_t *a, expr_t *b) { @@ -587,6 +742,15 @@ cross_expr (type_t *type, expr_t *a, expr_t *b) return 0; } + int a_terms = count_terms (a); + int b_terms = count_terms (b); + + if (a_terms && !b_terms) { + a = check_cross (b, a, a_terms); + } else if (!a_terms && b_terms) { + b = check_cross (a, b, b_terms); + } + auto cross = new_binary_expr (CROSS, a, b); cross->expr.type = type; cross->expr.anticommute = true;