Skip to content

Commit

Permalink
Fix benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
DrA1ex committed Sep 22, 2023
1 parent b7d8ec0 commit d8af3ff
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions examples/src/benchmark.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ const Count = 2000;
const trainData = Matrix.random_2d(Count, Sizes[0]);
const singleTrainData = trainData.slice(0, 1);

const tfTrainData = tf.tensor(Matrix.copy_2d(trainData));
const tfSingleData = tf.tensor(Matrix.copy_2d(singleTrainData));
const tfTrainData = tf.tensor(trainData.map(t => Array.from(t)));
const tfSingleData = tf.tensor(singleTrainData.map(t => Array.from(t)));

const brTrainData = trainData.map(d => ({input: d, output: d}));
const brSingleData = brTrainData.slice(0, 1);
Expand Down
2 changes: 1 addition & 1 deletion examples/src/benchmark_gpu.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ tfModel.compile({loss: "meanSquaredError", optimizer: "sgd"});

const trainData = Matrix.random_2d(Count, Sizes[0]);

const tfTrainData = tf.tensor(Matrix.copy_2d(trainData));
const tfTrainData = tf.tensor(trainData.map(t => Array.from(t)));

const brTrainData = trainData.map(d => ({input: d, output: d}));

Expand Down

0 comments on commit d8af3ff

Please sign in to comment.