Skip to content

Commit

Permalink
ONNX Inference (#180)
Browse files Browse the repository at this point in the history
Completed tasks:
- Ability to export model to ONNX format
- Integration with Inference module for ONNX models
- Unit test for ONNX Inference module
- Example notebook for ONNX Inference module
  • Loading branch information
PvtKaefsky authored Dec 24, 2024
1 parent 817a99d commit 94b2e74
Show file tree
Hide file tree
Showing 14 changed files with 2,149 additions and 17 deletions.
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ scipy = "~=1.10.1"
pymonad = "*"
distributed = "*"
dask = "*"
onnxruntime = "==1.19.2"

[packages]
pytorch-lifestream = {editable = true, path = "."}
Expand Down
87 changes: 83 additions & 4 deletions Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 9 additions & 3 deletions ptls/data_load/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,17 +390,23 @@ def padded_collate_emb_valid(batch):
'pos_distribution': pos_distribution}
return x.float(), y


def padded_collate_wo_target(batch):
def _collate(batch):
new_x_ = defaultdict(list)
for x in batch:
for k, v in x.items():
new_x_[k].append(v)

lengths = torch.IntTensor([len(e) for e in next(iter(new_x_.values()))])
new_x = {k: torch.nn.utils.rnn.pad_sequence(v, batch_first=True) for k, v in new_x_.items()}
return new_x, lengths

def padded_collate_wo_target(batch):
new_x, lengths = _collate(batch)
return PaddedBatch(new_x, lengths)


def collate_wo_target(batch):
new_x, lengths = _collate(batch)
return (new_x, lengths)

class ZeroDownSampler(Sampler):
def __init__(self, targets):
Expand Down
8 changes: 6 additions & 2 deletions ptls/data_load/datasets/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch.utils.data import DataLoader

from ptls.data_load import IterableChain, padded_collate_wo_target
from ptls.data_load import IterableChain, padded_collate_wo_target, collate_wo_target
from ptls.data_load.filter_dataset import FilterDataset
from ptls.data_load.iterable_processing import ToTorch, FilterNonArray, ISeqLenLimit

Expand All @@ -12,6 +12,7 @@ def inference_data_loader(
max_seq_len: int = 10000,
num_workers: int = 0,
batch_size: int = 512,
onnx = False
):
"""
Generate an inference data loader. The data loader will return a batch of sequences.
Expand All @@ -23,6 +24,7 @@ def inference_data_loader(
num_workers: the number of workers for the dataloader. Default: 0 - single-process loader.
batch_size: the batch size. Default: 512. The number of samples (before splitting to subsequences) in
each batch.
onnx: flag for ONNX export. Default: False
Returns:
DataLoader
Expand All @@ -37,9 +39,11 @@ def inference_data_loader(
)
)

collate_fn = collate_wo_target if onnx else padded_collate_wo_target

return DataLoader(
dataset=dataset,
collate_fn=padded_collate_wo_target,
collate_fn=collate_fn,
shuffle=False,
num_workers=num_workers,
batch_size=batch_size,
Expand Down
6 changes: 5 additions & 1 deletion ptls/frames/abs_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def __init__(self,

self._optimizer_partial = optimizer_partial
self._lr_scheduler_partial = lr_scheduler_partial
self._col_names = None
self._seq_len = None

@property
def metric_name(self):
Expand Down Expand Up @@ -67,7 +69,9 @@ def seq_encoder(self):
return self._seq_encoder

def forward(self, x):
return self._seq_encoder(x)
names = self._col_names
seq_len = self._seq_len
return self._seq_encoder(x, names, seq_len)

def training_step(self, batch, _):
y_h, y = self.shared_step(*batch)
Expand Down
4 changes: 3 additions & 1 deletion ptls/frames/coles/multimodal_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def multimodal_trx_encoder(self, x):
length = length + source_length
return res, length

def forward(self, x, **kwargs):
def forward(self, x, names=None, seq_len=None, **kwargs):
if names and seq_len is not None:
raise NotImplementedError
x, length = self.multimodal_trx_encoder(x)
x = self.merge_by_time(x)
padded_x = PaddedBatch(payload=x, length=length)
Expand Down
81 changes: 80 additions & 1 deletion ptls/frames/inference_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import pandas as pd
import pytorch_lightning as pl
import torch
import numpy as np
import onnxruntime as ort

from itertools import chain
from ptls.data_load.padded_batch import PaddedBatch
Expand Down Expand Up @@ -43,7 +45,7 @@ def to_pandas(self, x):
len_mask = v.seq_len_mask.bool().cpu().numpy()
v = v.payload
if type(v) is torch.Tensor:
v = v.cpu().numpy()
v = v.detach().cpu().numpy()
if type(v) is list or len(v.shape) == 1:
scalar_features[k] = v
elif k.startswith('target'):
Expand All @@ -69,6 +71,9 @@ def to_pandas(self, x):
out_df = pd.concat([out_df.reset_index(drop=True), df_expand], axis = 1)

return out_df

def to_numpy(self, tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

@staticmethod
def to_pandas_record(x, expand_features, scalar_features, seq_features, len_mask):
Expand Down Expand Up @@ -157,3 +162,77 @@ def to_pandas(x):
dataframes.append(pd.DataFrame(v, columns=[f'{col}_{i:04d}' for i in range(v.shape[1])]))

return pd.concat(dataframes, axis=1)

class ONNXInferenceModule(InferenceModule):
def __init__(self, model, dl, model_out_name='emb.onnx', pandas_output=False):
super().__init__(model)
self.model = model
self.pandas_output = pandas_output
self.model_out_name = model_out_name
self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

batch = next(iter(dl))
features, names, seq_len = self.preprocessing(batch)
model._col_names = names
model._seq_len = seq_len
model.example_input_array = features
self.export(self.model_out_name, model)

self.ort_session = ort.InferenceSession(
self.model_out_name,
providers=self.providers
)

def stack(self, x):
x = [v for v in x[0].values()]
return torch.stack(x)

def preprocessing(self, x):
features = self.stack(x)
names = [k for k in x[0].keys()]
seq_len = x[1]
return features, names, seq_len

def export(self,
path: str,
model
) -> None:

model.to_onnx(path,
export_params=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {
0: "features",
1: "batch_size",
2: "seq_len"
},
"output": {
0: "batch_size",
1: "hidden_size"
}
}
)

def forward(self, x, dtype: torch.dtype = torch.float16):
inputs = self.to_numpy(self.stack(x))
out = self.ort_session.run(None, {"input": inputs})
out = torch.tensor(out[0], dtype=dtype)
if self.pandas_output:
return self.to_pandas(out)
return out

def to(self, device):
return self

def size(self):
return os.path.getsize(self.model_name)

def predict(self, dl, dtype: torch.dtype = torch.float16):
pred = list()
with torch.no_grad():
for batch in dl:
output = self(batch, dtype=dtype)
pred.append(output)
return pred
7 changes: 7 additions & 0 deletions ptls/nn/seq_encoder/agg_feature_seq_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ def forward(self, x: PaddedBatch):
:return:
"""

if isinstance(x, PaddedBatch) is False:
pre_x = dict()
self._col_names = ["mcc_code", "tr_type", "amount"]
for i, field_name in enumerate(self._col_names):
pre_x[field_name] = x[i]
x = PaddedBatch(pre_x, self._seq_len)

feature_arrays = x.payload
device = x.device
B, T = x.seq_feature_shape
Expand Down
8 changes: 5 additions & 3 deletions ptls/nn/seq_encoder/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def category_names(self):
def embedding_size(self):
return self.seq_encoder.embedding_size

def forward(self, x: PaddedBatch):
def forward(self, x, names=None, seq_len=None):
if names and seq_len is not None:
raise NotImplementedError
x = self.trx_encoder(x)
x = self.seq_encoder(x)
return x
Expand Down Expand Up @@ -122,8 +124,8 @@ def __init__(self,
is_reduce_sequence=is_reduce_sequence,
)

def forward(self, x: PaddedBatch, h_0=None):
x = self.trx_encoder(x)
def forward(self, x, names=None, seq_len=None, h_0=None):
x = self.trx_encoder(x, names, seq_len)
x = self.seq_encoder(x, h_0)
return x

Expand Down
Loading

0 comments on commit 94b2e74

Please sign in to comment.