Skip to content

Commit

Permalink
fix: use too much memory of mkldnn models (#2782)
Browse files Browse the repository at this point in the history
The `realSize` of `DnnTensor` means the element number of this tensor. But mkldnn will return the bytes of memory, we should make it to the elements number.
  • Loading branch information
i8run authored Mar 25, 2019
1 parent 0c87be5 commit 1298c80
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.intel.analytics.bigdl.nn.mkldnn

import com.intel.analytics.bigdl.mkl._
import com.intel.analytics.bigdl.tensor.DnnStorage

sealed trait MemoryData extends Serializable {
def shape: Array[Int]
Expand Down Expand Up @@ -81,13 +82,23 @@ sealed trait MemoryData extends Serializable {

def getRealSize: Long = {
require(primitiveDesc != UNDEFINED && primitiveDesc != ERROR)
MklDnn.PrimitiveDescGetSize(primitiveDesc)
MklDnn.PrimitiveDescGetSize(primitiveDesc) / getDataTypeBytes
}

def getPaddingShape: Array[Int] = {
require(description != UNDEFINED && description != ERROR)
Memory.GetPaddingShape(description)
}

private def getDataTypeBytes: Int = {
dataType match {
case DataType.F32 => DnnStorage.FLOAT_BYTES
case DataType.S32 => DnnStorage.INT_BYTES
case DataType.S8 => DnnStorage.INT8_BYTES
case DataType.U8 => DnnStorage.INT8_BYTES
case _ => throw new UnsupportedOperationException(s"unsupported data type")
}
}
}

case class HeapData(private var _shape: Array[Int], private var _layout: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ private[bigdl] class Pointer(val address: Long)

object DnnStorage {
private[tensor] val CACHE_LINE_SIZE = System.getProperty("bigdl.cache.line", "64").toInt
private[tensor] val FLOAT_BYTES: Int = 4
private[tensor] val INT8_BYTES: Int = 1
private[tensor] val INT_BYTES: Int = 4
private[bigdl] val FLOAT_BYTES: Int = 4
private[bigdl] val INT8_BYTES: Int = 1
private[bigdl] val INT_BYTES: Int = 4

import java.util.concurrent.ConcurrentHashMap
private val nativeStorages: ConcurrentHashMap[Long, Boolean] = new ConcurrentHashMap()
Expand Down

0 comments on commit 1298c80

Please sign in to comment.