- add support in the jit compiler to do direct native calls using the x64 calling convention

This commit is contained in:
Magnus Norddahl 2018-11-23 04:47:18 +01:00
parent a5ee673c91
commit 3ba6290419
4 changed files with 234 additions and 274 deletions

View file

@ -1,7 +1,6 @@
#include "jit.h"
#include "jitintern.h"
#include <map>
extern PString *TypeString;
extern PStruct *TypeVector2;
@ -826,242 +825,3 @@ void JitCompiler::EmitNOP()
{
cc.nop();
}
#if 0
void JitCompiler::SetupNative()
{
using namespace asmjit;
ResetTemp();
static const char *marks = "=======================================================";
cc.comment("", 0);
cc.comment(marks, 56);
FString funcname;
funcname.Format("Function: %s", sfunc->PrintableName.GetChars());
cc.comment(funcname.GetChars(), funcname.Len());
cc.comment(marks, 56);
cc.comment("", 0);
konstd = sfunc->KonstD;
konstf = sfunc->KonstF;
konsts = sfunc->KonstS;
konsta = sfunc->KonstA;
CreateRegisters();
func = cc.addFunc(CreateFuncSignature(sfunc));
int argsPos = 0;
int regd = 0, regf = 0, regs = 0, rega = 0;
for (unsigned int i = 0; i < sfunc->Proto->ArgumentTypes.Size(); i++)
{
const PType *type = sfunc->Proto->ArgumentTypes[i];
if (sfunc->ArgFlags[i] & (VARF_Out | VARF_Ref))
{
cc.setArg(argsPos++, regA[rega++]);
}
else if (type == TypeVector2)
{
cc.setArg(argsPos++, regF[regf++]);
cc.setArg(argsPos++, regF[regf++]);
}
else if (type == TypeVector3)
{
cc.setArg(argsPos++, regF[regf++]);
cc.setArg(argsPos++, regF[regf++]);
cc.setArg(argsPos++, regF[regf++]);
}
else if (type == TypeFloat64)
{
cc.setArg(argsPos++, regF[regf++]);
}
else if (type == TypeString)
{
cc.setArg(argsPos++, regS[regs++]);
}
else if (type->isIntCompatible())
{
cc.setArg(argsPos++, regA[regd++]);
}
else
{
cc.setArg(argsPos++, regA[rega++]);
}
}
if (sfunc->NumArgs != argsPos || regd > sfunc->NumRegD || regf > sfunc->NumRegF || regs > sfunc->NumRegS || rega > sfunc->NumRegA)
I_FatalError("JIT: sfunc->NumArgs != argsPos || regd > sfunc->NumRegD || regf > sfunc->NumRegF || regs > sfunc->NumRegS || rega > sfunc->NumRegA");
for (int i = regd; i < sfunc->NumRegD; i++)
cc.xor_(regD[i], regD[i]);
for (int i = regf; i < sfunc->NumRegF; i++)
cc.xorpd(regF[i], regF[i]);
for (int i = regs; i < sfunc->NumRegS; i++)
cc.xor_(regS[i], regS[i]);
for (int i = rega; i < sfunc->NumRegA; i++)
cc.xor_(regA[i], regA[i]);
labels.Resize(sfunc->CodeSize);
IncrementVMCalls();
}
asmjit::CCFunc *JitCompiler::CodegenThunk(asmjit::X86Compiler &cc, VMScriptFunction *sfunc, void *nativefunc)
{
using namespace asmjit;
static const char *marks = "=======================================================";
cc.comment("", 0);
cc.comment(marks, 56);
FString funcname;
funcname.Format("Thunk: %s", sfunc->PrintableName.GetChars());
cc.comment(funcname.GetChars(), funcname.Len());
cc.comment(marks, 56);
cc.comment("", 0);
auto unusedFunc = cc.newIntPtr("func"); // VMFunction*
auto args = cc.newIntPtr("args"); // VMValue *params
auto numargs = cc.newInt32("numargs"); // int numargs
auto ret = cc.newIntPtr("ret"); // VMReturn *ret
auto numret = cc.newInt32("numret"); // int numret
CCFunc *func = cc.addFunc(FuncSignature5<int, VMFunction *, void *, int, void *, int>());
cc.setArg(0, unusedFunc);
cc.setArg(1, args);
cc.setArg(2, numargs);
cc.setArg(3, ret);
cc.setArg(4, numret);
TArray<Reg> callArgs;
int argsPos = 0;
for (unsigned int i = 0; i < sfunc->Proto->ArgumentTypes.Size(); i++)
{
const PType *type = sfunc->Proto->ArgumentTypes[i];
if (sfunc->ArgFlags[i] & (VARF_Out | VARF_Ref))
{
auto reg = cc.newIntPtr();
cc.mov(reg, x86::ptr(args, argsPos++ * sizeof(VMValue) + offsetof(VMValue, a)));
callArgs.Push(reg);
}
else if (type == TypeVector2)
{
for (int j = 0; j < 2; j++)
{
auto reg = cc.newXmmSd();
cc.movsd(reg, x86::qword_ptr(args, argsPos++ * sizeof(VMValue) + offsetof(VMValue, f)));
callArgs.Push(reg);
}
}
else if (type == TypeVector3)
{
for (int j = 0; j < 3; j++)
{
auto reg = cc.newXmmSd();
cc.movsd(reg, x86::qword_ptr(args, argsPos++ * sizeof(VMValue) + offsetof(VMValue, f)));
callArgs.Push(reg);
}
}
else if (type == TypeFloat64)
{
auto reg = cc.newXmmSd();
cc.movsd(reg, x86::qword_ptr(args, argsPos++ * sizeof(VMValue) + offsetof(VMValue, f)));
callArgs.Push(reg);
}
else if (type == TypeString)
{
auto reg = cc.newIntPtr();
cc.mov(reg, x86::ptr(args, argsPos++ * sizeof(VMValue) + offsetof(VMValue, a)));
callArgs.Push(reg);
}
else if (type->isIntCompatible())
{
auto reg = cc.newInt32();
cc.mov(reg, x86::dword_ptr(args, argsPos++ * sizeof(VMValue) + offsetof(VMValue, i)));
callArgs.Push(reg);
}
else
{
auto reg = cc.newIntPtr();
cc.mov(reg, x86::ptr(args, argsPos++ * sizeof(VMValue) + offsetof(VMValue, a)));
callArgs.Push(reg);
}
}
auto call = cc.call(imm_ptr(nativefunc), CreateFuncSignature(sfunc));
for (unsigned int i = 0; i < callArgs.Size(); i++)
call->setArg(i, callArgs[i]);
cc.ret(numret);
return func;
}
asmjit::FuncSignature JitCompiler::CreateFuncSignature(VMScriptFunction *sfunc)
{
using namespace asmjit;
TArray<uint8_t> args;
FString key;
for (unsigned int i = 0; i < sfunc->Proto->ArgumentTypes.Size(); i++)
{
const PType *type = sfunc->Proto->ArgumentTypes[i];
if (sfunc->ArgFlags[i] & (VARF_Out | VARF_Ref))
{
args.Push(TypeIdOf<void*>::kTypeId);
key += "v";
}
else if (type == TypeVector2)
{
args.Push(TypeIdOf<double>::kTypeId);
args.Push(TypeIdOf<double>::kTypeId);
key += "ff";
}
else if (type == TypeVector3)
{
args.Push(TypeIdOf<double>::kTypeId);
args.Push(TypeIdOf<double>::kTypeId);
args.Push(TypeIdOf<double>::kTypeId);
key += "fff";
}
else if (type == TypeFloat64)
{
args.Push(TypeIdOf<double>::kTypeId);
key += "f";
}
else if (type == TypeString)
{
args.Push(TypeIdOf<void*>::kTypeId);
key += "s";
}
else if (type->isIntCompatible())
{
args.Push(TypeIdOf<int>::kTypeId);
key += "i";
}
else
{
args.Push(TypeIdOf<void*>::kTypeId);
key += "v";
}
}
// FuncSignature only keeps a pointer to its args array. Keep a copy of each args array variant.
static std::map<FString, std::unique_ptr<TArray<uint8_t>>> argsCache;
std::unique_ptr<TArray<uint8_t>> &cachedArgs = argsCache[key];
if (!cachedArgs) cachedArgs.reset(new TArray<uint8_t>(args));
FuncSignature signature;
signature.init(CallConv::kIdHost, TypeIdOf<void>::kTypeId, cachedArgs->Data(), cachedArgs->Size());
return signature;
}
#endif

View file

@ -1,5 +1,6 @@
#include "jitintern.h"
#include <map>
void JitCompiler::EmitPARAM()
{
@ -50,30 +51,39 @@ void JitCompiler::EmitVtbl(const VMOP *op)
void JitCompiler::EmitCALL()
{
EmitDoCall(regA[A], nullptr);
EmitVMCall(regA[A]);
pc += C; // Skip RESULTs
}
void JitCompiler::EmitCALL_K()
{
VMFunction *target = static_cast<VMFunction*>(konsta[A].v);
VMNativeFunction *ntarget = nullptr;
if (target && (target->VarFlags & VARF_Native))
ntarget = static_cast<VMNativeFunction *>(target);
if (ntarget && ntarget->DirectNativeCall)
{
EmitNativeCall(ntarget);
}
else
{
auto ptr = newTempIntPtr();
cc.mov(ptr, asmjit::imm_ptr(konsta[A].v));
EmitDoCall(ptr, static_cast<VMFunction*>(konsta[A].v));
cc.mov(ptr, asmjit::imm_ptr(target));
EmitVMCall(ptr);
}
void JitCompiler::EmitDoCall(asmjit::X86Gp vmfunc, VMFunction *target)
pc += C; // Skip RESULTs
}
void JitCompiler::EmitVMCall(asmjit::X86Gp vmfunc)
{
using namespace asmjit;
bool simpleFrameTarget = false;
if (target && (target->VarFlags & VARF_Native))
{
VMScriptFunction *starget = static_cast<VMScriptFunction*>(target);
simpleFrameTarget = starget->SpecialInits.Size() == 0 && starget->NumRegS == 0;
}
CheckVMFrame();
int numparams = StoreCallParams(simpleFrameTarget);
int numparams = StoreCallParams();
if (numparams != B)
I_FatalError("OP_CALL parameter count does not match the number of preceding OP_PARAM instructions");
@ -85,20 +95,6 @@ void JitCompiler::EmitDoCall(asmjit::X86Gp vmfunc, VMFunction *target)
X86Gp paramsptr = newTempIntPtr();
cc.lea(paramsptr, x86::ptr(vmframe, offsetParams));
EmitScriptCall(vmfunc, paramsptr);
LoadInOuts();
LoadReturns(pc + 1, C);
ParamOpcodes.Clear();
pc += C; // Skip RESULTs
}
void JitCompiler::EmitScriptCall(asmjit::X86Gp vmfunc, asmjit::X86Gp paramsptr)
{
using namespace asmjit;
auto scriptcall = newTempIntPtr();
cc.mov(scriptcall, x86::ptr(vmfunc, myoffsetof(VMScriptFunction, ScriptCall)));
@ -110,9 +106,14 @@ void JitCompiler::EmitScriptCall(asmjit::X86Gp vmfunc, asmjit::X86Gp paramsptr)
call->setArg(2, Imm(B));
call->setArg(3, GetCallReturns());
call->setArg(4, Imm(C));
LoadInOuts();
LoadReturns(pc + 1, C);
ParamOpcodes.Clear();
}
int JitCompiler::StoreCallParams(bool simpleFrameTarget)
int JitCompiler::StoreCallParams()
{
using namespace asmjit;
@ -320,3 +321,201 @@ void JitCompiler::FillReturns(const VMOP *retval, int numret)
cc.mov(x86::byte_ptr(GetCallReturns(), i * sizeof(VMReturn) + myoffsetof(VMReturn, RegType)), type);
}
}
void JitCompiler::EmitNativeCall(VMNativeFunction *target)
{
using namespace asmjit;
auto call = cc.call(imm_ptr(target->NativeCall), CreateFuncSignature(target));
if ((pc - 1)->op == OP_VTBL)
{
I_FatalError("Native direct member function calls not implemented\n");
}
X86Gp tmp;
X86Xmm tmp2;
int numparams = 0;
for (unsigned int i = 0; i < ParamOpcodes.Size(); i++)
{
int slot = numparams++;
if (ParamOpcodes[i]->op == OP_PARAMI)
{
int abcs = ParamOpcodes[i]->i24;
call->setArg(slot, imm(abcs));
}
else // OP_PARAM
{
int bc = ParamOpcodes[i]->i16u;
switch (ParamOpcodes[i]->a)
{
case REGT_NIL:
call->setArg(slot, imm(0));
break;
case REGT_INT:
call->setArg(slot, regD[bc]);
break;
case REGT_INT | REGT_KONST:
call->setArg(slot, imm(konstd[bc]));
break;
case REGT_STRING:
call->setArg(slot, regS[bc]);
break;
case REGT_STRING | REGT_KONST:
call->setArg(slot, imm_ptr(&konsts[bc]));
break;
case REGT_POINTER:
call->setArg(slot, regA[bc]);
break;
case REGT_POINTER | REGT_KONST:
call->setArg(slot, asmjit::imm_ptr(konsta[bc].v));
break;
case REGT_FLOAT:
call->setArg(slot, regF[bc]);
break;
case REGT_FLOAT | REGT_MULTIREG2:
for (int j = 0; j < 2; j++)
call->setArg(slot + j, regF[bc + j]);
numparams++;
break;
case REGT_FLOAT | REGT_MULTIREG3:
for (int j = 0; j < 3; j++)
call->setArg(slot + j, regF[bc + j]);
numparams += 2;
break;
case REGT_FLOAT | REGT_KONST:
tmp = newTempIntPtr();
tmp2 = newTempXmmSd();
cc.mov(tmp, asmjit::imm_ptr(konstf + bc));
cc.movsd(tmp2, asmjit::x86::qword_ptr(tmp));
call->setArg(slot, tmp2);
break;
case REGT_STRING | REGT_ADDROF:
case REGT_INT | REGT_ADDROF:
case REGT_POINTER | REGT_ADDROF:
case REGT_FLOAT | REGT_ADDROF:
I_FatalError("REGT_ADDROF not implemented for native direct calls\n");
break;
default:
I_FatalError("Unknown REGT value passed to EmitPARAM\n");
break;
}
}
}
if (numparams != B)
I_FatalError("OP_CALL parameter count does not match the number of preceding OP_PARAM instructions\n");
int numret = C;
if (numret > 1)
I_FatalError("Only one return parameter is supported for direct native calls\n");
if (numret == 1)
{
const auto &retval = pc[1];
if (retval.op != OP_RESULT)
{
I_FatalError("Expected OP_RESULT to follow OP_CALL\n");
}
int type = retval.b;
int regnum = retval.c;
if (type & REGT_KONST)
{
I_FatalError("OP_RESULT with REGT_KONST is not allowed\n");
}
// Note: the usage of newResultXX is intentional. Asmjit has a register allocation bug
// if the return virtual register is already allocated in an argument slot.
switch (type & REGT_TYPE)
{
case REGT_INT:
tmp = newResultInt32();
call->setRet(0, tmp);
cc.mov(regD[regnum], tmp);
break;
case REGT_FLOAT:
tmp2 = newResultXmmSd();
call->setRet(0, tmp2);
cc.movsd(regF[regnum], tmp2);
break;
case REGT_POINTER:
tmp = newResultIntPtr();
cc.mov(regA[regnum], tmp);
break;
case REGT_STRING:
case REGT_FLOAT | REGT_MULTIREG2:
case REGT_FLOAT | REGT_MULTIREG3:
default:
I_FatalError("Unsupported OP_RESULT type encountered in EmitNativeCall\n");
break;
}
}
ParamOpcodes.Clear();
}
asmjit::FuncSignature JitCompiler::CreateFuncSignature(VMFunction *func)
{
using namespace asmjit;
TArray<uint8_t> args;
FString key;
for (unsigned int i = 0; i < func->Proto->ArgumentTypes.Size(); i++)
{
const PType *type = func->Proto->ArgumentTypes[i];
if (func->ArgFlags[i] & (VARF_Out | VARF_Ref))
{
args.Push(TypeIdOf<void*>::kTypeId);
key += "v";
}
else if (type == TypeVector2)
{
args.Push(TypeIdOf<double>::kTypeId);
args.Push(TypeIdOf<double>::kTypeId);
key += "ff";
}
else if (type == TypeVector3)
{
args.Push(TypeIdOf<double>::kTypeId);
args.Push(TypeIdOf<double>::kTypeId);
args.Push(TypeIdOf<double>::kTypeId);
key += "fff";
}
else if (type == TypeFloat64)
{
args.Push(TypeIdOf<double>::kTypeId);
key += "f";
}
else if (type == TypeString)
{
args.Push(TypeIdOf<void*>::kTypeId);
key += "s";
}
else if (type->isIntCompatible())
{
args.Push(TypeIdOf<int>::kTypeId);
key += "i";
}
else
{
args.Push(TypeIdOf<void*>::kTypeId);
key += "v";
}
}
// FuncSignature only keeps a pointer to its args array. Keep a copy of each args array variant.
static std::map<FString, std::unique_ptr<TArray<uint8_t>>> argsCache;
std::unique_ptr<TArray<uint8_t>> &cachedArgs = argsCache[key];
if (!cachedArgs) cachedArgs.reset(new TArray<uint8_t>(args));
FuncSignature signature;
signature.init(CallConv::kIdHost, TypeIdOf<void>::kTypeId, cachedArgs->Data(), cachedArgs->Size());
return signature;
}

View file

@ -38,9 +38,7 @@ private:
#include "vmops.h"
#undef xx
//static asmjit::FuncSignature CreateFuncSignature(VMScriptFunction *sfunc);
//static asmjit::CCFunc *CodegenThunk(asmjit::X86Compiler &cc, VMScriptFunction *sfunc, void *nativefunc);
//void SetupNative();
static asmjit::FuncSignature CreateFuncSignature(VMFunction *sfunc);
void Setup();
void CreateRegisters();
@ -52,11 +50,11 @@ private:
void EmitOpcode();
void EmitPopFrame();
void EmitDoCall(asmjit::X86Gp ptr, VMFunction *target);
void EmitScriptCall(asmjit::X86Gp vmfunc, asmjit::X86Gp paramsptr);
void EmitNativeCall(VMNativeFunction *target);
void EmitVMCall(asmjit::X86Gp ptr);
void EmitVtbl(const VMOP *op);
int StoreCallParams(bool simpleFrameTarget);
int StoreCallParams();
void LoadInOuts();
void LoadReturns(const VMOP *retval, int numret);
void FillReturns(const VMOP *retval, int numret);

View file

@ -464,6 +464,9 @@ public:
// Return value is the number of results.
NativeCallType NativeCall;
// Function pointer to a native function to be called directly by the JIT using the platform calling convention
void *DirectNativeCall = nullptr;
private:
static int NativeScriptCall(VMFunction *func, VMValue *params, int numparams, VMReturn *ret, int numret);
};