diff --git a/src/d_dehacked.cpp b/src/d_dehacked.cpp index 8258a656e..eef9b137a 100644 --- a/src/d_dehacked.cpp +++ b/src/d_dehacked.cpp @@ -3030,7 +3030,6 @@ void FinishDehPatch () while (subclass == nullptr); AActor *defaults2 = GetDefaultByType (subclass); - memcpy ((void *)defaults2, (void *)defaults1, sizeof(AActor)); // Make a copy of the replaced class's state labels FStateDefinitions statedef; diff --git a/src/dobjtype.cpp b/src/dobjtype.cpp index 0e17299cd..b7e666fc1 100644 --- a/src/dobjtype.cpp +++ b/src/dobjtype.cpp @@ -156,14 +156,14 @@ PClassType::PClassType() //========================================================================== // -// PClassType :: Derive +// PClassType :: DeriveData // //========================================================================== -void PClassType::Derive(PClass *newclass) +void PClassType::DeriveData(PClass *newclass) { assert(newclass->IsKindOf(RUNTIME_CLASS(PClassType))); - Super::Derive(newclass); + Super::DeriveData(newclass); static_cast(newclass)->TypeTableType = TypeTableType; } @@ -3030,12 +3030,14 @@ void PClass::DestroySpecials(void *addr) const // //========================================================================== -void PClass::Derive(PClass *newclass) +void PClass::Derive(PClass *newclass, FName name) { + newclass->bRuntimeClass = true; newclass->ParentClass = this; newclass->ConstructNative = ConstructNative; newclass->Symbols.SetParentTable(&this->Symbols); - newclass->InitializeDefaults(); + newclass->TypeName = name; + newclass->mDescriptiveName.Format("Class<%s>", name.GetChars()); } //========================================================================== @@ -3082,6 +3084,18 @@ void PClass::InitializeDefaults() } } +//========================================================================== +// +// PClass :: DeriveData +// +// Copies inheritable data to the child class. +// +//========================================================================== + +void PClass::DeriveData(PClass *newclass) +{ +} + //========================================================================== // // PClass :: CreateDerivedClass @@ -3126,11 +3140,10 @@ PClass *PClass::CreateDerivedClass(FName name, unsigned int size) // Create a new type object of the same type as us. (We may be a derived class of PClass.) type = static_cast(GetClass()->CreateNew()); - type->TypeName = name; type->Size = size; - type->bRuntimeClass = true; - type->mDescriptiveName.Format("Class<%s>", name.GetChars()); - Derive(type); + Derive(type, name); + type->InitializeDefaults(); + type->Virtuals = Virtuals; DeriveData(type); if (!notnew) { @@ -3178,8 +3191,7 @@ PField *PClass::AddField(FName name, PType *type, DWORD flags) // PClass :: FindClassTentative // // Like FindClass but creates a placeholder if no class is found. -// CreateDerivedClass will automatically fill in the placeholder when the -// actual class is defined. +// This will be filled in when the actual class is constructed. // //========================================================================== @@ -3201,17 +3213,59 @@ PClass *PClass::FindClassTentative(FName name, bool fatal) PClass *type = static_cast(GetClass()->CreateNew()); DPrintf(DMSG_SPAMMY, "Creating placeholder class %s : %s\n", name.GetChars(), TypeName.GetChars()); - type->TypeName = name; - type->ParentClass = this; - type->ConstructNative = ConstructNative; + Derive(type, name); type->Size = TentativeClass; - type->bRuntimeClass = true; - type->mDescriptiveName.Format("Class<%s>", name.GetChars()); - type->Symbols.SetParentTable(&Symbols); TypeTable.AddType(type, RUNTIME_CLASS(PClass), (intptr_t)type->Outer, name, bucket); return type; } +//========================================================================== +// +// PClass :: FindVirtualIndex +// +// Compares a prototype with the existing list of virtual functions +// and returns an index if something matching is found. +// +//========================================================================== + +int PClass::FindVirtualIndex(FName name, PPrototype *proto) +{ + for (unsigned i = 0; i < Virtuals.Size(); i++) + { + if (Virtuals[i]->Name == name) + { + auto vproto = Virtuals[i]->Proto; + if (vproto->ReturnTypes.Size() != proto->ReturnTypes.Size() || + vproto->ArgumentTypes.Size() != proto->ArgumentTypes.Size()) + { + continue; // number of parameters does not match, so it's incompatible + } + bool fail = false; + // The first argument is self and will mismatch so just skip it. + for (unsigned a = 1; a < proto->ArgumentTypes.Size(); a++) + { + if (proto->ArgumentTypes[a] != vproto->ArgumentTypes[a]) + { + fail = true; + break; + } + } + if (fail) continue; + + for (unsigned a = 0; a < proto->ReturnTypes.Size(); a++) + { + if (proto->ReturnTypes[a] != vproto->ReturnTypes[a]) + { + fail = true; + break; + } + } + if (!fail) return i; + } + } + return -1; +} + //========================================================================== // // PClass :: BuildFlatPointers diff --git a/src/dobjtype.h b/src/dobjtype.h index 62860d0bd..7746dce29 100644 --- a/src/dobjtype.h +++ b/src/dobjtype.h @@ -31,6 +31,7 @@ enum VARF_Implicit = (1<<12), // implicitly created parameters (i.e. do not compare types when checking function signatures) VARF_Static = (1<<13), // static class data (by necessity read only.) VARF_InternalAccess = (1<<14), // overrides VARF_ReadOnly for internal script code. + VARF_Override = (1<<15), // overrides a virtual function from the parent class. }; // Symbol information ------------------------------------------------------- @@ -756,7 +757,7 @@ protected: // We unravel _WITH_META here just as we did for PType. enum { MetaClassNum = CLASSREG_PClassClass }; TArray SpecialInits; - virtual void Derive(PClass *newclass); + void Derive(PClass *newclass, FName name); void InitializeSpecials(void *addr) const; void SetSuper(); public: @@ -768,8 +769,9 @@ public: bool ReadValue(FSerializer &ar, const char *key,void *addr) const override; bool ReadAllFields(FSerializer &ar, void *addr) const; void InitializeDefaults(); + int FindVirtualIndex(FName name, PPrototype *proto); + virtual void DeriveData(PClass *newclass); - virtual void DeriveData(PClass *newclass) {} static void StaticInit(); static void StaticShutdown(); static void StaticBootstrap(); @@ -782,6 +784,7 @@ public: BYTE *Defaults; bool bRuntimeClass; // class was defined at run-time, not compile-time bool bExported; // This type has been declared in a script + TArray Virtuals; // virtual function table void (*ConstructNative)(void *); @@ -838,7 +841,7 @@ class PClassType : public PClass protected: public: PClassType(); - virtual void Derive(PClass *newclass); + virtual void DeriveData(PClass *newclass); PClass *TypeTableType; // The type to use for hashing into the type table }; diff --git a/src/sc_man_scanner.re b/src/sc_man_scanner.re index bc0a4b2e2..194582d73 100644 --- a/src/sc_man_scanner.re +++ b/src/sc_man_scanner.re @@ -168,6 +168,7 @@ std2: 'optional' { RET(TK_Optional); } 'export' { RET(TK_Export); } 'virtual' { RET(TK_Virtual); } + 'override' { RET(TK_Override); } 'super' { RET(TK_Super); } 'global' { RET(TK_Global); } 'stop' { RET(TK_Stop); } diff --git a/src/sc_man_tokens.h b/src/sc_man_tokens.h index 2b93c9ac2..801d02be3 100644 --- a/src/sc_man_tokens.h +++ b/src/sc_man_tokens.h @@ -105,6 +105,7 @@ xx(TK_Iterator, "'iterator'") xx(TK_Optional, "'optional'") xx(TK_Export, "'expert'") xx(TK_Virtual, "'virtual'") +xx(TK_Override, "'override'") xx(TK_Super, "'super'") xx(TK_Null, "'null'") xx(TK_Global, "'global'") diff --git a/src/scripting/codegeneration/codegen.cpp b/src/scripting/codegeneration/codegen.cpp index 316ae4246..49b1d6bae 100644 --- a/src/scripting/codegeneration/codegen.cpp +++ b/src/scripting/codegeneration/codegen.cpp @@ -6693,12 +6693,16 @@ ExpEmit FxVMFunctionCall::Emit(VMFunctionBuilder *build) } } + VMFunction *vmfunc = Function->Variants[0].Implementation; + bool staticcall = (vmfunc->Final || vmfunc->VirtualIndex == -1 || NoVirtual); + count = 0; // Emit code to pass implied parameters + ExpEmit selfemit; if (Function->Variants[0].Flags & VARF_Method) { assert(Self != nullptr); - ExpEmit selfemit = Self->Emit(build); + selfemit = Self->Emit(build); assert(selfemit.RegType == REGT_POINTER); build->Emit(OP_PARAM, 0, selfemit.RegType, selfemit.RegNum); count += 1; @@ -6718,8 +6722,9 @@ ExpEmit FxVMFunctionCall::Emit(VMFunctionBuilder *build) } count += 2; } - selfemit.Free(build); + if (staticcall) selfemit.Free(build); } + else staticcall = true; // Emit code to pass explicit parameters for (unsigned i = 0; i < ArgList.Size(); ++i) { @@ -6729,27 +6734,56 @@ ExpEmit FxVMFunctionCall::Emit(VMFunctionBuilder *build) ArgList.ShrinkToFit(); // Get a constant register for this function - VMFunction *vmfunc = Function->Variants[0].Implementation; - int funcaddr = build->GetConstantAddress(vmfunc, ATAG_OBJECT); - // Emit the call - if (EmitTail) - { // Tail call - build->Emit(OP_TAIL_K, funcaddr, count, 0); - ExpEmit call; - call.Final = true; - return call; - } - else if (vmfunc->Proto->ReturnTypes.Size() > 0) - { // Call, expecting one result - ExpEmit reg(build, vmfunc->Proto->ReturnTypes[0]->GetRegType(), vmfunc->Proto->ReturnTypes[0]->GetRegCount()); - build->Emit(OP_CALL_K, funcaddr, count, 1); - build->Emit(OP_RESULT, 0, EncodeRegType(reg), reg.RegNum); - return reg; + if (staticcall) + { + int funcaddr = build->GetConstantAddress(vmfunc, ATAG_OBJECT); + // Emit the call + if (EmitTail) + { // Tail call + build->Emit(OP_TAIL_K, funcaddr, count, 0); + ExpEmit call; + call.Final = true; + return call; + } + else if (vmfunc->Proto->ReturnTypes.Size() > 0) + { // Call, expecting one result + ExpEmit reg(build, vmfunc->Proto->ReturnTypes[0]->GetRegType(), vmfunc->Proto->ReturnTypes[0]->GetRegCount()); + build->Emit(OP_CALL_K, funcaddr, count, 1); + build->Emit(OP_RESULT, 0, EncodeRegType(reg), reg.RegNum); + return reg; + } + else + { // Call, expecting no results + build->Emit(OP_CALL_K, funcaddr, count, 0); + return ExpEmit(); + } } else - { // Call, expecting no results - build->Emit(OP_CALL_K, funcaddr, count, 0); - return ExpEmit(); + { + selfemit.Free(build); + ExpEmit funcreg(build, REGT_POINTER); + build->Emit(OP_VTBL, funcreg.RegNum, selfemit.RegNum, vmfunc->VirtualIndex); + if (EmitTail) + { // Tail call + build->Emit(OP_TAIL, funcreg.RegNum, count, 0); + ExpEmit call; + call.Final = true; + return call; + } + else if (vmfunc->Proto->ReturnTypes.Size() > 0) + { // Call, expecting one result + ExpEmit reg(build, vmfunc->Proto->ReturnTypes[0]->GetRegType(), vmfunc->Proto->ReturnTypes[0]->GetRegCount()); + build->Emit(OP_CALL, funcreg.RegNum, count, 1); + build->Emit(OP_RESULT, 0, EncodeRegType(reg), reg.RegNum); + return reg; + } + else + { // Call, expecting no results + build->Emit(OP_CALL, funcreg.RegNum, count, 0); + return ExpEmit(); + } + + } } diff --git a/src/scripting/vm/vm.h b/src/scripting/vm/vm.h index abce2f557..6c7878a9a 100644 --- a/src/scripting/vm/vm.h +++ b/src/scripting/vm/vm.h @@ -655,7 +655,9 @@ class VMFunction : public DObject HAS_OBJECT_POINTERS; public: bool Native; + bool Final = false; // cannot be overridden BYTE ImplicitArgs = 0; // either 0 for static, 1 for method or 3 for action + int VirtualIndex = -1; FName Name; TArray DefaultArgs; diff --git a/src/scripting/vm/vmdisasm.cpp b/src/scripting/vm/vmdisasm.cpp index 239a0dc51..49c929f05 100644 --- a/src/scripting/vm/vmdisasm.cpp +++ b/src/scripting/vm/vmdisasm.cpp @@ -84,6 +84,7 @@ #define RPI8 MODE_AP | MODE_BIMMZ | MODE_CUNUSED #define KPI8 MODE_AKP | MODE_BIMMZ | MODE_CUNUSED #define RPI8I8 MODE_AP | MODE_BIMMZ | MODE_CIMMZ +#define RPRPI8 MODE_AP | MODE_BP | MODE_CIMMZ #define KPI8I8 MODE_AKP | MODE_BIMMZ | MODE_CIMMZ #define I8BCP MODE_AIMMZ | MODE_BCJOINT | MODE_BCPARAM #define THROW MODE_AIMMZ | MODE_BCTHROW diff --git a/src/scripting/vm/vmexec.h b/src/scripting/vm/vmexec.h index 69f406dd8..a6eda2a24 100644 --- a/src/scripting/vm/vmexec.h +++ b/src/scripting/vm/vmexec.h @@ -525,6 +525,15 @@ begin: } } NEXTOP; + OP(VTBL): + ASSERTA(a); ASSERTA(B); + { + auto o = (DObject*)reg.a[B]; + auto p = o->GetClass(); + assert(C < p->Virtuals.Size()); + reg.a[a] = p->Virtuals[C]; + } + NEXTOP; OP(CALL_K): ASSERTKA(a); assert(konstatag[a] == ATAG_OBJECT); diff --git a/src/scripting/vm/vmops.h b/src/scripting/vm/vmops.h index e7ff19c22..368f143bd 100644 --- a/src/scripting/vm/vmops.h +++ b/src/scripting/vm/vmops.h @@ -82,6 +82,7 @@ xx(PARAM, param, __BCP), // push parameter encoded in BC for function call (B=r xx(PARAMI, parami, I24), // push immediate, signed integer for function call xx(CALL, call, RPI8I8), // Call function pkA with parameter count B and expected result count C xx(CALL_K, call, KPI8I8), +xx(VTBL, vtbl, RPRPI8), // dereferences a virtual method table. xx(TAIL, tail, RPI8), // Call+Ret in a single instruction xx(TAIL_K, tail, KPI8), xx(RESULT, result, __BCP), // Result should go in register encoded in BC (in caller, after CALL) diff --git a/src/scripting/zscript/zcc-parse.lemon b/src/scripting/zscript/zcc-parse.lemon index fc707428e..6caef777e 100644 --- a/src/scripting/zscript/zcc-parse.lemon +++ b/src/scripting/zscript/zcc-parse.lemon @@ -905,6 +905,8 @@ decl_flags(X) ::= decl_flags(A) META(T). { X.Int = A.Int | ZCC_Meta; X.SourceLo decl_flags(X) ::= decl_flags(A) ACTION(T). { X.Int = A.Int | ZCC_Action; X.SourceLoc = A.SourceLoc ? A.SourceLoc : T.SourceLoc; } decl_flags(X) ::= decl_flags(A) READONLY(T). { X.Int = A.Int | ZCC_ReadOnly; X.SourceLoc = A.SourceLoc ? A.SourceLoc : T.SourceLoc; } decl_flags(X) ::= decl_flags(A) DEPRECATED(T). { X.Int = A.Int | ZCC_Deprecated; X.SourceLoc = A.SourceLoc ? A.SourceLoc : T.SourceLoc; } +decl_flags(X) ::= decl_flags(A) VIRTUAL(T). { X.Int = A.Int | ZCC_Virtual; X.SourceLoc = A.SourceLoc ? A.SourceLoc : T.SourceLoc; } +decl_flags(X) ::= decl_flags(A) OVERRIDE(T). { X.Int = A.Int | ZCC_Override; X.SourceLoc = A.SourceLoc ? A.SourceLoc : T.SourceLoc; } func_const(X) ::= . { X.Int = 0; X.SourceLoc = stat->sc->GetMessageLine(); } func_const(X) ::= CONST(T). { X.Int = ZCC_FuncConst; X.SourceLoc = T.SourceLoc; } diff --git a/src/scripting/zscript/zcc_compile.cpp b/src/scripting/zscript/zcc_compile.cpp index 7d338cc69..6c837cead 100644 --- a/src/scripting/zscript/zcc_compile.cpp +++ b/src/scripting/zscript/zcc_compile.cpp @@ -1258,7 +1258,7 @@ bool ZCCCompiler::CompileFields(PStruct *type, TArray &Fiel PType *fieldtype = DetermineType(type, field, field->Names->Name, field->Type, true, true); // For structs only allow 'deprecated', for classes exclude function qualifiers. - int notallowed = forstruct? ~ZCC_Deprecated : ZCC_Latent | ZCC_Final | ZCC_Action | ZCC_Static | ZCC_FuncConst | ZCC_Abstract; + int notallowed = forstruct? ~ZCC_Deprecated : ZCC_Latent | ZCC_Final | ZCC_Action | ZCC_Static | ZCC_FuncConst | ZCC_Abstract | ZCC_Virtual | ZCC_Override; if (field->Flags & notallowed) { @@ -1894,6 +1894,7 @@ void ZCCCompiler::InitDefaults() if (!c->Type()->IsDescendantOf(RUNTIME_CLASS(AActor))) { if (c->Defaults.Size()) Error(c->cls, "%s: Non-actor classes may not have defaults", c->Type()->TypeName.GetChars()); + if (c->Type()->ParentClass) c->Type()->ParentClass->DeriveData(c->Type()); } else { @@ -1989,6 +1990,16 @@ void ZCCCompiler::InitFunctions() for (auto c : Classes) { + // cannot be done earlier because it requires the parent class to be processed by this code, too. + if (c->Type()->ParentClass != nullptr) + { + if (c->Type()->ParentClass->Virtuals.Size() == 0) + { + // This a VMClass which didn't get processed here. + c->Type()->ParentClass->Virtuals = c->Type()->ParentClass->ParentClass->Virtuals; + } + c->Type()->Virtuals = c->Type()->ParentClass->Virtuals; + } for (auto f : c->Functions) { rets.Clear(); @@ -2030,13 +2041,32 @@ void ZCCCompiler::InitFunctions() if (f->Flags & ZCC_Private) varflags |= VARF_Private; if (f->Flags & ZCC_Protected) varflags |= VARF_Protected; if (f->Flags & ZCC_Deprecated) varflags |= VARF_Deprecated; + if (f->Flags & ZCC_Virtual) varflags |= VARF_Virtual; + if (f->Flags & ZCC_Override) varflags |= VARF_Override; if (f->Flags & ZCC_Action) varflags |= VARF_Action|VARF_Final, implicitargs = 3; // Action implies Final. if (f->Flags & ZCC_Static) varflags = (varflags & ~VARF_Method) | VARF_Final, implicitargs = 0; // Static implies Final. - if ((f->Flags & (ZCC_Action | ZCC_Static)) == (ZCC_Action | ZCC_Static)) + + if (varflags & VARF_Override) varflags &= ~VARF_Virtual; // allow 'virtual override'. + // Only one of these flags may be used. + static int exclude[] = { ZCC_Virtual, ZCC_Override, ZCC_Action, ZCC_Static }; + static const char * print[] = { "virtual", "override", "action", "static" }; + int fc = 0; + FString build; + for (int i = 0; i < 4; i++) { - Error(f, "%s: Action and Static on the same function is not allowed.", FName(f->Name).GetChars()); + if (f->Flags & exclude[i]) + { + fc++; + if (build.Len() > 0) build += ", "; + build += print[i]; + } + } + if (fc > 1) + { + Error(f, "Invalid combination of qualifiers %s on function %s.", FName(f->Name).GetChars(), build.GetChars() ); varflags |= VARF_Method; } + if (varflags & VARF_Override) varflags |= VARF_Virtual; // Now that the flags are checked, make all override functions virtual as well. if (f->Flags & ZCC_Native) { @@ -2193,7 +2223,41 @@ void ZCCCompiler::InitFunctions() { sym->Variants[0].Implementation->DefaultArgs = std::move(argdefaults); } - // todo: Check inheritance. + + if (varflags & VARF_Virtual) + { + if (varflags & VARF_Final) + { + sym->Variants[0].Implementation->Final = true; + } + int vindex = c->Type()->FindVirtualIndex(sym->SymbolName, sym->Variants[0].Proto); + // specifying 'override' is necessary to prevent one of the biggest problem spots with virtual inheritance: Mismatching argument types. + if (varflags & VARF_Override) + { + if (vindex == -1) + { + Error(p, "Attempt to override non-existent virtual function %s", FName(f->Name).GetChars()); + } + else + { + auto oldfunc = c->Type()->Virtuals[vindex]; + if (oldfunc->Final) + { + Error(p, "Attempt to override final function %s", FName(f->Name).GetChars()); + } + c->Type()->Virtuals[vindex] = sym->Variants[0].Implementation; + sym->Variants[0].Implementation->VirtualIndex = vindex; + } + } + else + { + if (vindex != -1) + { + Error(p, "Function %s attempts to override parent function without 'override' qualifier", FName(f->Name).GetChars()); + } + sym->Variants[0].Implementation->VirtualIndex = c->Type()->Virtuals.Push(sym->Variants[0].Implementation); + } + } } } } diff --git a/src/scripting/zscript/zcc_parser.cpp b/src/scripting/zscript/zcc_parser.cpp index c84b72c66..098a38d82 100644 --- a/src/scripting/zscript/zcc_parser.cpp +++ b/src/scripting/zscript/zcc_parser.cpp @@ -123,6 +123,8 @@ static void InitTokenMap() TOKENDEF (TK_Private, ZCC_PRIVATE); TOKENDEF (TK_Protected, ZCC_PROTECTED); TOKENDEF (TK_Latent, ZCC_LATENT); + TOKENDEF (TK_Virtual, ZCC_VIRTUAL); + TOKENDEF (TK_Override, ZCC_OVERRIDE); TOKENDEF (TK_Final, ZCC_FINAL); TOKENDEF (TK_Meta, ZCC_META); TOKENDEF (TK_Deprecated, ZCC_DEPRECATED); diff --git a/src/scripting/zscript/zcc_parser.h b/src/scripting/zscript/zcc_parser.h index 01ea0d2d0..ec8bd956a 100644 --- a/src/scripting/zscript/zcc_parser.h +++ b/src/scripting/zscript/zcc_parser.h @@ -33,6 +33,8 @@ enum ZCC_FuncConst = 1 << 10, ZCC_Abstract = 1 << 11, ZCC_Extension = 1 << 12, + ZCC_Virtual = 1 << 13, + ZCC_Override = 1 << 14, }; // Function parameter modifiers diff --git a/wadsrc/static/zscript/base.txt b/wadsrc/static/zscript/base.txt index 4564f0e86..33c1c6135 100644 --- a/wadsrc/static/zscript/base.txt +++ b/wadsrc/static/zscript/base.txt @@ -1,6 +1,6 @@ class Object native { - /*virtual*/ native void Destroy(); + virtual native void Destroy(); native class GetClass(); }