hle: Improve safety (#2778)

* timezone: Make timezone implementation safe

* hle: Do not use TrimEnd to parse ASCII strings

This adds an util that handle reading an ASCII string in a safe way.
Previously it was possible to read malformed data that could cause
various undefined behaviours in multiple services.

* hid: Remove an useless unsafe modifier on keyboard update

* Address gdkchan's comment

* Address gdkchan's comment
This commit is contained in:
Mary 2021-10-25 00:13:20 +02:00 committed by GitHub
parent b4dc33efc2
commit 51fa1b2cb0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 141 additions and 172 deletions

View file

@ -2,6 +2,7 @@
using Ryujinx.Common.Utilities;
using Ryujinx.HLE.Utilities;
using System;
using System.Buffers.Binary;
using System.IO;
using System.Runtime.InteropServices;
using System.Text;
@ -107,40 +108,24 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
public int TransitionTime;
}
private static int Detzcode32(byte[] bytes)
private static int Detzcode32(ReadOnlySpan<byte> bytes)
{
if (BitConverter.IsLittleEndian)
{
Array.Reverse(bytes, 0, bytes.Length);
}
return BitConverter.ToInt32(bytes, 0);
return BinaryPrimitives.ReadInt32BigEndian(bytes);
}
private static unsafe int Detzcode32(int* data)
private static int Detzcode32(int value)
{
int result = *data;
if (BitConverter.IsLittleEndian)
{
byte[] bytes = BitConverter.GetBytes(result);
Array.Reverse(bytes, 0, bytes.Length);
result = BitConverter.ToInt32(bytes, 0);
return BinaryPrimitives.ReverseEndianness(value);
}
return result;
return value;
}
private static unsafe long Detzcode64(long* data)
private static long Detzcode64(ReadOnlySpan<byte> bytes)
{
long result = *data;
if (BitConverter.IsLittleEndian)
{
byte[] bytes = BitConverter.GetBytes(result);
Array.Reverse(bytes, 0, bytes.Length);
result = BitConverter.ToInt64(bytes, 0);
}
return result;
return BinaryPrimitives.ReadInt64BigEndian(bytes);
}
private static bool DifferByRepeat(long t1, long t0)
@ -148,7 +133,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
return (t1 - t0) == SecondsPerRepeat;
}
private static unsafe bool TimeTypeEquals(TimeZoneRule outRules, byte aIndex, byte bIndex)
private static bool TimeTypeEquals(TimeZoneRule outRules, byte aIndex, byte bIndex)
{
if (aIndex < 0 || aIndex >= outRules.TypeCount || bIndex < 0 || bIndex >= outRules.TypeCount)
{
@ -158,17 +143,14 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
TimeTypeInfo a = outRules.Ttis[aIndex];
TimeTypeInfo b = outRules.Ttis[bIndex];
fixed (char* chars = outRules.Chars)
{
return a.GmtOffset == b.GmtOffset &&
a.IsDaySavingTime == b.IsDaySavingTime &&
a.IsStandardTimeDaylight == b.IsStandardTimeDaylight &&
a.IsGMT == b.IsGMT &&
StringUtils.CompareCStr(chars + a.AbbreviationListIndex, chars + b.AbbreviationListIndex) == 0;
}
return a.GmtOffset == b.GmtOffset &&
a.IsDaySavingTime == b.IsDaySavingTime &&
a.IsStandardTimeDaylight == b.IsStandardTimeDaylight &&
a.IsGMT == b.IsGMT &&
StringUtils.CompareCStr(outRules.Chars[a.AbbreviationListIndex..], outRules.Chars[b.AbbreviationListIndex..]) == 0;
}
private static int GetQZName(char[] name, int namePosition, char delimiter)
private static int GetQZName(ReadOnlySpan<char> name, int namePosition, char delimiter)
{
int i = namePosition;
@ -403,7 +385,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
return 0;
}
private static bool ParsePosixName(Span<char> name, out TimeZoneRule outRules, bool lastDitch)
private static bool ParsePosixName(ReadOnlySpan<char> name, out TimeZoneRule outRules, bool lastDitch)
{
outRules = new TimeZoneRule
{
@ -414,9 +396,10 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
};
int stdLen;
Span<char> stdName = name;
int namePosition = 0;
int stdOffset = 0;
ReadOnlySpan<char> stdName = name;
int namePosition = 0;
int stdOffset = 0;
if (lastDitch)
{
@ -433,7 +416,8 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
int stdNamePosition = namePosition;
namePosition = GetQZName(name.ToArray(), namePosition, '>');
namePosition = GetQZName(name, namePosition, '>');
if (name[namePosition] != '>')
{
return false;
@ -465,7 +449,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
int destLen = 0;
int dstOffset = 0;
Span<char> destName = name.Slice(namePosition);
ReadOnlySpan<char> destName = name.Slice(namePosition);
if (TzCharsArraySize < charCount)
{
@ -903,7 +887,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
return ParsePosixName(name.ToCharArray(), out outRules, false);
}
internal static unsafe bool ParseTimeZoneBinary(out TimeZoneRule outRules, Stream inputData)
internal static bool ParseTimeZoneBinary(out TimeZoneRule outRules, Stream inputData)
{
outRules = new TimeZoneRule
{
@ -967,12 +951,11 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
timeCount = 0;
fixed (byte* workBufferPtrStart = workBuffer)
{
byte* p = workBufferPtrStart;
Span<byte> p = workBuffer;
for (int i = 0; i < outRules.TimeCount; i++)
{
long at = Detzcode64((long*)p);
long at = Detzcode64(p);
outRules.Types[i] = 1;
if (timeCount != 0 && at <= outRules.Ats[timeCount - 1])
@ -988,13 +971,15 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
outRules.Ats[timeCount++] = at;
p += TimeTypeSize;
p = p[TimeTypeSize..];
}
timeCount = 0;
for (int i = 0; i < outRules.TimeCount; i++)
{
byte type = *p++;
byte type = p[0];
p = p[1..];
if (outRules.TypeCount <= type)
{
return false;
@ -1011,18 +996,20 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
for (int i = 0; i < outRules.TypeCount; i++)
{
TimeTypeInfo ttis = outRules.Ttis[i];
ttis.GmtOffset = Detzcode32((int*)p);
p += 4;
ttis.GmtOffset = Detzcode32(p);
p = p[sizeof(int)..];
if (*p >= 2)
if (p[0] >= 2)
{
return false;
}
ttis.IsDaySavingTime = *p != 0;
p++;
ttis.IsDaySavingTime = p[0] != 0;
p = p[1..];
int abbreviationListIndex = p[0];
p = p[1..];
int abbreviationListIndex = *p++;
if (abbreviationListIndex >= outRules.CharCount)
{
return false;
@ -1033,12 +1020,9 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
outRules.Ttis[i] = ttis;
}
fixed (char* chars = outRules.Chars)
{
Encoding.ASCII.GetChars(p, outRules.CharCount, chars, outRules.CharCount);
}
Encoding.ASCII.GetChars(p[..outRules.CharCount].ToArray()).CopyTo(outRules.Chars.AsSpan());
p += outRules.CharCount;
p = p[outRules.CharCount..];
outRules.Chars[outRules.CharCount] = '\0';
for (int i = 0; i < outRules.TypeCount; i++)
@ -1049,14 +1033,14 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
}
else
{
if (*p >= 2)
if (p[0] >= 2)
{
return false;
}
outRules.Ttis[i].IsStandardTimeDaylight = *p++ != 0;
outRules.Ttis[i].IsStandardTimeDaylight = p[0] != 0;
p = p[1..];
}
}
for (int i = 0; i < outRules.TypeCount; i++)
@ -1067,17 +1051,18 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
}
else
{
if (*p >= 2)
if (p[0] >= 2)
{
return false;
}
outRules.Ttis[i].IsGMT = *p++ != 0;
outRules.Ttis[i].IsGMT = p[0] != 0;
p = p[1..];
}
}
long position = (p - workBufferPtrStart);
long position = (workBuffer.Length - p.Length);
long nRead = streamLength - position;
if (nRead < 0)
@ -1107,77 +1092,75 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
int abbreviationCount = 0;
charCount = outRules.CharCount;
fixed (char* chars = outRules.Chars)
Span<char> chars = outRules.Chars;
for (int i = 0; i < tempRules.TypeCount; i++)
{
for (int i = 0; i < tempRules.TypeCount; i++)
ReadOnlySpan<char> tempChars = tempRules.Chars;
ReadOnlySpan<char> tempAbbreviation = tempChars[tempRules.Ttis[i].AbbreviationListIndex..];
int j;
for (j = 0; j < charCount; j++)
{
fixed (char* tempChars = tempRules.Chars)
if (StringUtils.CompareCStr(chars[j..], tempAbbreviation) == 0)
{
char* tempAbbreviation = tempChars + tempRules.Ttis[i].AbbreviationListIndex;
int j;
for (j = 0; j < charCount; j++)
{
if (StringUtils.CompareCStr(chars + j, tempAbbreviation) == 0)
{
tempRules.Ttis[i].AbbreviationListIndex = j;
abbreviationCount++;
break;
}
}
if (j >= charCount)
{
int abbreviationLength = StringUtils.LengthCstr(tempAbbreviation);
if (j + abbreviationLength < TzMaxChars)
{
for (int x = 0; x < abbreviationLength; x++)
{
chars[j + x] = tempAbbreviation[x];
}
charCount = j + abbreviationLength + 1;
tempRules.Ttis[i].AbbreviationListIndex = j;
abbreviationCount++;
}
}
tempRules.Ttis[i].AbbreviationListIndex = j;
abbreviationCount++;
break;
}
}
if (abbreviationCount == tempRules.TypeCount)
if (j >= charCount)
{
outRules.CharCount = charCount;
// Remove trailing
while (1 < outRules.TimeCount && (outRules.Types[outRules.TimeCount - 1] == outRules.Types[outRules.TimeCount - 2]))
int abbreviationLength = StringUtils.LengthCstr(tempAbbreviation);
if (j + abbreviationLength < TzMaxChars)
{
outRules.TimeCount--;
}
int i;
for (i = 0; i < tempRules.TimeCount; i++)
{
if (outRules.TimeCount == 0 || outRules.Ats[outRules.TimeCount - 1] < tempRules.Ats[i])
for (int x = 0; x < abbreviationLength; x++)
{
break;
chars[j + x] = tempAbbreviation[x];
}
}
while (i < tempRules.TimeCount && outRules.TimeCount < TzMaxTimes)
charCount = j + abbreviationLength + 1;
tempRules.Ttis[i].AbbreviationListIndex = j;
abbreviationCount++;
}
}
}
if (abbreviationCount == tempRules.TypeCount)
{
outRules.CharCount = charCount;
// Remove trailing
while (1 < outRules.TimeCount && (outRules.Types[outRules.TimeCount - 1] == outRules.Types[outRules.TimeCount - 2]))
{
outRules.TimeCount--;
}
int i;
for (i = 0; i < tempRules.TimeCount; i++)
{
if (outRules.TimeCount == 0 || outRules.Ats[outRules.TimeCount - 1] < tempRules.Ats[i])
{
outRules.Ats[outRules.TimeCount] = tempRules.Ats[i];
outRules.Types[outRules.TimeCount] = (byte)(outRules.TypeCount + (byte)tempRules.Types[i]);
outRules.TimeCount++;
i++;
break;
}
}
for (i = 0; i < tempRules.TypeCount; i++)
{
outRules.Ttis[outRules.TypeCount++] = tempRules.Ttis[i];
}
while (i < tempRules.TimeCount && outRules.TimeCount < TzMaxTimes)
{
outRules.Ats[outRules.TimeCount] = tempRules.Ats[i];
outRules.Types[outRules.TimeCount] = (byte)(outRules.TypeCount + (byte)tempRules.Types[i]);
outRules.TimeCount++;
i++;
}
for (i = 0; i < tempRules.TypeCount; i++)
{
outRules.Ttis[outRules.TypeCount++] = tempRules.Ttis[i];
}
}
}
@ -1467,17 +1450,11 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
{
calendarAdditionalInfo.IsDaySavingTime = rules.Ttis[ttiIndex].IsDaySavingTime;
unsafe
{
fixed (char* timeZoneAbbreviation = &rules.Chars[rules.Ttis[ttiIndex].AbbreviationListIndex])
{
int timeZoneSize = Math.Min(StringUtils.LengthCstr(timeZoneAbbreviation), 8);
for (int i = 0; i < timeZoneSize; i++)
{
calendarAdditionalInfo.TimezoneName[i] = timeZoneAbbreviation[i];
}
}
}
ReadOnlySpan<char> timeZoneAbbreviation = rules.Chars.AsSpan()[rules.Ttis[ttiIndex].AbbreviationListIndex..];
int timeZoneSize = Math.Min(StringUtils.LengthCstr(timeZoneAbbreviation), 8);
timeZoneAbbreviation[..timeZoneSize].CopyTo(calendarAdditionalInfo.TimezoneName.AsSpan());
}
return result;