-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from zhxiaogg/feature/add_sample_stage
WIP: add sample stage
- Loading branch information
Showing
2 changed files
with
136 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package akka.stream.contrib | ||
|
||
import java.util.Random | ||
import java.util.concurrent.ThreadLocalRandom | ||
|
||
import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler} | ||
import akka.stream.{Attributes, FlowShape, Inlet, Outlet} | ||
|
||
|
||
object Sample { | ||
/** | ||
* | ||
* returns every nth elements | ||
* | ||
* @param nth must > 0 | ||
* @tparam T | ||
* @return | ||
*/ | ||
def apply[T](nth: Int): Sample[T] = Sample[T](() => nth) | ||
|
||
/** | ||
* | ||
* randomly sampling on a stream | ||
* | ||
* @param maxStep must > 0, default 1000, the randomly step will be between 1 (inclusive) and maxStep (inclusive) | ||
* @tparam T | ||
* @return | ||
*/ | ||
def random[T](maxStep: Int = 1000): Sample[T] = { | ||
require(maxStep > 0, "max step for a random sampling must > 0") | ||
Sample[T](() => ThreadLocalRandom.current().nextInt(maxStep) + 1) | ||
} | ||
} | ||
|
||
|
||
/** | ||
* supports sampling on stream | ||
* | ||
* @param next a lambda returns next sample position | ||
* @tparam T | ||
*/ | ||
case class Sample[T](next: () => Int) extends GraphStage[FlowShape[T, T]] { | ||
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler { | ||
var step = getNextStep() | ||
var counter = 0 | ||
|
||
def onPull(): Unit = { | ||
pull(in) | ||
} | ||
|
||
def onPush(): Unit = { | ||
counter += 1 | ||
if (counter >= step) { | ||
counter = 0 | ||
step = getNextStep() | ||
push(out, grab(in)) | ||
} else { | ||
pull(in) | ||
} | ||
} | ||
|
||
private def getNextStep(): Long = { | ||
val nextStep = next() | ||
require(nextStep > 0, s"sampling step should be a positive value: ${nextStep}") | ||
nextStep | ||
} | ||
|
||
setHandlers(in, out, this) | ||
} | ||
|
||
val in = Inlet[T]("Sample-in") | ||
val out = Outlet[T]("Sample-out") | ||
override val shape = FlowShape(in, out) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
package akka.stream.contrib | ||
|
||
import akka.actor.ActorSystem | ||
import akka.stream._ | ||
import akka.stream.scaladsl.{Sink, Source} | ||
import org.scalatest.{Matchers, WordSpec} | ||
|
||
import scala.concurrent.Await | ||
import scala.concurrent.duration._ | ||
|
||
class SampleSpec extends WordSpec with Matchers { | ||
private implicit val system = ActorSystem("SampleTest") | ||
private implicit val materializer = ActorMaterializer() | ||
|
||
"Sample Stage" should { | ||
"returns every Nth element in stream" in { | ||
val list = 1 to 1000 | ||
val source = Source.fromIterator[Int](() => list.toIterator) | ||
|
||
for (n <- 1 to 100) { | ||
val future = source.via(Sample(n)).runWith(Sink.seq) | ||
val expected = list.filter(_ % n == 0).toList | ||
|
||
Await.result(future, 3 seconds) should ===(expected) | ||
} | ||
} | ||
|
||
"returns elements randomly" in { | ||
// a fake random, increase by 1 for every invocation result | ||
var num = 0 | ||
val mockRandom = () => { | ||
num += 1 | ||
num | ||
} | ||
|
||
val future = Source.fromIterator[Int](() => (1 to 10).toIterator) | ||
.via(Sample(mockRandom)) | ||
.runWith(Sink.seq) | ||
|
||
Await.result(future, 3 seconds) should ===((1 :: 3 :: 6 :: 10 :: Nil)) | ||
} | ||
|
||
"throws exception when next step <= 0" in { | ||
intercept[IllegalArgumentException] { | ||
Await.result(Source.empty.via(Sample(() => 0)).runWith(Sink.seq), 3 seconds) | ||
} | ||
|
||
intercept[IllegalArgumentException] { | ||
Await.result(Source.empty.via(Sample(() => -1)).runWith(Sink.seq), 3 seconds) | ||
} | ||
} | ||
|
||
"throws exceptions when max random step <= 0" in { | ||
intercept[IllegalArgumentException] { | ||
Await.result(Source.empty.via(Sample.random(0)).runWith(Sink.seq), 3 seconds) | ||
} | ||
} | ||
} | ||
} | ||
|
||
|
||
|