Skip to content

Commit

Permalink
Merge pull request #561 from aviaIguazio/1.6.x-dev
Browse files Browse the repository at this point in the history
[Stocks] fix input size
  • Loading branch information
aviaIguazio authored Jan 15, 2024
2 parents 0b07b9e + 1c02ae8 commit ad8fd1d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions stocks-prediction/src/train_stocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ def handler(vector_name='stocks',
n_layers=1,
seq_size=5,
epochs=3,
input_size=11,
model_filepath=''
):
context = get_or_create_ctx(name='train-context')
dataset = StocksDataset(vector_name, seq_size, start_time, end_time)
training_set = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)
#input_size = dataset.data[0][0].shape[0]
input_size = dataset.data[0][0].shape[0]
context.logger.info("input size {}".format(input_size))
output_size = 1
# creating the model
model = Model(input_size=input_size, output_size=output_size, hidden_dim=hidden_dim, n_layers=n_layers,
Expand Down

0 comments on commit ad8fd1d

Please sign in to comment.