Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf: Poseidon2 GKR circuit #1410

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open

Conversation

Tabaie
Copy link
Contributor

@Tabaie Tabaie commented Feb 4, 2025

@Tabaie Tabaie marked this pull request as ready for review February 12, 2025 15:14
Copy link
Collaborator

@ivokub ivokub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged the current Poseidon2 PR and fixed the conflicts. When that PR is done with the issues addressed then we can merge once again.

I checked the GKR gates and they seem correct. However, I'm not able to fully follow how the full permutation is implemented, the gate registrations etc. are inlined and Poseidon2 permutation steps are unrolled.

See the comments and I would come back to review the PR once again, maybe we can have a look at the unrolled methods in a call.

Currently the PR is only hardcoded for BLS12-377, which I think for start is fine as it is really difficult to support different native elements. I actually encountered the same issue when implementing non-native sumcheck where I implemented generic arithEngine interface which can be used to perform operations on different types (see https://github.com/Consensys/gnark/blob/master/std/recursion/sumcheck/arithengine.go), but it was exploratory approahc.

See also the suggested edit below to resolve one TODO in the test file.

@@ -44,7 +44,6 @@ func (api *API) Mul(i1, i2 constraint.GkrVariable, in ...constraint.GkrVariable)
return api.namedGate2PlusIn("mul", i1, i2, in...)
}

// TODO @Tabaie This can be useful
func (api *API) Println(a ...constraint.GkrVariable) {
func (api *API) Println(...constraint.GkrVariable) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would remove unimplemented function.

"math/big"
)

// SolveAll IS A TEST FUNCTION USED ONLY TO DEBUG a GKR circuit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is only a test function not supposed to be publicly exposed, then I would rename the package of this file to gkr_test. Then it won't be publicly exposed and only available in test files which also set package gkr_test. This also requires that this is a test file, i.e. should have a suffix _test.go. So I would recommend renaming it to debugging_test.go.

Permutation
Ins []frontend.Variable
Outs []frontend.Variable
plainHasher hash.Hash
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plainHasher not used?

@@ -0,0 +1,388 @@
package poseidon2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would refactor the GKR Poseidon2 implementation into a separate package std/permutation/poseidon2/gkr.

// SHA256 hash of the hash parameters - as a unique identifier
// Note that the identifier is only unique with respect to the size parameters
// t, d, rF, rP
func (h *Permutation) hash(curve ecc.ID) []byte {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Method doesn't seem to be used? But I see the point of it, however, instead of having separate implementation of the unique identifier have a look at https://github.com/Consensys/gnark-crypto/blob/master/ecc/bls12-377/fr/poseidon2/poseidon2.go#L73-L75. I think it would be nice to have consistent identifiers across.

Currently it is not nicely accessible though as the Parameters structs in the gnark-crypto package are typed per field. However, we could implement some getters to get the parameter values (numbers of rounds, degree and width) and then in the gnark poseidon2 package we could have generic interface

type Parameters interface {
    String() string
    GetDegree() int
    GetWidth() int
    GetNbFullRounds() int
    GetNbPartialRounds() int
}

gkr.Gates["pow4"] = pow4Gate{}
gkr.Gates["pow4Times"] = pow4TimesGate{}

sBox := func(round int, u constraint.GkrVariable) constraint.GkrVariable {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

round argument not used. Can omit

return &res
}

func (p *GkrPermutations) finalize(api frontend.API) error {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method is currently quite difficult to follow - it inlines hash registration, gate registration (and all the helpers to derive names). I would separate the gate registration into a separate function (which can be registered automatically at init). I would define them separately at package-level init() function.

Then, it would be nice if we would have gates collected as in methods of the Permutation methods a la matMulExternal (similar to matMulExternalInPlace but doesn't do in place but rather returns the result) etc.

}

const (
rF = 6
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though currently it supports only one set of parameters and it would be difficult to generalize (as we need to hint permutations of different curves), then I would avoid hardcoding the parameters here. It makes any changes very complicated, needing to update parameters here. We can instead provide the parameters in the NewGkrPermutations and then refer to the variables instead.

@ivokub
Copy link
Collaborator

ivokub commented Feb 13, 2025

Suggested edit:

diff --git a/std/gkr/api_test.go b/std/gkr/api_test.go
index 3519f8e1..55dad8fb 100644
--- a/std/gkr/api_test.go
+++ b/std/gkr/api_test.go
@@ -16,8 +16,6 @@ import (
 	bw6761 "github.com/consensys/gnark/constraint/bw6-761"
 	"github.com/consensys/gnark/test"
 
-	"github.com/consensys/gnark-crypto/kzg"
-	"github.com/consensys/gnark/backend/plonk"
 	bn254 "github.com/consensys/gnark/constraint/bn254"
 	"github.com/stretchr/testify/require"
 
@@ -26,15 +24,12 @@ import (
 	"github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr"
 	bn254MiMC "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc"
 	"github.com/consensys/gnark/backend/groth16"
-	"github.com/consensys/gnark/backend/witness"
 	"github.com/consensys/gnark/constraint"
 	"github.com/consensys/gnark/frontend"
 	"github.com/consensys/gnark/frontend/cs/r1cs"
-	"github.com/consensys/gnark/frontend/cs/scs"
 	stdHash "github.com/consensys/gnark/std/hash"
 	"github.com/consensys/gnark/std/hash/mimc"
 	test_vector_utils "github.com/consensys/gnark/std/internal/test_vectors_utils"
-	"github.com/consensys/gnark/test/unsafekzg"
 )
 
 // compressThreshold --> if linear expressions are larger than this, the frontend will introduce
@@ -69,6 +64,7 @@ func (c *doubleNoDependencyCircuit) Define(api frontend.API) error {
 }
 
 func TestDoubleNoDependencyCircuit(t *testing.T) {
+	assert := test.NewAssert(t)
 
 	xValuess := [][]frontend.Variable{
 		{1, 1},
@@ -77,12 +73,13 @@ func TestDoubleNoDependencyCircuit(t *testing.T) {
 
 	hashes := []string{"-1", "-20"}
 
-	for _, xValues := range xValuess {
+	for i, xValues := range xValuess {
 		for _, hashName := range hashes {
 			assignment := doubleNoDependencyCircuit{X: xValues}
 			circuit := doubleNoDependencyCircuit{X: make([]frontend.Variable, len(xValues)), hashName: hashName}
-
-			test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment))
+			assert.Run(func(assert *test.Assert) {
+				assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254))
+			}, fmt.Sprintf("xValue=%d/hash=%s", i, hashName))
 
 		}
 	}
@@ -115,6 +112,7 @@ func (c *sqNoDependencyCircuit) Define(api frontend.API) error {
 }
 
 func TestSqNoDependencyCircuit(t *testing.T) {
+	assert := test.NewAssert(t)
 
 	xValuess := [][]frontend.Variable{
 		{1, 1},
@@ -123,12 +121,13 @@ func TestSqNoDependencyCircuit(t *testing.T) {
 
 	hashes := []string{"-1", "-20"}
 
-	for _, xValues := range xValuess {
+	for i, xValues := range xValuess {
 		for _, hashName := range hashes {
 			assignment := sqNoDependencyCircuit{X: xValues}
 			circuit := sqNoDependencyCircuit{X: make([]frontend.Variable, len(xValues)), hashName: hashName}
-			testGroth16(t, &circuit, &assignment)
-			testPlonk(t, &circuit, &assignment)
+			assert.Run(func(assert *test.Assert) {
+				assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254))
+			}, fmt.Sprintf("xValues=%d/hash=%s", i, hashName))
 		}
 	}
 }
@@ -168,6 +167,7 @@ func (c *mulNoDependencyCircuit) Define(api frontend.API) error {
 }
 
 func TestMulNoDependency(t *testing.T) {
+	assert := test.NewAssert(t)
 	xValuess := [][]frontend.Variable{
 		{1, 2},
 	}
@@ -189,9 +189,9 @@ func TestMulNoDependency(t *testing.T) {
 				Y:        make([]frontend.Variable, len(yValuess[i])),
 				hashName: hashName,
 			}
-
-			testGroth16(t, &circuit, &assignment)
-			testPlonk(t, &circuit, &assignment)
+			assert.Run(func(assert *test.Assert) {
+				assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254))
+			}, fmt.Sprintf("xValues=%d/hash=%s", i, hashName))
 		}
 	}
 }
@@ -240,14 +240,13 @@ func (c *mulWithDependencyCircuit) Define(api frontend.API) error {
 }
 
 func TestSolveMulWithDependency(t *testing.T) {
+	assert := test.NewAssert(t)
 	assignment := mulWithDependencyCircuit{
 		XLast: 1,
 		Y:     []frontend.Variable{3, 2},
 	}
 	circuit := mulWithDependencyCircuit{Y: make([]frontend.Variable, len(assignment.Y)), hashName: "-20"}
-
-	testGroth16(t, &circuit, &assignment)
-	testPlonk(t, &circuit, &assignment)
+	assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254))
 }
 
 func TestApiMul(t *testing.T) {
@@ -387,54 +386,6 @@ func (c *benchMiMCMerkleTreeCircuit) Define(api frontend.API) error {
 	return solution.Verify("-20", challenge)
 }
 
-// TODO @Tabaie just try using IsSolved instead?
-func testGroth16(t *testing.T, circuit, assignment frontend.Circuit) {
-	cs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, circuit, frontend.WithCompressThreshold(compressThreshold))
-	require.NoError(t, err)
-	var (
-		fullWitness   witness.Witness
-		publicWitness witness.Witness
-		pk            groth16.ProvingKey
-		vk            groth16.VerifyingKey
-		proof         groth16.Proof
-	)
-	fullWitness, err = frontend.NewWitness(assignment, ecc.BN254.ScalarField())
-	require.NoError(t, err)
-	publicWitness, err = fullWitness.Public()
-	require.NoError(t, err)
-	pk, vk, err = groth16.Setup(cs)
-	require.NoError(t, err)
-	proof, err = groth16.Prove(cs, pk, fullWitness)
-	require.NoError(t, err)
-	err = groth16.Verify(proof, vk, publicWitness)
-	require.NoError(t, err)
-}
-
-func testPlonk(t *testing.T, circuit, assignment frontend.Circuit) {
-	cs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, circuit, frontend.WithCompressThreshold(compressThreshold))
-	require.NoError(t, err)
-	var (
-		fullWitness   witness.Witness
-		publicWitness witness.Witness
-		pk            plonk.ProvingKey
-		vk            plonk.VerifyingKey
-		proof         plonk.Proof
-		kzgSrs        kzg.SRS
-	)
-	fullWitness, err = frontend.NewWitness(assignment, ecc.BN254.ScalarField())
-	require.NoError(t, err)
-	publicWitness, err = fullWitness.Public()
-	require.NoError(t, err)
-	kzgSrs, srsLagrange, err := unsafekzg.NewSRS(cs)
-	require.NoError(t, err)
-	pk, vk, err = plonk.Setup(cs, kzgSrs, srsLagrange)
-	require.NoError(t, err)
-	proof, err = plonk.Prove(cs, pk, fullWitness)
-	require.NoError(t, err)
-	err = plonk.Verify(proof, vk, publicWitness)
-	require.NoError(t, err)
-}
-
 func registerMiMC() {
 	bn254.RegisterHashBuilder("mimc", func() hash.Hash {
 		return bn254MiMC.NewMiMC()
@@ -646,19 +597,21 @@ func BenchmarkMiMCNoGkrFullDepthSolve(b *testing.B) {
 }
 
 func TestMiMCFullDepthNoDepSolve(t *testing.T) {
+	assert := test.NewAssert(t)
 	registerMiMC()
 	for i := 0; i < 100; i++ {
 		circuit, assignment := mimcNoDepCircuits(5, 1<<2, "-20")
-		testGroth16(t, circuit, assignment)
-		testPlonk(t, circuit, assignment)
+		assert.Run(func(assert *test.Assert) {
+			assert.CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254))
+		}, fmt.Sprintf("i=%d", i))
 	}
 }
 
 func TestMiMCFullDepthNoDepSolveWithMiMCHash(t *testing.T) {
+	assert := test.NewAssert(t)
 	registerMiMC()
 	circuit, assignment := mimcNoDepCircuits(5, 1<<2, "mimc")
-	testGroth16(t, circuit, assignment)
-	testPlonk(t, circuit, assignment)
+	assert.CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254))
 }
 
 func mimcNoGkrCircuits(mimcDepth, nbInstances int) (circuit, assignment frontend.Circuit) {
diff --git a/std/gkr/testing.go b/std/gkr/testing.go
index 5f824cac..5a3f7337 100644
--- a/std/gkr/testing.go
+++ b/std/gkr/testing.go
@@ -3,12 +3,13 @@ package gkr
 import (
 	"errors"
 	"fmt"
+	"math/big"
+
 	"github.com/consensys/gnark-crypto/ecc"
 	frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
 	gkrBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/gkr"
 	hint "github.com/consensys/gnark/constraint/solver"
 	"github.com/consensys/gnark/frontend"
-	"math/big"
 )
 
 // SolveAll IS A TEST FUNCTION USED ONLY TO DEBUG a GKR circuit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants