57 lines
2.3 KiB
Diff
57 lines
2.3 KiB
Diff
|
|
diff --git a/Tensile/Source/lib/source/msgpack/MessagePack.cpp b/Tensile/Source/lib/source/msgpack/MessagePack.cpp
|
||
|
|
index de97929c..dbc397e0 100644
|
||
|
|
--- a/tensilelite/src/msgpack/MessagePack.cpp
|
||
|
|
+++ b/tensilelite/src/msgpack/MessagePack.cpp
|
||
|
|
@@ -28,6 +28,8 @@
|
||
|
|
|
||
|
|
#include <Tensile/msgpack/Loading.hpp>
|
||
|
|
|
||
|
|
+#include <zstd.h>
|
||
|
|
+
|
||
|
|
#include <fstream>
|
||
|
|
|
||
|
|
namespace Tensile
|
||
|
|
@@ -86,6 +88,34 @@ namespace Tensile
|
||
|
|
return nullptr;
|
||
|
|
}
|
||
|
|
|
||
|
|
+ // Check if the file is zstd compressed
|
||
|
|
+ char magic[4];
|
||
|
|
+ in.read(magic, 4);
|
||
|
|
+ bool isCompressed = (in.gcount() == 4 && magic[0] == '\x28' && magic[1] == '\xB5' && magic[2] == '\x2F' && magic[3] == '\xFD');
|
||
|
|
+ // Reset file pointer to the beginning
|
||
|
|
+ in.seekg(0, std::ios::beg);
|
||
|
|
+
|
||
|
|
+ if (isCompressed) {
|
||
|
|
+ // Decompress zstd file
|
||
|
|
+ std::vector<char> compressedData((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());
|
||
|
|
+
|
||
|
|
+ size_t decompressedSize = ZSTD_getFrameContentSize(compressedData.data(), compressedData.size());
|
||
|
|
+ if (decompressedSize == ZSTD_CONTENTSIZE_ERROR || decompressedSize == ZSTD_CONTENTSIZE_UNKNOWN) {
|
||
|
|
+ if(Debug::Instance().printDataInit())
|
||
|
|
+ std::cout << "Error: Unable to determine decompressed size for " << filename << std::endl;
|
||
|
|
+ return false;
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ std::vector<char> decompressedData(decompressedSize);
|
||
|
|
+ size_t dSize = ZSTD_decompress(decompressedData.data(), decompressedSize, compressedData.data(), compressedData.size());
|
||
|
|
+ if (ZSTD_isError(dSize)) {
|
||
|
|
+ if(Debug::Instance().printDataInit())
|
||
|
|
+ std::cout << "Error: ZSTD decompression failed for " << filename << std::endl;
|
||
|
|
+ return false;
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ msgpack::unpack(result, decompressedData.data(), dSize);
|
||
|
|
+ } else {
|
||
|
|
msgpack::unpacker unp;
|
||
|
|
bool finished_parsing;
|
||
|
|
constexpr size_t buffer_size = 1 << 19;
|
||
|
|
@@ -109,6 +139,7 @@ namespace Tensile
|
||
|
|
|
||
|
|
return nullptr;
|
||
|
|
}
|
||
|
|
+ }
|
||
|
|
}
|
||
|
|
catch(std::runtime_error const& exc)
|
||
|
|
{
|