-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- LibTorch DLLs aren't copied to plugin's Binaries directory anymore - torchscript_wrapper.dll udpated to PyTorch Build 1.10.1 - copyrights updated
- Loading branch information
Showing
6 changed files
with
83 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
// VR IK Body Plugin | ||
// (c) Yuri N Kalinin, 2021, [email protected]. All right reserved. | ||
// (c) Yuri N Kalinin, 2021-2022, [email protected]. All right reserved. | ||
|
||
#include "SimplePyTorch.h" | ||
#include "HAL/PlatformProcess.h" | ||
|
@@ -14,27 +14,32 @@ void FSimplePyTorchModule::StartupModule() | |
FString FilePath; | ||
const FString szBinaries = TEXT("Binaries"); | ||
const FString szPlatform = TEXT("Win64"); | ||
|
||
#if WITH_EDITOR | ||
auto ThisPlugin = IPluginManager::Get().FindPlugin(TEXT("SimplePyTorch")); | ||
if (ThisPlugin.IsValid()) | ||
{ | ||
FilePath = FPaths::ConvertRelativePathToFull(ThisPlugin->GetBaseDir()); | ||
|
||
FString PluginBinariesDir = FilePath / TEXT("Source/ThirdParty/pytorch") / szBinaries / szPlatform; | ||
UE_LOG(LogTemp, Log, TEXT("PyTorch third-party dlls directory: %s"), *PluginBinariesDir); | ||
FPlatformProcess::PushDllDirectory(*PluginBinariesDir); | ||
FilePath = FilePath / TEXT("Source/ThirdParty/pytorch") / szBinaries / szPlatform; | ||
} | ||
else | ||
{ | ||
FilePath = FPaths::ProjectDir() / TEXT("Binaries/ThirdParty/PyTorch"); | ||
} | ||
#else | ||
FilePath = FPaths::ConvertRelativePathToFull(FPaths::ProjectDir()); | ||
#endif | ||
FilePath = FilePath / szBinaries / szPlatform / TEXT("torchscript_wrapper.dll"); | ||
FilePath = FPaths::ProjectDir() / TEXT("Binaries/ThirdParty/PyTorch"); | ||
#endif | ||
FPlatformProcess::PushDllDirectory(*FilePath); | ||
FilePath = FilePath / TEXT("torchscript_wrapper.dll"); | ||
|
||
WrapperDllHandle = NULL; | ||
bDllLoaded = false; | ||
|
||
#if PLATFORM_WINDOWS | ||
if (FPaths::FileExists(FilePath)) | ||
{ | ||
UE_LOG(LogTemp, Log, TEXT("SimplePyTorchModule: Loading torch wrapper from %s"), *FilePath); | ||
|
||
WrapperDllHandle = FPlatformProcess::GetDllHandle(*FilePath); | ||
|
||
if (WrapperDllHandle != NULL) | ||
|
@@ -91,6 +96,14 @@ void FSimplePyTorchModule::ShutdownModule() | |
if (WrapperDllHandle != NULL) | ||
{ | ||
FPlatformProcess::FreeDllHandle(WrapperDllHandle); | ||
WrapperDllHandle = NULL; | ||
bDllLoaded = false; | ||
FuncTSW_LoadScriptModel = NULL; | ||
FuncTSW_CheckModel = NULL; | ||
FuncTSW_Forward1d = NULL; | ||
FuncTSW_ForwardTensor = NULL; | ||
FuncTSW_ForwardPass_Def = NULL; | ||
FuncTSW_Execute_Def = NULL; | ||
} | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
// VR IK Body Plugin | ||
// (c) Yuri N Kalinin, 2021, [email protected]. All right reserved. | ||
// (c) Yuri N Kalinin, 2021-2022, [email protected]. All right reserved. | ||
|
||
#include "SimpleTorchModule.h" | ||
#include "SimplePyTorch.h" | ||
|
@@ -32,7 +32,9 @@ void USimpleTorchModule::BeginDestroy() | |
|
||
USimpleTorchModule* USimpleTorchModule::CreateSimpleTorchModule(UObject* InParent) | ||
{ | ||
return NewObject<USimpleTorchModule>(InParent); | ||
return InParent | ||
? NewObject<USimpleTorchModule>(InParent) | ||
: NewObject<USimpleTorchModule>(); | ||
} | ||
|
||
bool USimpleTorchModule::LoadTorchScriptModel(FString FileName) | ||
|
@@ -71,9 +73,9 @@ bool USimpleTorchModule::IsTorchModelLoaded() const | |
FSimplePyTorchModule& Module = FModuleManager::GetModuleChecked<FSimplePyTorchModule>(TEXT("SimplePyTorch")); | ||
|
||
bool bResult = false; | ||
if (Module.bDllLoaded) | ||
if (Module.bDllLoaded && Module.FuncTSW_CheckModel) | ||
{ | ||
return Module.FuncTSW_CheckModel(ModelId); | ||
bResult = Module.FuncTSW_CheckModel(ModelId); | ||
} | ||
|
||
return bResult; | ||
|
@@ -91,7 +93,7 @@ bool USimpleTorchModule::ExecuteModelMethod(const FString& MethodName, const FSi | |
bool bResult = false; | ||
if (Module.bDllLoaded && Buffer != NULL && OutData.IsDataOwner()) | ||
{ | ||
TArray<int> InDims = InData.GetDimensions(); | ||
TArray<int> InDims = InData.GetDimensions().Array(); | ||
|
||
float* pOutData = OutData.IsValid() | ||
? OutData.GetRawData() | ||
|
@@ -100,21 +102,25 @@ bool USimpleTorchModule::ExecuteModelMethod(const FString& MethodName, const FSi | |
int OutDimsCount = 0; | ||
if (MethodName == TEXT("forward")) | ||
{ | ||
if (Module.FuncTSW_ForwardPass_Def == NULL) return false; | ||
|
||
Module.FuncTSW_ForwardPass_Def(ModelId, InData.GetRawData(), InDims.GetData(), InDims.Num(), | ||
pOutData, BufferDims, &OutDimsCount); | ||
} | ||
else | ||
{ | ||
if (Module.FuncTSW_Execute_Def == NULL) return false; | ||
|
||
Module.FuncTSW_Execute_Def(ModelId, TCHAR_TO_ANSI(*MethodName), InData.GetRawData(), InDims.GetData(), InDims.Num(), | ||
pOutData, BufferDims, &OutDimsCount); | ||
} | ||
|
||
bResult = (OutDimsCount > 0); | ||
bResult = (OutDimsCount > 0) && pOutData != NULL && BufferDims != NULL; | ||
if (bResult) | ||
{ | ||
TArray<int32> OldOutDims = OutData.GetDimensions(); | ||
TArray<int32> OldOutDims = OutData.GetDimensions().Array(); | ||
bool bOutTensorMatches = (OutDimsCount == OutData.GetDimensions().Num()); | ||
TArray<int32> NewOutDims; | ||
TSet<int32> NewOutDims; | ||
|
||
int32 Length = 1; | ||
for (int i = 0; i < OutDimsCount; i++) | ||
|
@@ -197,14 +203,14 @@ void FSimpleTorchTensor::Cleanup() | |
} | ||
} | ||
|
||
int32 FSimpleTorchTensor::GetAddress(TArray<int32> Address) const | ||
int32 FSimpleTorchTensor::GetAddress(TSet<int32> Address) const | ||
{ | ||
if (!Data) | ||
{ | ||
return INDEX_NONE; | ||
} | ||
|
||
TArray<int32> AddrArray = Address; | ||
TArray<int32> AddrArray = Address.Array(); | ||
int32 Addr = 0; | ||
for (int32 i = 0; i < AddressMultipliersCache.Num(); i++) | ||
{ | ||
|
@@ -220,7 +226,7 @@ int32 FSimpleTorchTensor::GetAddress(TArray<int32> Address) const | |
return Addr; | ||
} | ||
|
||
bool FSimpleTorchTensor::Create(TArray<int32> TensorDimensions) | ||
bool FSimpleTorchTensor::Create(TSet<int32> TensorDimensions) | ||
{ | ||
if (Data) | ||
{ | ||
|
@@ -235,7 +241,7 @@ bool FSimpleTorchTensor::Create(TArray<int32> TensorDimensions) | |
|
||
if (DataSize == 0) return false; | ||
|
||
Dimensions = TensorDimensions; | ||
Dimensions = TensorDimensions.Array(); | ||
InitAddressSpace(); | ||
|
||
Data = new float[DataSize]; | ||
|
@@ -244,7 +250,7 @@ bool FSimpleTorchTensor::Create(TArray<int32> TensorDimensions) | |
return true; | ||
} | ||
|
||
bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TArray<int32> Address) | ||
bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TSet<int32> Address) | ||
{ | ||
bDataOwner = false; | ||
|
||
|
@@ -298,9 +304,9 @@ bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TArray<int32> | |
return true; | ||
} | ||
|
||
TArray<int32> FSimpleTorchTensor::GetDimensions() const | ||
TSet<int32> FSimpleTorchTensor::GetDimensions() const | ||
{ | ||
TArray<int32> t; | ||
TSet<int32> t; | ||
for (const auto& d : Dimensions) | ||
t.Add(d); | ||
|
||
|
@@ -331,13 +337,13 @@ float* FSimpleTorchTensor::GetRawData(int32* Size) const | |
} | ||
} | ||
|
||
float* FSimpleTorchTensor::GetCell(TArray<int32> Address) | ||
float* FSimpleTorchTensor::GetCell(TSet<int32> Address) | ||
{ | ||
int32 Addr = GetAddress(Address); | ||
return Addr == INDEX_NONE ? NULL : &Data[Addr]; | ||
} | ||
|
||
float FSimpleTorchTensor::GetValue(TArray<int32> Address) const | ||
float FSimpleTorchTensor::GetValue(TSet<int32> Address) const | ||
{ | ||
int32 Addr = GetAddress(Address); | ||
return Addr == INDEX_NONE ? 0 : Data[Addr]; | ||
|
@@ -373,7 +379,7 @@ bool FSimpleTorchTensor::FromArray(const TArray<float>& InData) | |
return false; | ||
} | ||
|
||
bool FSimpleTorchTensor::Reshape(TArray<int32> NewShape) | ||
bool FSimpleTorchTensor::Reshape(TSet<int32> NewShape) | ||
{ | ||
if (NewShape.Num() == 0) | ||
{ | ||
|
@@ -392,7 +398,7 @@ bool FSimpleTorchTensor::Reshape(TArray<int32> NewShape) | |
|
||
if (NewDataSize == DataSize) | ||
{ | ||
Dimensions = NewShape; | ||
Dimensions = NewShape.Array(); | ||
InitAddressSpace(); | ||
} | ||
else | ||
|
@@ -402,7 +408,7 @@ bool FSimpleTorchTensor::Reshape(TArray<int32> NewShape) | |
Cleanup(); | ||
|
||
DataSize = NewDataSize; | ||
Dimensions = NewShape; | ||
Dimensions = NewShape.Array(); | ||
Data = new float[DataSize]; | ||
|
||
InitAddressSpace(); | ||
|
@@ -424,7 +430,7 @@ FSimpleTorchTensor FSimpleTorchTensor::Detach() | |
return FSimpleTorchTensor(); | ||
} | ||
|
||
TArray<int32> Dims; | ||
TSet<int32> Dims; | ||
for (const auto& Val : Dimensions) Dims.Add(Val); | ||
|
||
FSimpleTorchTensor ret = FSimpleTorchTensor(Dims); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
// VR IK Body Plugin | ||
// (c) Yuri N Kalinin, 2021, [email protected]. All right reserved. | ||
// (c) Yuri N Kalinin, 2021-2022, [email protected]. All right reserved. | ||
|
||
#pragma once | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
// VR IK Body Plugin | ||
// (c) Yuri N Kalinin, 2021, [email protected]. All right reserved. | ||
// (c) Yuri N Kalinin, 2021-2022, [email protected]. All right reserved. | ||
|
||
#pragma once | ||
|
||
|
@@ -53,7 +53,7 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor | |
void InitAddressSpace(); | ||
|
||
/** Get flat address in Data from multidimensional address */ | ||
int32 GetAddress(TArray<int32> Address) const; | ||
int32 GetAddress(TSet<int32> Address) const; | ||
public: | ||
|
||
FSimpleTorchTensor() | ||
|
@@ -62,15 +62,15 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor | |
, DataSize(0) | ||
, ParentTensor(nullptr) | ||
{} | ||
FSimpleTorchTensor(FSimpleTorchTensor* Parent, TArray<int32> SubAddress) | ||
FSimpleTorchTensor(FSimpleTorchTensor* Parent, TSet<int32> SubAddress) | ||
: Data(NULL) | ||
, bDataOwner(true) | ||
, DataSize(0) | ||
, ParentTensor(nullptr) | ||
{ | ||
CreateAsChild(Parent, SubAddress); | ||
} | ||
FSimpleTorchTensor(TArray<int32> Dimensions) | ||
FSimpleTorchTensor(TSet<int32> Dimensions) | ||
: Data(NULL) | ||
, bDataOwner(true) | ||
, DataSize(0) | ||
|
@@ -87,10 +87,10 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor | |
void Cleanup(); | ||
|
||
// Set dimensions and allocate memory | ||
bool Create(TArray<int32> TensorDimensions); | ||
bool Create(TSet<int32> TensorDimensions); | ||
|
||
// Create tensor as a subtensor in another tensor (share the same memory) | ||
bool CreateAsChild(FSimpleTorchTensor* Parent, TArray<int32> Address); | ||
bool CreateAsChild(FSimpleTorchTensor* Parent, TSet<int32> Address); | ||
|
||
// Is tensor initialized? | ||
bool IsValid() const { return Data != NULL; } | ||
|
@@ -99,7 +99,7 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor | |
bool IsDataOwner() const { return bDataOwner; } | ||
|
||
// Get current tensor dimensions | ||
TArray<int32> GetDimensions() const; | ||
TSet<int32> GetDimensions() const; | ||
|
||
// Get number of itemes in flat array | ||
int32 GetDataSize() const { return DataSize; } | ||
|
@@ -112,11 +112,11 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor | |
float* GetRawData() const { return Data; } | ||
|
||
// Convert multidimensional address to flat address | ||
int32 GetRawAddress(TArray<int32> Address) const { return GetAddress(Address); } | ||
int32 GetRawAddress(TSet<int32> Address) const { return GetAddress(Address); } | ||
|
||
// Get reference to single float value with address | ||
float* GetCell(TArray<int32> Address); | ||
float GetValue(TArray<int32> Address) const; | ||
float* GetCell(TSet<int32> Address); | ||
float GetValue(TSet<int32> Address) const; | ||
|
||
/* Create float array. Only works for tensor with one dimension. | ||
* Ex: auto p = FSimpleTorchTensor({ 4, 12 }); | ||
|
@@ -139,7 +139,7 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor | |
|
||
// Change dimensions. | ||
// Only keeps data if new overall size is equal to old sizse | ||
bool Reshape(TArray<int32> NewShape); | ||
bool Reshape(TSet<int32> NewShape); | ||
|
||
// Create copy of this tensor | ||
FSimpleTorchTensor Detach(); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
// VR IK Body Plugin | ||
// (c) Yuri N Kalinin, 2021, [email protected]. All right reserved. | ||
// (c) Yuri N Kalinin, 2021-2022, [email protected]. All right reserved. | ||
|
||
using UnrealBuildTool; | ||
using System.IO; | ||
|
@@ -60,27 +60,28 @@ public SimplePyTorch(ReadOnlyTargetRules Target) : base(Target) | |
|
||
if (Target.Platform == UnrealTargetPlatform.Win64) | ||
{ | ||
// my PyTorch wrapper | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/torchscript_wrapper.dll", Path.Combine(TorchBinariesPath, "torchscript_wrapper.dll")); | ||
if (!Target.bBuildEditor) | ||
// LibTorch libraries | ||
string[] DLLs = new string[] | ||
{ | ||
// Copy DLLs to target packaged project | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/asmjit.dll", Path.Combine(TorchBinariesPath, "asmjit.dll")); | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/c10.dll", Path.Combine(TorchBinariesPath, "c10.dll")); | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/caffe2_detectron_ops.dll", Path.Combine(TorchBinariesPath, "caffe2_detectron_ops.dll")); | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/caffe2_module_test_dynamic.dll", Path.Combine(TorchBinariesPath, "caffe2_module_test_dynamic.dll")); | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/fbgemm.dll", Path.Combine(TorchBinariesPath, "fbgemm.dll")); | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/fbjni.dll", Path.Combine(TorchBinariesPath, "fbjni.dll")); | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/libiomp5md.dll", Path.Combine(TorchBinariesPath, "libiomp5md.dll")); | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/libiompstubs5md.dll", Path.Combine(TorchBinariesPath, "libiompstubs5md.dll")); | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/pytorch_jni.dll", Path.Combine(TorchBinariesPath, "pytorch_jni.dll")); | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/torch.dll", Path.Combine(TorchBinariesPath, "torch.dll")); | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/torch_cpu.dll", Path.Combine(TorchBinariesPath, "torch_cpu.dll")); | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/torch_global_deps.dll", Path.Combine(TorchBinariesPath, "torch_global_deps.dll")); | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/uv.dll", Path.Combine(TorchBinariesPath, "uv.dll")); | ||
"asmjit.dll", "c10.dll", "caffe2_detectron_ops.dll", "caffe2_module_test_dynamic.dll", "fbgemm.dll", "fbjni.dll", "libiomp5md.dll", | ||
"libiompstubs5md.dll", "pytorch_jni.dll", "torch.dll", "torch_cpu.dll", "torch_global_deps.dll", "uv.dll" | ||
}; | ||
|
||
// copy all DLLs to the packaged build | ||
if (!Target.bBuildEditor && Target.Type == TargetType.Game) | ||
{ | ||
string DllTargetDir = "$(ProjectDir)/Binaries/ThirdParty/PyTorch/"; | ||
foreach (string DllName in DLLs) | ||
{ | ||
PublicDelayLoadDLLs.Add(DllName); | ||
RuntimeDependencies.Add(Path.Combine(DllTargetDir, DllName), Path.Combine(TorchBinariesPath, DllName)); | ||
} | ||
|
||
// my PyTorch wrapper is loaded dynamically | ||
RuntimeDependencies.Add(Path.Combine(DllTargetDir, "torchscript_wrapper.dll"), Path.Combine(TorchBinariesPath, "torchscript_wrapper.dll")); | ||
// licenses | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/LICENSE.txt", Path.Combine(TorchPath, "LICENSE.txt")); | ||
RuntimeDependencies.Add("$(BinaryOutputDir)/NOTICE.txt", Path.Combine(TorchPath, "NOTICE.txt")); | ||
RuntimeDependencies.Add(Path.Combine(DllTargetDir, "LICENSE.txt"), Path.Combine(TorchPath, "LICENSE.txt"), StagedFileType.NonUFS); | ||
RuntimeDependencies.Add(Path.Combine(DllTargetDir, "NOTICE.txt"), Path.Combine(TorchPath, "NOTICE.txt"), StagedFileType.NonUFS); | ||
} | ||
} | ||
} | ||
|
Binary file modified
BIN
+23.5 KB
(130%)
Source/ThirdParty/pytorch/Binaries/Win64/torchscript_wrapper.dll
Binary file not shown.