/*
	function.c

	QC function support code

	Copyright (C) 2002 Bill Currie

	Author: Bill Currie <bill@taniwha.org>
	Date: 2002/5/7

	This program is free software; you can redistribute it and/or
	modify it under the terms of the GNU General Public License
	as published by the Free Software Foundation; either version 2
	of the License, or (at your option) any later version.

	This program is distributed in the hope that it will be useful,
	but WITHOUT ANY WARRANTY; without even the implied warranty of
	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

	See the GNU General Public License for more details.

	You should have received a copy of the GNU General Public License
	along with this program; if not, write to:

		Free Software Foundation, Inc.
		59 Temple Place - Suite 330
		Boston, MA  02111-1307, USA

*/
#ifdef HAVE_CONFIG_H
# include "config.h"
#endif

#ifdef HAVE_STRING_H
# include <string.h>
#endif
#ifdef HAVE_STRINGS_H
# include <strings.h>
#endif
#include <stdlib.h>

#include "QF/alloc.h"
#include "QF/dstring.h"
#include "QF/hash.h"
#include "QF/va.h"

#include "tools/qfcc/include/qfcc.h"

#include "tools/qfcc/include/class.h"
#include "tools/qfcc/include/codespace.h"
#include "tools/qfcc/include/debug.h"
#include "tools/qfcc/include/def.h"
#include "tools/qfcc/include/defspace.h"
#include "tools/qfcc/include/diagnostic.h"
#include "tools/qfcc/include/emit.h"
#include "tools/qfcc/include/expr.h"
#include "tools/qfcc/include/flow.h"
#include "tools/qfcc/include/function.h"
#include "tools/qfcc/include/opcodes.h"
#include "tools/qfcc/include/options.h"
#include "tools/qfcc/include/reloc.h"
#include "tools/qfcc/include/shared.h"
#include "tools/qfcc/include/statements.h"
#include "tools/qfcc/include/strpool.h"
#include "tools/qfcc/include/symtab.h"
#include "tools/qfcc/include/type.h"
#include "tools/qfcc/include/value.h"

static param_t *params_freelist;
static function_t *functions_freelist;
static hashtab_t *overloaded_functions;
static hashtab_t *function_map;

// standardized base register to use for all locals (arguments, local defs,
// params)
#define LOCALS_REG 1

static const char *
ol_func_get_key (const void *_f, void *unused)
{
	overloaded_function_t *f = (overloaded_function_t *) _f;
	return f->full_name;
}

static const char *
func_map_get_key (const void *_f, void *unused)
{
	overloaded_function_t *f = (overloaded_function_t *) _f;
	return f->name;
}

param_t *
new_param (const char *selector, type_t *type, const char *name)
{
	param_t    *param;

	ALLOC (4096, param_t, params, param);
	param->next = 0;
	param->selector = selector;
	param->type = type;
	param->name = name;

	return param;
}

param_t *
param_append_identifiers (param_t *params, symbol_t *idents, type_t *type)
{
	param_t   **p = &params;

	while (*p)
		p = &(*p)->next;
	if (!idents) {
		*p = new_param (0, 0, 0);
		p = &(*p)->next;
	}
	while (idents) {
		idents->type = type;
		*p = new_param (0, type, idents->name);
		(*p)->symbol = idents;
		p = &(*p)->next;
		idents = idents->next;
	}
	return params;
}

static param_t *
_reverse_params (param_t *params, param_t *next)
{
	param_t    *p = params;
	if (params->next)
		p = _reverse_params (params->next, params);
	params->next = next;
	return p;
}

param_t *
reverse_params (param_t *params)
{
	if (!params)
		return 0;
	return _reverse_params (params, 0);
}

param_t *
append_params (param_t *params, param_t *more_params)
{
	if (params) {
		param_t *p;
		for (p = params; p->next; ) {
			p = p->next;
		}
		p->next = more_params;
		return params;
	}
	return more_params;
}

param_t *
copy_params (param_t *params)
{
	param_t    *n_parms = 0, **p = &n_parms;

	while (params) {
		*p = new_param (params->selector, params->type, params->name);
		params = params->next;
		p = &(*p)->next;
	}
	return n_parms;
}

type_t *
parse_params (type_t *type, param_t *parms)
{
	param_t    *p;
	type_t     *new;
	type_t     *ptype;
	int         count = 0;

	if (type && is_class (type)) {
		error (0, "cannot return an object (forgot *?)");
		type = &type_id;
	}

	new = new_type ();
	new->type = ev_func;
	new->alignment = 1;
	new->t.func.type = type;
	new->t.func.num_params = 0;

	for (p = parms; p; p = p->next) {
		if (p->type) {
			count++;
		}
	}
	if (count) {
		new->t.func.param_types = malloc (count * sizeof (type_t));
	}
	for (p = parms; p; p = p->next) {
		if (!p->selector && !p->type && !p->name) {
			if (p->next)
				internal_error (0, 0);
			new->t.func.num_params = -(new->t.func.num_params + 1);
		} else if (p->type) {
			if (is_class (p->type)) {
				error (0, "cannot use an object as a parameter (forgot *?)");
				p->type = &type_id;
			}
			ptype = (type_t *) unalias_type (p->type); //FIXME cast
			new->t.func.param_types[new->t.func.num_params] = ptype;
			new->t.func.num_params++;
		}
	}
	return new;
}

param_t *
check_params (param_t *params)
{
	int         num = 1;
	param_t    *p = params;
	if (!params)
		return 0;
	while (p) {
		if (p->type && is_void(p->type)) {
			if (p->name) {
				error (0, "parameter %d ('%s') has incomplete type", num,
					   p->name);
				p->type = type_default;
			} else if (num > 1 || p->next) {
				error (0, "'void' must be the only parameter");
				p->name = "void";
			} else {
				// this is a void function
				return 0;
			}
		}
		p = p->next;
	}
	return params;
}

static overloaded_function_t *
get_function (const char *name, const type_t *type, int overload, int create)
{
	const char *full_name;
	overloaded_function_t *func;

	if (!overloaded_functions) {
		overloaded_functions = Hash_NewTable (1021, ol_func_get_key, 0, 0, 0);
		function_map = Hash_NewTable (1021, func_map_get_key, 0, 0, 0);
	}

	name = save_string (name);

	full_name = save_string (va (0, "%s|%s", name, encode_params (type)));

	func = Hash_Find (overloaded_functions, full_name);
	if (func) {
		if (func->type != type) {
			error (0, "can't overload on return types");
			return func;
		}
		return func;
	}

	if (!create)
		return 0;

	func = Hash_Find (function_map, name);
	if (func) {
		if (!overload && !func->overloaded) {
			expr_t     *e = new_expr ();
			e->line = func->line;
			e->file = func->file;
			warning (0, "creating overloaded function %s without @overload",
					 full_name);
			warning (e, "(previous function is %s)", func->full_name);
		}
		overload = 1;
	}

	func = calloc (1, sizeof (overloaded_function_t));
	func->name = name;
	func->full_name = full_name;
	func->type = type;
	func->overloaded = overload;
	func->file = pr.source_file;
	func->line = pr.source_line;

	Hash_Add (overloaded_functions, func);
	Hash_Add (function_map, func);
	return func;
}

symbol_t *
function_symbol (symbol_t *sym, int overload, int create)
{
	const char *name = sym->name;
	overloaded_function_t *func;
	symbol_t   *s;

	func = get_function (name, unalias_type (sym->type), overload, create);

	if (func && func->overloaded)
		name = func->full_name;
	s = symtab_lookup (current_symtab, name);
	if ((!s || s->table != current_symtab) && create) {
		s = new_symbol (name);
		s->sy_type = sy_func;
		s->type = (type_t *) unalias_type (sym->type); // FIXME cast
		s->params = sym->params;
		s->s.func = 0;				// function not yet defined
		symtab_addsymbol (current_symtab, s);
	}
	return s;
}

// NOTE sorts the list in /reverse/ order
static int
func_compare (const void *a, const void *b)
{
	overloaded_function_t *fa = *(overloaded_function_t **) a;
	overloaded_function_t *fb = *(overloaded_function_t **) b;
	const type_t *ta = fa->type;
	const type_t *tb = fb->type;
	int         na = ta->t.func.num_params;
	int         nb = tb->t.func.num_params;
	int         ret, i;

	if (na < 0)
		na = ~na;
	if (nb < 0)
		nb = ~nb;
	if (na != nb)
		return nb - na;
	if ((ret = (fb->type->t.func.num_params - fa->type->t.func.num_params)))
		return ret;
	for (i = 0; i < na && i < nb; i++)
		if (ta->t.func.param_types[i] != tb->t.func.param_types[i])
			return (long)(tb->t.func.param_types[i] - ta->t.func.param_types[i]);
	return 0;
}

expr_t *
find_function (expr_t *fexpr, expr_t *params)
{
	expr_t     *e;
	int         i, j, func_count, parm_count, reported = 0;
	overloaded_function_t *f, dummy, *best = 0;
	type_t      type;
	void      **funcs, *dummy_p = &dummy;

	if (fexpr->type != ex_symbol)
		return fexpr;

	memset (&type, 0, sizeof (type));
	type.type = ev_func;

	for (e = params; e; e = e->next) {
		if (e->type == ex_error)
			return e;
		type.t.func.num_params++;
	}
	i = type.t.func.num_params * sizeof (type_t);
	type.t.func.param_types = alloca(i);
	memset (type.t.func.param_types, 0, i);
	for (i = 0, e = params; e; i++, e = e->next) {
		type.t.func.param_types[type.t.func.num_params - 1 - i] = get_type (e);
		if (e->type == ex_error)
			return e;
	}
	funcs = Hash_FindList (function_map, fexpr->e.symbol->name);
	if (!funcs)
		return fexpr;
	for (func_count = 0; funcs[func_count]; func_count++)
		;
	if (func_count < 2) {
		f = (overloaded_function_t *) funcs[0];
		if (func_count && !f->overloaded) {
			free (funcs);
			return fexpr;
		}
	}
	type.t.func.type = ((overloaded_function_t *) funcs[0])->type->t.func.type;
	dummy.type = find_type (&type);

	qsort (funcs, func_count, sizeof (void *), func_compare);
	dummy.full_name = save_string (va (0, "%s|%s", fexpr->e.symbol->name,
									   encode_params (&type)));
	dummy_p = bsearch (&dummy_p, funcs, func_count, sizeof (void *),
					   func_compare);
	if (dummy_p) {
		f = (overloaded_function_t *) *(void **) dummy_p;
		if (f->overloaded) {
			fexpr->e.symbol = symtab_lookup (current_symtab, f->full_name);
			if (!fexpr->e.symbol)
				internal_error (fexpr, "overloaded function %s not found",
								best->full_name);
		}
		free (funcs);
		return fexpr;
	}
	for (i = 0; i < func_count; i++) {
		f = (overloaded_function_t *) funcs[i];
		parm_count = f->type->t.func.num_params;
		if ((parm_count >= 0 && parm_count != type.t.func.num_params)
			|| (parm_count < 0 && ~parm_count > type.t.func.num_params)) {
			funcs[i] = 0;
			continue;
		}
		if (parm_count < 0)
			parm_count = ~parm_count;
		for (j = 0; j < parm_count; j++) {
			if (!type_assignable (f->type->t.func.param_types[j],
								  type.t.func.param_types[j])) {
				funcs[i] = 0;
				break;
			}
		}
		if (j < parm_count)
			continue;
	}
	for (i = 0; i < func_count; i++) {
		f = (overloaded_function_t *) funcs[i];
		if (f) {
			if (!best) {
				best = f;
			} else {
				if (!reported) {
					reported = 1;
					error (fexpr, "unable to disambiguate %s",
						   dummy.full_name);
					error (fexpr, "possible match: %s", best->full_name);
				}
				error (fexpr, "possible match: %s", f->full_name);
			}
		}
	}
	if (reported)
		return fexpr;
	if (best) {
		if (best->overloaded) {
			fexpr->e.symbol = symtab_lookup (current_symtab,
											 best->full_name);
			if (!fexpr->e.symbol)
				internal_error (fexpr, "overloaded function %s not found",
								best->full_name);
		}
		free (funcs);
		return fexpr;
	}
	error (fexpr, "unable to find function matching %s", dummy.full_name);
	free (funcs);
	return fexpr;
}

static void
check_function (symbol_t *fsym)
{
	param_t    *params = fsym->params;
	param_t    *p;
	int         i;

	if (!type_size (fsym->type->t.func.type)) {
		error (0, "return type is an incomplete type");
		fsym->type->t.func.type = &type_void;//FIXME better type?
	}
	if (type_size (fsym->type->t.func.type) > type_size (&type_param)) {
		error (0, "return value too large to be passed by value (%d)",
			   type_size (&type_param));
		fsym->type->t.func.type = &type_void;//FIXME better type?
	}
	for (p = params, i = 0; p; p = p->next, i++) {
		if (!p->selector && !p->type && !p->name)
			continue;					// ellipsis marker
		if (!p->type)
			continue;					// non-param selector
		if (!type_size (p->type))
			error (0, "parameter %d (ā€˜%sā€™) has incomplete type",
				   i + 1, p->name);
		if (type_size (p->type) > type_size (&type_param))
			error (0, "param %d (ā€˜%sā€™) is too large to be passed by value",
				   i + 1, p->name);
	}
}

static void
build_scope (symbol_t *fsym, symtab_t *parent)
{
	int         i;
	param_t    *p;
	symbol_t   *args = 0;
	symbol_t   *param;
	symtab_t   *parameters;
	symtab_t   *locals;
	symtab_t   *cs = current_symtab;//FIXME

	check_function (fsym);

	fsym->s.func->label_scope = new_symtab (0, stab_local);

	parameters = new_symtab (parent, stab_local);
	parameters->space = defspace_new (ds_virtual);
	fsym->s.func->parameters = parameters;

	locals = new_symtab (parameters, stab_local);
	locals->space = defspace_new (ds_virtual);
	fsym->s.func->locals = locals;

	current_symtab = locals;//FIXME

	if (!fsym->s.func) {
		internal_error (0, "function %s not defined", fsym->name);
	}
	if (!is_func (fsym->s.func->type)) {
		internal_error (0, "function type %s not a funciton", fsym->name);
	}
	if (fsym->s.func->type->t.func.num_params < 0) {
		args = new_symbol_type (".args", &type_va_list);
		initialize_def (args, 0, parameters->space, sc_param);
	}

	for (p = fsym->params, i = 0; p; p = p->next) {
		if (!p->selector && !p->type && !p->name)
			continue;					// ellipsis marker
		if (!p->type)
			continue;					// non-param selector
		if (!p->name) {
			error (0, "parameter name omitted");
			p->name = save_string ("");
		}
		param = new_symbol_type (p->name, p->type);
		initialize_def (param, 0, parameters->space, sc_param);
		i++;
	}

	if (args) {
		while (i < MAX_PARMS) {
			param = new_symbol_type (va (0, ".par%d", i), &type_param);
			initialize_def (param, 0, parameters->space, sc_param);
			i++;
		}
	}
	current_symtab = cs;
}

function_t *
new_function (const char *name, const char *nice_name)
{
	function_t	*f;

	ALLOC (1024, function_t, functions, f);
	f->s_name = ReuseString (name);
	f->s_file = pr.source_file;
	if (!(f->name = nice_name))
		f->name = name;
	return f;
}

void
make_function (symbol_t *sym, const char *nice_name, defspace_t *space,
			   storage_class_t storage)
{
	reloc_t    *relocs = 0;
	if (sym->sy_type != sy_func)
		internal_error (0, "%s is not a function", sym->name);
	if (storage == sc_extern && sym->s.func)
		return;
	if (!sym->s.func) {
		sym->s.func = new_function (sym->name, nice_name);
		sym->s.func->sym = sym;
		sym->s.func->type = unalias_type (sym->type);
	}
	if (sym->s.func->def && sym->s.func->def->external
		&& storage != sc_extern) {
		//FIXME this really is not the right way
		relocs = sym->s.func->def->relocs;
		free_def (sym->s.func->def);
		sym->s.func->def = 0;
	}
	if (!sym->s.func->def) {
		sym->s.func->def = new_def (sym->name, sym->type, space, storage);
		reloc_attach_relocs (relocs, &sym->s.func->def->relocs);
	}
}

void
add_function (function_t *f)
{
	*pr.func_tail = f;
	pr.func_tail = &f->next;
	f->function_num = pr.num_functions++;
}

function_t *
begin_function (symbol_t *sym, const char *nicename, symtab_t *parent,
				int far, storage_class_t storage)
{
	defspace_t *space;

	if (sym->sy_type != sy_func) {
		error (0, "%s is not a function", sym->name);
		sym = new_symbol_type (sym->name, &type_func);
		sym = function_symbol (sym, 1, 1);
	}
	if (sym->s.func && sym->s.func->def && sym->s.func->def->initialized) {
		error (0, "%s redefined", sym->name);
		sym = new_symbol_type (sym->name, sym->type);
		sym = function_symbol (sym, 1, 1);
	}
	space = sym->table->space;
	if (far)
		space = pr.far_data;
	make_function (sym, nicename, space, storage);
	if (!sym->s.func->def->external) {
		sym->s.func->def->initialized = 1;
		sym->s.func->def->constant = 1;
		sym->s.func->def->nosave = 1;
		add_function (sym->s.func);
		reloc_def_func (sym->s.func, sym->s.func->def);
	}
	sym->s.func->code = pr.code->size;

	sym->s.func->s_file = pr.source_file;
	if (options.code.debug) {
		pr_lineno_t *lineno = new_lineno ();
		sym->s.func->line_info = lineno - pr.linenos;
	}

	build_scope (sym, parent);
	return sym->s.func;
}

static void
build_function (symbol_t *fsym)
{
	const type_t *func_type = fsym->s.func->type;
	if (func_type->t.func.num_params > MAX_PARMS) {
		error (0, "too many params");
	}
}

static void
merge_spaces (defspace_t *dst, defspace_t *src, int alignment)
{
	int         offset;

	for (def_t *def = src->defs; def; def = def->next) {
		if (def->type->alignment > alignment) {
			alignment = def->type->alignment;
		}
	}
	offset = defspace_alloc_aligned_highwater (dst, src->size, alignment);
	for (def_t *def = src->defs; def; def = def->next) {
		def->offset += offset;
		def->space = dst;
	}

	if (src->defs) {
		*dst->def_tail = src->defs;
		dst->def_tail = src->def_tail;
		src->def_tail = &src->defs;
		*src->def_tail = 0;
	}

	defspace_delete (src);
}

function_t *
build_code_function (symbol_t *fsym, expr_t *state_expr, expr_t *statements)
{
	if (fsym->sy_type != sy_func)	// probably in error recovery
		return 0;
	build_function (fsym);
	if (state_expr) {
		state_expr->next = statements;
		statements = state_expr;
	}
	function_t *func = fsym->s.func;
	if (options.code.progsversion == PROG_VERSION) {
		expr_t     *e;
		e = new_with_expr (2, LOCALS_REG, new_short_expr (0));
		e->next = statements;
		e->file = func->def->file;
		e->line = func->def->line;
		statements = e;

		e = new_adjstk_expr (0, 0);
		e->next = statements;
		e->file = func->def->file;
		e->line = func->def->line;
		statements = e;

		func->temp_reg = LOCALS_REG;
		for (def_t *def = func->locals->space->defs; def; def = def->next) {
			if (def->local || def->param) {
				def->reg = LOCALS_REG;
			}
		}
		for (def_t *def = func->parameters->space->defs; def; def = def->next) {
			if (def->local || def->param) {
				def->reg = LOCALS_REG;
			}
		}
	}
	emit_function (func, statements);
	if (options.code.progsversion < PROG_VERSION) {
		// stitch parameter and locals data together with parameters coming
		// first
		defspace_t *space = defspace_new (ds_virtual);

		merge_spaces (space, func->parameters->space, 1);
		func->parameters->space = space;

		merge_spaces (space, func->locals->space, 1);
		func->locals->space = space;
	} else {
		defspace_t *space = defspace_new (ds_virtual);

		if (func->arguments) {
			func->arguments->size = func->arguments->max_size;
			merge_spaces (space, func->arguments, 4);
			func->arguments = 0;
		}

		merge_spaces (space, func->locals->space, 4);
		func->locals->space = space;

		// allocate 0 words to force alignment
		defspace_alloc_aligned_highwater (space, 0, 4);

		dstatement_t *st = &pr.code->code[func->code];
		if (st->op == OP_ADJSTK) {
			st->b = -space->size;
		}
		merge_spaces (space, func->parameters->space, 4);
		func->parameters->space = space;
	}
	return fsym->s.func;
}

function_t *
build_builtin_function (symbol_t *sym, expr_t *bi_val, int far,
						storage_class_t storage)
{
	int         bi;
	defspace_t *space;

	if (sym->sy_type != sy_func) {
		error (bi_val, "%s is not a function", sym->name);
		return 0;
	}
	if (sym->s.func && sym->s.func->def && sym->s.func->def->initialized) {
		error (bi_val, "%s redefined", sym->name);
		return 0;
	}
	if (!is_int_val (bi_val) && !is_float_val (bi_val)) {
		error (bi_val, "invalid constant for = #");
		return 0;
	}
	space = sym->table->space;
	if (far)
		space = pr.far_data;
	make_function (sym, 0, space, storage);
	if (sym->s.func->def->external)
		return 0;

	sym->s.func->def->initialized = 1;
	sym->s.func->def->constant = 1;
	sym->s.func->def->nosave = 1;
	add_function (sym->s.func);

	if (is_int_val (bi_val))
		bi = expr_int (bi_val);
	else
		bi = expr_float (bi_val);
	if (bi < 0) {
		error (bi_val, "builtin functions must be positive or 0");
		return 0;
	}
	sym->s.func->builtin = bi;
	reloc_def_func (sym->s.func, sym->s.func->def);
	build_function (sym);

	// for debug info
	build_scope (sym, current_symtab);
	sym->s.func->parameters->space->size = 0;
	sym->s.func->locals->space = sym->s.func->parameters->space;
	return sym->s.func;
}

void
emit_function (function_t *f, expr_t *e)
{
	if (pr.error_count)
		return;
	f->code = pr.code->size;
	lineno_base = f->def->line;
	f->sblock = make_statements (e);
	if (options.code.optimize) {
		flow_data_flow (f);
	} else {
		statements_count_temps (f->sblock);
	}
	emit_statements (f->sblock);
}

int
function_parms (function_t *f, byte *parm_size)
{
	int         count, i;
	ty_func_t  *func = &f->sym->type->t.func;

	if (func->num_params >= 0)
		count = func->num_params;
	else
		count = -func->num_params - 1;

	for (i = 0; i < count; i++)
		parm_size[i] = type_size (func->param_types[i]);
	return func->num_params;
}

void
clear_functions (void)
{
	if (overloaded_functions)
		Hash_FlushTable (overloaded_functions);
	if (function_map)
		Hash_FlushTable (function_map);
}