From d8af3ff7fca2a35a632c4e2ef7f99a47cb8bcee7 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 22 Sep 2023 14:51:05 +0500 Subject: [PATCH] Fix benchmarks --- examples/src/benchmark.js | 4 ++-- examples/src/benchmark_gpu.js | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/src/benchmark.js b/examples/src/benchmark.js index 3be9edd..88461d0 100644 --- a/examples/src/benchmark.js +++ b/examples/src/benchmark.js @@ -39,8 +39,8 @@ const Count = 2000; const trainData = Matrix.random_2d(Count, Sizes[0]); const singleTrainData = trainData.slice(0, 1); -const tfTrainData = tf.tensor(Matrix.copy_2d(trainData)); -const tfSingleData = tf.tensor(Matrix.copy_2d(singleTrainData)); +const tfTrainData = tf.tensor(trainData.map(t => Array.from(t))); +const tfSingleData = tf.tensor(singleTrainData.map(t => Array.from(t))); const brTrainData = trainData.map(d => ({input: d, output: d})); const brSingleData = brTrainData.slice(0, 1); diff --git a/examples/src/benchmark_gpu.js b/examples/src/benchmark_gpu.js index 87c98f5..ad59c1e 100644 --- a/examples/src/benchmark_gpu.js +++ b/examples/src/benchmark_gpu.js @@ -47,7 +47,7 @@ tfModel.compile({loss: "meanSquaredError", optimizer: "sgd"}); const trainData = Matrix.random_2d(Count, Sizes[0]); -const tfTrainData = tf.tensor(Matrix.copy_2d(trainData)); +const tfTrainData = tf.tensor(trainData.map(t => Array.from(t))); const brTrainData = trainData.map(d => ({input: d, output: d}));