Skip to content

Commit

Permalink
Third-party dlls management
Browse files Browse the repository at this point in the history
- LibTorch DLLs aren't copied to plugin's Binaries directory anymore
- torchscript_wrapper.dll udpated to PyTorch Build 1.10.1
- copyrights updated
  • Loading branch information
AntiAnti authored Jan 4, 2022
1 parent db0036c commit e5bde76
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 63 deletions.
33 changes: 23 additions & 10 deletions Source/SimplePyTorch/Private/SimplePyTorch.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// VR IK Body Plugin
// (c) Yuri N Kalinin, 2021, [email protected]. All right reserved.
// (c) Yuri N Kalinin, 2021-2022, [email protected]. All right reserved.

#include "SimplePyTorch.h"
#include "HAL/PlatformProcess.h"
Expand All @@ -14,27 +14,32 @@ void FSimplePyTorchModule::StartupModule()
FString FilePath;
const FString szBinaries = TEXT("Binaries");
const FString szPlatform = TEXT("Win64");

#if WITH_EDITOR
auto ThisPlugin = IPluginManager::Get().FindPlugin(TEXT("SimplePyTorch"));
if (ThisPlugin.IsValid())
{
FilePath = FPaths::ConvertRelativePathToFull(ThisPlugin->GetBaseDir());

FString PluginBinariesDir = FilePath / TEXT("Source/ThirdParty/pytorch") / szBinaries / szPlatform;
UE_LOG(LogTemp, Log, TEXT("PyTorch third-party dlls directory: %s"), *PluginBinariesDir);
FPlatformProcess::PushDllDirectory(*PluginBinariesDir);
FilePath = FilePath / TEXT("Source/ThirdParty/pytorch") / szBinaries / szPlatform;
}
else
{
FilePath = FPaths::ProjectDir() / TEXT("Binaries/ThirdParty/PyTorch");
}
#else
FilePath = FPaths::ConvertRelativePathToFull(FPaths::ProjectDir());
#endif
FilePath = FilePath / szBinaries / szPlatform / TEXT("torchscript_wrapper.dll");
FilePath = FPaths::ProjectDir() / TEXT("Binaries/ThirdParty/PyTorch");
#endif
FPlatformProcess::PushDllDirectory(*FilePath);
FilePath = FilePath / TEXT("torchscript_wrapper.dll");

WrapperDllHandle = NULL;
bDllLoaded = false;

#if PLATFORM_WINDOWS
if (FPaths::FileExists(FilePath))
{
UE_LOG(LogTemp, Log, TEXT("SimplePyTorchModule: Loading torch wrapper from %s"), *FilePath);

WrapperDllHandle = FPlatformProcess::GetDllHandle(*FilePath);

if (WrapperDllHandle != NULL)
Expand Down Expand Up @@ -91,6 +96,14 @@ void FSimplePyTorchModule::ShutdownModule()
if (WrapperDllHandle != NULL)
{
FPlatformProcess::FreeDllHandle(WrapperDllHandle);
WrapperDllHandle = NULL;
bDllLoaded = false;
FuncTSW_LoadScriptModel = NULL;
FuncTSW_CheckModel = NULL;
FuncTSW_Forward1d = NULL;
FuncTSW_ForwardTensor = NULL;
FuncTSW_ForwardPass_Def = NULL;
FuncTSW_Execute_Def = NULL;
}
}

Expand Down
48 changes: 27 additions & 21 deletions Source/SimplePyTorch/Private/SimpleTorchModule.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// VR IK Body Plugin
// (c) Yuri N Kalinin, 2021, [email protected]. All right reserved.
// (c) Yuri N Kalinin, 2021-2022, [email protected]. All right reserved.

#include "SimpleTorchModule.h"
#include "SimplePyTorch.h"
Expand Down Expand Up @@ -32,7 +32,9 @@ void USimpleTorchModule::BeginDestroy()

USimpleTorchModule* USimpleTorchModule::CreateSimpleTorchModule(UObject* InParent)
{
return NewObject<USimpleTorchModule>(InParent);
return InParent
? NewObject<USimpleTorchModule>(InParent)
: NewObject<USimpleTorchModule>();
}

bool USimpleTorchModule::LoadTorchScriptModel(FString FileName)
Expand Down Expand Up @@ -71,9 +73,9 @@ bool USimpleTorchModule::IsTorchModelLoaded() const
FSimplePyTorchModule& Module = FModuleManager::GetModuleChecked<FSimplePyTorchModule>(TEXT("SimplePyTorch"));

bool bResult = false;
if (Module.bDllLoaded)
if (Module.bDllLoaded && Module.FuncTSW_CheckModel)
{
return Module.FuncTSW_CheckModel(ModelId);
bResult = Module.FuncTSW_CheckModel(ModelId);
}

return bResult;
Expand All @@ -91,7 +93,7 @@ bool USimpleTorchModule::ExecuteModelMethod(const FString& MethodName, const FSi
bool bResult = false;
if (Module.bDllLoaded && Buffer != NULL && OutData.IsDataOwner())
{
TArray<int> InDims = InData.GetDimensions();
TArray<int> InDims = InData.GetDimensions().Array();

float* pOutData = OutData.IsValid()
? OutData.GetRawData()
Expand All @@ -100,21 +102,25 @@ bool USimpleTorchModule::ExecuteModelMethod(const FString& MethodName, const FSi
int OutDimsCount = 0;
if (MethodName == TEXT("forward"))
{
if (Module.FuncTSW_ForwardPass_Def == NULL) return false;

Module.FuncTSW_ForwardPass_Def(ModelId, InData.GetRawData(), InDims.GetData(), InDims.Num(),
pOutData, BufferDims, &OutDimsCount);
}
else
{
if (Module.FuncTSW_Execute_Def == NULL) return false;

Module.FuncTSW_Execute_Def(ModelId, TCHAR_TO_ANSI(*MethodName), InData.GetRawData(), InDims.GetData(), InDims.Num(),
pOutData, BufferDims, &OutDimsCount);
}

bResult = (OutDimsCount > 0);
bResult = (OutDimsCount > 0) && pOutData != NULL && BufferDims != NULL;
if (bResult)
{
TArray<int32> OldOutDims = OutData.GetDimensions();
TArray<int32> OldOutDims = OutData.GetDimensions().Array();
bool bOutTensorMatches = (OutDimsCount == OutData.GetDimensions().Num());
TArray<int32> NewOutDims;
TSet<int32> NewOutDims;

int32 Length = 1;
for (int i = 0; i < OutDimsCount; i++)
Expand Down Expand Up @@ -197,14 +203,14 @@ void FSimpleTorchTensor::Cleanup()
}
}

int32 FSimpleTorchTensor::GetAddress(TArray<int32> Address) const
int32 FSimpleTorchTensor::GetAddress(TSet<int32> Address) const
{
if (!Data)
{
return INDEX_NONE;
}

TArray<int32> AddrArray = Address;
TArray<int32> AddrArray = Address.Array();
int32 Addr = 0;
for (int32 i = 0; i < AddressMultipliersCache.Num(); i++)
{
Expand All @@ -220,7 +226,7 @@ int32 FSimpleTorchTensor::GetAddress(TArray<int32> Address) const
return Addr;
}

bool FSimpleTorchTensor::Create(TArray<int32> TensorDimensions)
bool FSimpleTorchTensor::Create(TSet<int32> TensorDimensions)
{
if (Data)
{
Expand All @@ -235,7 +241,7 @@ bool FSimpleTorchTensor::Create(TArray<int32> TensorDimensions)

if (DataSize == 0) return false;

Dimensions = TensorDimensions;
Dimensions = TensorDimensions.Array();
InitAddressSpace();

Data = new float[DataSize];
Expand All @@ -244,7 +250,7 @@ bool FSimpleTorchTensor::Create(TArray<int32> TensorDimensions)
return true;
}

bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TArray<int32> Address)
bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TSet<int32> Address)
{
bDataOwner = false;

Expand Down Expand Up @@ -298,9 +304,9 @@ bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TArray<int32>
return true;
}

TArray<int32> FSimpleTorchTensor::GetDimensions() const
TSet<int32> FSimpleTorchTensor::GetDimensions() const
{
TArray<int32> t;
TSet<int32> t;
for (const auto& d : Dimensions)
t.Add(d);

Expand Down Expand Up @@ -331,13 +337,13 @@ float* FSimpleTorchTensor::GetRawData(int32* Size) const
}
}

float* FSimpleTorchTensor::GetCell(TArray<int32> Address)
float* FSimpleTorchTensor::GetCell(TSet<int32> Address)
{
int32 Addr = GetAddress(Address);
return Addr == INDEX_NONE ? NULL : &Data[Addr];
}

float FSimpleTorchTensor::GetValue(TArray<int32> Address) const
float FSimpleTorchTensor::GetValue(TSet<int32> Address) const
{
int32 Addr = GetAddress(Address);
return Addr == INDEX_NONE ? 0 : Data[Addr];
Expand Down Expand Up @@ -373,7 +379,7 @@ bool FSimpleTorchTensor::FromArray(const TArray<float>& InData)
return false;
}

bool FSimpleTorchTensor::Reshape(TArray<int32> NewShape)
bool FSimpleTorchTensor::Reshape(TSet<int32> NewShape)
{
if (NewShape.Num() == 0)
{
Expand All @@ -392,7 +398,7 @@ bool FSimpleTorchTensor::Reshape(TArray<int32> NewShape)

if (NewDataSize == DataSize)
{
Dimensions = NewShape;
Dimensions = NewShape.Array();
InitAddressSpace();
}
else
Expand All @@ -402,7 +408,7 @@ bool FSimpleTorchTensor::Reshape(TArray<int32> NewShape)
Cleanup();

DataSize = NewDataSize;
Dimensions = NewShape;
Dimensions = NewShape.Array();
Data = new float[DataSize];

InitAddressSpace();
Expand All @@ -424,7 +430,7 @@ FSimpleTorchTensor FSimpleTorchTensor::Detach()
return FSimpleTorchTensor();
}

TArray<int32> Dims;
TSet<int32> Dims;
for (const auto& Val : Dimensions) Dims.Add(Val);

FSimpleTorchTensor ret = FSimpleTorchTensor(Dims);
Expand Down
2 changes: 1 addition & 1 deletion Source/SimplePyTorch/Public/SimplePyTorch.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// VR IK Body Plugin
// (c) Yuri N Kalinin, 2021, [email protected]. All right reserved.
// (c) Yuri N Kalinin, 2021-2022, [email protected]. All right reserved.

#pragma once

Expand Down
22 changes: 11 additions & 11 deletions Source/SimplePyTorch/Public/SimpleTorchModule.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// VR IK Body Plugin
// (c) Yuri N Kalinin, 2021, [email protected]. All right reserved.
// (c) Yuri N Kalinin, 2021-2022, [email protected]. All right reserved.

#pragma once

Expand Down Expand Up @@ -53,7 +53,7 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
void InitAddressSpace();

/** Get flat address in Data from multidimensional address */
int32 GetAddress(TArray<int32> Address) const;
int32 GetAddress(TSet<int32> Address) const;
public:

FSimpleTorchTensor()
Expand All @@ -62,15 +62,15 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
, DataSize(0)
, ParentTensor(nullptr)
{}
FSimpleTorchTensor(FSimpleTorchTensor* Parent, TArray<int32> SubAddress)
FSimpleTorchTensor(FSimpleTorchTensor* Parent, TSet<int32> SubAddress)
: Data(NULL)
, bDataOwner(true)
, DataSize(0)
, ParentTensor(nullptr)
{
CreateAsChild(Parent, SubAddress);
}
FSimpleTorchTensor(TArray<int32> Dimensions)
FSimpleTorchTensor(TSet<int32> Dimensions)
: Data(NULL)
, bDataOwner(true)
, DataSize(0)
Expand All @@ -87,10 +87,10 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
void Cleanup();

// Set dimensions and allocate memory
bool Create(TArray<int32> TensorDimensions);
bool Create(TSet<int32> TensorDimensions);

// Create tensor as a subtensor in another tensor (share the same memory)
bool CreateAsChild(FSimpleTorchTensor* Parent, TArray<int32> Address);
bool CreateAsChild(FSimpleTorchTensor* Parent, TSet<int32> Address);

// Is tensor initialized?
bool IsValid() const { return Data != NULL; }
Expand All @@ -99,7 +99,7 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
bool IsDataOwner() const { return bDataOwner; }

// Get current tensor dimensions
TArray<int32> GetDimensions() const;
TSet<int32> GetDimensions() const;

// Get number of itemes in flat array
int32 GetDataSize() const { return DataSize; }
Expand All @@ -112,11 +112,11 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
float* GetRawData() const { return Data; }

// Convert multidimensional address to flat address
int32 GetRawAddress(TArray<int32> Address) const { return GetAddress(Address); }
int32 GetRawAddress(TSet<int32> Address) const { return GetAddress(Address); }

// Get reference to single float value with address
float* GetCell(TArray<int32> Address);
float GetValue(TArray<int32> Address) const;
float* GetCell(TSet<int32> Address);
float GetValue(TSet<int32> Address) const;

/* Create float array. Only works for tensor with one dimension.
* Ex: auto p = FSimpleTorchTensor({ 4, 12 });
Expand All @@ -139,7 +139,7 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor

// Change dimensions.
// Only keeps data if new overall size is equal to old sizse
bool Reshape(TArray<int32> NewShape);
bool Reshape(TSet<int32> NewShape);

// Create copy of this tensor
FSimpleTorchTensor Detach();
Expand Down
41 changes: 21 additions & 20 deletions Source/SimplePyTorch/SimplePyTorch.Build.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// VR IK Body Plugin
// (c) Yuri N Kalinin, 2021, [email protected]. All right reserved.
// (c) Yuri N Kalinin, 2021-2022, [email protected]. All right reserved.

using UnrealBuildTool;
using System.IO;
Expand Down Expand Up @@ -60,27 +60,28 @@ public SimplePyTorch(ReadOnlyTargetRules Target) : base(Target)

if (Target.Platform == UnrealTargetPlatform.Win64)
{
// my PyTorch wrapper
RuntimeDependencies.Add("$(BinaryOutputDir)/torchscript_wrapper.dll", Path.Combine(TorchBinariesPath, "torchscript_wrapper.dll"));
if (!Target.bBuildEditor)
// LibTorch libraries
string[] DLLs = new string[]
{
// Copy DLLs to target packaged project
RuntimeDependencies.Add("$(BinaryOutputDir)/asmjit.dll", Path.Combine(TorchBinariesPath, "asmjit.dll"));
RuntimeDependencies.Add("$(BinaryOutputDir)/c10.dll", Path.Combine(TorchBinariesPath, "c10.dll"));
RuntimeDependencies.Add("$(BinaryOutputDir)/caffe2_detectron_ops.dll", Path.Combine(TorchBinariesPath, "caffe2_detectron_ops.dll"));
RuntimeDependencies.Add("$(BinaryOutputDir)/caffe2_module_test_dynamic.dll", Path.Combine(TorchBinariesPath, "caffe2_module_test_dynamic.dll"));
RuntimeDependencies.Add("$(BinaryOutputDir)/fbgemm.dll", Path.Combine(TorchBinariesPath, "fbgemm.dll"));
RuntimeDependencies.Add("$(BinaryOutputDir)/fbjni.dll", Path.Combine(TorchBinariesPath, "fbjni.dll"));
RuntimeDependencies.Add("$(BinaryOutputDir)/libiomp5md.dll", Path.Combine(TorchBinariesPath, "libiomp5md.dll"));
RuntimeDependencies.Add("$(BinaryOutputDir)/libiompstubs5md.dll", Path.Combine(TorchBinariesPath, "libiompstubs5md.dll"));
RuntimeDependencies.Add("$(BinaryOutputDir)/pytorch_jni.dll", Path.Combine(TorchBinariesPath, "pytorch_jni.dll"));
RuntimeDependencies.Add("$(BinaryOutputDir)/torch.dll", Path.Combine(TorchBinariesPath, "torch.dll"));
RuntimeDependencies.Add("$(BinaryOutputDir)/torch_cpu.dll", Path.Combine(TorchBinariesPath, "torch_cpu.dll"));
RuntimeDependencies.Add("$(BinaryOutputDir)/torch_global_deps.dll", Path.Combine(TorchBinariesPath, "torch_global_deps.dll"));
RuntimeDependencies.Add("$(BinaryOutputDir)/uv.dll", Path.Combine(TorchBinariesPath, "uv.dll"));
"asmjit.dll", "c10.dll", "caffe2_detectron_ops.dll", "caffe2_module_test_dynamic.dll", "fbgemm.dll", "fbjni.dll", "libiomp5md.dll",
"libiompstubs5md.dll", "pytorch_jni.dll", "torch.dll", "torch_cpu.dll", "torch_global_deps.dll", "uv.dll"
};

// copy all DLLs to the packaged build
if (!Target.bBuildEditor && Target.Type == TargetType.Game)
{
string DllTargetDir = "$(ProjectDir)/Binaries/ThirdParty/PyTorch/";
foreach (string DllName in DLLs)
{
PublicDelayLoadDLLs.Add(DllName);
RuntimeDependencies.Add(Path.Combine(DllTargetDir, DllName), Path.Combine(TorchBinariesPath, DllName));
}

// my PyTorch wrapper is loaded dynamically
RuntimeDependencies.Add(Path.Combine(DllTargetDir, "torchscript_wrapper.dll"), Path.Combine(TorchBinariesPath, "torchscript_wrapper.dll"));
// licenses
RuntimeDependencies.Add("$(BinaryOutputDir)/LICENSE.txt", Path.Combine(TorchPath, "LICENSE.txt"));
RuntimeDependencies.Add("$(BinaryOutputDir)/NOTICE.txt", Path.Combine(TorchPath, "NOTICE.txt"));
RuntimeDependencies.Add(Path.Combine(DllTargetDir, "LICENSE.txt"), Path.Combine(TorchPath, "LICENSE.txt"), StagedFileType.NonUFS);
RuntimeDependencies.Add(Path.Combine(DllTargetDir, "NOTICE.txt"), Path.Combine(TorchPath, "NOTICE.txt"), StagedFileType.NonUFS);
}
}
}
Expand Down
Binary file modified Source/ThirdParty/pytorch/Binaries/Win64/torchscript_wrapper.dll
Binary file not shown.

0 comments on commit e5bde76

Please sign in to comment.