Skip to content

Commit

Permalink
gt: parallel torus multiexp
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Nov 28, 2024
1 parent 475cc7d commit 823b9f4
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 69 deletions.
1 change: 1 addition & 0 deletions constantine/math/pairings/gt_multiexp.nim
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ template withTorus[exponentsBits: static int, GT](
var r_torus {.noInit.}: T2Prj[F]
multiExpProc(r_torus, elemsTorus, expos, len, c)
r.fromTorus2_vartime(r_torus)
freeHeap(elemsTorus)

# Combined accel
# -----------------------------------------------------------------------------------------------------------------------
Expand Down
248 changes: 184 additions & 64 deletions constantine/math/pairings/gt_multiexp_parallel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import constantine/named/algebras,
constantine/math/arithmetic,
constantine/named/zoo_endomorphisms,
constantine/platforms/abstractions,
./cyclotomic_subgroups,
./cyclotomic_subgroups, ./gt_prj,
constantine/threadpool

import ./gt_multiexp {.all.}
Expand All @@ -27,21 +27,21 @@ import ./gt_multiexp {.all.}
# #
# ########################################################### #

proc bucketAccumReduce_withInit[bits: static int, GT](
windowProd: ptr GT,
buckets: ptr GT or ptr UncheckedArray[GT],
proc bucketAccumReduce_withInit[bits: static int, GtAcc, GtElt](
windowProd: ptr GtAcc,
buckets: ptr GtAcc or ptr UncheckedArray[GtAcc],
bitIndex: int, miniMultiExpKind: static MiniMultiExpKind, c: static int,
elems: ptr UncheckedArray[GT], expos: ptr UncheckedArray[BigInt[bits]], N: int) =
elems: ptr UncheckedArray[GtElt], expos: ptr UncheckedArray[BigInt[bits]], N: int) =
const numBuckets = 1 shl (c-1)
let buckets = cast[ptr UncheckedArray[GT]](buckets)
let buckets = cast[ptr UncheckedArray[GtAcc]](buckets)
for i in 0 ..< numBuckets:
buckets[i].setNeutral()
bucketAccumReduce(windowProd[], buckets, bitIndex, miniMultiExpKind, c, elems, expos, N)

proc multiexpImpl_vartime_parallel[bits: static int, GT](
proc multiexpImpl_vartime_parallel[bits: static int, GtAcc, GtElt](
tp: Threadpool,
r: ptr GT,
elems: ptr UncheckedArray[GT], expos: ptr UncheckedArray[BigInt[bits]],
r: ptr GtAcc,
elems: ptr UncheckedArray[GtElt], expos: ptr UncheckedArray[BigInt[bits]],
N: int, c: static int) =

# Prologue
Expand All @@ -53,10 +53,10 @@ proc multiexpImpl_vartime_parallel[bits: static int, GT](
# Instead of storing the result in futures, risking them being scattered in memory
# we store them in a contiguous array, and the synchronizing future just returns a bool.
# top window is done on this thread
let miniMultiExpsResults = allocHeapArray(GT, numFullWindows)
let miniMultiExpsResults = allocHeapArray(GtAcc, numFullWindows)
let miniMultiExpsReady = allocStackArray(FlowVar[bool], numFullWindows)

let bucketsMatrix = allocHeapArray(GT, numBuckets*numWindows)
let bucketsMatrix = allocHeapArray(GtAcc, numBuckets*numWindows)

# Algorithm
# ---------
Expand All @@ -78,32 +78,22 @@ proc multiexpImpl_vartime_parallel[bits: static int, GT](
# Last window is done sync on this thread, directly initializing r
const excess = bits mod c
const top = bits-excess

when top != 0:
when excess != 0:
bucketAccumReduce_withInit(
r,
bucketsMatrix[numFullWindows*numBuckets].addr,
bitIndex = top, kTopWindow, c,
elems, expos, N)
else:
r[].setNeutral()

# 3. Final reduction, r initialized to what would be miniMSMsReady[numWindows-1]
when excess != 0:
for w in countdown(numWindows-2, 0):
for _ in 0 ..< c:
r[].cyclotomic_square()
discard sync miniMultiExpsReady[w]
r[] ~*= miniMultiExpsResults[w]
elif numWindows >= 2:
discard sync miniMultiExpsReady[numWindows-2]
r[] = miniMultiExpsResults[numWindows-2]
for w in countdown(numWindows-3, 0):
for _ in 0 ..< c:
r[].cyclotomic_square()
discard sync miniMultiExpsReady[w]
r[] ~*= miniMultiExpsResults[w]
const msmKind = if top == 0: kBottomWindow
elif excess == 0: kFullWindow
else: kTopWindow

bucketAccumReduce_withInit(
r,
bucketsMatrix[numFullWindows*numBuckets].addr,
bitIndex = top, msmKind, c,
elems, expos, N)

# 3. Final reduction
for w in countdown(numFullWindows-1, 0):
for _ in 0 ..< c:
r[].cyclotomic_square()
discard sync miniMultiExpsReady[w]
r[] ~*= miniMultiExpsResults[w]

# Cleanup
# -------
Expand Down Expand Up @@ -170,14 +160,119 @@ template withEndo[exponentsBits: static int, GT](
else:
multiExpProc(tp, r, elems, expos, N, c)

# Torus acceleration
# -----------------------------------------------------------------------------------------------------------------------

template withTorus[exponentsBits: static int, GT](
multiExpProc: untyped,
tp: Threadpool,
r: ptr GT,
elems: ptr UncheckedArray[GT],
expos: ptr UncheckedArray[BigInt[exponentsBits]],
len: int, c: static int) =
static: doAssert Gt is QuadraticExt, "GT was: " & $Gt
type F = typeof(elems[0].c0)
var elemsTorus = allocHeapArrayAligned(T2Aff[F], len, alignment = 64)
# TODO: macro symbol resolution bug
# syncScope:
# tp.parallelFor i in 0 ..< N:
# captures: {elems, elemsTorus}
# # TODO: Parallel batch conversion
# elemsTorus.fromGT_vartime(elems[i])
elemsTorus.toOpenArray(0, len-1).batchFromGT_vartime(
elems.toOpenArray(0, len-1)
)
var r_torus {.noInit.}: T2Prj[F]
multiExpProc(tp, r_torus.addr, elemsTorus, expos, len, c)
r[].fromTorus2_vartime(r_torus)
freeHeap(elemsTorus)

# Combined accel
# -----------------------------------------------------------------------------------------------------------------------

# Endomorphism acceleration on a torus can be implemented through either of the following approaches:
# - First convert to Torus then apply endomorphism acceleration
# - or apply endomorphism acceleration then convert to Torus
#
# The first approach minimizes memory as we use a compressed torus representation and is easier to compose (no withEndoTorus)
# the second approach reuses Constantine's Frobenius implementation.
# It's unsure which one is more efficient, but difference is dwarfed by the rest of the compute.

proc applyEndoTorus_parallel[bits: static int, GT](
tp: Threadpool,
elems: ptr UncheckedArray[GT],
expos: ptr UncheckedArray[BigInt[bits]],
N: int): auto =
## Decompose (elems, expos) into mini-scalars
## and apply Torus conversion
## Returns a new triplet (endoTorusElems, endoTorusExpos, N)
## endoTorusElems and endoTorusExpos MUST be freed afterwards

const M = when Gt.Name.getEmbeddingDegree() == 6: 2
elif Gt.Name.getEmbeddingDegree() == 12: 4
else: {.error: "Unconfigured".}

const L = Fr[Gt.Name].bits().computeEndoRecodedLength(M)
let splitExpos = allocHeapArray(array[M, BigInt[L]], N)
let endoBasis = allocHeapArray(array[M, GT], N)

type F = typeof(elems[0].c0)
let endoTorusBasis = allocHeapArray(array[M, T2Aff[F]], N)

syncScope:
tp.parallelFor i in 0 ..< N:
captures: {elems, expos, splitExpos, endoBasis, endoTorusBasis}

var negateElems {.noinit.}: array[M, SecretBool]
splitExpos[i].decomposeEndo(negateElems, expos[i], Fr[Gt.Name].bits(), Gt.Name, G2) # 𝔾ₜ has same decomposition as 𝔾₂
if negateElems[0].bool:
endoBasis[i][0].cyclotomic_inv(elems[i])
else:
endoBasis[i][0] = elems[i]

cast[ptr array[M-1, GT]](endoBasis[i][1].addr)[].computeEndomorphisms(elems[i])
for m in 1 ..< M:
if negateElems[m].bool:
endoBasis[i][m].cyclotomic_inv()

# TODO: we batch-torus convert M by M
# but we could parallel batch convert over the whole range
endoTorusBasis[i].batchFromGT_vartime(endoBasis[i])

let endoTorusElems = cast[ptr UncheckedArray[GT]](endoTorusBasis)
let endoExpos = cast[ptr UncheckedArray[BigInt[L]]](splitExpos)
freeHeapAligned(endoBasis)

return (endoTorusElems, endoExpos, M*N)

template withEndoTorus[exponentsBits: static int, GT](
multiExpProc: untyped,
tp: Threadpool,
r: ptr GT,
elems: ptr UncheckedArray[GT],
expos: ptr UncheckedArray[BigInt[exponentsBits]],
N: int, c: static int) =
when Gt.Name.hasEndomorphismAcceleration() and
EndomorphismThreshold <= exponentsBits and
exponentsBits <= Fr[Gt.Name].bits():
let (endoTorusElems, endoExpos, endoN) = applyEndoTorus_parallel(tp, elems, expos, N)
# Given that bits and N changed, we are able to use a bigger `c`
# TODO: bench
multiExpProc(tp, r, endoTorusElems, endoExpos, endoN, c)
freeHeap(endoTorusElems)
freeHeap(endoExpos)
else:
withTorus(multiExpProc, r, elems, expos, N, c)

# Algorithm selection
# -----------------------------------------------------------------------------------------------------------------------

proc multiexp_dispatch_vartime_parallel[bits: static int, GT](
tp: Threadpool,
r: ptr GT,
elems: ptr UncheckedArray[GT],
expos: ptr UncheckedArray[BigInt[bits]], N: int) =
expos: ptr UncheckedArray[BigInt[bits]], N: int,
useTorus: static bool) =
## Multiexponentiation:
## r <- g₀^a₀ + g₁^a₁ + ... + gₙ^aₙ
let c = bestBucketBitSize(N, bits, useSignedBuckets = true, useManualTuning = true)
Expand All @@ -186,53 +281,77 @@ proc multiexp_dispatch_vartime_parallel[bits: static int, GT](
# we are able to use a bigger `c`
# TODO: benchmark

case c
of 2: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 2)
of 3: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 3)
of 4: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 4)
of 5: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 5)
of 6: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 6)
of 7: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 7)
of 8: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 8)
of 9: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 9)
of 10: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 10)
of 11: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 11)
of 12: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 12)
of 13: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 13)
of 14: multiExpImpl_vartime_parallel(tp, r, elems, expos, N, c = 14)
of 15: multiExpImpl_vartime_parallel(tp, r, elems, expos, N, c = 15)

of 16..17: multiExpImpl_vartime_parallel(tp, r, elems, expos, N, c = 16)
when useTorus:
case c
of 2: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 2)
of 3: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 3)
of 4: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 4)
of 5: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 5)
of 6: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 6)
of 7: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 7)
of 8: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 8)
of 9: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 9)
of 10: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 10)
of 11: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 11)
of 12: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 12)
of 13: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 13)
of 14: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 14)
of 15: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 15)

of 16..17: withTorus(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 16)
else:
unreachable()
else:
unreachable()
case c
of 2: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 2)
of 3: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 3)
of 4: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 4)
of 5: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 5)
of 6: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 6)
of 7: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 7)
of 8: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 8)
of 9: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 9)
of 10: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 10)
of 11: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 11)
of 12: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 12)
of 13: withEndo(multiExpImpl_vartime_parallel, tp, r, elems, expos, N, c = 13)
of 14: multiExpImpl_vartime_parallel(tp, r, elems, expos, N, c = 14)
of 15: multiExpImpl_vartime_parallel(tp, r, elems, expos, N, c = 15)

of 16..17: multiExpImpl_vartime_parallel(tp, r, elems, expos, N, c = 16)
else:
unreachable()

proc multiExp_vartime_parallel*[bits: static int, GT](
tp: Threadpool,
r: ptr GT,
elems: ptr UncheckedArray[GT],
expos: ptr UncheckedArray[BigInt[bits]],
len: int) {.meter, inline.} =
len: int,
useTorus: static bool = false) {.meter, inline.} =
## Multiexponentiation:
## r <- g₀^a₀ + g₁^a₁ + ... + gₙ^aₙ
tp.multiExp_dispatch_vartime_parallel(r, elems, expos, len)
tp.multiExp_dispatch_vartime_parallel(r, elems, expos, len, useTorus)

proc multiExp_vartime_parallel*[bits: static int, GT](
tp: Threadpool,
r: var GT,
elems: openArray[GT],
expos: openArray[BigInt[bits]]) {.meter, inline.} =
expos: openArray[BigInt[bits]],
useTorus: static bool = false) {.meter, inline.} =
## Multiexponentiation:
## r <- g₀^a₀ + g₁^a₁ + ... + gₙ^aₙ
debug: doAssert elems.len == expos.len
let N = elems.len
tp.multiExp_dispatch_vartime_parallel(r.addr, elems.asUnchecked(), expos.asUnchecked(), N)
tp.multiExp_dispatch_vartime_parallel(r.addr, elems.asUnchecked(), expos.asUnchecked(), N, useTorus)

proc multiExp_vartime_parallel*[F, GT](
tp: Threadpool,
r: ptr GT,
elems: ptr UncheckedArray[GT],
expos: ptr UncheckedArray[F],
len: int) {.meter.} =
len: int,
useTorus: static bool = false) {.meter.} =
## Multiexponentiation:
## r <- g₀^a₀ + g₁^a₁ + ... + gₙ^aₙ
let n = cast[int](len)
Expand All @@ -242,17 +361,18 @@ proc multiExp_vartime_parallel*[F, GT](
tp.parallelFor i in 0 ..< n:
captures: {expos, expos_big}
expos_big[i].fromField(expos[i])
tp.multiExp_vartime_parallel(r, elems, expos_big, n)
tp.multiExp_vartime_parallel(r, elems, expos_big, n, useTorus)

freeHeapAligned(expos_big)

proc multiExp_vartime_parallel*[GT](
tp: Threadpool,
r: var GT,
elems: openArray[GT],
expos: openArray[Fr]) {.meter, inline.} =
expos: openArray[Fr],
useTorus: static bool = false) {.meter, inline.} =
## Multiexponentiation:
## r <- g₀^a₀ + g₁^a₁ + ... + gₙ^aₙ
debug: doAssert elems.len == expos.len
let N = elems.len
tp.multiExp_vartime_parallel(r.addr, elems.asUnchecked(), expos.asUnchecked(), N)
tp.multiExp_vartime_parallel(r.addr, elems.asUnchecked(), expos.asUnchecked(), N, useTorus)
1 change: 1 addition & 0 deletions constantine/math/pairings/gt_prj.nim
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ proc batchFromGT_vartime*[F](dst: var openArray[T2Aff[F]],
## so this is about a ~25% speedup

# TODO: handle neutral element
# TODO: Parallel batch inversion

debug: doAssert dst.len == src.len

Expand Down
2 changes: 1 addition & 1 deletion tests/math_pairings/t_pairing_bls12_381_gt_multiexp.nim
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import
# Test utilities
./t_pairing_template

const numPoints = [1, 2, 8, 16, 128, 256, 1024]
const numPoints = [1, 2, 3, 4, 5, 6, 7, 8, 16, 128, 256, 1024]

runGTmultiexpTests(
# Torus-based cryptography requires quadratic extension
Expand Down
7 changes: 5 additions & 2 deletions tests/parallel/t_pairing_bls12_381_gt_multiexp_parallel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ import
# Test utilities
./t_pairing_template_parallel

const numPoints = [1, 2, 8, 16, 128, 256, 1024]
const numPoints = [1, 2, 3, 4, 5, 6, 7, 8, 16, 128, 256, 1024]

runGTmultiexp_parallel_Tests(
GT = Fp12[BLS12_381],
# Torus-based cryptography requires quadratic extension
# but by default cubic extensions are faster
# GT = Fp12[BLS12_381],
GT = QuadraticExt[Fp6[BLS12_381]],
numPoints,
Iters = 4)
6 changes: 4 additions & 2 deletions tests/parallel/t_pairing_template_parallel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@ proc runGTmultiexp_parallel_Tests*[N: static int](GT: typedesc, num_points: arra
t.gtExp_vartime(elems[i], exponents[i])
naive *= t

var mexp: GT
tp.multiExp_vartime_parallel(mexp, elems, exponents)
var mexp, mexp_torus: GT
tp.multiExp_vartime_parallel(mexp, elems, exponents, useTorus = false)
tp.multiExp_vartime_parallel(mexp_torus, elems, exponents, useTorus = true)

doAssert bool(naive == mexp)
doAssert bool(naive == mexp_torus)

stdout.write '.'

Expand Down

0 comments on commit 823b9f4

Please sign in to comment.