diff --git a/src/scripting/vm/jit.cpp b/src/scripting/vm/jit.cpp index 28b642a5d6..dc95200052 100644 --- a/src/scripting/vm/jit.cpp +++ b/src/scripting/vm/jit.cpp @@ -6,24 +6,184 @@ extern PString *TypeString; extern PStruct *TypeVector2; extern PStruct *TypeVector3; -static asmjit::JitRuntime *jit; -static int jitRefCount = 0; +static void OutputJitLog(const asmjit::StringLogger &logger); -asmjit::JitRuntime *JitGetRuntime() +static TArray JitBlocks; +static size_t JitBlockPos = 0; +static size_t JitBlockSize = 0; + +static asmjit::CodeInfo GetHostCodeInfo() { - if (!jit) - jit = new asmjit::JitRuntime; - jitRefCount++; - return jit; + static bool firstCall = true; + static asmjit::CodeInfo codeInfo; + + if (firstCall) + { + asmjit::JitRuntime rt; + codeInfo = rt.getCodeInfo(); + firstCall = false; + } + + return codeInfo; } -void JitCleanUp(VMScriptFunction *func) +void *AllocJitMemory(size_t size) { - jitRefCount--; - if (jitRefCount == 0) + using namespace asmjit; + + if (JitBlockPos + size <= JitBlockSize) { - delete jit; - jit = nullptr; + uint8_t *p = JitBlocks[JitBlocks.Size() - 1]; + p += JitBlockPos; + JitBlockPos += size; + return p; + } + else + { + size_t allocatedSize = 0; + void *p = OSUtils::allocVirtualMemory(1024 * 1024, &allocatedSize, OSUtils::kVMWritable | OSUtils::kVMExecutable); + if (!p) + return nullptr; + JitBlocks.Push((uint8_t*)p); + JitBlockSize = allocatedSize; + JitBlockPos = size; + return p; + } +} + +static TArray CreateUnwindInfo(asmjit::CCFunc *func) +{ + TArray info; + + uint32_t version = 1, flags = 0, sizeOfProlog = 0, countOfCodes = 0, frameRegister = 0, frameOffset = 0; + + // To do: query FuncFrameLayout to immitate what X86Internal::emitProlog does + + info.Push(version | (flags << 3) | (sizeOfProlog << 8) | (countOfCodes << 16) | (frameRegister << 24) | (frameOffset << 28)); + + // To do: add UNWIND_CODE entries + + info[0] |= (countOfCodes << 16); + + /* // For reference, we don't need any of this + if (flags & UNW_FLAG_EHANDLER) + { + uint32_t exceptionHandler = 0; + info.Push(exceptionHandler); + } + else if (flags & UNW_FLAG_CHAININFO) + { + uint32_t functionEntry = 0; + info.Push(functionEntry); + } + + if (flags & UNW_FLAG_EHANDLER) + { + uint32_t ExceptionData[]; + info.Push(ExceptionData[]); + } + */ + + return info; +} + +static void *AddJitFunction(asmjit::CodeHolder* code, asmjit::CCFunc *func) +{ + using namespace asmjit; + + size_t codeSize = code->getCodeSize(); + if (codeSize == 0) + return nullptr; + + TArray unwindInfo = CreateUnwindInfo(func); + size_t unwindInfoSize = unwindInfo.Size() * sizeof(uint32_t); + + codeSize = (codeSize + 3) / 4 * 4; + + uint8_t *p = (uint8_t *)AllocJitMemory(codeSize + unwindInfoSize); + if (!p) + return nullptr; + + size_t relocSize = code->relocate(p); + if (relocSize == 0) + return nullptr; + + relocSize = (relocSize + 3) / 4 * 4; + JitBlockPos -= codeSize - relocSize; + +#ifdef WIN32 + uint8_t *unwindptr = p + relocSize; + memcpy(unwindptr, &unwindInfo[0], unwindInfoSize); + + RUNTIME_FUNCTION table; + table.BeginAddress = 0; + table.EndAddress = (DWORD)(ptrdiff_t)(unwindptr - p); + table.UnwindData = (DWORD)(ptrdiff_t)(unwindptr - p); + BOOLEAN result = RtlAddFunctionTable(&table, 1, (DWORD64)p); + if (result == 0) + I_FatalError("RtlAddFunctionTable failed"); +#endif + + return p; +} + +JitFuncPtr JitCompile(VMScriptFunction *sfunc) +{ +#if 0 + if (strcmp(sfunc->PrintableName.GetChars(), "StatusScreen.drawNum") != 0) + return nullptr; +#endif + + using namespace asmjit; + StringLogger logger; + try + { + ThrowingErrorHandler errorHandler; + CodeHolder code; + code.init(GetHostCodeInfo()); + code.setErrorHandler(&errorHandler); + code.setLogger(&logger); + + JitCompiler compiler(&code, sfunc); + CCFunc *func = compiler.Codegen(); + + return reinterpret_cast(AddJitFunction(&code, func)); + } + catch (const std::exception &e) + { + OutputJitLog(logger); + I_FatalError("Unexpected JIT error: %s\n", e.what()); + return nullptr; + } +} + +void JitDumpLog(FILE *file, VMScriptFunction *sfunc) +{ + using namespace asmjit; + StringLogger logger; + try + { + ThrowingErrorHandler errorHandler; + CodeHolder code; + code.init(GetHostCodeInfo()); + code.setErrorHandler(&errorHandler); + code.setLogger(&logger); + + JitCompiler compiler(&code, sfunc); + compiler.Codegen(); + + fwrite(logger.getString(), logger.getLength(), 1, file); + } + catch (const std::exception &e) + { + fwrite(logger.getString(), logger.getLength(), 1, file); + + FString err; + err.Format("Unexpected JIT error: %s\n", e.what()); + fwrite(err.GetChars(), err.Len(), 1, file); + fclose(file); + + I_FatalError("Unexpected JIT error: %s\n", e.what()); } } @@ -46,83 +206,6 @@ static void OutputJitLog(const asmjit::StringLogger &logger) Printf("%s\n", pos); } -//#define DEBUG_JIT - -JitFuncPtr JitCompile(VMScriptFunction *sfunc) -{ -#if defined(DEBUG_JIT) - if (strcmp(sfunc->PrintableName.GetChars(), "StatusScreen.drawNum") != 0) - return nullptr; -#endif - - //Printf("Jitting function: %s\n", sfunc->PrintableName.GetChars()); - - using namespace asmjit; - StringLogger logger; - try - { - auto *jit = JitGetRuntime(); - - ThrowingErrorHandler errorHandler; - CodeHolder code; - code.init(jit->getCodeInfo()); - code.setErrorHandler(&errorHandler); - code.setLogger(&logger); - - JitCompiler compiler(&code, sfunc); - compiler.Codegen(); - - JitFuncPtr fn = nullptr; - Error err = jit->add(&fn, &code); - if (err) - I_FatalError("JitRuntime::add failed: %d", err); - -#if defined(DEBUG_JIT) - OutputJitLog(logger); -#endif - - return fn; - } - catch (const std::exception &e) - { - OutputJitLog(logger); - I_FatalError("Unexpected JIT error: %s\n", e.what()); - return nullptr; - } -} - -void JitDumpLog(FILE *file, VMScriptFunction *sfunc) -{ - using namespace asmjit; - StringLogger logger; - try - { - auto *jit = JitGetRuntime(); - - ThrowingErrorHandler errorHandler; - CodeHolder code; - code.init(jit->getCodeInfo()); - code.setErrorHandler(&errorHandler); - code.setLogger(&logger); - - JitCompiler compiler(&code, sfunc); - compiler.Codegen(); - - fwrite(logger.getString(), logger.getLength(), 1, file); - } - catch (const std::exception &e) - { - fwrite(logger.getString(), logger.getLength(), 1, file); - - FString err; - err.Format("Unexpected JIT error: %s\n", e.what()); - fwrite(err.GetChars(), err.Len(), 1, file); - fclose(file); - - I_FatalError("Unexpected JIT error: %s\n", e.what()); - } -} - ///////////////////////////////////////////////////////////////////////////// static const char *OpNames[NUM_OPS] = @@ -132,7 +215,7 @@ static const char *OpNames[NUM_OPS] = #undef xx }; -void JitCompiler::Codegen() +asmjit::CCFunc *JitCompiler::Codegen() { Setup(); @@ -159,6 +242,8 @@ void JitCompiler::Codegen() cc.endFunc(); cc.finalize(); + + return func; } void JitCompiler::EmitOpcode() @@ -214,7 +299,7 @@ void JitCompiler::Setup() ret = cc.newIntPtr("ret"); // VMReturn *ret numret = cc.newInt32("numret"); // int numret - cc.addFunc(FuncSignature5()); + func = cc.addFunc(FuncSignature5()); cc.setArg(0, unusedFunc); cc.setArg(1, args); cc.setArg(2, numargs); diff --git a/src/scripting/vm/jit.h b/src/scripting/vm/jit.h index 8188aa9b0c..618f37d4ba 100644 --- a/src/scripting/vm/jit.h +++ b/src/scripting/vm/jit.h @@ -4,5 +4,4 @@ #include "vmintern.h" JitFuncPtr JitCompile(VMScriptFunction *func); -void JitCleanUp(VMScriptFunction *func); void JitDumpLog(FILE *file, VMScriptFunction *func); diff --git a/src/scripting/vm/jitintern.h b/src/scripting/vm/jitintern.h index bb8e0a9269..b011652401 100644 --- a/src/scripting/vm/jitintern.h +++ b/src/scripting/vm/jitintern.h @@ -30,7 +30,7 @@ class JitCompiler public: JitCompiler(asmjit::CodeHolder *code, VMScriptFunction *sfunc) : cc(code), sfunc(sfunc) { } - void Codegen(); + asmjit::CCFunc *Codegen(); private: // Declare EmitXX functions for the opcodes: @@ -158,6 +158,7 @@ private: asmjit::X86Compiler cc; VMScriptFunction *sfunc; + asmjit::CCFunc *func = nullptr; asmjit::X86Gp args; asmjit::X86Gp numargs; asmjit::X86Gp ret; diff --git a/src/scripting/vm/vmframe.cpp b/src/scripting/vm/vmframe.cpp index ffd2f93911..4ca4f3f972 100644 --- a/src/scripting/vm/vmframe.cpp +++ b/src/scripting/vm/vmframe.cpp @@ -79,12 +79,6 @@ VMScriptFunction::VMScriptFunction(FName name) VMScriptFunction::~VMScriptFunction() { - if (FunctionJitted) - { - JitCleanUp(this); - FunctionJitted = false; - } - if (Code != NULL) { if (KonstS != NULL) @@ -220,8 +214,6 @@ int VMScriptFunction::FirstScriptCall(VMFunction *func, VMValue *params, int num sfunc->ScriptCall = JitCompile(sfunc); if (!sfunc->ScriptCall) sfunc->ScriptCall = VMExec; - else - sfunc->FunctionJitted = true; return func->ScriptCall(func, params, numparams, ret, numret); } diff --git a/src/scripting/vm/vmintern.h b/src/scripting/vm/vmintern.h index 4291eb5c24..60850ee8ab 100644 --- a/src/scripting/vm/vmintern.h +++ b/src/scripting/vm/vmintern.h @@ -489,6 +489,4 @@ public: private: static int FirstScriptCall(VMFunction *func, VMValue *params, int numparams, VMReturn *ret, int numret); - - bool FunctionJitted = false; };