[qfcc] Collect common cross product terms

This reduces the number of cross products in `m * p * ~m` from 4 or 5 (4
after the old CSE went through the code) to 2 even before CSE.
This commit is contained in:
Bill Currie 2023-10-01 16:29:14 +09:00
parent afe6ea526b
commit ca1b455aa0
7 changed files with 322 additions and 30 deletions

View file

@ -122,5 +122,14 @@ const struct expr_s *algebra_assign_expr (const struct expr_s *dst,
const struct expr_s *src);
const struct expr_s *algebra_field_expr (const struct expr_s *mvec,
const struct expr_s *field_name);
const struct expr_s *algebra_optimize (const struct expr_s *e);
const struct expr_s *mvec_expr (const struct expr_s *expr, algebra_t *algebra);
void mvec_scatter (const struct expr_s **components, const struct expr_s *mvec,
algebra_t *algebra);
const struct expr_s *mvec_gather (const struct expr_s **components,
algebra_t *algebra);
#endif//__algebra_h

View file

@ -924,6 +924,26 @@ const expr_t *fold_constants (const expr_t *e);
void edag_flush (void);
const expr_t *edag_add_expr (const expr_t *e);
bool is_scale (const expr_t *expr) __attribute__((pure));
bool is_cross (const expr_t *expr) __attribute__((pure));
bool is_sum (const expr_t *expr) __attribute__((pure));
bool is_neg (const expr_t *expr) __attribute__((pure));
const expr_t *neg_expr (const expr_t *e);
const expr_t *ext_expr (const expr_t *src, struct type_s *type, int extend,
bool reverse);
const expr_t *scale_expr (struct type_s *type, const expr_t *a, const expr_t *b);
const expr_t *traverse_scale (const expr_t *expr) __attribute__((pure));
const expr_t *typed_binary_expr (struct type_s *type, int op,
const expr_t *e1, const expr_t *e2);
int count_terms (const expr_t *expr) __attribute__((pure));
void scatter_terms (const expr_t *sum,
const expr_t **adds, const expr_t **subs);
const expr_t *gather_terms (struct type_s *type,
const expr_t **adds, const expr_t **subs);
///@}

View file

@ -34,6 +34,7 @@ qfcc_SOURCES = \
tools/qfcc/source/expr_compound.c \
tools/qfcc/source/expr_dag.c \
tools/qfcc/source/expr_obj.c \
tools/qfcc/source/expr_optimize.c \
tools/qfcc/source/expr_vector.c \
tools/qfcc/source/evaluate.c \
tools/qfcc/source/flow.c \

View file

@ -2085,6 +2085,9 @@ build_function_call (const expr_t *fexpr, const type_t *ftype, const expr_t *par
if (e->type == ex_error) {
return e;
}
if (e->type != ex_compound) {
arguments[i] = algebra_optimize (e);
}
}
if (options.code.progsversion < PROG_VERSION
@ -2376,6 +2379,8 @@ return_expr (function_t *f, const expr_t *e)
if (e->type == ex_compound) {
e = expr_file_line (initialized_temp_expr (ret_type, e), e);
} else {
e = algebra_optimize (e);
}
t = get_type (e);

View file

@ -77,7 +77,7 @@ get_group_mask (const type_t *type, algebra_t *algebra)
}
}
static bool __attribute__((const))
bool
is_neg (const expr_t *e)
{
return e->type == ex_uexpr && e->expr.op == '-';
@ -101,7 +101,7 @@ op_anti_com (int op)
return (op == '-' || op == CROSS || op == WEDGE);
}
static const expr_t *
const expr_t *
typed_binary_expr (type_t *type, int op, const expr_t *e1, const expr_t *e2)
{
auto e = new_binary_expr (op, e1, e2);
@ -111,7 +111,7 @@ typed_binary_expr (type_t *type, int op, const expr_t *e1, const expr_t *e2)
return e;
}
static const expr_t *
const expr_t *
neg_expr (const expr_t *e)
{
if (!e) {
@ -132,7 +132,7 @@ neg_expr (const expr_t *e)
return edag_add_expr (fold_constants (neg));
}
static const expr_t *
const expr_t *
ext_expr (const expr_t *src, type_t *type, int extend, bool reverse)
{
if (!src) {
@ -283,7 +283,7 @@ promote_scalar (type_t *dst_type, const expr_t *scalar)
return edag_add_expr (scalar);
}
static const expr_t *
const expr_t *
mvec_expr (const expr_t *expr, algebra_t *algebra)
{
auto mvtype = get_type (expr);
@ -323,7 +323,7 @@ mvec_expr (const expr_t *expr, algebra_t *algebra)
return mvec;
}
static void
void
mvec_scatter (const expr_t **components, const expr_t *mvec, algebra_t *algebra)
{
auto layout = &algebra->layout;
@ -367,7 +367,7 @@ mvec_scatter (const expr_t **components, const expr_t *mvec, algebra_t *algebra)
}
}
static const expr_t *
const expr_t *
mvec_gather (const expr_t **components, algebra_t *algebra)
{
auto layout = &algebra->layout;
@ -403,13 +403,13 @@ mvec_gather (const expr_t **components, algebra_t *algebra)
return mvec;
}
static bool __attribute__((pure))
bool
is_scale (const expr_t *expr)
{
return expr && expr->type == ex_expr && expr->expr.op == SCALE;
}
static const expr_t * __attribute__((pure))
const expr_t *
traverse_scale (const expr_t *expr)
{
while (is_scale (expr)) {
@ -418,7 +418,7 @@ traverse_scale (const expr_t *expr)
return expr;
}
static bool __attribute__((pure))
bool
is_sum (const expr_t *expr)
{
return (expr && expr->type == ex_expr
@ -432,7 +432,7 @@ is_mult (const expr_t *expr)
&& (expr->expr.op == '*' || expr->expr.op == HADAMARD));
}
static int __attribute__((pure))
int
count_terms (const expr_t *expr)
{
if (!is_sum (expr)) {
@ -468,7 +468,7 @@ count_factors (const expr_t *expr)
return terms;
}
static bool __attribute__((pure))
bool __attribute__((pure))
is_cross (const expr_t *expr)
{
return (expr && expr->type == ex_expr && (expr->expr.op == CROSS));
@ -577,7 +577,7 @@ sum_expr_low (type_t *type, int op, const expr_t *a, const expr_t *b)
}
static void
distribute_terms_core (const expr_t *sum,
scatter_terms_core (const expr_t *sum,
const expr_t **adds, int *addind,
const expr_t **subs, int *subind, bool negative)
{
@ -585,10 +585,10 @@ distribute_terms_core (const expr_t *sum,
auto e1 = sum->expr.e1;
auto e2 = sum->expr.e2;
if (is_sum (e1)) {
distribute_terms_core (e1, adds, addind, subs, subind, negative);
scatter_terms_core (e1, adds, addind, subs, subind, negative);
}
if (is_sum (e2)) {
distribute_terms_core (e2, adds, addind, subs, subind, subtract);
scatter_terms_core (e2, adds, addind, subs, subind, subtract);
}
if (!is_sum (e1)) {
auto e = sum->expr.e1;
@ -610,20 +610,20 @@ distribute_terms_core (const expr_t *sum,
}
}
static void
distribute_terms (const expr_t *sum, const expr_t **adds, const expr_t **subs)
void
scatter_terms (const expr_t *sum, const expr_t **adds, const expr_t **subs)
{
if (!is_sum (sum)) {
internal_error (sum, "distribute_terms with no sum");
internal_error (sum, "scatter_terms with no sum");
}
int addind = 0;
int subind = 0;
distribute_terms_core (sum, adds, &addind, subs, &subind, false);
scatter_terms_core (sum, adds, &addind, subs, &subind, false);
}
static const expr_t *
collect_terms (type_t *type, const expr_t **adds, const expr_t **subs)
const expr_t *
gather_terms (type_t *type, const expr_t **adds, const expr_t **subs)
{
const expr_t *a = 0;
const expr_t *b = 0;
@ -688,7 +688,7 @@ sum_expr (type_t *type, const expr_t *a, const expr_t *b)
int num_terms = count_terms (sum);
const expr_t *adds[num_terms + 1] = {};
const expr_t *subs[num_terms + 1] = {};
distribute_terms (sum, adds, subs);
scatter_terms (sum, adds, subs);
const expr_t **dstadd, **srcadd;
const expr_t **dstsub, **srcsub;
@ -725,7 +725,7 @@ sum_expr (type_t *type, const expr_t *a, const expr_t *b)
*dstadd++ = *srcadd;
}
*dstadd = 0;
sum = collect_terms (type, adds, subs);
sum = gather_terms (type, adds, subs);
return sum;
}
@ -790,13 +790,13 @@ distribute_product (type_t *type, const expr_t *a, const expr_t *b,
const expr_t *b_subs[a_terms + 2] = {};
if (a_terms) {
distribute_terms (a, a_adds, a_subs);
scatter_terms (a, a_adds, a_subs);
} else {
a_adds[0] = a;
}
if (b_terms) {
distribute_terms (b, b_adds, b_subs);
scatter_terms (b, b_adds, b_subs);
} else {
b_adds[0] = b;
}
@ -854,9 +854,6 @@ distribute_product (type_t *type, const expr_t *a, const expr_t *b,
}
}
static const expr_t *scale_expr (type_t *type,
const expr_t *a, const expr_t *b);
static const expr_t *
extract_scale (const expr_t **expr, const expr_t *prod)
{
@ -895,7 +892,7 @@ do_scale (type_t *type, const expr_t *a, const expr_t *b)
return typed_binary_expr (type, SCALE, a, b);
}
static const expr_t *
const expr_t *
scale_expr (type_t *type, const expr_t *a, const expr_t *b)
{
if (!a || !b) {

View file

@ -0,0 +1,260 @@
/*
expr_optimize.c
algebraic expression optimization
Copyright (C) 2023 Bill Currie <bill@taniwha.org>
This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program; if not, write to:
Free Software Foundation, Inc.
59 Temple Place - Suite 330
Boston, MA 02111-1307, USA
*/
#ifdef HAVE_CONFIG_H
# include "config.h"
#endif
#include <string.h>
#include "QF/fbsearch.h"
#include "QF/heapsort.h"
#include "QF/math/bitop.h"
#include "tools/qfcc/include/algebra.h"
#include "tools/qfcc/include/diagnostic.h"
#include "tools/qfcc/include/expr.h"
#include "tools/qfcc/include/symtab.h"
#include "tools/qfcc/include/type.h"
#include "tools/qfcc/include/value.h"
#include "tools/qfcc/source/qc-parse.h"
static const expr_t skip;
static const expr_t *
rescale (const expr_t *expr, const expr_t *target, const expr_t *remove)
{
if (expr == target) {
if (target->expr.e1 == remove) {
return target->expr.e2;
}
return target->expr.e1;
}
if (!is_scale (expr)) {
internal_error (expr, "not a scale expression");
}
auto type = get_type (expr);
auto scale = expr->expr.e2;
return scale_expr (type, rescale (expr->expr.e1, target, remove), scale);
}
static const expr_t *
optimize_cross (const expr_t *expr, const expr_t **adds, const expr_t **subs)
{
auto l = traverse_scale (expr)->expr.e1;
auto r = traverse_scale (expr)->expr.e2;
int l_count = 0;
int r_count = 0;
for (auto search = adds; *search; search++) {
if (*search != &skip) {
auto c = traverse_scale (*search);
if (is_cross (c)) {
l_count += c->expr.e1 == l;
l_count += c->expr.e2 == l;
r_count += c->expr.e1 == r;
r_count += c->expr.e2 == r;
}
}
}
for (auto search = subs; *search; search++) {
if (*search != &skip) {
auto c = traverse_scale (*search);
if (is_cross (c)) {
l_count += c->expr.e1 == l;
l_count += c->expr.e2 == l;
r_count += c->expr.e1 == r;
r_count += c->expr.e2 == r;
}
}
}
if (!(l_count + r_count)) {
return expr;
}
bool right = r_count > l_count;
int count = right ? r_count : l_count;
auto com = right ? r : l;
const expr_t *com_adds[count + 2] = {};
const expr_t *com_subs[count + 2] = {};
auto adst = com_adds;
auto sdst = com_subs;
*adst++ = rescale (expr, traverse_scale (expr), com);
for (auto search = adds; *search; search++) {
if (*search != &skip) {
auto c = traverse_scale (*search);
const expr_t *scale = 0;
bool neg = false;
if (is_cross (c)) {
if (c->expr.e1 == com) {
neg = right;
scale = rescale (*search, c, com);
} else if (c->expr.e2 == com) {
neg = !right;
scale = rescale (*search, c, com);
}
}
if (scale) {
*search = &skip;
if (neg) {
*sdst++ = scale;
} else {
*adst++ = scale;
}
}
}
}
for (auto search = subs; *search; search++) {
if (*search != &skip) {
auto c = traverse_scale (*search);
const expr_t *scale = 0;
bool neg = false;
if (is_cross (c)) {
if (c->expr.e1 == com) {
neg = right;
scale = rescale (*search, c, com);
} else if (c->expr.e2 == com) {
neg = !right;
scale = rescale (*search, c, com);
}
}
if (scale) {
*search = &skip;
if (neg) {
*adst++ = scale;
} else {
*sdst++ = scale;
}
}
}
}
auto type = get_type (com);
auto col = gather_terms (type, com_adds, com_subs);
if (is_neg (col)) {
col = neg_expr (col);
right = !right;
}
const expr_t *cross;
if (right) {
cross = typed_binary_expr (type, CROSS, col, com);
} else {
cross = typed_binary_expr (type, CROSS, com, col);
}
cross = edag_add_expr (cross);
return cross;
}
static void
clean_skips (const expr_t **expr_list)
{
auto dst = expr_list;
for (auto src = dst; *src; src++) {
if (*src != &skip) {
*dst++ = *src;
}
}
*dst = 0;
}
static const expr_t *optimize_core (const expr_t *expr);
static void
optimize_extends (const expr_t **expr_list)
{
for (auto scan = expr_list; *scan; scan++) {
if ((*scan)->type == ex_extend) {
auto ext = (*scan)->extend;
ext.src = optimize_core (ext.src);
*scan = ext_expr (ext.src, ext.type, ext.extend, ext.reverse);
}
}
}
static void
optimize_cross_products (const expr_t **adds, const expr_t **subs)
{
for (auto scan = adds; *scan; scan++) {
if (is_cross (traverse_scale (*scan))) {
auto e = *scan;
*scan = &skip;
*scan = optimize_cross (e, adds, subs);
}
}
for (auto scan = subs; *scan; scan++) {
if (is_cross (traverse_scale (*scan))) {
auto e = *scan;
*scan = &skip;
*scan = optimize_cross (e, adds, subs);
}
}
clean_skips (adds);
clean_skips (subs);
}
static const expr_t *
optimize_core (const expr_t *expr)
{
if (is_sum (expr)) {
auto type = get_type (expr);
int count = count_terms (expr);
const expr_t *adds[count + 1] = {};
const expr_t *subs[count + 1] = {};
scatter_terms (expr, adds, subs);
optimize_extends (adds);
optimize_extends (subs);
optimize_cross_products (adds, subs);
expr = gather_terms (type, adds, subs);
}
return expr;
}
const expr_t *
algebra_optimize (const expr_t *expr)
{
if (expr->type != ex_multivec) {
return optimize_core (expr);
} else {
auto algebra = algebra_get (get_type (expr));
auto layout = &algebra->layout;
const expr_t *groups[layout->count] = {};
expr = mvec_expr (expr, algebra);
mvec_scatter (groups, expr, algebra);
for (int i = 0; i < layout->count; i++) {
if (groups[i]) {
groups[i] = optimize_core (groups[i]);
}
}
expr = mvec_gather (groups, algebra);
}
return expr;
}

View file

@ -23,7 +23,7 @@ main (void)
point_t p = (point_t)'10 4 -1.5 1'f;
point_t n = apply_motor (m, p);
printf ("n: %.9q\n", n);
if ((vec4)n != '9.99999905 -4.00000048 -1.49999988 0.99999994'f) {
if ((vec4)n != '10 -3.99999952 -1.49999988 0.99999994'f) {
return 1;
}
return 0;