diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/Engine.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/Engine.scala index adefe180c76..e0c880573b0 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/Engine.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/Engine.scala @@ -372,6 +372,7 @@ object Engine { // this thread and the omp threads forked from computing. if (engineType == MklDnn) { dnnComputing.setMKLThreadOfMklDnnBackend(MKL.getMklNumThreads) + _model.setMKLThreadOfMklDnnBackend(MKL.getMklNumThreads) } if (System.getProperty("multiThread", "false").toBoolean) { wrapperComputing.setMKLThread(1) diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/ThreadPool.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/ThreadPool.scala index 7a8eea87427..75b5e30f48b 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/ThreadPool.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/ThreadPool.scala @@ -91,7 +91,6 @@ class ThreadPool(private var poolSize: Int) { mklPoolSize = Some(size) (1 to poolSize).map(i => Future { MKL.setNumThreads(size) - BackendMklDnn.setNumThreads(size) val tid = Thread.currentThread().getId() logger.info(s"Set mkl threads to $size on thread $tid") }(context)).foreach(Await.result(_, Duration.Inf))