constant folding for +, - and *

This commit is contained in:
Wolfgang (Blub) Bumiller 2012-08-14 22:51:05 +02:00
parent 9ed9a0c669
commit b8d92dfa01
2 changed files with 106 additions and 8 deletions

View file

@ -780,6 +780,11 @@ typedef struct {
float x, y, z;
} vector;
vector vec3_add (vector, vector);
vector vec3_sub (vector, vector);
qcfloat vec3_mulvv(vector, vector);
vector vec3_mulvf(vector, float);
/*
* A shallow copy of a lex_file to remember where which ast node
* came from.

109
parser.c
View file

@ -82,6 +82,46 @@ bool GMQCC_WARN parsewarning(parser_t *parser, int warntype, const char *fmt, ..
return OPTS_WARN(WARN_ERROR);
}
/**********************************************************************
* some maths used for constant folding
*/
vector vec3_add(vector a, vector b)
{
vector out;
out.x = a.x + b.x;
out.y = a.y + b.y;
out.z = a.z + b.z;
return out;
}
vector vec3_sub(vector a, vector b)
{
vector out;
out.x = a.x - b.x;
out.y = a.y - b.y;
out.z = a.z - b.z;
return out;
}
qcfloat vec3_mulvv(vector a, vector b)
{
return (a.x * b.x + a.y * b.y + a.z * b.z);
}
vector vec3_mulvf(vector a, float b)
{
vector out;
out.x = a.x * b;
out.y = a.y * b;
out.z = a.z * b;
return out;
}
/**********************************************************************
* parsing
*/
bool parser_next(parser_t *parser)
{
/* lex_do kills the previous token */
@ -349,6 +389,7 @@ static bool parser_sy_pop(parser_t *parser, shunt *sy)
ast_expression *out = NULL;
ast_expression *exprs[3];
ast_block *blocks[3];
ast_value *asvalue[3];
size_t i, assignop;
qcint generated_op = 0;
@ -378,6 +419,7 @@ static bool parser_sy_pop(parser_t *parser, shunt *sy)
for (i = 0; i < op->operands; ++i) {
exprs[i] = sy->out[sy->out_count+i].out;
blocks[i] = sy->out[sy->out_count+i].block;
asvalue[i] = (ast_value*)exprs[i];
}
if (blocks[0] && !blocks[0]->exprs_count && op->id != opid1(',')) {
@ -388,6 +430,9 @@ static bool parser_sy_pop(parser_t *parser, shunt *sy)
#define NotSameType(T) \
(exprs[0]->expression.vtype != exprs[1]->expression.vtype || \
exprs[0]->expression.vtype != T)
#define CanConstFold(A, B) \
(ast_istype((A), ast_value) && ast_istype((B), ast_value) && \
((ast_value*)(A))->isconst && ((ast_value*)(B))->isconst)
switch (op->id)
{
default:
@ -441,10 +486,22 @@ static bool parser_sy_pop(parser_t *parser, shunt *sy)
}
switch (exprs[0]->expression.vtype) {
case TYPE_FLOAT:
out = (ast_expression*)ast_binary_new(ctx, INSTR_ADD_F, exprs[0], exprs[1]);
if (CanConstFold(exprs[0], exprs[1]))
{
out = (ast_expression*)parser_const_float(parser,
asvalue[0]->constval.vfloat + asvalue[1]->constval.vfloat);
}
else
out = (ast_expression*)ast_binary_new(ctx, INSTR_ADD_F, exprs[0], exprs[1]);
break;
case TYPE_VECTOR:
out = (ast_expression*)ast_binary_new(ctx, INSTR_ADD_V, exprs[0], exprs[1]);
if (CanConstFold(exprs[0], exprs[1]))
{
out = (ast_expression*)parser_const_vector(parser,
vec3_add(asvalue[0]->constval.vvec, asvalue[1]->constval.vvec));
}
else
out = (ast_expression*)ast_binary_new(ctx, INSTR_ADD_V, exprs[0], exprs[1]);
break;
default:
parseerror(parser, "invalid types used in expression: cannot add type %s and %s",
@ -464,10 +521,22 @@ static bool parser_sy_pop(parser_t *parser, shunt *sy)
}
switch (exprs[0]->expression.vtype) {
case TYPE_FLOAT:
out = (ast_expression*)ast_binary_new(ctx, INSTR_SUB_F, exprs[0], exprs[1]);
if (CanConstFold(exprs[0], exprs[1]))
{
out = (ast_expression*)parser_const_float(parser,
asvalue[0]->constval.vfloat - asvalue[1]->constval.vfloat);
}
else
out = (ast_expression*)ast_binary_new(ctx, INSTR_SUB_F, exprs[0], exprs[1]);
break;
case TYPE_VECTOR:
out = (ast_expression*)ast_binary_new(ctx, INSTR_SUB_V, exprs[0], exprs[1]);
if (CanConstFold(exprs[0], exprs[1]))
{
out = (ast_expression*)parser_const_vector(parser,
vec3_sub(asvalue[0]->constval.vvec, asvalue[1]->constval.vvec));
}
else
out = (ast_expression*)ast_binary_new(ctx, INSTR_SUB_V, exprs[0], exprs[1]);
break;
default:
parseerror(parser, "invalid types used in expression: cannot subtract type %s from %s",
@ -491,15 +560,39 @@ static bool parser_sy_pop(parser_t *parser, shunt *sy)
switch (exprs[0]->expression.vtype) {
case TYPE_FLOAT:
if (exprs[1]->expression.vtype == TYPE_VECTOR)
out = (ast_expression*)ast_binary_new(ctx, INSTR_MUL_FV, exprs[0], exprs[1]);
{
if (CanConstFold(exprs[0], exprs[1]))
out = (ast_expression*)parser_const_vector(parser,
vec3_mulvf(asvalue[1]->constval.vvec, asvalue[0]->constval.vfloat));
else
out = (ast_expression*)ast_binary_new(ctx, INSTR_MUL_FV, exprs[0], exprs[1]);
}
else
out = (ast_expression*)ast_binary_new(ctx, INSTR_MUL_F, exprs[0], exprs[1]);
{
if (CanConstFold(exprs[0], exprs[1]))
out = (ast_expression*)parser_const_float(parser,
asvalue[0]->constval.vfloat * asvalue[1]->constval.vfloat);
else
out = (ast_expression*)ast_binary_new(ctx, INSTR_MUL_F, exprs[0], exprs[1]);
}
break;
case TYPE_VECTOR:
if (exprs[1]->expression.vtype == TYPE_FLOAT)
out = (ast_expression*)ast_binary_new(ctx, INSTR_MUL_VF, exprs[0], exprs[1]);
{
if (CanConstFold(exprs[0], exprs[1]))
out = (ast_expression*)parser_const_vector(parser,
vec3_mulvf(asvalue[0]->constval.vvec, asvalue[1]->constval.vfloat));
else
out = (ast_expression*)ast_binary_new(ctx, INSTR_MUL_VF, exprs[0], exprs[1]);
}
else
out = (ast_expression*)ast_binary_new(ctx, INSTR_MUL_V, exprs[0], exprs[1]);
{
if (CanConstFold(exprs[0], exprs[1]))
out = (ast_expression*)parser_const_float(parser,
vec3_mulvv(asvalue[0]->constval.vvec, asvalue[1]->constval.vvec));
else
out = (ast_expression*)ast_binary_new(ctx, INSTR_MUL_V, exprs[0], exprs[1]);
}
break;
default:
parseerror(parser, "invalid types used in expression: cannot multiply types %s and %s",