[qfcc] Collect common multiplication terms

This gets my `m * p * ~m` code as optimal as possible if my counting is
correct (this does not include the extra extends and add needed to merge
the values). Also, there might be a possibility of recombining some ops
into a vector op, but I'm happy with this.
This commit is contained in:
Bill Currie 2023-10-01 23:17:51 +09:00
parent f3edc06c45
commit 2134c85a47

View file

@ -321,6 +321,80 @@ optimize_scale (const expr_t *expr, const expr_t **adds, const expr_t **subs)
return scale;
}
static const expr_t *
optimize_mult (const expr_t *expr, const expr_t **adds, const expr_t **subs)
{
int num_factors = count_factors (expr);
int total = 0;
int fac_counts[num_factors + 1] = {};
const expr_t *factors[num_factors + 2] = {};
if (is_mult (expr)) {
scatter_factors (expr, factors);
} else {
factors[0] = expr;
}
for (auto search = adds; *search; search++) {
if (is_mult (*search)) {
for (auto f = factors; *f; f++) {
if (mult_has_factor (*search, *f)) {
fac_counts[f - factors]++;
total++;
}
}
}
}
for (auto search = subs; *search; search++) {
if (is_mult (*search)) {
for (auto f = factors; *f; f++) {
if (mult_has_factor (*search, *f)) {
fac_counts[f - factors]++;
total++;
}
}
}
}
if (!total) {
return expr;
}
const expr_t *common = 0;
int count = 0;
for (auto f = factors; *f; f++) {
if (fac_counts[f - factors] > count
|| (fac_counts[f - factors] == count && is_constant (*f))) {
common = *f;
count = fac_counts[f - factors];
}
}
const expr_t *com_adds[count + 2] = {};
const expr_t *com_subs[count + 2] = {};
auto dst = com_adds;
*dst++ = remult (expr, common);
for (auto src = adds; *src; src++) {
if (is_mult (*src) && mult_has_factor (*src, common)) {
*dst++ = remult (*src, common);
*src = &skip;
}
}
dst = com_subs;
for (auto src = subs; *src; src++) {
if (is_mult (*src) && mult_has_factor (*src, common)) {
*dst++ = remult (*src, common);
*src = &skip;
}
}
auto type = get_type (expr);
auto col = gather_terms (type, com_adds, com_subs);
col = optimize_core (col);
auto mult = typed_binary_expr (type, expr->expr.op, col, common);
mult = edag_add_expr (mult);
return mult;
}
static void
optimize_extends (const expr_t **expr_list)
{
@ -391,6 +465,42 @@ optimize_scale_products (const expr_t **adds, const expr_t **subs)
clean_skips (subs);
}
static void
optimize_mult_products (const expr_t **adds, const expr_t **subs)
{
for (auto scan = adds; *scan; scan++) {
if (is_mult (*scan) && has_const_factor (*scan)) {
auto e = *scan;
*scan = &skip;
*scan = optimize_mult (e, adds, subs);
}
}
for (auto scan = subs; *scan; scan++) {
if (is_mult (*scan) && has_const_factor (*scan)) {
auto e = *scan;
*scan = &skip;
*scan = optimize_mult (e, subs, adds);
}
}
for (auto scan = adds; *scan; scan++) {
if (is_mult (*scan)) {
auto e = *scan;
*scan = &skip;
*scan = optimize_mult (e, adds, subs);
}
}
for (auto scan = subs; *scan; scan++) {
if (is_mult (*scan)) {
auto e = *scan;
*scan = &skip;
*scan = optimize_mult (e, subs, adds);
}
}
clean_skips (adds);
clean_skips (subs);
}
static int
expr_ptr_cmp (const void *_a, const void *_b)
{
@ -445,6 +555,7 @@ optimize_core (const expr_t *expr)
optimize_cross_products (adds, subs);
optimize_scale_products (adds, subs);
optimize_mult_products (adds, subs);
expr = gather_terms (type, adds, subs);
}