diff --git a/constraint/bls12-377/solver.go b/constraint/bls12-377/solver.go index f79940e3be..f57abfae52 100644 --- a/constraint/bls12-377/solver.go +++ b/constraint/bls12-377/solver.go @@ -51,14 +51,15 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bls12-381/solver.go b/constraint/bls12-381/solver.go index 1bfa4c5884..67f12ef2aa 100644 --- a/constraint/bls12-381/solver.go +++ b/constraint/bls12-381/solver.go @@ -51,14 +51,15 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bls24-315/solver.go b/constraint/bls24-315/solver.go index 4f5b72c776..05d4c6f11c 100644 --- a/constraint/bls24-315/solver.go +++ b/constraint/bls24-315/solver.go @@ -51,14 +51,15 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bls24-317/solver.go b/constraint/bls24-317/solver.go index 9462b5d3e4..29af5c28b2 100644 --- a/constraint/bls24-317/solver.go +++ b/constraint/bls24-317/solver.go @@ -51,14 +51,15 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bn254/solver.go b/constraint/bn254/solver.go index 4ccc03e7e0..5e9b70c548 100644 --- a/constraint/bn254/solver.go +++ b/constraint/bn254/solver.go @@ -51,14 +51,15 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bw6-633/solver.go b/constraint/bw6-633/solver.go index 642369791f..7fc43652f6 100644 --- a/constraint/bw6-633/solver.go +++ b/constraint/bw6-633/solver.go @@ -51,14 +51,15 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bw6-761/solver.go b/constraint/bw6-761/solver.go index a65445eb4c..d226b03a53 100644 --- a/constraint/bw6-761/solver.go +++ b/constraint/bw6-761/solver.go @@ -51,14 +51,15 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/solver/gkrgates/registry.go b/constraint/solver/gkrgates/registry.go index 49610a1789..23b821cf63 100644 --- a/constraint/solver/gkrgates/registry.go +++ b/constraint/solver/gkrgates/registry.go @@ -7,6 +7,7 @@ import ( "runtime" "sync" + "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" bls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" @@ -99,16 +100,48 @@ func WithCurves(curves ...ecc.ID) registerOption { // - name is a human-readable name for the gate. // - f is the polynomial function defining the gate. // - nbIn is the number of inputs to the gate. -func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { - s := registerSettings{degree: -1, solvableVar: -1, name: GetDefaultGateName(f), curves: []ecc.ID{ecc.BN254}} +// +// If the gate is already registered, it will return false and no error. +func Register(f gkr.GateFunction, nbIn int, options ...registerOption) (registered bool, err error) { + s := registerSettings{degree: -1, solvableVar: -1, name: GetDefaultGateName(f)} for _, option := range options { option(&s) } - for _, curve := range s.curves { + curvesForTesting := s.curves + allowedCurves := s.curves + if len(curvesForTesting) == 0 { + // no restriction on curves, but only test on BN254 + curvesForTesting = []ecc.ID{ecc.BN254} + allowedCurves = gnark.Curves() + } + + gatesLock.Lock() + defer gatesLock.Unlock() + + if g, ok := gates[s.name]; ok { + // gate already registered + if g.NbIn() != nbIn { + return false, fmt.Errorf("gate \"%s\" already registered with a different number of inputs (%d != %d)", s.name, g.NbIn(), nbIn) + } + + for _, curve := range curvesForTesting { + gateVer, err := NewGateVerifier(curve) + if err != nil { + return false, err + } + if !gateVer.equal(f, g.Evaluate, nbIn) { + return false, fmt.Errorf("mismatch with already registered gate \"%s\" (degree %d) over curve %s", s.name, g.Degree(), curve) + } + } + + return false, nil // gate already registered + } + + for _, curve := range curvesForTesting { gateVer, err := NewGateVerifier(curve) if err != nil { - return err + return false, err } if s.degree == -1 { // find a degree @@ -116,14 +149,13 @@ func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { panic("invalid settings") } const maxAutoDegreeBound = 32 - var err error if s.degree, err = gateVer.findDegree(f, maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", s.name, err) + return false, fmt.Errorf("for gate \"%s\": %v", s.name, err) } } else { if !s.noDegreeVerification { // check that the given degree is correct if err = gateVer.verifyDegree(f, s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", s.name, err) + return false, fmt.Errorf("for gate \"%s\": %v", s.name, err) } } } @@ -135,16 +167,14 @@ func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { } else { // solvable variable given if !s.noSolvableVarVerification && !gateVer.isVarSolvable(f, s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, s.name) + return false, fmt.Errorf("cannot verify the solvability of variable %d in gate \"%s\"", s.solvableVar, s.name) } } } - gatesLock.Lock() - defer gatesLock.Unlock() - gates[s.name] = gkrtypes.NewGate(f, nbIn, s.degree, s.solvableVar) - return nil + gates[s.name] = gkrtypes.NewGate(f, nbIn, s.degree, s.solvableVar, allowedCurves) + return true, nil } func Get(name gkr.GateName) *gkrtypes.Gate { @@ -160,6 +190,7 @@ type gateVerifier struct { isAdditive func(f gkr.GateFunction, i int, nbIn int) bool findDegree func(f gkr.GateFunction, max, nbIn int) (int, error) verifyDegree func(f gkr.GateFunction, claimedDegree, nbIn int) error + equal func(f1, f2 gkr.GateFunction, nbIn int) bool } func NewGateVerifier(curve ecc.ID) (*gateVerifier, error) { @@ -172,30 +203,37 @@ func NewGateVerifier(curve ecc.ID) (*gateVerifier, error) { o.isAdditive = bls12377.IsGateFunctionAdditive o.findDegree = bls12377.FindGateFunctionDegree o.verifyDegree = bls12377.VerifyGateFunctionDegree + o.equal = bls12377.EqualGateFunction case ecc.BLS12_381: o.isAdditive = bls12381.IsGateFunctionAdditive o.findDegree = bls12381.FindGateFunctionDegree o.verifyDegree = bls12381.VerifyGateFunctionDegree + o.equal = bls12381.EqualGateFunction case ecc.BLS24_315: o.isAdditive = bls24315.IsGateFunctionAdditive o.findDegree = bls24315.FindGateFunctionDegree o.verifyDegree = bls24315.VerifyGateFunctionDegree + o.equal = bls24315.EqualGateFunction case ecc.BLS24_317: o.isAdditive = bls24317.IsGateFunctionAdditive o.findDegree = bls24317.FindGateFunctionDegree o.verifyDegree = bls24317.VerifyGateFunctionDegree + o.equal = bls24317.EqualGateFunction case ecc.BN254: o.isAdditive = bn254.IsGateFunctionAdditive o.findDegree = bn254.FindGateFunctionDegree o.verifyDegree = bn254.VerifyGateFunctionDegree + o.equal = bn254.EqualGateFunction case ecc.BW6_633: o.isAdditive = bw6633.IsGateFunctionAdditive o.findDegree = bw6633.FindGateFunctionDegree o.verifyDegree = bw6633.VerifyGateFunctionDegree + o.equal = bw6633.EqualGateFunction case ecc.BW6_761: o.isAdditive = bw6761.IsGateFunctionAdditive o.findDegree = bw6761.FindGateFunctionDegree o.verifyDegree = bw6761.VerifyGateFunctionDegree + o.equal = bw6761.EqualGateFunction default: err = fmt.Errorf("unsupported curve %s", curve) } diff --git a/constraint/solver/gkrgates/registry_test.go b/constraint/solver/gkrgates/registry_test.go index ec41888ef3..7fe739b152 100644 --- a/constraint/solver/gkrgates/registry_test.go +++ b/constraint/solver/gkrgates/registry_test.go @@ -11,20 +11,38 @@ import ( "github.com/stretchr/testify/assert" ) -func TestRegisterDegreeDetection(t *testing.T) { +func TestRegister(t *testing.T) { testGate := func(name gkr.GateName, f gkr.GateFunction, nbIn, degree int) { t.Run(string(name), func(t *testing.T) { name = name + "-register-gate-test" - assert.NoError(t, Register(f, nbIn, WithDegree(degree), WithName(name)), "given degree must be accepted") + added, err := Register(f, nbIn, WithDegree(degree), WithName(name+"_given")) + assert.NoError(t, err, "given degree must be accepted") + assert.True(t, added, "registration must succeed for given degree") - assert.Error(t, Register(f, nbIn, WithDegree(degree-1), WithName(name)), "lower degree must be rejected") + registered, err := Register(f, nbIn, WithDegree(degree-1), WithName(name+"_lower")) + assert.Error(t, err, "error must be returned for lower degree") + assert.False(t, registered, "registration must fail for lower degree") - assert.Error(t, Register(f, nbIn, WithDegree(degree+1), WithName(name)), "higher degree must be rejected") + registered, err = Register(f, nbIn, WithDegree(degree+1), WithName(name+"_higher")) + assert.Error(t, err, "error must be returned for higher degree") + assert.False(t, registered, "registration must fail for higher degree") - assert.NoError(t, Register(f, nbIn), "no degree must be accepted") + registered, err = Register(f, nbIn, WithName(name+"_no_degree")) + assert.NoError(t, err, "no error must be returned when no degree is specified") + assert.True(t, registered, "registration must succeed when no degree is specified") - assert.Equal(t, degree, Get(name).Degree(), "degree must be detected correctly") + assert.Equal(t, degree, Get(name+"_no_degree").Degree(), "degree must be detected correctly") + + added, err = Register(f, nbIn, WithDegree(degree), WithName(name+"_given")) + assert.NoError(t, err, "given degree must be accepted") + assert.False(t, added, "gate must not be re-registered") + + added, err = Register(func(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + return api.Add(f(api, x...), 1) + }, nbIn, WithDegree(degree), WithName(name+"_given")) + assert.Error(t, err, "registering another function under the same name must fail") + assert.False(t, added, "gate must not be re-registered") }) } @@ -47,15 +65,23 @@ func TestRegisterDegreeDetection(t *testing.T) { ) }, 2, 1) - // zero polynomial must not be accepted t.Run("zero", func(t *testing.T) { const gateName gkr.GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, gkrtypes.ErrZeroFunction) + expectedError := fmt.Errorf("for gate \"%s\": %v", gateName, gkrtypes.ErrZeroFunction).Error() zeroGate := func(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Sub(x[0], x[0]) } - assert.Equal(t, expectedError, Register(zeroGate, 1, WithName(gateName))) - assert.Equal(t, expectedError, Register(zeroGate, 1, WithName(gateName), WithDegree(2))) + // Attempt to register the zero gate without specifying a degree + registered, err := Register(zeroGate, 1, WithName(gateName)) + assert.Error(t, err, "error must be returned for zero polynomial") + assert.EqualError(t, err, expectedError, "error message must match expected error") + assert.False(t, registered, "registration must fail for zero polynomial") + + // Attempt to register the zero gate with a specified degree + registered, err = Register(zeroGate, 1, WithName(gateName), WithDegree(2)) + assert.Error(t, err, "error must be returned for zero polynomial with degree") + assert.EqualError(t, err, expectedError, "error message must match expected error") + assert.False(t, registered, "registration must fail for zero polynomial with degree") }) } diff --git a/internal/generator/backend/template/gkr/gate_testing.go.tmpl b/internal/generator/backend/template/gkr/gate_testing.go.tmpl index 534b4b01c8..89d1343be6 100644 --- a/internal/generator/backend/template/gkr/gate_testing.go.tmpl +++ b/internal/generator/backend/template/gkr/gate_testing.go.tmpl @@ -155,6 +155,17 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error return nil } +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make({{.FieldPackageName}}.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} + {{- if not .CanUseFFT }} // interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) // Note that the runtime is O(len(X)³) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 5105b0a33d..3e3881d15f 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -735,7 +735,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod {{ .ElementType }} - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index e1d41e8cb8..29698e0e3b 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -2,8 +2,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -14,105 +14,110 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end + return &d +} + +// this module assumes that wire and instance indexes respect dependencies + + +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } @@ -120,7 +125,13 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -128,7 +139,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_{{.FieldID}}") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -136,4 +147,14 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { return proof.SerializeToBigInts(outs) } +} + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +{{ print "// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}}"}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } } \ No newline at end of file diff --git a/internal/generator/backend/template/representations/solver.go.tmpl b/internal/generator/backend/template/representations/solver.go.tmpl index fd685e6e21..ddb0b7428c 100644 --- a/internal/generator/backend/template/representations/solver.go.tmpl +++ b/internal/generator/backend/template/representations/solver.go.tmpl @@ -43,14 +43,15 @@ func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, {{ if not .NoGKR -}} // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(cs.GkrInfo.GetAssignmentHintID, gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(gkrData)), + csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } {{ end -}} diff --git a/internal/gkr/bls12-377/gate_testing.go b/internal/gkr/bls12-377/gate_testing.go index 415a5ff5b3..9e5a3868f3 100644 --- a/internal/gkr/bls12-377/gate_testing.go +++ b/internal/gkr/bls12-377/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index f5dfad020e..b92ac1249d 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 39547cff29..04c5f52586 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,105 +21,109 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } @@ -127,7 +131,13 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +145,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BLS12_377") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +154,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bls12-381/gate_testing.go b/internal/gkr/bls12-381/gate_testing.go index ef7694dc18..5b281fd634 100644 --- a/internal/gkr/bls12-381/gate_testing.go +++ b/internal/gkr/bls12-381/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index f5617a59d4..82084049d9 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index cb498c78b7..e92e543398 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,105 +21,109 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } @@ -127,7 +131,13 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +145,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BLS12_381") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +154,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bls24-315/gate_testing.go b/internal/gkr/bls24-315/gate_testing.go index 1682d24771..058b53cc06 100644 --- a/internal/gkr/bls24-315/gate_testing.go +++ b/internal/gkr/bls24-315/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 7d89baf7ef..f182c9176b 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index 914c8a9d61..f57537b985 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,105 +21,109 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } @@ -127,7 +131,13 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +145,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BLS24_315") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +154,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bls24-317/gate_testing.go b/internal/gkr/bls24-317/gate_testing.go index 1bffab29e3..ed418ff1b0 100644 --- a/internal/gkr/bls24-317/gate_testing.go +++ b/internal/gkr/bls24-317/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index fc9908b918..a284f14ae9 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index f6e1ad993d..d2cc4d32b1 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,105 +21,109 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } @@ -127,7 +131,13 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +145,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BLS24_317") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +154,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bn254/gate_testing.go b/internal/gkr/bn254/gate_testing.go index 716ba3891b..e9311a3ea5 100644 --- a/internal/gkr/bn254/gate_testing.go +++ b/internal/gkr/bn254/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 04cf3512af..14269151b3 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 7bc3782932..5813b89661 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,105 +21,109 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } @@ -127,7 +131,13 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +145,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BN254") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +154,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bw6-633/gate_testing.go b/internal/gkr/bw6-633/gate_testing.go index 0fafa45a0d..8074b9621c 100644 --- a/internal/gkr/bw6-633/gate_testing.go +++ b/internal/gkr/bw6-633/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index cc1245e726..ec1067f736 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 57343d291f..ef945e25f7 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,105 +21,109 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } @@ -127,7 +131,13 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +145,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BW6_633") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +154,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bw6-761/gate_testing.go b/internal/gkr/bw6-761/gate_testing.go index 6eda2ebe73..0bae6258dc 100644 --- a/internal/gkr/bw6-761/gate_testing.go +++ b/internal/gkr/bw6-761/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index f90f28114b..ad5197feef 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 606f13ec23..1a91928171 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,105 +21,109 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type newSolvingDataOption func(*newSolvingDataSettings) + +func WithAssignment(assignment gkrtypes.WireAssignment) newSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +func NewSolvingData(info gkrtypes.SolvingInfo, options ...newSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Errorf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Errorf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Errorf("provided assignment for wire %d instance %d is not a valid field element: %w", i, j, err)) } } + // inline equivalent of RepeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: instance index, wire index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } @@ -127,7 +131,13 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.RepeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +145,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BW6_761") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +154,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// RepeatUntilEnd for each wire, sets all the values starting from n to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.RepeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) RepeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go new file mode 100644 index 0000000000..74b15c77ba --- /dev/null +++ b/internal/gkr/engine_hints.go @@ -0,0 +1,195 @@ +package gkr + +import ( + "errors" + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/constraint/solver/gkrgates" + "github.com/consensys/gnark/frontend" + bls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" + bls12381 "github.com/consensys/gnark/internal/gkr/bls12-381" + bls24315 "github.com/consensys/gnark/internal/gkr/bls24-315" + bls24317 "github.com/consensys/gnark/internal/gkr/bls24-317" + bn254 "github.com/consensys/gnark/internal/gkr/bn254" + bw6633 "github.com/consensys/gnark/internal/gkr/bw6-633" + bw6761 "github.com/consensys/gnark/internal/gkr/bw6-761" + "github.com/consensys/gnark/internal/gkr/gkrinfo" + "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/internal/utils" +) + +type TestEngineHints struct { + assignment gkrtypes.WireAssignment + info *gkrinfo.StoringInfo // we retain a reference to the solving info to allow the caller to modify it between calls to Solve and Prove + circuit gkrtypes.Circuit + gateIns []frontend.Variable +} + +func NewTestEngineHints(info *gkrinfo.StoringInfo) (*TestEngineHints, error) { + circuit, err := gkrtypes.CircuitInfoToCircuit(info.Circuit, gkrgates.Get) + if err != nil { + return nil, err + } + + return &TestEngineHints{ + info: info, + circuit: circuit, + gateIns: make([]frontend.Variable, circuit.MaxGateNbIn()), + assignment: make(gkrtypes.WireAssignment, len(circuit)), + }, + err +} + +// Solve solves one instance of a GKR circuit. +// The first input is the index of the instance. The rest are the inputs of the circuit, in their nominal order. +func (h *TestEngineHints) Solve(mod *big.Int, ins []*big.Int, outs []*big.Int) error { + + instanceI := len(h.assignment[0]) + if in0 := ins[0].Uint64(); !ins[0].IsUint64() || in0 > 0xffffffff { + return errors.New("first input must be a uint32 instance index") + } else if in0 != uint64(instanceI) || h.info.NbInstances != instanceI { + return errors.New("first input must equal the number of instances, and calls to Solve must be done in order of instance index") + } + + api := gateAPI{mod} + + inI := 1 + outI := 0 + for wI := range h.circuit { + w := &h.circuit[wI] + var val frontend.Variable + if w.IsInput() { + val = utils.FromInterface(ins[inI]) + inI++ + } else { + for gateInI, inWI := range w.Inputs { + h.gateIns[gateInI] = h.assignment[inWI][instanceI] + } + val = w.Gate.Evaluate(api, h.gateIns[:len(w.Inputs)]...) + } + if w.IsOutput() { + *outs[outI] = utils.FromInterface(val) + outI++ + } + h.assignment[wI] = append(h.assignment[wI], val) + } + return nil +} + +func (h *TestEngineHints) Prove(mod *big.Int, ins, outs []*big.Int) error { + + info, err := gkrtypes.StoringToSolvingInfo(*h.info, gkrgates.Get) + if err != nil { + return fmt.Errorf("failed to convert storing info to solving info: %w", err) + } + + // TODO @Tabaie autogenerate this or decide not to + if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { + data := bls12377.NewSolvingData(info, bls12377.WithAssignment(h.assignment)) + return bls12377.ProveHint(info.HashName, data)(mod, ins, outs) + } + if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { + data := bls12381.NewSolvingData(info, bls12381.WithAssignment(h.assignment)) + return bls12381.ProveHint(info.HashName, data)(mod, ins, outs) + } + if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { + data := bls24315.NewSolvingData(info, bls24315.WithAssignment(h.assignment)) + return bls24315.ProveHint(info.HashName, data)(mod, ins, outs) + } + if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { + data := bls24317.NewSolvingData(info, bls24317.WithAssignment(h.assignment)) + return bls24317.ProveHint(info.HashName, data)(mod, ins, outs) + } + if mod.Cmp(ecc.BN254.ScalarField()) == 0 { + data := bn254.NewSolvingData(info, bn254.WithAssignment(h.assignment)) + return bn254.ProveHint(info.HashName, data)(mod, ins, outs) + } + if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { + data := bw6633.NewSolvingData(info, bw6633.WithAssignment(h.assignment)) + return bw6633.ProveHint(info.HashName, data)(mod, ins, outs) + } + if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { + data := bw6761.NewSolvingData(info, bw6761.WithAssignment(h.assignment)) + return bw6761.ProveHint(info.HashName, data)(mod, ins, outs) + } + + return errors.New("unsupported modulus") +} + +// GetAssignment returns the assignment for a particular wire and instance. +func (h *TestEngineHints) GetAssignment(_ *big.Int, ins []*big.Int, outs []*big.Int) error { + if len(ins) != 3 || !ins[0].IsUint64() || !ins[1].IsUint64() { + return errors.New("expected 3 inputs: wire index, instance index, and dummy output from the same instance") + } + if len(outs) != 1 { + return errors.New("expected 1 output: the value of the wire at the given instance") + } + *outs[0] = utils.FromInterface(h.assignment[ins[0].Uint64()][ins[1].Uint64()]) + return nil +} + +type gateAPI struct{ *big.Int } + +func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + in1 := utils.FromInterface(i1) + in2 := utils.FromInterface(i2) + + in1.Add(&in1, &in2) + for _, v := range in { + inV := utils.FromInterface(v) + in1.Add(&in1, &inV) + } + return &in1 +} + +func (g gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + x, y := utils.FromInterface(b), utils.FromInterface(c) + x.Mul(&x, &y) + x.Mod(&x, g.Int) // reduce + y = utils.FromInterface(a) + x.Add(&x, &y) + return &x +} + +func (g gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + x := utils.FromInterface(i1) + x.Neg(&x) + return &x +} + +func (g gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + x := utils.FromInterface(i1) + y := utils.FromInterface(i2) + x.Sub(&x, &y) + for _, v := range in { + y = utils.FromInterface(v) + x.Sub(&x, &y) + } + return &x +} + +func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + x := utils.FromInterface(i1) + y := utils.FromInterface(i2) + x.Mul(&x, &y) + for _, v := range in { + y = utils.FromInterface(v) + x.Mul(&x, &y) + } + x.Mod(&x, g.Int) // reduce + return &x +} + +func (g gateAPI) Println(a ...frontend.Variable) { + strings := make([]string, len(a)) + for i := range a { + if s, ok := a[i].(fmt.Stringer); ok { + strings[i] = s.String() + } else { + bigInt := utils.FromInterface(a[i]) + strings[i] = bigInt.String() + } + } +} diff --git a/internal/gkr/gkrinfo/info.go b/internal/gkr/gkrinfo/info.go index de9a845e8d..6d14e37dfb 100644 --- a/internal/gkr/gkrinfo/info.go +++ b/internal/gkr/gkrinfo/info.go @@ -2,11 +2,7 @@ package gkrinfo import ( - "fmt" - "sort" - "github.com/consensys/gnark/constraint/solver" - "github.com/consensys/gnark/internal/utils" ) type ( @@ -17,26 +13,19 @@ type ( } Wire struct { - Gate string - Inputs []int - NbUniqueOutputs int + Gate string + Inputs []int } Circuit []Wire - PrintInfo struct { - Values []any - Instance uint32 - IsGkrVar []bool - } StoringInfo struct { - Circuit Circuit - Dependencies [][]InputDependency // nil for input wires - NbInstances int - HashName string - SolveHintID solver.HintID - ProveHintID solver.HintID - Prints []PrintInfo + Circuit Circuit + NbInstances int + HashName string + GetAssignmentHintID solver.HintID + SolveHintID solver.HintID + ProveHintID solver.HintID } Permutations struct { @@ -51,88 +40,12 @@ func (w Wire) IsInput() bool { return len(w.Inputs) == 0 } -func (w Wire) IsOutput() bool { - return w.NbUniqueOutputs == 0 -} - func (d *StoringInfo) NewInputVariable() int { i := len(d.Circuit) d.Circuit = append(d.Circuit, Wire{}) - d.Dependencies = append(d.Dependencies, nil) return i } -// Compile sorts the Circuit wires, their dependencies and the instances -func (d *StoringInfo) Compile(nbInstances int) (Permutations, error) { - - var p Permutations - d.NbInstances = nbInstances - // sort the instances to decide the order in which they are to be solved - instanceDeps := make([][]int, nbInstances) - for i := range d.Circuit { - for _, dep := range d.Dependencies[i] { - instanceDeps[dep.InputInstance] = append(instanceDeps[dep.InputInstance], dep.OutputInstance) - } - } - - p.SortedInstances, _ = utils.TopologicalSort(instanceDeps) - p.InstancesPermutation = utils.InvertPermutation(p.SortedInstances) - - // this whole circuit sorting is a bit of a charade. if things are built using an api, there's no way it could NOT already be topologically sorted - // worth keeping for future-proofing? - - inputs := utils.Map(d.Circuit, func(w Wire) []int { - return w.Inputs - }) - - var uniqueOuts [][]int - p.SortedWires, uniqueOuts = utils.TopologicalSort(inputs) - p.WiresPermutation = utils.InvertPermutation(p.SortedWires) - wirePermutationAt := utils.SliceAt(p.WiresPermutation) - sorted := make([]Wire, len(d.Circuit)) // TODO: Directly manipulate d.circuit instead - sortedDeps := make([][]InputDependency, len(d.Circuit)) - - // go through the wires in the sorted order and fix the input and dependency indices according to the permutations - for newI, oldI := range p.SortedWires { - oldW := d.Circuit[oldI] - - for depI := range d.Dependencies[oldI] { - dep := &d.Dependencies[oldI][depI] - dep.OutputWire = p.WiresPermutation[dep.OutputWire] - dep.InputInstance = p.InstancesPermutation[dep.InputInstance] - dep.OutputInstance = p.InstancesPermutation[dep.OutputInstance] - } - sort.Slice(d.Dependencies[oldI], func(i, j int) bool { - return d.Dependencies[oldI][i].InputInstance < d.Dependencies[oldI][j].InputInstance - }) - for i := 1; i < len(d.Dependencies[oldI]); i++ { - if d.Dependencies[oldI][i].InputInstance == d.Dependencies[oldI][i-1].InputInstance { - return p, fmt.Errorf("an input wire can only have one dependency per instance") - } - } // TODO: Check that dependencies and explicit assignments cover all instances - - sortedDeps[newI] = d.Dependencies[oldI] - sorted[newI] = Wire{ - Gate: oldW.Gate, - Inputs: utils.Map(oldW.Inputs, wirePermutationAt), - NbUniqueOutputs: len(uniqueOuts[oldI]), - } - } - - // re-arrange the prints - for i := range d.Prints { - for j, isVar := range d.Prints[i].IsGkrVar { - if isVar { - d.Prints[i].Values[j] = uint32(p.WiresPermutation[d.Prints[i].Values[j].(uint32)]) - } - } - } - - d.Circuit, d.Dependencies = sorted, sortedDeps - - return p, nil -} - func (d *StoringInfo) Is() bool { return d.Circuit != nil } diff --git a/internal/gkr/gkrtesting/gkrtesting.go b/internal/gkr/gkrtesting/gkrtesting.go index ce9ba88942..4c901f04a2 100644 --- a/internal/gkr/gkrtesting/gkrtesting.go +++ b/internal/gkr/gkrtesting/gkrtesting.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" + "github.com/consensys/gnark" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -35,10 +36,10 @@ func NewCache() *Cache { res = api.Mul(res, sum) // sum^7 return res - }, 2, 7, -1) + }, 2, 7, -1, gnark.Curves()) gates["select-input-3"] = gkrtypes.NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return in[2] - }, 3, 1, 0) + }, 3, 1, 0, gnark.Curves()) return &Cache{ circuits: make(map[string]gkrtypes.Circuit), diff --git a/internal/gkr/gkrtypes/types.go b/internal/gkr/gkrtypes/types.go index 7aed5ccd27..d313a7bc59 100644 --- a/internal/gkr/gkrtypes/types.go +++ b/internal/gkr/gkrtypes/types.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" + "github.com/consensys/gnark" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/utils" @@ -17,15 +19,27 @@ type Gate struct { nbIn int // number of inputs degree int // total degree of the polynomial solvableVar int // if there is a variable whose value can be uniquely determined from the value of the gate and the other inputs, its index, -1 otherwise + curves []ecc.ID // curves that the gate is allowed to be used over } -func NewGate(f gkr.GateFunction, nbIn int, degree int, solvableVar int) *Gate { +func NewGate(f gkr.GateFunction, nbIn int, degree int, solvableVar int, curves []ecc.ID) *Gate { + return &Gate{ evaluate: f, nbIn: nbIn, degree: degree, solvableVar: solvableVar, + curves: curves, + } +} + +func (g *Gate) SupportsCurve(curve ecc.ID) bool { + for _, c := range g.curves { + if c == curve { + return true + } } + return false } func (g *Gate) Evaluate(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { @@ -133,49 +147,9 @@ func (c Circuit) MemoryRequirements(nbInstances int) []int { } type SolvingInfo struct { - Circuit Circuit - Dependencies [][]gkrinfo.InputDependency - NbInstances int - HashName string - Prints []gkrinfo.PrintInfo -} - -// Chunks returns intervals of instances that are independent of each other and can be solved in parallel -func (info *SolvingInfo) Chunks() []int { - res := make([]int, 0, 1) - lastSeenDependencyI := make([]int, len(info.Circuit)) - - for start, end := 0, 0; start != info.NbInstances; start = end { - end = info.NbInstances - endWireI := -1 - for wI := range info.Circuit { - deps := info.Dependencies[wI] - if wDepI := lastSeenDependencyI[wI]; wDepI < len(deps) && deps[wDepI].InputInstance < end { - end = deps[wDepI].InputInstance - endWireI = wI - } - } - if endWireI != -1 { - lastSeenDependencyI[endWireI]++ - } - res = append(res, end) - } - return res -} - -// AssignmentOffsets describes the input layout of the Solve hint, by returning -// for each wire, the index of the first hint input element corresponding to it. -func (info *SolvingInfo) AssignmentOffsets() []int { - c := info.Circuit - res := make([]int, len(c)+1) - for i := range c { - nbExplicitAssignments := 0 - if c[i].IsInput() { - nbExplicitAssignments = info.NbInstances - len(info.Dependencies[i]) - } - res[i+1] = res[i] + nbExplicitAssignments - } - return res + Circuit Circuit + NbInstances int + HashName string } // OutputsList for each wire, returns the set of indexes of wires it is input to. @@ -206,7 +180,7 @@ func (c Circuit) OutputsList() [][]int { return res } -func (c Circuit) SetNbUniqueOutputs() { +func (c Circuit) setNbUniqueOutputs() { for i := range c { c[i].NbUniqueOutputs = 0 @@ -254,6 +228,7 @@ func CircuitInfoToCircuit(info gkrinfo.Circuit, gateGetter func(name gkr.GateNam resCircuit := make(Circuit, len(info)) for i := range info { if info[i].Gate == "" && len(info[i].Inputs) == 0 { + resCircuit[i].Gate = Identity() // input wire continue } resCircuit[i].Inputs = info[i].Inputs @@ -262,17 +237,16 @@ func CircuitInfoToCircuit(info gkrinfo.Circuit, gateGetter func(name gkr.GateNam return nil, fmt.Errorf("gate \"%s\" not found", info[i].Gate) } } + resCircuit.setNbUniqueOutputs() return resCircuit, nil } func StoringToSolvingInfo(info gkrinfo.StoringInfo, gateGetter func(name gkr.GateName) *Gate) (SolvingInfo, error) { circuit, err := CircuitInfoToCircuit(info.Circuit, gateGetter) return SolvingInfo{ - Circuit: circuit, - NbInstances: info.NbInstances, - HashName: info.HashName, - Dependencies: info.Dependencies, - Prints: info.Prints, + Circuit: circuit, + NbInstances: info.NbInstances, + HashName: info.HashName, }, err } @@ -388,33 +362,33 @@ var ErrZeroFunction = errors.New("detected a zero function") func Identity() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return in[0] - }, 1, 1, 0) + }, 1, 1, 0, gnark.Curves()) } // Add2 gate: (x, y) -> x + y func Add2() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return api.Add(in[0], in[1]) - }, 2, 1, 0) + }, 2, 1, 0, gnark.Curves()) } // Sub2 gate: (x, y) -> x - y func Sub2() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return api.Sub(in[0], in[1]) - }, 2, 1, 0) + }, 2, 1, 0, gnark.Curves()) } // Neg gate: x -> -x func Neg() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return api.Neg(in[0]) - }, 1, 1, 0) + }, 1, 1, 0, gnark.Curves()) } // Mul2 gate: (x, y) -> x * y func Mul2() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return api.Mul(in[0], in[1]) - }, 2, 2, -1) + }, 2, 2, -1, gnark.Curves()) } diff --git a/internal/gkr/small_rational/gate_testing.go b/internal/gkr/small_rational/gate_testing.go index dc29624d7b..6e3dea5781 100644 --- a/internal/gkr/small_rational/gate_testing.go +++ b/internal/gkr/small_rational/gate_testing.go @@ -142,6 +142,17 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error return nil } +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(small_rational.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} + // interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) // Note that the runtime is O(len(X)³) func interpolate(X, Y []small_rational.SmallRational) (polynomial.Polynomial, error) { diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index e8e78f4b96..cdf62359f2 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod small_rational.SmallRational - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/utils/slices.go b/internal/utils/slices.go index dd2e2db31f..f493bf4bca 100644 --- a/internal/utils/slices.go +++ b/internal/utils/slices.go @@ -16,3 +16,15 @@ func References[T any](v []T) []*T { } return res } + +// ExtendRepeatLast extends the slice s by repeating the last element until it reaches the length n. +func ExtendRepeatLast[T any](s []T, n int) []T { + if n <= len(s) { + return s[:n] + } + s = s[:len(s):len(s)] // ensure s is a slice with a capacity equal to its length + for len(s) < n { + s = append(s, s[len(s)-1]) // append the last element until the length is n + } + return s +} diff --git a/internal/utils/slices_test.go b/internal/utils/slices_test.go new file mode 100644 index 0000000000..f61ec18fed --- /dev/null +++ b/internal/utils/slices_test.go @@ -0,0 +1,25 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtendRepeatLast(t *testing.T) { + // normal case + s := []int{1, 2, 3} + u := ExtendRepeatLast(s, 5) + assert.Equal(t, []int{1, 2, 3, 3, 3}, u) + + // don't overwrite super-slice + s = []int{1, 2, 3} + u = ExtendRepeatLast(s[:1], 2) + assert.Equal(t, []int{1, 1}, u) + assert.Equal(t, []int{1, 2, 3}, s) + + // trim if n < len(s) + s = []int{1, 2, 3} + u = ExtendRepeatLast(s, 2) + assert.Equal(t, []int{1, 2}, u) +} diff --git a/std/gkrapi/api.go b/std/gkrapi/api.go index 771613ce0d..ae3c2b7954 100644 --- a/std/gkrapi/api.go +++ b/std/gkrapi/api.go @@ -23,12 +23,11 @@ func (api *API) NamedGate(gate gkr.GateName, in ...gkr.Variable) gkr.Variable { Inputs: utils.Map(in, frontendVarToInt), }) api.assignments = append(api.assignments, nil) - api.toStore.Dependencies = append(api.toStore.Dependencies, nil) // formality. Dependencies are only defined for input vars. return gkr.Variable(len(api.toStore.Circuit) - 1) } func (api *API) Gate(gate gkr.GateFunction, in ...gkr.Variable) gkr.Variable { - if err := gkrgates.Register(gate, len(in)); err != nil { + if _, err := gkrgates.Register(gate, len(in)); err != nil { panic(err) } return api.NamedGate(gkrgates.GetDefaultGateName(gate), in...) @@ -59,25 +58,3 @@ func (api *API) Sub(i1, i2 gkr.Variable) gkr.Variable { func (api *API) Mul(i1, i2 gkr.Variable) gkr.Variable { return api.namedGate2PlusIn(gkr.Mul2, i1, i2) } - -// Println writes to the standard output. -// instance determines which values are chosen for gkr.Variable input. -func (api *API) Println(instance int, a ...any) { - isVar := make([]bool, len(a)) - vals := make([]any, len(a)) - for i := range a { - v, ok := a[i].(gkr.Variable) - isVar[i] = ok - if ok { - vals[i] = uint32(v) - } else { - vals[i] = a[i] - } - } - - api.toStore.Prints = append(api.toStore.Prints, gkrinfo.PrintInfo{ - Values: vals, - Instance: uint32(instance), - IsGkrVar: isVar, - }) -} diff --git a/std/gkrapi/api_test.go b/std/gkrapi/api_test.go index 5823c687dd..1f3d187dae 100644 --- a/std/gkrapi/api_test.go +++ b/std/gkrapi/api_test.go @@ -1,7 +1,6 @@ package gkrapi import ( - "bytes" "fmt" "hash" "math/big" @@ -16,10 +15,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" gcHash "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/backend/groth16" - "github.com/consensys/gnark/constraint/solver/gkrgates" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/internal/gkr/gkrinfo" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/gkrapi/gkr" stdHash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/test" @@ -27,7 +25,7 @@ import ( "github.com/stretchr/testify/require" ) -// compressThreshold --> if linear expressions are larger than this, the frontend will introduce +// compressThreshold → if linear expressions are larger than this, the frontend will introduce // intermediate constraints. The lower this number is, the faster compile time should be (to a point) // but resulting circuit will have more constraints (slower proving time). const compressThreshold = 1000 @@ -39,23 +37,21 @@ type doubleNoDependencyCircuit struct { func (c *doubleNoDependencyCircuit) Define(api frontend.API) error { gkrApi := New() - var x gkr.Variable - var err error - if x, err = gkrApi.Import(c.X); err != nil { - return err - } + x := gkrApi.NewInput() z := gkrApi.Add(x, x) - var solution Solution - if solution, err = gkrApi.Solve(api); err != nil { - return err - } - Z := solution.Export(z) - for i := range Z { - api.AssertIsEqual(Z[i], api.Mul(2, c.X[i])) - } + gkrCircuit := gkrApi.Compile(api, c.hashName) - return solution.Verify(c.hashName) + instanceIn := make(map[gkr.Variable]frontend.Variable) + for i := range c.X { + instanceIn[x] = c.X[i] + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + api.AssertIsEqual(instanceOut[z], api.Mul(2, c.X[i])) + } + return nil } func TestDoubleNoDependencyCircuit(t *testing.T) { @@ -87,23 +83,21 @@ type sqNoDependencyCircuit struct { func (c *sqNoDependencyCircuit) Define(api frontend.API) error { gkrApi := New() - var x gkr.Variable - var err error - if x, err = gkrApi.Import(c.X); err != nil { - return err - } + x := gkrApi.NewInput() z := gkrApi.Mul(x, x) - var solution Solution - if solution, err = gkrApi.Solve(api); err != nil { - return err - } - Z := solution.Export(z) - for i := range Z { - api.AssertIsEqual(Z[i], api.Mul(c.X[i], c.X[i])) - } + gkrCircuit := gkrApi.Compile(api, c.hashName) - return solution.Verify(c.hashName) + instanceIn := make(map[gkr.Variable]frontend.Variable) + for i := range c.X { + instanceIn[x] = c.X[i] + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + api.AssertIsEqual(instanceOut[z], api.Mul(c.X[i], c.X[i])) + } + return nil } func TestSqNoDependencyCircuit(t *testing.T) { @@ -134,29 +128,23 @@ type mulNoDependencyCircuit struct { func (c *mulNoDependencyCircuit) Define(api frontend.API) error { gkrApi := New() - var x, y gkr.Variable - var err error - if x, err = gkrApi.Import(c.X); err != nil { - return err - } - if y, err = gkrApi.Import(c.Y); err != nil { - return err - } - gkrApi.Println(0, "values of x and y in instance number", 0, x, y) - + x := gkrApi.NewInput() + y := gkrApi.NewInput() z := gkrApi.Mul(x, y) - gkrApi.Println(1, "value of z in instance number", 1, z) - var solution Solution - if solution, err = gkrApi.Solve(api); err != nil { - return err - } - Z := solution.Export(z) + gkrCircuit := gkrApi.Compile(api, c.hashName) + + instanceIn := make(map[gkr.Variable]frontend.Variable) for i := range c.X { - api.AssertIsEqual(Z[i], api.Mul(c.X[i], c.Y[i])) + instanceIn[x] = c.X[i] + instanceIn[y] = c.Y[i] + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + api.AssertIsEqual(instanceOut[z], api.Mul(c.Y[i], c.X[i])) } - - return solution.Verify(c.hashName) + return nil } func TestMulNoDependency(t *testing.T) { @@ -190,91 +178,68 @@ func TestMulNoDependency(t *testing.T) { } type mulWithDependencyCircuit struct { - XLast frontend.Variable + XFirst frontend.Variable Y []frontend.Variable hashName string } func (c *mulWithDependencyCircuit) Define(api frontend.API) error { gkrApi := New() - var x, y gkr.Variable - var err error - - X := make([]frontend.Variable, len(c.Y)) - X[len(c.Y)-1] = c.XLast - if x, err = gkrApi.Import(X); err != nil { - return err - } - if y, err = gkrApi.Import(c.Y); err != nil { - return err - } + x := gkrApi.NewInput() // x is the state variable + y := gkrApi.NewInput() z := gkrApi.Mul(x, y) - for i := len(X) - 1; i > 0; i-- { - gkrApi.Series(x, z, i-1, i) - } + gkrCircuit := gkrApi.Compile(api, c.hashName) - var solution Solution - if solution, err = gkrApi.Solve(api); err != nil { - return err - } - X = solution.Export(x) - Y := solution.Export(y) - Z := solution.Export(z) + state := c.XFirst + instanceIn := make(map[gkr.Variable]frontend.Variable) - lastI := len(X) - 1 - api.AssertIsEqual(Z[lastI], api.Mul(c.XLast, Y[lastI])) - for i := 0; i < lastI; i++ { - api.AssertIsEqual(Z[i], api.Mul(Z[i+1], Y[i])) + for i := range c.Y { + instanceIn[x] = state + instanceIn[y] = c.Y[i] + + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + + api.AssertIsEqual(instanceOut[z], api.Mul(state, c.Y[i])) + state = instanceOut[z] // update state for the next iteration } - return solution.Verify(c.hashName) + return nil } func TestSolveMulWithDependency(t *testing.T) { assert := test.NewAssert(t) assignment := mulWithDependencyCircuit{ - XLast: 1, - Y: []frontend.Variable{3, 2}, + XFirst: 1, + Y: []frontend.Variable{3, 2}, } circuit := mulWithDependencyCircuit{Y: make([]frontend.Variable, len(assignment.Y)), hashName: "-20"} assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254)) } func TestApiMul(t *testing.T) { - var ( - x gkr.Variable - y gkr.Variable - z gkr.Variable - err error - ) api := New() - x, err = api.Import([]frontend.Variable{nil, nil}) - require.NoError(t, err) - y, err = api.Import([]frontend.Variable{nil, nil}) - require.NoError(t, err) - z = api.Mul(x, y) + x := api.NewInput() + y := api.NewInput() + z := api.Mul(x, y) assertSliceEqual(t, api.toStore.Circuit[z].Inputs, []int{int(x), int(y)}) // TODO: Find out why assert.Equal gives false positives ( []*Wire{x,x} as second argument passes when it shouldn't ) } func BenchmarkMiMCMerkleTree(b *testing.B) { - depth := 14 - bottom := make([]frontend.Variable, 1<= 0; d-- { - for i := 0; i < 1< 1 { + nextLayer := curLayer[:len(curLayer)/2] - challenge, err := api.(frontend.Committer).Commit(Z...) - if err != nil { - return err - } + for i := range nextLayer { + instanceIn[x] = curLayer[2*i] + instanceIn[y] = curLayer[2*i+1] - return solution.Verify("-20", challenge) -} + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + nextLayer[i] = instanceOut[z] // store the result of the hash + } -func init() { - registerMiMCGate() + curLayer = nextLayer + } + return nil } -func registerMiMCGate() { - // register mimc gate - panicIfError(gkrgates.Register(func(api gkr.GateAPI, input ...frontend.Variable) frontend.Variable { - mimcSnarkTotalCalls++ +func mimcGate(api gkr.GateAPI, input ...frontend.Variable) frontend.Variable { + mimcSnarkTotalCalls++ - if len(input) != 2 { - panic("mimc has fan-in 2") - } - sum := api.Add(input[0], input[1] /*, m.Ark*/) + if len(input) != 2 { + panic("mimc has fan-in 2") + } + sum := api.Add(input[0], input[1] /*, m.Ark*/) - sumCubed := api.Mul(sum, sum, sum) // sum^3 - return api.Mul(sumCubed, sumCubed, sum) - }, 2, gkrgates.WithDegree(7), gkrgates.WithName("MIMC"))) + sumCubed := api.Mul(sum, sum, sum) // sum³ + return api.Mul(sumCubed, sumCubed, sum) } type constPseudoHash int @@ -406,8 +362,6 @@ func (c constPseudoHash) Write(...frontend.Variable) {} func (c constPseudoHash) Reset() {} -var mimcFrTotalCalls = 0 - type mimcNoGkrCircuit struct { X []frontend.Variable Y []frontend.Variable @@ -456,26 +410,33 @@ type mimcNoDepCircuit struct { } func (c *mimcNoDepCircuit) Define(api frontend.API) error { - _gkr := New() - x, err := _gkr.Import(c.X) - if err != nil { - return err + // define the circuit + gkrApi := New() + x := gkrApi.NewInput() + y := gkrApi.NewInput() + + if c.mimcDepth < 1 { + return fmt.Errorf("mimcDepth must be at least 1, got %d", c.mimcDepth) } - var ( - y gkr.Variable - solution Solution - ) - if y, err = _gkr.Import(c.Y); err != nil { - return err + + z := y + for range c.mimcDepth { + z = gkrApi.Gate(mimcGate, x, z) } - z := _gkr.NamedGate("MIMC", x, y) + gkrCircuit := gkrApi.Compile(api, c.hashName) + + instanceIn := make(map[gkr.Variable]frontend.Variable) + for i := range c.X { + instanceIn[x] = c.X[i] + instanceIn[y] = c.Y[i] - if solution, err = _gkr.Solve(api); err != nil { - return err + _, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } } - Z := solution.Export(z) - return solution.Verify(c.hashName, Z...) + return nil } func mimcNoDepCircuits(mimcDepth, nbInstances int, hashName string) (circuit, assignment frontend.Circuit) { @@ -557,64 +518,6 @@ func mimcNoGkrCircuits(mimcDepth, nbInstances int) (circuit, assignment frontend return } -func TestSolveInTestEngine(t *testing.T) { - assignment := testSolveInTestEngineCircuit{ - X: []frontend.Variable{2, 3, 4, 5, 6, 7, 8, 9}, - } - circuit := testSolveInTestEngineCircuit{ - X: make([]frontend.Variable, len(assignment.X)), - } - - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BN254.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS24_315.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS12_381.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS24_317.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BW6_633.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BW6_761.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS12_377.ScalarField())) -} - -type testSolveInTestEngineCircuit struct { - X []frontend.Variable -} - -func (c *testSolveInTestEngineCircuit) Define(api frontend.API) error { - gkrBn254 := New() - x, err := gkrBn254.Import(c.X) - if err != nil { - return err - } - Y := make([]frontend.Variable, len(c.X)) - Y[0] = 1 - y, err := gkrBn254.Import(Y) - if err != nil { - return err - } - - z := gkrBn254.Mul(x, y) - - for i := range len(c.X) - 1 { - gkrBn254.Series(y, z, i+1, i) - } - - assignments := gkrBn254.SolveInTestEngine(api) - - product := frontend.Variable(1) - for i := range c.X { - api.AssertIsEqual(assignments[y][i], product) - product = api.Mul(product, c.X[i]) - api.AssertIsEqual(assignments[z][i], product) - } - - return nil -} - -func panicIfError(err error) { - if err != nil { - panic(err) - } -} - func assertSliceEqual[T comparable](t *testing.T, expected, seen []T) { assert.Equal(t, len(expected), len(seen)) for i := range seen { @@ -636,7 +539,7 @@ func (m MiMCCipherGate) Evaluate(api frontend.API, input ...frontend.Variable) f } sum := api.Add(input[0], input[1], m.Ark) - sumCubed := api.Mul(sum, sum, sum) // sum^3 + sumCubed := api.Mul(sum, sum, sum) // sum³ return api.Mul(sumCubed, sumCubed, sum) } @@ -689,46 +592,102 @@ func init() { } } -func ExamplePrintln() { +// pow3Circuit computes x⁴ and also checks the correctness of intermediate value x². +// This is to demonstrate the use of [Circuit.GetValue] and should not be done +// in production code, as it negates the performance benefits of using GKR in the first place. +type pow4Circuit struct { + X []frontend.Variable +} + +func (c *pow4Circuit) Define(api frontend.API) error { + gkrApi := New() + x := gkrApi.NewInput() + x2 := gkrApi.Mul(x, x) // x² + x4 := gkrApi.Mul(x2, x2) // x⁴ + + gkrCircuit := gkrApi.Compile(api, "MIMC") - circuit := &mulNoDependencyCircuit{ - X: make([]frontend.Variable, 2), - Y: make([]frontend.Variable, 2), - hashName: "MIMC", + for i := range c.X { + instanceIn := make(map[gkr.Variable]frontend.Variable) + instanceIn[x] = c.X[i] + + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + + api.AssertIsEqual(gkrCircuit.GetValue(x, i), c.X[i]) // x + + v := api.Mul(c.X[i], c.X[i]) // x² + api.AssertIsEqual(gkrCircuit.GetValue(x2, i), v) // x² + + v = api.Mul(v, v) // x⁴ + api.AssertIsEqual(gkrCircuit.GetValue(x4, i), v) // x⁴ + api.AssertIsEqual(instanceOut[x4], v) // x⁴ + } + + return nil +} + +func TestPow4Circuit_GetValue(t *testing.T) { + assignment := pow4Circuit{ + X: []frontend.Variable{1, 2, 3, 4, 5}, } - assignment := &mulNoDependencyCircuit{ - X: []frontend.Variable{10, 11}, - Y: []frontend.Variable{12, 13}, + circuit := pow4Circuit{ + X: make([]frontend.Variable, len(assignment.X)), } - field := ecc.BN254.ScalarField() + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) +} - // with test engine - err := test.IsSolved(circuit, assignment, field) - panicIfError(err) +func TestWitnessExtend(t *testing.T) { + circuit := doubleNoDependencyCircuit{X: make([]frontend.Variable, 3), hashName: "-1"} + assignment := doubleNoDependencyCircuit{X: []frontend.Variable{0, 0, 1}} - // with groth16 / serialized CS - firstCs, err := frontend.Compile(field, r1cs.NewBuilder, circuit) - panicIfError(err) + cs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuit) + require.NoError(t, err) - var bb bytes.Buffer - _, err = firstCs.WriteTo(&bb) - panicIfError(err) - cs := groth16.NewCS(ecc.BN254) - _, err = cs.ReadFrom(&bb) - panicIfError(err) + witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + require.NoError(t, err) - pk, _, err := groth16.Setup(cs) - panicIfError(err) - w, err := frontend.NewWitness(assignment, field) - panicIfError(err) - _, err = groth16.Prove(cs, pk, w) - panicIfError(err) - - // Output: - // values of x and y in instance number 0 10 12 - // value of z in instance number 1 143 - // values of x and y in instance number 0 10 12 - // value of z in instance number 1 143 + _, err = cs.Solve(witness) + require.NoError(t, err) +} + +func TestSingleInstance(t *testing.T) { + circuit := mimcNoDepCircuit{ + X: make([]frontend.Variable, 1), + Y: make([]frontend.Variable, 1), + mimcDepth: 2, + hashName: "MIMC", + } + assignment := mimcNoDepCircuit{ + X: []frontend.Variable{10}, + Y: []frontend.Variable{2}, + } + + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) +} + +func TestNoInstance(t *testing.T) { + var circuit testNoInstanceCircuit + assignment := testNoInstanceCircuit{0} + + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) +} + +type testNoInstanceCircuit struct { + Dummy frontend.Variable // Plonk prover would fail on an empty witness +} + +func (c *testNoInstanceCircuit) Define(api frontend.API) error { + gkrApi := New() + x := gkrApi.NewInput() + y := gkrApi.Mul(x, x) + gkrApi.Mul(x, y) + + gkrApi.Compile(api, "MIMC") + + return nil } diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 1e7784fb97..390647b89d 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -1,10 +1,10 @@ package gkrapi import ( - "errors" "fmt" "math/bits" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/constraint/solver/gkrgates" "github.com/consensys/gnark/frontend" @@ -15,6 +15,7 @@ import ( fiatshamir "github.com/consensys/gnark/std/fiat-shamir" "github.com/consensys/gnark/std/gkrapi/gkr" "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/multicommit" ) type circuitDataForSnark struct { @@ -22,18 +23,17 @@ type circuitDataForSnark struct { assignments gkrtypes.WireAssignment } -type Solution struct { - toStore gkrinfo.StoringInfo - assignments gkrtypes.WireAssignment - parentApi frontend.API - permutations gkrinfo.Permutations -} - -func (api *API) nbInstances() int { - if len(api.assignments) == 0 { - return -1 - } - return api.assignments.NbInstances() +type InitialChallengeGetter func() []frontend.Variable + +// Circuit represents a GKR circuit. +type Circuit struct { + toStore gkrinfo.StoringInfo + assignments gkrtypes.WireAssignment + getInitialChallenges InitialChallengeGetter // optional getter for the initial Fiat-Shamir challenge + ins []gkr.Variable + outs []gkr.Variable + api frontend.API // the parent API used for hints + hints *gadget.TestEngineHints // hints for the GKR circuit, used for testing purposes } // New creates a new GKR API @@ -41,170 +41,212 @@ func New() *API { return &API{} } -// log2 returns -1 if x is not a power of 2 -func log2(x uint) int { - if bits.OnesCount(x) != 1 { - return -1 - } - return bits.TrailingZeros(x) +// NewInput creates a new input variable. +func (api *API) NewInput() gkr.Variable { + return gkr.Variable(api.toStore.NewInputVariable()) } -// Series like in an electric circuit, binds an input of an instance to an output of another -func (api *API) Series(input, output gkr.Variable, inputInstance, outputInstance int) *API { - if api.assignments[input][inputInstance] != nil { - panic("dependency attempting to override explicit value assignment") - } - api.toStore.Dependencies[input] = - append(api.toStore.Dependencies[input], gkrinfo.InputDependency{ - OutputWire: int(output), - OutputInstance: outputInstance, - InputInstance: inputInstance, - }) - return api -} +type compileOption func(*Circuit) -// Import creates a new input variable, whose values across all instances are given by assignment. -// If the value in an instance depends on an output of another instance, leave the corresponding index in assignment nil and use Series to specify the dependency. -func (api *API) Import(assignment []frontend.Variable) (gkr.Variable, error) { - nbInstances := len(assignment) - logNbInstances := log2(uint(nbInstances)) - if logNbInstances == -1 { - return -1, errors.New("number of assignments must be a power of 2") +// WithInitialChallenge provides a getter for the initial Fiat-Shamir challenge. +// If not provided, the initial challenge will be a commitment to all the input and output values of the circuit. +func WithInitialChallenge(getInitialChallenge InitialChallengeGetter) compileOption { + return func(c *Circuit) { + c.getInitialChallenges = getInitialChallenge } - - if currentNbInstances := api.nbInstances(); currentNbInstances != -1 && currentNbInstances != nbInstances { - return -1, errors.New("number of assignments must be consistent across all variables") - } - api.assignments = append(api.assignments, assignment) - return gkr.Variable(api.toStore.NewInputVariable()), nil } -// appendNonNil filters out nil values from src and appends the non-nil values to dst. -// i.e. dst = [0,1], src = [nil, 2, nil, 3] => dst = [0,1,2,3]. -func appendNonNil(dst *[]frontend.Variable, src []frontend.Variable) { - for i := range src { - if src[i] != nil { - *dst = append(*dst, src[i]) - } +// Compile finalizes the GKR circuit. +// From this point on, the circuit cannot be modified. +// But instances can be added to the circuit. +func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, options ...compileOption) *Circuit { + // TODO define levels here + res := Circuit{ + toStore: api.toStore, + assignments: make(gkrtypes.WireAssignment, len(api.toStore.Circuit)), + api: parentApi, } -} -// Solve finalizes the GKR circuit and returns the output variables in the order created -func (api *API) Solve(parentApi frontend.API) (Solution, error) { + res.toStore.HashName = fiatshamirHashName - var p gkrinfo.Permutations var err error - if p, err = api.toStore.Compile(api.assignments.NbInstances()); err != nil { - return Solution{}, err + res.hints, err = gadget.NewTestEngineHints(&res.toStore) + if err != nil { + panic(fmt.Errorf("failed to call GKR hints: %w", err)) } - api.assignments.Permute(p) - - nbInstances := api.toStore.NbInstances - circuit := api.toStore.Circuit - solveHintNIn := 0 - solveHintNOut := 0 + for _, opt := range options { + opt(&res) + } - for i := range circuit { - v := &circuit[i] - in, out := v.IsInput(), v.IsOutput() - if in && out { - return Solution{}, fmt.Errorf("unused input (variable #%d)", i) + notOut := make([]bool, len(res.toStore.Circuit)) + for i := range res.toStore.Circuit { + if res.toStore.Circuit[i].IsInput() { + res.ins = append(res.ins, gkr.Variable(i)) } + for _, inWI := range res.toStore.Circuit[i].Inputs { + notOut[inWI] = true + } + } - if in { - solveHintNIn += nbInstances - len(api.toStore.Dependencies[i]) - } else if out { - solveHintNOut += nbInstances + for i := range res.toStore.Circuit { + if !notOut[i] { + res.outs = append(res.outs, gkr.Variable(i)) } } - // arrange inputs wire first, then in the order solved - ins := make([]frontend.Variable, 0, solveHintNIn) - for i := range circuit { - if circuit[i].IsInput() { - appendNonNil(&ins, api.assignments[i]) + res.toStore.GetAssignmentHintID = solver.GetHintID(res.hints.GetAssignment) + res.toStore.ProveHintID = solver.GetHintID(res.hints.Prove) + res.toStore.SolveHintID = solver.GetHintID(res.hints.Solve) + + parentApi.Compiler().Defer(res.finalize) + + return &res +} + +// AddInstance adds a new instance to the GKR circuit, returning the values of output variables for the instance. +func (c *Circuit) AddInstance(input map[gkr.Variable]frontend.Variable) (map[gkr.Variable]frontend.Variable, error) { + if len(input) != len(c.ins) { + for k := range input { + if k >= gkr.Variable(len(c.ins)) { + return nil, fmt.Errorf("variable %d is out of bounds (max %d)", k, len(c.ins)-1) + } + if !c.toStore.Circuit[k].IsInput() { + return nil, fmt.Errorf("value provided for non-input variable %d", k) + } + } + } + hintIn := make([]frontend.Variable, 1+len(c.ins)) // first input denotes the instance number + hintIn[0] = c.toStore.NbInstances + for hintInI, wI := range c.ins { + if inV, ok := input[wI]; !ok { + return nil, fmt.Errorf("missing entry for input variable %d", wI) + } else { + hintIn[hintInI+1] = inV + c.assignments[wI] = append(c.assignments[wI], inV) } } - solveHintPlaceholder := SolveHintPlaceholder(api.toStore) - outsSerialized, err := parentApi.Compiler().NewHint(solveHintPlaceholder, solveHintNOut, ins...) - api.toStore.SolveHintID = solver.GetHintID(solveHintPlaceholder) + outsSerialized, err := c.api.Compiler().NewHint(c.hints.Solve, len(c.outs), hintIn...) if err != nil { - return Solution{}, err + return nil, fmt.Errorf("failed to call solve hint: %w", err) + } + c.toStore.NbInstances++ + res := make(map[gkr.Variable]frontend.Variable, len(c.outs)) + for i, v := range c.outs { + res[v] = outsSerialized[i] + c.assignments[v] = append(c.assignments[v], outsSerialized[i]) } - for i := range circuit { - if circuit[i].IsOutput() { - api.assignments[i] = outsSerialized[:nbInstances] - outsSerialized = outsSerialized[nbInstances:] + return res, nil +} + +// finalize encodes the verification circuitry for the GKR circuit +func (c *Circuit) finalize(api frontend.API) error { + if api != c.api { + panic("api mismatch") + } + + // if the circuit is empty or with no instances, there is nothing to do. + if len(c.outs) == 0 || len(c.assignments[0]) == 0 { // wire 0 is always an input wire + return nil + } + + // pad instances to the next power of 2 + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(c.toStore.NbInstances))) + // pad instances to the next power of 2 by repeating the last instance + if c.toStore.NbInstances < nbPaddedInstances && c.toStore.NbInstances > 0 { + for _, wI := range c.ins { + c.assignments[wI] = utils.ExtendRepeatLast(c.assignments[wI], nbPaddedInstances) + } + for _, wI := range c.outs { + c.assignments[wI] = utils.ExtendRepeatLast(c.assignments[wI], nbPaddedInstances) } } - for i := range circuit { - for _, dep := range api.toStore.Dependencies[i] { - api.assignments[i][dep.InputInstance] = api.assignments[dep.OutputWire][dep.OutputInstance] + if err := api.(gkrinfo.ConstraintSystem).SetGkrInfo(c.toStore); err != nil { + return err + } + + // if the circuit consists of only one instance, directly solve the circuit + if len(c.assignments[c.ins[0]]) == 1 { + circuit, err := gkrtypes.CircuitInfoToCircuit(c.toStore.Circuit, gkrgates.Get) + if err != nil { + return fmt.Errorf("failed to convert GKR info to circuit: %w", err) + } + gateIn := make([]frontend.Variable, circuit.MaxGateNbIn()) + for wI, w := range circuit { + if w.IsInput() { + continue + } + for inI, inWI := range w.Inputs { + gateIn[inI] = c.assignments[inWI][0] // take the first (only) instance + } + res := w.Gate.Evaluate(api, gateIn[:len(w.Inputs)]...) + if w.IsOutput() { + api.AssertIsEqual(res, c.assignments[wI][0]) + } else { + c.assignments[wI] = append(c.assignments[wI], res) + } } + return nil } - return Solution{ - toStore: api.toStore, - assignments: api.assignments, - parentApi: parentApi, - permutations: p, - }, nil -} + if c.getInitialChallenges != nil { + return c.verify(api, c.getInitialChallenges()) + } + + // default initial challenge is a commitment to all input and output values + insOuts := make([]frontend.Variable, 0, (len(c.ins)+len(c.outs))*len(c.assignments[c.ins[0]])) + for _, in := range c.ins { + insOuts = append(insOuts, c.assignments[in]...) + } + for _, out := range c.outs { + insOuts = append(insOuts, c.assignments[out]...) + } + + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + return c.verify(api, []frontend.Variable{commitment}) + }, insOuts...) -// Export returns the values of an output variable across all instances -func (s Solution) Export(v gkr.Variable) []frontend.Variable { - return utils.Map(s.permutations.SortedInstances, utils.SliceAt(s.assignments[v])) + return nil } -// Verify encodes the verification circuitry for the GKR circuit -func (s Solution) Verify(hashName string, initialChallenges ...frontend.Variable) error { +func (c *Circuit) verify(api frontend.API, initialChallenges []frontend.Variable) error { + forSnark, err := newCircuitDataForSnark(utils.FieldToCurve(api.Compiler().Field()), c.toStore, c.assignments) + if err != nil { + return fmt.Errorf("failed to create circuit data for snark: %w", err) + } + + hintIns := make([]frontend.Variable, len(initialChallenges)+1) // hack: adding one of the outputs of the solve hint to ensure "prove" is called after "solve" + firstOutputAssignment := c.assignments[c.outs[0]] + hintIns[0] = firstOutputAssignment[len(firstOutputAssignment)-1] // take the last output of the first output wire + + copy(hintIns[1:], initialChallenges) + var ( - err error proofSerialized []frontend.Variable proof gadget.Proof ) - forSnark := newCircuitDataForSnark(s.toStore, s.assignments) - logNbInstances := log2(uint(s.assignments.NbInstances())) - - hintIns := make([]frontend.Variable, len(initialChallenges)+1) // hack: adding one of the outputs of the solve hint to ensure "prove" is called after "solve" - for i, w := range s.toStore.Circuit { - if w.IsOutput() { - hintIns[0] = s.assignments[i][len(s.assignments[i])-1] - break - } - } - copy(hintIns[1:], initialChallenges) - - proveHintPlaceholder := ProveHintPlaceholder(hashName) - if proofSerialized, err = s.parentApi.Compiler().NewHint( - proveHintPlaceholder, gadget.ProofSize(forSnark.circuit, logNbInstances), hintIns...); err != nil { + if proofSerialized, err = api.Compiler().NewHint( + c.hints.Prove, gadget.ProofSize(forSnark.circuit, bits.TrailingZeros(uint(len(c.assignments[0])))), hintIns...); err != nil { return err } - s.toStore.ProveHintID = solver.GetHintID(proveHintPlaceholder) + c.toStore.ProveHintID = solver.GetHintID(c.hints.Prove) - forSnarkSorted := utils.MapRange(0, len(s.toStore.Circuit), slicePtrAt(forSnark.circuit)) + forSnarkSorted := utils.MapRange(0, len(c.toStore.Circuit), slicePtrAt(forSnark.circuit)) if proof, err = gadget.DeserializeProof(forSnarkSorted, proofSerialized); err != nil { return err } var hsh hash.FieldHasher - if hsh, err = hash.GetFieldHasher(hashName, s.parentApi); err != nil { - return err - } - s.toStore.HashName = hashName - - err = gadget.Verify(s.parentApi, forSnark.circuit, forSnark.assignments, proof, fiatshamir.WithHash(hsh, initialChallenges...), gadget.WithSortedCircuit(forSnarkSorted)) - if err != nil { + if hsh, err = hash.GetFieldHasher(c.toStore.HashName, api); err != nil { return err } - return s.parentApi.(gkrinfo.ConstraintSystem).SetGkrInfo(s.toStore) + return gadget.Verify(api, forSnark.circuit, forSnark.assignments, proof, fiatshamir.WithHash(hsh, initialChallenges...), gadget.WithSortedCircuit(forSnarkSorted)) } func slicePtrAt[T any](slice []T) func(int) *T { @@ -213,28 +255,31 @@ func slicePtrAt[T any](slice []T) func(int) *T { } } -func ite[T any](condition bool, ifNot, IfSo T) T { - if condition { - return IfSo +func newCircuitDataForSnark(curve ecc.ID, info gkrinfo.StoringInfo, assignment gkrtypes.WireAssignment) (circuitDataForSnark, error) { + circuit, err := gkrtypes.CircuitInfoToCircuit(info.Circuit, gkrgates.Get) + if err != nil { + return circuitDataForSnark{}, fmt.Errorf("failed to convert GKR info to circuit: %w", err) } - return ifNot -} - -func newCircuitDataForSnark(info gkrinfo.StoringInfo, assignment gkrtypes.WireAssignment) circuitDataForSnark { - circuit := make(gkrtypes.Circuit, len(info.Circuit)) - snarkAssignment := make(gkrtypes.WireAssignment, len(info.Circuit)) for i := range circuit { - w := info.Circuit[i] - circuit[i] = gkrtypes.Wire{ - Gate: gkrgates.Get(ite(w.IsInput(), gkr.GateName(w.Gate), gkr.Identity)), - Inputs: w.Inputs, - NbUniqueOutputs: w.NbUniqueOutputs, + if !circuit[i].Gate.SupportsCurve(curve) { + return circuitDataForSnark{}, fmt.Errorf("gate \"%s\" not usable over curve \"%s\"", info.Circuit[i].Gate, curve) } - snarkAssignment[i] = assignment[i] } + return circuitDataForSnark{ circuit: circuit, - assignments: snarkAssignment, + assignments: assignment, + }, nil +} + +// GetValue is a debugging utility returning the value of variable v at instance i. +// While v can be an input or output variable, GetValue is most useful for querying intermediate values in the circuit. +func (c *Circuit) GetValue(v gkr.Variable, i int) frontend.Variable { + // last input to ensure the solver's work is done before GetAssignment is called + res, err := c.api.Compiler().NewHint(c.hints.GetAssignment, 1, int(v), i, c.assignments[c.outs[0]][i]) + if err != nil { + panic(err) } + return res[0] } diff --git a/std/gkrapi/compile_test.go b/std/gkrapi/compile_test.go deleted file mode 100644 index a0ca992ed4..0000000000 --- a/std/gkrapi/compile_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package gkrapi - -import ( - "testing" - - "github.com/consensys/gnark/internal/gkr/gkrinfo" - "github.com/stretchr/testify/assert" -) - -func TestCompile2Cycles(t *testing.T) { - var d = gkrinfo.StoringInfo{ - Dependencies: [][]gkrinfo.InputDependency{ - nil, - { - { - OutputWire: 0, - OutputInstance: 1, - InputInstance: 0, - }, - }, - }, - Circuit: gkrinfo.Circuit{ - { - Inputs: []int{1}, - }, - { - Inputs: []int{}, - }, - }, - } - - expectedCompiled := gkrinfo.StoringInfo{ - Dependencies: [][]gkrinfo.InputDependency{ - {{ - OutputWire: 1, - OutputInstance: 0, - InputInstance: 1, - }}, - nil, - }, - Circuit: gkrinfo.Circuit{ - { - Inputs: []int{}, - NbUniqueOutputs: 1, - }, - { - Inputs: []int{0}, - }}, - NbInstances: 2, - } - - expectedPermutations := gkrinfo.Permutations{ - SortedInstances: []int{1, 0}, - SortedWires: []int{1, 0}, - InstancesPermutation: []int{1, 0}, - WiresPermutation: []int{1, 0}, - } - - p, err := d.Compile(2) - assert.NoError(t, err) - assert.Equal(t, expectedPermutations, p) - assert.Equal(t, expectedCompiled, d) -} - -func TestCompile3Cycles(t *testing.T) { - var d = gkrinfo.StoringInfo{ - Dependencies: [][]gkrinfo.InputDependency{ - nil, - { - { - OutputWire: 0, - OutputInstance: 2, - InputInstance: 0, - }, - { - OutputWire: 0, - OutputInstance: 1, - InputInstance: 2, - }, - }, - nil, - }, - Circuit: gkrinfo.Circuit{ - { - Inputs: []int{2}, - }, - { - Inputs: []int{}, - }, - { - Inputs: []int{1}, - }, - }, - } - - expectedCompiled := gkrinfo.StoringInfo{ - Dependencies: [][]gkrinfo.InputDependency{ - {{ - OutputWire: 2, - OutputInstance: 0, - InputInstance: 1, - }, { - OutputWire: 2, - OutputInstance: 1, - InputInstance: 2, - }}, - - nil, - nil, - }, - Circuit: gkrinfo.Circuit{ - { - Inputs: []int{}, - NbUniqueOutputs: 1, - }, - { - Inputs: []int{0}, - NbUniqueOutputs: 1, - }, - { - Inputs: []int{1}, - NbUniqueOutputs: 0, - }, - }, - NbInstances: 3, // not allowed if we were actually performing gkr - } - - expectedPermutations := gkrinfo.Permutations{ - SortedInstances: []int{1, 2, 0}, - SortedWires: []int{1, 2, 0}, - InstancesPermutation: []int{2, 0, 1}, - WiresPermutation: []int{2, 0, 1}, - } - - p, err := d.Compile(3) - assert.NoError(t, err) - assert.Equal(t, expectedPermutations, p) - assert.Equal(t, expectedCompiled, d) -} diff --git a/std/gkrapi/example_test.go b/std/gkrapi/example_test.go index 29244e6aab..705f288a73 100644 --- a/std/gkrapi/example_test.go +++ b/std/gkrapi/example_test.go @@ -19,18 +19,22 @@ func Example() { // This means that the imported fr and fp packages are the same, being from BW6-761 and BLS12-377 respectively. TODO @Tabaie delete if no longer have fp imported // It is based on the function DoubleAssign() of type G1Jac in gnark-crypto v0.17.0. // github.com/consensys/gnark-crypto/ecc/bls12-377 - const fsHashName = "MIMC" // register the gates: Doing so is not needed here because // the proof is being computed in the same session as the // SNARK circuit being compiled. // But in production applications it would be necessary. - assertNoError(gkrgates.Register(squareGate, 1)) - assertNoError(gkrgates.Register(sGate, 4)) - assertNoError(gkrgates.Register(zGate, 4)) - assertNoError(gkrgates.Register(xGate, 2)) - assertNoError(gkrgates.Register(yGate, 4)) + _, err := gkrgates.Register(squareGate, 1) + assertNoError(err) + _, err = gkrgates.Register(sGate, 4) + assertNoError(err) + _, err = gkrgates.Register(zGate, 4) + assertNoError(err) + _, err = gkrgates.Register(xGate, 2) + assertNoError(err) + _, err = gkrgates.Register(yGate, 4) + assertNoError(err) const nbInstances = 2 // create instances @@ -63,13 +67,12 @@ func Example() { } circuit := exampleCircuit{ - X: make([]frontend.Variable, nbInstances), - Y: make([]frontend.Variable, nbInstances), - Z: make([]frontend.Variable, nbInstances), - XOut: make([]frontend.Variable, nbInstances), - YOut: make([]frontend.Variable, nbInstances), - ZOut: make([]frontend.Variable, nbInstances), - fsHashName: fsHashName, + X: make([]frontend.Variable, nbInstances), + Y: make([]frontend.Variable, nbInstances), + Z: make([]frontend.Variable, nbInstances), + XOut: make([]frontend.Variable, nbInstances), + YOut: make([]frontend.Variable, nbInstances), + ZOut: make([]frontend.Variable, nbInstances), } assertNoError(test.IsSolved(&circuit, &assignment, ecc.BW6_761.ScalarField())) @@ -80,7 +83,6 @@ func Example() { type exampleCircuit struct { X, Y, Z []frontend.Variable // Jacobian coordinates for each point (input) XOut, YOut, ZOut []frontend.Variable // Jacobian coordinates for the double of each point (expected output) - fsHashName string // name of the hash function used for Fiat-Shamir in the GKR verifier } func (c *exampleCircuit) Define(api frontend.API) error { @@ -90,21 +92,10 @@ func (c *exampleCircuit) Define(api frontend.API) error { gkrApi := gkrapi.New() - // create GKR circuit variables based on the given assignments - X, err := gkrApi.Import(c.X) - if err != nil { - return err - } - - Y, err := gkrApi.Import(c.Y) - if err != nil { - return err - } - - Z, err := gkrApi.Import(c.Z) - if err != nil { - return err - } + // create the GKR circuit + X := gkrApi.NewInput() + Y := gkrApi.NewInput() + Z := gkrApi.NewInput() XX := gkrApi.Gate(squareGate, X) // 405: XX.Square(&p.X) YY := gkrApi.Gate(squareGate, Y) // 406: YY.Square(&p.Y) @@ -116,45 +107,31 @@ func (c *exampleCircuit) Define(api frontend.API) error { // 414: M.Double(&XX).Add(&M, &XX) // Note (but don't explicitly compute) that M = 3XX - Z = gkrApi.Gate(zGate, Z, Y, YY, ZZ) // 415 - 418 - X = gkrApi.Gate(xGate, XX, S) // 419-422 - Y = gkrApi.Gate(yGate, S, X, XX, YYYY) // 423 - 426 - - // have to duplicate X for it to be considered an output variable - X = gkrApi.NamedGate(gkr.Identity, X) - - // solve and prove the circuit - solution, err := gkrApi.Solve(api) - if err != nil { - return err - } - - // check the output - - XOut := solution.Export(X) - YOut := solution.Export(Y) - ZOut := solution.Export(Z) - for i := range XOut { - api.AssertIsEqual(XOut[i], c.XOut[i]) - api.AssertIsEqual(YOut[i], c.YOut[i]) - api.AssertIsEqual(ZOut[i], c.ZOut[i]) - } - - challenges := make([]frontend.Variable, 0, len(c.X)*6) - challenges = append(challenges, XOut...) - challenges = append(challenges, YOut...) - challenges = append(challenges, ZOut...) - challenges = append(challenges, c.X...) - challenges = append(challenges, c.Y...) - challenges = append(challenges, c.Z...) - - challenge, err := api.(frontend.Committer).Commit(challenges...) - if err != nil { - return err + ZOut := gkrApi.Gate(zGate, Z, Y, YY, ZZ) // 415 - 418 + XOut := gkrApi.Gate(xGate, XX, S) // 419-422 + YOut := gkrApi.Gate(yGate, S, XOut, XX, YYYY) // 423 - 426 + + // have to duplicate X for it to be considered an output variable; this is an implementation detail and will be fixed in the future [https://github.com/Consensys/gnark/issues/1452] + XOut = gkrApi.NamedGate(gkr.Identity, XOut) + + gkrCircuit := gkrApi.Compile(api, "MIMC") + + // add input and check output for correctness + instanceIn := make(map[gkr.Variable]frontend.Variable) + for i := range c.X { + instanceIn[X] = c.X[i] + instanceIn[Y] = c.Y[i] + instanceIn[Z] = c.Z[i] + + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return err + } + api.AssertIsEqual(instanceOut[XOut], c.XOut[i]) + api.AssertIsEqual(instanceOut[YOut], c.YOut[i]) + api.AssertIsEqual(instanceOut[ZOut], c.ZOut[i]) } - - // verify the proof - return solution.Verify(c.fsHashName, challenge) + return nil } // custom gates diff --git a/std/gkrapi/hints.go b/std/gkrapi/hints.go deleted file mode 100644 index 577a4d6ed8..0000000000 --- a/std/gkrapi/hints.go +++ /dev/null @@ -1,137 +0,0 @@ -package gkrapi - -import ( - "errors" - "fmt" - "math/big" - "strings" - - "github.com/consensys/gnark-crypto/ecc" - gcHash "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark/constraint/solver" - "github.com/consensys/gnark/constraint/solver/gkrgates" - bls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" - bls12381 "github.com/consensys/gnark/internal/gkr/bls12-381" - bls24315 "github.com/consensys/gnark/internal/gkr/bls24-315" - bls24317 "github.com/consensys/gnark/internal/gkr/bls24-317" - bn254 "github.com/consensys/gnark/internal/gkr/bn254" - bw6633 "github.com/consensys/gnark/internal/gkr/bw6-633" - bw6761 "github.com/consensys/gnark/internal/gkr/bw6-761" - "github.com/consensys/gnark/internal/gkr/gkrinfo" - "github.com/consensys/gnark/internal/gkr/gkrtypes" - "github.com/consensys/gnark/internal/utils" -) - -var testEngineGkrSolvingData = make(map[string]any) - -func modKey(mod *big.Int) string { - return mod.Text(32) -} - -func SolveHintPlaceholder(gkrInfo gkrinfo.StoringInfo) solver.Hint { - return func(mod *big.Int, ins []*big.Int, outs []*big.Int) error { - - solvingInfo, err := gkrtypes.StoringToSolvingInfo(gkrInfo, gkrgates.Get) - if err != nil { - return err - } - - // TODO @Tabaie autogenerate this or decide not to - if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { - var data bls12377.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bls12377.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { - var data bls12381.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bls12381.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { - var data bls24315.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bls24315.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { - var data bls24317.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bls24317.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BN254.ScalarField()) == 0 { - var data bn254.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bn254.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { - var data bw6633.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bw6633.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { - var data bw6761.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bw6761.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - - return errors.New("unsupported modulus") - } -} - -func ProveHintPlaceholder(hashName string) solver.Hint { - return func(mod *big.Int, ins, outs []*big.Int) error { - k := modKey(mod) - data, ok := testEngineGkrSolvingData[k] - if !ok { - return errors.New("solving data not found") - } - delete(testEngineGkrSolvingData, k) - - // TODO @Tabaie autogenerate this or decide not to - if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { - return bls12377.ProveHint(hashName, data.(*bls12377.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { - return bls12381.ProveHint(hashName, data.(*bls12381.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { - return bls24315.ProveHint(hashName, data.(*bls24315.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { - return bls24317.ProveHint(hashName, data.(*bls24317.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BN254.ScalarField()) == 0 { - return bn254.ProveHint(hashName, data.(*bn254.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { - return bw6633.ProveHint(hashName, data.(*bw6633.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { - return bw6761.ProveHint(hashName, data.(*bw6761.SolvingData))(mod, ins, outs) - } - - return errors.New("unsupported modulus") - } -} - -func CheckHashHint(hashName string) solver.Hint { - return func(mod *big.Int, ins, outs []*big.Int) error { - if len(ins) != 2 || len(outs) != 1 { - return errors.New("invalid number of inputs/outputs") - } - - toHash := ins[0].Bytes() - expectedHash := ins[1] - - hsh := gcHash.NewHash(fmt.Sprintf("%s_%s", hashName, strings.ToUpper(utils.FieldToCurve(mod).String()))) - hsh.Write(toHash) - hashed := hsh.Sum(nil) - - if hashed := new(big.Int).SetBytes(hashed); hashed.Cmp(expectedHash) != 0 { - return fmt.Errorf("hash mismatch: expected %s, got %s", expectedHash.String(), hashed.String()) - } - - outs[0].SetBytes(hashed) - - return nil - } -} diff --git a/std/gkrapi/testing.go b/std/gkrapi/testing.go deleted file mode 100644 index 17163c0b5a..0000000000 --- a/std/gkrapi/testing.go +++ /dev/null @@ -1,120 +0,0 @@ -package gkrapi - -import ( - "errors" - "fmt" - "sync" - - "github.com/consensys/gnark/constraint/solver/gkrgates" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/std/gkrapi/gkr" - stdHash "github.com/consensys/gnark/std/hash" -) - -type solveInTestEngineSettings struct { - hashName string -} - -type SolveInTestEngineOption func(*solveInTestEngineSettings) - -func WithHashName(name string) SolveInTestEngineOption { - return func(s *solveInTestEngineSettings) { - s.hashName = name - } -} - -// SolveInTestEngine solves the defined circuit directly inside the SNARK circuit. This means that the method does not compute the GKR proof of the circuit and does not embed the GKR proof verifier inside a SNARK. -// The output is the values of all variables, across all instances; i.e. indexed variable-first, instance-second. -// This method only works under the test engine and should only be called to debug a GKR circuit, as the GKR prover's errors can be obscure. -func (api *API) SolveInTestEngine(parentApi frontend.API, options ...SolveInTestEngineOption) [][]frontend.Variable { - gateVer, err := gkrgates.NewGateVerifier(utils.FieldToCurve(parentApi.Compiler().Field())) - if err != nil { - panic(err) - } - - var s solveInTestEngineSettings - for _, o := range options { - o(&s) - } - if s.hashName != "" { - // hash something and make sure it gives the same answer both on prover and verifier sides - // TODO @Tabaie If indeed cheap, move this feature to Verify so that it is always run - h, err := stdHash.GetFieldHasher(s.hashName, parentApi) - if err != nil { - panic(err) - } - nbBytes := (parentApi.Compiler().FieldBitLen() + 7) / 8 - toHash := frontend.Variable(0) - for i := range nbBytes { - toHash = parentApi.Add(parentApi.Mul(toHash, 256), i%256) - } - h.Reset() - h.Write(toHash) - hashed := h.Sum() - - hintOut, err := parentApi.Compiler().NewHint(CheckHashHint(s.hashName), 1, toHash, hashed) - if err != nil { - panic(err) - } - parentApi.AssertIsEqual(hintOut[0], hashed) // the hint already checks this - } - - res := make([][]frontend.Variable, len(api.toStore.Circuit)) - var verifiedGates sync.Map - for i, w := range api.toStore.Circuit { - res[i] = make([]frontend.Variable, api.nbInstances()) - copy(res[i], api.assignments[i]) - if len(w.Inputs) == 0 { - continue - } - } - for instanceI := range api.nbInstances() { - for wireI, w := range api.toStore.Circuit { - deps := api.toStore.Dependencies[wireI] - if len(deps) != 0 && len(w.Inputs) != 0 { - panic(fmt.Errorf("non-input wire %d should not have dependencies", wireI)) - } - for _, dep := range deps { - if dep.InputInstance == instanceI { - if dep.OutputInstance >= instanceI { - panic(fmt.Errorf("out of order dependency not yet supported in SolveInTestEngine; (wire %d, instance %d) depends on (wire %d, instance %d)", wireI, instanceI, dep.OutputWire, dep.OutputInstance)) - } - if res[wireI][instanceI] != nil { - panic(fmt.Errorf("dependency (wire %d, instance %d) <- (wire %d, instance %d) attempting to override existing value assignment", wireI, instanceI, dep.OutputWire, dep.OutputInstance)) - } - res[wireI][instanceI] = res[dep.OutputWire][dep.OutputInstance] - } - } - - if res[wireI][instanceI] == nil { // no assignment or dependency - if len(w.Inputs) == 0 { - panic(fmt.Errorf("input wire %d, instance %d has no dependency or explicit assignment", wireI, instanceI)) - } - ins := make([]frontend.Variable, len(w.Inputs)) - for i, in := range w.Inputs { - ins[i] = res[in][instanceI] - } - gate := gkrgates.Get(gkr.GateName(w.Gate)) - if gate == nil && !w.IsInput() { - panic(fmt.Errorf("gate %s not found", w.Gate)) - } - if _, ok := verifiedGates.Load(w.Gate); !ok { - verifiedGates.Store(w.Gate, struct{}{}) - - err = errors.Join( - gateVer.VerifyDegree(gate), - gateVer.VerifySolvability(gate), - ) - if err != nil { - panic(fmt.Errorf("gate %s: %w", w.Gate, err)) - } - } - if gate != nil { - res[wireI][instanceI] = gate.Evaluate(parentApi, ins...) - } - } - } - } - return res -} diff --git a/std/lookup/logderivlookup/logderivlookup.go b/std/lookup/logderivlookup/logderivlookup.go index 63f2bc694d..dbeb042762 100644 --- a/std/lookup/logderivlookup/logderivlookup.go +++ b/std/lookup/logderivlookup/logderivlookup.go @@ -1,4 +1,4 @@ -// Package logderiv implements append-only lookups using log-derivative +// Package logderivlookup implements append-only lookups using log-derivative // argument. // // The lookup is based on log-derivative argument as described in [logderivarg]. diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index 218d252eba..800e79ed05 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -1,23 +1,16 @@ package gkr_poseidon2 import ( - "errors" "fmt" - "math/big" "sync" - "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/constraint/solver/gkrgates" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" - "github.com/consensys/gnark/std/hash" - _ "github.com/consensys/gnark/std/hash/mimc" // to ensure mimc is registered "github.com/consensys/gnark-crypto/ecc" - frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" poseidon2Bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" + "github.com/consensys/gnark/frontend" ) // extKeyGate applies the external matrix mul, then adds the round key @@ -45,7 +38,7 @@ func pow4Gate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { // pow4TimesGate computes a, b -> a⁴ * b func pow4TimesGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { - panic("expected 1 input") + panic("expected 2 input") } y := api.Mul(x[0], x[0]) y = api.Mul(y, y) @@ -115,63 +108,68 @@ func extAddGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Add(api.Mul(x[0], 2), x[1], x[2]) } -type GkrCompressions struct { - api frontend.API - ins1 []frontend.Variable - ins2 []frontend.Variable - outs []frontend.Variable +type GkrCompressor struct { + api frontend.API + gkrCircuit *gkrapi.Circuit + in1, in2, out gkr.Variable } -// NewGkrCompressions returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) +// NewGkrCompressor returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) // which consists of a permutation along with the input fed forward. // The correctness of the compression functions is proven using GKR. -// Note that the solver will need the function RegisterGkrSolverOptions to be called with the desired curves -func NewGkrCompressions(api frontend.API) *GkrCompressions { - res := GkrCompressions{ - api: api, +// Note that the solver will need the function RegisterGkrGates to be called with the desired curves +func NewGkrCompressor(api frontend.API) *GkrCompressor { + if api.Compiler().Field().Cmp(ecc.BLS12_377.ScalarField()) != 0 { + panic("currently only BL12-377 is supported") + } + gkrApi, in1, in2, out, err := defineCircuitBls12377() + if err != nil { + panic(fmt.Errorf("failed to define GKR circuit: %v", err)) + } + return &GkrCompressor{ + api: api, + gkrCircuit: gkrApi.Compile(api, "MIMC"), + in1: in1, + in2: in2, + out: out, } - api.Compiler().Defer(res.finalize) - return &res } -func (p *GkrCompressions) Compress(a, b frontend.Variable) frontend.Variable { - s, err := p.api.Compiler().NewHint(permuteHint, 1, a, b) +func (p *GkrCompressor) Compress(a, b frontend.Variable) frontend.Variable { + outs, err := p.gkrCircuit.AddInstance(map[gkr.Variable]frontend.Variable{p.in1: a, p.in2: b}) if err != nil { panic(err) } - p.ins1 = append(p.ins1, a) - p.ins2 = append(p.ins2, b) - p.outs = append(p.outs, s[0]) - return s[0] + + return outs[p.out] } -// defineCircuit defines the GKR circuit for the Poseidon2 permutation over BLS12-377 +// defineCircuitBls12377 defines the GKR circuit for the Poseidon2 permutation over BLS12-377 // insLeft and insRight are the inputs to the permutation // they must be padded to a power of 2 -func defineCircuit(insLeft, insRight []frontend.Variable) (*gkrapi.API, gkr.Variable, error) { +func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, err error) { // variable indexes const ( xI = iota yI ) + if err = registerGatesBls12377(); err != nil { + return + } + // poseidon2 parameters gateNamer := newRoundGateNamer(poseidon2Bls12377.GetDefaultParameters()) rF := poseidon2Bls12377.GetDefaultParameters().NbFullRounds rP := poseidon2Bls12377.GetDefaultParameters().NbPartialRounds halfRf := rF / 2 - gkrApi := gkrapi.New() + gkrApi = gkrapi.New() - x, err := gkrApi.Import(insLeft) - if err != nil { - return nil, -1, err - } - y, err := gkrApi.Import(insRight) - y0 := y // save to feed forward at the end - if err != nil { - return nil, -1, err - } + x := gkrApi.NewInput() + y := gkrApi.NewInput() + + in1, in2 = x, y // save to feed forward at the end // *** helper functions to register and apply gates *** @@ -240,80 +238,9 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkrapi.API, gkr.Vari } // apply the external matrix one last time to obtain the final value of y - y = gkrApi.NamedGate(gateNamer.linear(yI, rP+rF), y, x, y0) - - return gkrApi, y, nil -} - -func (p *GkrCompressions) finalize(api frontend.API) error { - if p.api != api { - panic("unexpected API") - } - - // register gates - registerGkrSolverOptions(api) - - // pad instances into a power of 2 - // TODO @Tabaie the GKR API to do this automatically? - ins1Padded := make([]frontend.Variable, ecc.NextPowerOfTwo(uint64(len(p.ins1)))) - ins2Padded := make([]frontend.Variable, len(ins1Padded)) - copy(ins1Padded, p.ins1) - copy(ins2Padded, p.ins2) - for i := len(p.ins1); i < len(ins1Padded); i++ { - ins1Padded[i] = 0 - ins2Padded[i] = 0 - } - - gkrApi, y, err := defineCircuit(ins1Padded, ins2Padded) - if err != nil { - return err - } - - // connect to output - // TODO can we save 1 constraint per instance by giving the desired outputs to the gkr api? - solution, err := gkrApi.Solve(api) - if err != nil { - return err - } - yVals := solution.Export(y) - for i := range p.outs { - api.AssertIsEqual(yVals[i], p.outs[i]) - } - - // verify GKR proof - allVals := make([]frontend.Variable, 0, 3*len(p.ins1)) - allVals = append(allVals, p.ins1...) - allVals = append(allVals, p.ins2...) - allVals = append(allVals, p.outs...) - challenge, err := p.api.(frontend.Committer).Commit(allVals...) - if err != nil { - return err - } - return solution.Verify(hash.MIMC.String(), challenge) -} + out = gkrApi.NamedGate(gateNamer.linear(yI, rP+rF), y, x, in2) -// registerGkrSolverOptions is a wrapper for RegisterGkrSolverOptions -// that performs the registration for the curve associated with api. -func registerGkrSolverOptions(api frontend.API) { - RegisterGkrSolverOptions(utils.FieldToCurve(api.Compiler().Field())) -} - -func permuteHint(m *big.Int, ins, outs []*big.Int) error { - if m.Cmp(ecc.BLS12_377.ScalarField()) != 0 { - return errors.New("only bls12-377 supported") - } - if len(ins) != 2 || len(outs) != 1 { - return errors.New("expected 2 inputs and 1 output") - } - var x [2]frBls12377.Element - x[0].SetBigInt(ins[0]) - x[1].SetBigInt(ins[1]) - y0 := x[1] - - err := bls12377Permutation().Permutation(x[:]) - x[1].Add(&x[1], &y0) // feed forward - x[1].BigInt(outs[0]) - return err + return } var bls12377Permutation = sync.OnceValue(func() *poseidon2Bls12377.Permutation { @@ -321,16 +248,15 @@ var bls12377Permutation = sync.OnceValue(func() *poseidon2Bls12377.Permutation { return poseidon2Bls12377.NewPermutation(2, params.NbFullRounds, params.NbPartialRounds) // TODO @Tabaie add NewDefaultPermutation to gnark-crypto }) -// RegisterGkrSolverOptions registers the GKR gates corresponding to the given curves for the solver -func RegisterGkrSolverOptions(curves ...ecc.ID) { +// RegisterGkrGates registers the GKR gates corresponding to the given curves for the solver +func RegisterGkrGates(curves ...ecc.ID) { if len(curves) == 0 { panic("expected at least one curve") } - solver.RegisterHint(permuteHint) for _, curve := range curves { switch curve { case ecc.BLS12_377: - if err := registerGkrGatesBls12377(); err != nil { + if err := registerGatesBls12377(); err != nil { panic(err) } default: @@ -339,7 +265,7 @@ func RegisterGkrSolverOptions(curves ...ecc.ID) { } } -func registerGkrGatesBls12377() error { +func registerGatesBls12377() error { const ( x = iota y @@ -349,29 +275,31 @@ func registerGkrGatesBls12377() error { halfRf := p.NbFullRounds / 2 gateNames := newRoundGateNamer(p) - if err := gkrgates.Register(pow2Gate, 1, gkrgates.WithUnverifiedDegree(2), gkrgates.WithNoSolvableVar()); err != nil { + if _, err := gkrgates.Register(pow2Gate, 1, gkrgates.WithUnverifiedDegree(2), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } - if err := gkrgates.Register(pow4Gate, 1, gkrgates.WithUnverifiedDegree(4), gkrgates.WithNoSolvableVar()); err != nil { + if _, err := gkrgates.Register(pow4Gate, 1, gkrgates.WithUnverifiedDegree(4), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } - if err := gkrgates.Register(pow2TimesGate, 2, gkrgates.WithUnverifiedDegree(3), gkrgates.WithNoSolvableVar()); err != nil { + if _, err := gkrgates.Register(pow2TimesGate, 2, gkrgates.WithUnverifiedDegree(3), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } - if err := gkrgates.Register(pow4TimesGate, 2, gkrgates.WithUnverifiedDegree(5), gkrgates.WithNoSolvableVar()); err != nil { + if _, err := gkrgates.Register(pow4TimesGate, 2, gkrgates.WithUnverifiedDegree(5), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } - if err := gkrgates.Register(intGate2, 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0)); err != nil { + if _, err := gkrgates.Register(intGate2, 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } extKeySBox := func(round int, varIndex int) error { - return gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round))) + _, err := gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round)), gkrgates.WithCurves(ecc.BLS12_377)) + return err } intKeySBox2 := func(round int) error { - return gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round))) + _, err := gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round)), gkrgates.WithCurves(ecc.BLS12_377)) + return err } fullRound := func(i int) error { @@ -415,7 +343,8 @@ func registerGkrGatesBls12377() error { } } - return gkrgates.Register(extAddGate, 3, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, p.NbPartialRounds+p.NbFullRounds))) + _, err := gkrgates.Register(extAddGate, 3, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, p.NbPartialRounds+p.NbFullRounds)), gkrgates.WithCurves(ecc.BLS12_377)) + return err } type roundGateNamer string diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 1503054a59..ffa60d8ccb 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -2,18 +2,20 @@ package gkr_poseidon2 import ( "fmt" + "os" + "runtime/pprof" "testing" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" + _ "github.com/consensys/gnark/std/hash/all" "github.com/consensys/gnark/test" "github.com/stretchr/testify/require" ) -func TestGkrCompression(t *testing.T) { - const n = 2 +func gkrPermutationsCircuits(t require.TestingT, n int) (circuit, assignment testGkrPermutationCircuit) { var k int64 ins := make([][2]frontend.Variable, n) outs := make([]frontend.Variable, n) @@ -32,14 +34,19 @@ func TestGkrCompression(t *testing.T) { k += 2 } - circuit := testGkrPermutationCircuit{ - Ins: ins, - Outs: outs, - } + return testGkrPermutationCircuit{ + Ins: make([][2]frontend.Variable, len(ins)), + Outs: make([]frontend.Variable, len(outs)), + }, testGkrPermutationCircuit{ + Ins: ins, + Outs: outs, + } +} - RegisterGkrSolverOptions(ecc.BLS12_377) +func TestGkrCompression(t *testing.T) { + circuit, assignment := gkrPermutationsCircuits(t, 2) - test.NewAssert(t).CheckCircuit(&testGkrPermutationCircuit{Ins: make([][2]frontend.Variable, len(ins)), Outs: make([]frontend.Variable, len(outs))}, test.WithValidAssignment(&circuit), test.WithCurves(ecc.BLS12_377)) + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BLS12_377)) } type testGkrPermutationCircuit struct { @@ -49,7 +56,7 @@ type testGkrPermutationCircuit struct { func (c *testGkrPermutationCircuit) Define(api frontend.API) error { - pos2 := NewGkrCompressions(api) + pos2 := NewGkrCompressor(api) api.AssertIsEqual(len(c.Ins), len(c.Outs)) for i := range c.Ins { api.AssertIsEqual(c.Outs[i], pos2.Compress(c.Ins[i][0], c.Ins[i][1])) @@ -67,3 +74,27 @@ func TestGkrPermutationCompiles(t *testing.T) { require.NoError(t, err) fmt.Println(cs.GetNbConstraints(), "constraints") } + +func BenchmarkGkrPermutations(b *testing.B) { + circuit, assignmment := gkrPermutationsCircuits(b, 50000) + + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) + require.NoError(b, err) + + witness, err := frontend.NewWitness(&assignmment, ecc.BLS12_377.ScalarField()) + require.NoError(b, err) + + // cpu profile + f, err := os.Create("cpu.pprof") + require.NoError(b, err) + defer func() { + require.NoError(b, f.Close()) + }() + + err = pprof.StartCPUProfile(f) + require.NoError(b, err) + defer pprof.StopCPUProfile() + + _, err = cs.Solve(witness) + require.NoError(b, err) +}