Skip to content

Commit

Permalink
Merge pull request #23 from zhxiaogg/feature/add_sample_stage
Browse files Browse the repository at this point in the history
WIP: add sample stage
  • Loading branch information
ktoso authored Aug 29, 2016
2 parents c8685c3 + ea0fa9f commit a2602a8
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 0 deletions.
74 changes: 74 additions & 0 deletions src/main/scala/akka/stream/contrib/Sample.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package akka.stream.contrib

import java.util.Random
import java.util.concurrent.ThreadLocalRandom

import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler}
import akka.stream.{Attributes, FlowShape, Inlet, Outlet}


object Sample {
/**
*
* returns every nth elements
*
* @param nth must > 0
* @tparam T
* @return
*/
def apply[T](nth: Int): Sample[T] = Sample[T](() => nth)

/**
*
* randomly sampling on a stream
*
* @param maxStep must > 0, default 1000, the randomly step will be between 1 (inclusive) and maxStep (inclusive)
* @tparam T
* @return
*/
def random[T](maxStep: Int = 1000): Sample[T] = {
require(maxStep > 0, "max step for a random sampling must > 0")
Sample[T](() => ThreadLocalRandom.current().nextInt(maxStep) + 1)
}
}


/**
* supports sampling on stream
*
* @param next a lambda returns next sample position
* @tparam T
*/
case class Sample[T](next: () => Int) extends GraphStage[FlowShape[T, T]] {
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
var step = getNextStep()
var counter = 0

def onPull(): Unit = {
pull(in)
}

def onPush(): Unit = {
counter += 1
if (counter >= step) {
counter = 0
step = getNextStep()
push(out, grab(in))
} else {
pull(in)
}
}

private def getNextStep(): Long = {
val nextStep = next()
require(nextStep > 0, s"sampling step should be a positive value: ${nextStep}")
nextStep
}

setHandlers(in, out, this)
}

val in = Inlet[T]("Sample-in")
val out = Outlet[T]("Sample-out")
override val shape = FlowShape(in, out)
}
62 changes: 62 additions & 0 deletions src/test/scala/akka/stream/contrib/SampleSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package akka.stream.contrib

import akka.actor.ActorSystem
import akka.stream._
import akka.stream.scaladsl.{Sink, Source}
import org.scalatest.{Matchers, WordSpec}

import scala.concurrent.Await
import scala.concurrent.duration._

class SampleSpec extends WordSpec with Matchers {
private implicit val system = ActorSystem("SampleTest")
private implicit val materializer = ActorMaterializer()

"Sample Stage" should {
"returns every Nth element in stream" in {
val list = 1 to 1000
val source = Source.fromIterator[Int](() => list.toIterator)

for (n <- 1 to 100) {
val future = source.via(Sample(n)).runWith(Sink.seq)
val expected = list.filter(_ % n == 0).toList

Await.result(future, 3 seconds) should ===(expected)
}
}

"returns elements randomly" in {
// a fake random, increase by 1 for every invocation result
var num = 0
val mockRandom = () => {
num += 1
num
}

val future = Source.fromIterator[Int](() => (1 to 10).toIterator)
.via(Sample(mockRandom))
.runWith(Sink.seq)

Await.result(future, 3 seconds) should ===((1 :: 3 :: 6 :: 10 :: Nil))
}

"throws exception when next step <= 0" in {
intercept[IllegalArgumentException] {
Await.result(Source.empty.via(Sample(() => 0)).runWith(Sink.seq), 3 seconds)
}

intercept[IllegalArgumentException] {
Await.result(Source.empty.via(Sample(() => -1)).runWith(Sink.seq), 3 seconds)
}
}

"throws exceptions when max random step <= 0" in {
intercept[IllegalArgumentException] {
Await.result(Source.empty.via(Sample.random(0)).runWith(Sink.seq), 3 seconds)
}
}
}
}



0 comments on commit a2602a8

Please sign in to comment.