Skip to content

Commit

Permalink
refactor scalers.py and add tests for periodic and ple
Browse files Browse the repository at this point in the history
  • Loading branch information
Artem Sakhno committed Jan 27, 2025
1 parent 9561b6d commit 2c1673c
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 14 deletions.
33 changes: 19 additions & 14 deletions ptls/nn/trx_encoder/scalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ def output_size(self):

class Periodic(IdentityScaler):
'''
x -> [cos(cx), sin(cx)], c - (num_periods)-dimensional learnable vector
x -> [cos(cx), sin(cx)], c - (num_periods)-dimensional learnable vector initialized from N(0, param_dist_sigma)
From paper "On embeddings for numerical features in tabular deep learning"
'''
def __init__(self, num_periods = 8):
def __init__(self, num_periods = 8, param_dist_sigma = 1):
super().__init__()
self.num_periods = num_periods
self.c = torch.nn.Parameter(torch.randn(1,1, num_periods), requires_grad=True)
self.c = torch.nn.Parameter(torch.normal(0, param_dist_sigma, size=(1,1, num_periods)), requires_grad=True)

def forward(self, x):
x = super().forward(x)
Expand All @@ -124,16 +124,17 @@ def output_size(self):

class PeriodicMLP(IdentityScaler):
'''
x -> [cos(cx), sin(cx)], c - (num_periods)-dimensional learnable vector
Then Linear, Then ReLU
x -> [cos(cx), sin(cx)], c - (num_periods)-dimensional learnable vector initialized from N(0, param_dist_sigma)
Then Linear, then ReLU
From paper "On embeddings for numerical features in tabular deep learning"
'''
def __init__(self, num_periods = 8):
def __init__(self, num_periods = 8, param_dist_sigma = 1, mlp_output_size = -1):
super().__init__()
self.num_periods = num_periods
self.c = torch.nn.Parameter(torch.randn(1,1, num_periods), requires_grad=True)
self.mlp = nn.Linear(self.output_size, self.output_size)
self.mlp_output_size = mlp_output_size if mlp_output_size > 0 else 2 * self.num_periods
self.c = torch.nn.Parameter(torch.normal(0, param_dist_sigma, size=(1,1, num_periods)), requires_grad=True)
self.mlp = nn.Linear(2 * self.num_periods, self.mlp_output_size)
self.relu = nn.ReLU()

def forward(self, x):
Expand All @@ -143,26 +144,28 @@ def forward(self, x):
x = self.mlp(x)
x = self.relu(x)
return x

@property
def output_size(self):
return 2 * self.num_periods

return self.mlp_output_size

class PLE(IdentityScaler):
'''
x -> [1, 1,1 , ax, 0, 0, 0] based on bins
From paper "On embeddings for numerical features in tabular deep learning"
'''
def __init__(self, bins = [ 0, 1, 2, 3,]):
def __init__(self, bins = [-1, 0, 1]):
super().__init__()
self.size = len(bins) - 1
self.bins = torch.tensor([[bins,]])

def forward(self, x):
self.bins = self.bins.to(x.device)
x = super().forward(x)
x = (x - self.bins[:,:,:-1]) / (self.bins[:,:,1:] - self.bins[:,:,:-1])
x = x.clamp(0, 1)
return(x)

@property
def output_size(self):
return self.size
Expand All @@ -174,12 +177,14 @@ class PLE_MLP(IdentityScaler):
From paper "On embeddings for numerical features in tabular deep learning"
'''
def __init__(self, bins = [ 0, 1, 2, 3,]):
def __init__(self, bins = [-1, 0, 1], mlp_output_size = -1):
super().__init__()
self.size = len(bins) - 1
self.mlp_output_size = mlp_output_size if mlp_output_size > 0 else self.size
self.bins = torch.tensor([[bins,]])
self.mlp = nn.Linear(self.output_size, self.output_size)
self.mlp = nn.Linear(self.size, self.mlp_output_size)
self.relu = nn.ReLU()

def forward(self, x):
self.bins = self.bins.to(x.device)
x = super().forward(x)
Expand All @@ -190,7 +195,7 @@ def forward(self, x):
return(x)
@property
def output_size(self):
return self.size
return self.mlp_output_size


def scaler_by_name(name):
Expand Down
82 changes: 82 additions & 0 deletions ptls_tests/test_nn/test_trx_encoder/test_scalers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch

from ptls.data_load.padded_batch import PaddedBatch
from ptls.nn.trx_encoder import TrxEncoder

from ptls.nn.trx_encoder.scalers import Periodic, PeriodicMLP, PLE, PLE_MLP


def test_periodic():
B, T = 5, 20
num_periods = 4
scaler = Periodic(num_periods = num_periods, param_dist_sigma = 3)
trx_encoder = TrxEncoder(
numeric_values={'amount': scaler},
)
x = PaddedBatch(
payload={
'amount': torch.randn(B, T),
},
length=torch.randint(10, 20, (B,)),
)
z = trx_encoder(x)
assert z.payload.shape == (5, 20, 2 * num_periods) # B, T, H
assert trx_encoder.output_size == 2 * num_periods


def test_periodic_mlp():
B, T = 5, 20
num_periods = 4
mlp_output_size = 32
scaler = PeriodicMLP(num_periods = num_periods, param_dist_sigma = 3, mlp_output_size = mlp_output_size)
trx_encoder = TrxEncoder(
numeric_values={'amount': scaler},
)
x = PaddedBatch(
payload={
'amount': torch.randn(B, T),
},
length=torch.randint(10, 20, (B,)),
)
z = trx_encoder(x)
assert z.payload.shape == (5, 20, mlp_output_size) # B, T, H
assert trx_encoder.output_size == mlp_output_size



def test_ple():
B, T = 5, 20
bins = [-1, 0, 1]
scaler = PLE(bins = bins)
trx_encoder = TrxEncoder(
numeric_values={'amount': scaler},
)
x = PaddedBatch(
payload={
'amount': torch.randn(B, T),
},
length=torch.randint(10, 20, (B,)),
)
z = trx_encoder(x)
assert z.payload.shape == (5, 20, len(bins) - 1) # B, T, H
assert trx_encoder.output_size == len(bins) - 1


def test_ple_mlp():
B, T = 5, 20
bins = [-1, 0, 1]
mlp_output_size = 64
scaler = PLE_MLP(bins = bins, mlp_output_size = mlp_output_size)
trx_encoder = TrxEncoder(
numeric_values={'amount': scaler},
)
x = PaddedBatch(
payload={
'amount': torch.randn(B, T),
},
length=torch.randint(10, 20, (B,)),
)
z = trx_encoder(x)
assert z.payload.shape == (5, 20, mlp_output_size) # B, T, H
assert trx_encoder.output_size == mlp_output_size

0 comments on commit 2c1673c

Please sign in to comment.