diff --git a/math/src/main/codegen/breeze/linalg/operators/CSCMatrixOps.scala b/math/src/main/codegen/breeze/linalg/operators/CSCMatrixOps.scala index b1f2bc12f..8f48472ce 100644 --- a/math/src/main/codegen/breeze/linalg/operators/CSCMatrixOps.scala +++ b/math/src/main/codegen/breeze/linalg/operators/CSCMatrixOps.scala @@ -56,6 +56,22 @@ trait CSCMatrixOps extends CSCMatrixOps_Ring { this: CSCMatrix.type => } } + implicit def canMulDVt_CSC_eq_DVt[T]( + implicit op: OpMulMatrix.Impl2[DenseMatrix[T], CSCMatrix[T], DenseMatrix[T]], + zero: Zero[T], + ct: ClassTag[T]): OpMulMatrix.Impl2[Transpose[DenseVector[T]], CSCMatrix[T], Transpose[DenseVector[T]]] = + new OpMulMatrix.Impl2[Transpose[DenseVector[T]], CSCMatrix[T], Transpose[DenseVector[T]]] { + def apply(v: Transpose[DenseVector[T]], v2: CSCMatrix[T]): Transpose[DenseVector[T]] = { + require(v2.rows == v.inner.length) + + val dm = v.inner.toDenseMatrix + val multiplied = op(dm, v2) + + new Transpose[DenseVector[T]](multiplied.toDenseVector) + + } + } + @expand @expand.valify implicit def csc_OpNeg[@expand.args(Int, Double, Float, Long) T]: OpNeg.Impl[CSCMatrix[T], CSCMatrix[T]] = { diff --git a/math/src/main/scala/breeze/linalg/DenseVector.scala b/math/src/main/scala/breeze/linalg/DenseVector.scala index 4172c03be..4eec7be4f 100644 --- a/math/src/main/scala/breeze/linalg/DenseVector.scala +++ b/math/src/main/scala/breeze/linalg/DenseVector.scala @@ -241,6 +241,14 @@ class DenseVector[@spec(Double, Int, Float, Long) V]( def toScalaVector()(implicit cm: ClassTag[V]): scala.Vector[V] = this.toArray.toVector // + def asCscRow(implicit man: ClassTag[V], zero: Zero[V]): CSCMatrix[V] = { + if (length == 0) + CSCMatrix.zeros[V](1, length) + else { + CSCMatrix.create(1, length, data) + } + } + @throws(classOf[ObjectStreamException]) protected def writeReplace(): Object = { new DenseVector.SerializedForm(data, offset, stride, length) @@ -402,6 +410,23 @@ object DenseVector result } + def fromCSCMatrix[V: ClassTag](csc: CSCMatrix[V]): DenseVector[V] = { + assert(csc.rows == 1) + val indices = new Array[V](csc.cols) + var i = 0 + while (i < csc.cols) { + var j = csc.colPtrs(i) + while (j < csc.colPtrs(i + 1)) { + indices(i) = csc.data(j) + j += 1 + } + i += 1 + } + + DenseVector.create(indices, 0 ,1, csc.cols) + + } + // capabilities implicit def canCreateZerosLike[V: ClassTag: Zero]: CanCreateZerosLike[DenseVector[V], DenseVector[V]] = diff --git a/math/src/test/scala/breeze/linalg/DenseVectorTest.scala b/math/src/test/scala/breeze/linalg/DenseVectorTest.scala index 1f58b819d..ee10e31f2 100644 --- a/math/src/test/scala/breeze/linalg/DenseVectorTest.scala +++ b/math/src/test/scala/breeze/linalg/DenseVectorTest.scala @@ -589,6 +589,23 @@ class DenseVectorTest extends FunSuite with Checkers { assert(fromNew === slice) } } + + test("#715 - transpose DV * CSCMatrix") { + val dv = DenseVector(1,2,3,4) + + val csc = CSCMatrix.zeros[Int](4, 4) + csc(1, 1) = 1 + csc(1, 2) = 2 + csc(2, 1) = 2 + csc(2, 2) = 4 + + val multiplied = dv.t * csc + + val expected = DenseVector(0,8,16,0).t + + assert(multiplied == expected) + } + } abstract class DenseVectorPropertyTestBase[T: ClassTag] extends TensorSpaceTestBase[DenseVector[T], Int, T] {