Cleanup and add guards against malformed MIDS files

This commit is contained in:
Cacodemon345 2020-09-14 15:38:26 +06:00
parent 5bfa76c948
commit bff02053be
2 changed files with 26 additions and 20 deletions

View file

@ -240,7 +240,7 @@ protected:
uint32_t *MakeEvents(uint32_t *events, uint32_t *max_events_p, uint32_t max_time) override; uint32_t *MakeEvents(uint32_t *events, uint32_t *max_events_p, uint32_t max_time) override;
private: private:
std::vector<uint32_t> midiBuffer; std::vector<uint32_t> MidsBuffer;
size_t MidsP, MaxMidsP; size_t MidsP, MaxMidsP;
int FormatFlags; int FormatFlags;

View file

@ -1,6 +1,6 @@
/* /*
** midisource_mids.cpp ** midisource_mids.cpp
** Code to let ZDoom play MIDS MIDI music through the MIDI streaming API. ** Code to let ZMusic play MIDS MIDI music through the MIDI streaming API.
** **
**--------------------------------------------------------------------------- **---------------------------------------------------------------------------
** Copyright 2020 Cacodemon345 ** Copyright 2020 Cacodemon345
@ -66,11 +66,19 @@
MIDSSong::MIDSSong(const uint8_t* data, size_t len) MIDSSong::MIDSSong(const uint8_t* data, size_t len)
{ {
if (len <= 52)
return;
if ((len % 4) != 0)
return;
// Validate the header first. // Validate the header first.
if (data[12] != 'f' || data[13] != 'm' || data[14] != 't' || data[15] != ' ') if (data[12] != 'f' || data[13] != 'm' || data[14] != 't' || data[15] != ' ')
{ {
return; return;
} }
int headerSize = LittleLong(GetInt(&data[16]));
if (headerSize != 12) return;
Division = LittleLong(GetInt(&data[20])); Division = LittleLong(GetInt(&data[20]));
FormatFlags = LittleLong(GetInt(&data[28])); FormatFlags = LittleLong(GetInt(&data[28]));
// Validate the data chunk. // Validate the data chunk.
@ -79,22 +87,21 @@ MIDSSong::MIDSSong(const uint8_t* data, size_t len)
return; return;
} }
int NumBlocks = LittleLong(GetInt(&data[40])); int NumBlocks = LittleLong(GetInt(&data[40]));
const uint8_t* midiData = &data[44]; const uint32_t* midiData = (const uint32_t*)&data[44];
uint32_t tkStart = 0; uint32_t tkStart = 0;
uint32_t cbBuffer = 0; uint32_t cbBuffer = 0;
while (NumBlocks-- > 0) while (NumBlocks-- > 0)
{ {
tkStart = LittleLong(GetInt(midiData)); tkStart = LittleLong(*midiData);
cbBuffer = LittleLong(GetInt(midiData + 4)); cbBuffer = LittleLong(*(midiData + 1));
std::vector<uint32_t> midiMessageData; midiData += 2;
midiMessageData.resize(cbBuffer / 4); if ((cbBuffer % (FormatFlags ? 8 : 12)) != 0) return;
memcpy(midiMessageData.data(), midiData + 8, cbBuffer); MidsBuffer.insert(MidsBuffer.end(), midiData, midiData + (cbBuffer / 4));
midiBuffer.insert(midiBuffer.end(), midiMessageData.begin(), midiMessageData.end()); midiData += cbBuffer / 4;
midiData += 8 + cbBuffer;
} }
MidsP = 0; MidsP = 0;
MaxMidsP = midiBuffer.size() - 1; MaxMidsP = MidsBuffer.size() - 1;
for (auto& curMidiData : midiBuffer) for (auto& curMidiData : MidsBuffer)
{ {
curMidiData = LittleLong(curMidiData); curMidiData = LittleLong(curMidiData);
} }
@ -151,9 +158,9 @@ void MIDSSong::DoRestart()
void MIDSSong::ProcessInitialTempoEvents() void MIDSSong::ProcessInitialTempoEvents()
{ {
if (MEVENT_EVENTTYPE(midiBuffer[FormatFlags ? 1 : 2]) == MEVENT_TEMPO) if (MEVENT_EVENTTYPE(MidsBuffer[FormatFlags ? 1 : 2]) == MEVENT_TEMPO)
{ {
SetTempo(MEVENT_EVENTPARM(midiBuffer[FormatFlags ? 1 : 2])); SetTempo(MEVENT_EVENTPARM(MidsBuffer[FormatFlags ? 1 : 2]));
} }
} }
@ -174,13 +181,12 @@ uint32_t* MIDSSong::MakeEvents(uint32_t* events, uint32_t* max_event_p, uint32_t
max_time = max_time * Division / Tempo; max_time = max_time * Division / Tempo;
while (events < max_event_p && tot_time <= max_time) while (events < max_event_p && tot_time <= max_time)
{ {
events[0] = time = midiBuffer[MidsP++]; events[0] = time = MidsBuffer[MidsP++];
events[1] = FormatFlags ? 0 : midiBuffer[MidsP++]; events[1] = FormatFlags ? 0 : MidsBuffer[MidsP++];
events[2] = midiBuffer[MidsP++]; events[2] = MidsBuffer[MidsP++];
events += 3; events += 3;
tot_time += time; tot_time += time;
if (MidsP >= MaxMidsP) goto end; if (MidsP >= MaxMidsP) break;
} }
end:
return events; return events;
} }