diff --git a/lib/ThroatQueueFunction.js b/lib/ThroatQueueFunction.js index 87a2397..c4fe662 100644 --- a/lib/ThroatQueueFunction.js +++ b/lib/ThroatQueueFunction.js @@ -4,7 +4,7 @@ const NextTickPromise = Q() function ThroatQueueFunction(n = 5){ const running = [] - const ret = async function(what){ + let ret = async function(cancellationState, what){ if(what === null){ if(running.length === 0) return running await Promise.all(running) @@ -13,7 +13,7 @@ function ThroatQueueFunction(n = 5){ // This shouldn't happen if we correctly await on the throat while(running.length >= n){ - await Q.safeRace(running) + await cancellationState.promiseWrap(Q.cancelledRace(running)) } // call fn @@ -23,7 +23,7 @@ function ThroatQueueFunction(n = 5){ if(typeof what === 'function'){ what = what() } - return await what + return await cancellationState.promiseWrap(what) } finally { for(let i = 0 ; i < running.length; i++){ if(running[i].id === idObj){ @@ -38,14 +38,19 @@ function ThroatQueueFunction(n = 5){ const r = rFn() r.id = idObj r.fn = what + r.cancel = ()=>{ + cancellationState.cancel() + } running.push(r) await r while(running.length >= n){ - await Q.safeRace(running) + await cancellationState.promiseWrap(Q.cancelledRace(running)) } } + ret = Q.canceller(ret) + ret.running = running return ret diff --git a/test/throat_test.js b/test/throat_test.js index c2c4e9d..c0166d1 100644 --- a/test/throat_test.js +++ b/test/throat_test.js @@ -78,6 +78,65 @@ describe('ThroatQueueFunction', function(){ await tf(null) expect(count).to.be.eql(10) + }) + it('should run 5 times only if cancelled from fn', async() => { + const tf = ThroatQueueFunction(5) + + let count = 0 + const deferred = Q.defer() + const p = [] + for(let i = 0; i < 10; i++){ + p.push(tf(()=>{ + count++ + return deferred.promise + })) + } + + await Q.delay(10) + + + expect(count).to.be.eql(5) + + for(const pp of p){ + pp.cancel() + } + + deferred.resolve(true) + + await tf(null) + expect(count).to.be.eql(5) + + }) + it('should run 5 times only if cancelled from running', async() => { + const tf = ThroatQueueFunction(5) + + let count = 0 + const deferred = Q.defer() + const p = [] + for(let i = 0; i < 10; i++){ + p.push(tf(Q.canceller(async(cancellationState)=>{ + await cancellationState.promiseWrap(deferred.promise) + count++ + }))) + } + + await Q.delay(10) + + + expect(tf.running.length).to.be.eql(5) + + for(const pp of tf.running){ + pp.cancel() + } + + deferred.resolve(true) + + + expect(tf.running.length).to.be.eql(5) + + await tf(null) + expect(count).to.be.eql(5) + }) it('should capture stack trace', async() => { async function testFn(){