diff --git a/internal/stats/latest_stats.csv b/internal/stats/latest_stats.csv index c0bb1990e7..5eaa76d88b 100644 --- a/internal/stats/latest_stats.csv +++ b/internal/stats/latest_stats.csv @@ -87,6 +87,18 @@ math/emulated/secp256k1_64,bn254,plonk,4025,3923 math/emulated/secp256k1_64,bls12_377,plonk,4025,3923 math/emulated/secp256k1_64,bls12_381,plonk,4025,3923 math/emulated/secp256k1_64,bw6_761,plonk,4025,3923 +msm_G1_bn254_2,bn254,groth16,208925,312617 +msm_G1_bn254_2,bn254,plonk,688811,658743 +msm_P256_2,bn254,groth16,185846,288056 +msm_P256_2,bn254,plonk,635297,608874 +msm_babyjubjub_2,bn254,groth16,5269,5683 +msm_babyjubjub_2,bn254,plonk,12389,11848 +msm_bandersnatch_2,bls12_381,groth16,5016,5791 +msm_bandersnatch_2,bls12_381,plonk,12450,11904 +msm_jubjub_2,bls12_381,groth16,5276,5754 +msm_jubjub_2,bls12_381,plonk,12332,11855 +msm_secp256k1_2,bn254,groth16,208997,312737 +msm_secp256k1_2,bn254,plonk,689104,659028 pairing_bls12377,bw6_761,groth16,11876,11876 pairing_bls12377,bw6_761,plonk,41914,40431 pairing_bls12381,bn254,groth16,756837,1242260 diff --git a/internal/stats/snippet.go b/internal/stats/snippet.go index 9534fee643..e2129294dc 100644 --- a/internal/stats/snippet.go +++ b/internal/stats/snippet.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" + twistededwardsCryptoID "github.com/consensys/gnark-crypto/ecc/twistededwards" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/algopts" "github.com/consensys/gnark/std/algebra/emulated/sw_bls12381" @@ -13,6 +14,7 @@ import ( "github.com/consensys/gnark/std/algebra/emulated/sw_bw6761" "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/algebra/native/sw_bls12377" + "github.com/consensys/gnark/std/algebra/native/twistededwards" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/std/math/bits" "github.com/consensys/gnark/std/math/emulated" @@ -412,6 +414,144 @@ func initSnippets() { }, ecc.BN254) + // MSM(2, n) snippets for the four curve classes — used to evaluate which + // MSM-size variant is best in complete-arithmetic mode (Phase 4 of plan). + // Baselines: existing scalar_mul_* divided by 2 gives lower bound for + // MSM(2, n) via two ScalarMul + Add. + registerSnippet("msm_secp256k1_2", func(api frontend.API, newVariable func() frontend.Variable) { + cr, err := sw_emulated.New[emulated.Secp256k1Fp, emulated.Secp256k1Fr](api, sw_emulated.GetCurveParams[emulated.Secp256k1Fp]()) + if err != nil { + panic(err) + } + fr, _ := emulated.NewField[emulated.Secp256k1Fr](api) + newFr := func() *emulated.Element[emulated.Secp256k1Fr] { + n, _ := emulated.GetEffectiveFieldParams[emulated.Secp256k1Fr](api.Compiler().Field()) + limbs := make([]frontend.Variable, n) + for i := range limbs { + limbs[i] = newVariable() + } + return fr.NewElement(limbs) + } + fp, _ := emulated.NewField[emulated.Secp256k1Fp](api) + newPoint := func() *sw_emulated.AffinePoint[emulated.Secp256k1Fp] { + n, _ := emulated.GetEffectiveFieldParams[emulated.Secp256k1Fp](api.Compiler().Field()) + x := make([]frontend.Variable, n) + y := make([]frontend.Variable, n) + for i := range x { + x[i] = newVariable() + y[i] = newVariable() + } + return &sw_emulated.AffinePoint[emulated.Secp256k1Fp]{X: *fp.NewElement(x), Y: *fp.NewElement(y)} + } + _, _ = cr.MultiScalarMul( + []*sw_emulated.AffinePoint[emulated.Secp256k1Fp]{newPoint(), newPoint()}, + []*emulated.Element[emulated.Secp256k1Fr]{newFr(), newFr()}, + ) + }, ecc.BN254) + + registerSnippet("msm_P256_2", func(api frontend.API, newVariable func() frontend.Variable) { + cr, err := sw_emulated.New[emulated.P256Fp, emulated.P256Fr](api, sw_emulated.GetCurveParams[emulated.P256Fp]()) + if err != nil { + panic(err) + } + fr, _ := emulated.NewField[emulated.P256Fr](api) + newFr := func() *emulated.Element[emulated.P256Fr] { + n, _ := emulated.GetEffectiveFieldParams[emulated.P256Fr](api.Compiler().Field()) + limbs := make([]frontend.Variable, n) + for i := range limbs { + limbs[i] = newVariable() + } + return fr.NewElement(limbs) + } + fp, _ := emulated.NewField[emulated.P256Fp](api) + newPoint := func() *sw_emulated.AffinePoint[emulated.P256Fp] { + n, _ := emulated.GetEffectiveFieldParams[emulated.P256Fp](api.Compiler().Field()) + x := make([]frontend.Variable, n) + y := make([]frontend.Variable, n) + for i := range x { + x[i] = newVariable() + y[i] = newVariable() + } + return &sw_emulated.AffinePoint[emulated.P256Fp]{X: *fp.NewElement(x), Y: *fp.NewElement(y)} + } + _, _ = cr.MultiScalarMul( + []*sw_emulated.AffinePoint[emulated.P256Fp]{newPoint(), newPoint()}, + []*emulated.Element[emulated.P256Fr]{newFr(), newFr()}, + ) + }, ecc.BN254) + + registerSnippet("msm_G1_bn254_2", func(api frontend.API, newVariable func() frontend.Variable) { + cr, err := sw_emulated.New[emulated.BN254Fp, emulated.BN254Fr](api, sw_emulated.GetCurveParams[emulated.BN254Fp]()) + if err != nil { + panic(err) + } + fr, _ := emulated.NewField[emulated.BN254Fr](api) + newFr := func() *emulated.Element[emulated.BN254Fr] { + n, _ := emulated.GetEffectiveFieldParams[emulated.BN254Fr](api.Compiler().Field()) + limbs := make([]frontend.Variable, n) + for i := range limbs { + limbs[i] = newVariable() + } + return fr.NewElement(limbs) + } + fp, _ := emulated.NewField[emulated.BN254Fp](api) + newPoint := func() *sw_emulated.AffinePoint[emulated.BN254Fp] { + n, _ := emulated.GetEffectiveFieldParams[emulated.BN254Fp](api.Compiler().Field()) + x := make([]frontend.Variable, n) + y := make([]frontend.Variable, n) + for i := range x { + x[i] = newVariable() + y[i] = newVariable() + } + return &sw_emulated.AffinePoint[emulated.BN254Fp]{X: *fp.NewElement(x), Y: *fp.NewElement(y)} + } + _, _ = cr.MultiScalarMul( + []*sw_emulated.AffinePoint[emulated.BN254Fp]{newPoint(), newPoint()}, + []*emulated.Element[emulated.BN254Fr]{newFr(), newFr()}, + ) + }, ecc.BN254) + + // Twisted Edwards DoubleBaseScalarMul snippets — exercise the new + // MSM(3, 2n/3) (no GLV) and MSM(6, n/3) (GLV) variants from PR #1697. + registerSnippet("msm_babyjubjub_2", func(api frontend.API, newVariable func() frontend.Variable) { + curve, err := twistededwards.NewEdCurve(api, twistededwardsCryptoID.BN254) + if err != nil { + panic(err) + } + var P1, P2 twistededwards.Point + P1.X = newVariable() + P1.Y = newVariable() + P2.X = newVariable() + P2.Y = newVariable() + _ = curve.DoubleBaseScalarMulNonZero(P1, P2, newVariable(), newVariable()) + }, ecc.BN254) + + registerSnippet("msm_jubjub_2", func(api frontend.API, newVariable func() frontend.Variable) { + curve, err := twistededwards.NewEdCurve(api, twistededwardsCryptoID.BLS12_381) + if err != nil { + panic(err) + } + var P1, P2 twistededwards.Point + P1.X = newVariable() + P1.Y = newVariable() + P2.X = newVariable() + P2.Y = newVariable() + _ = curve.DoubleBaseScalarMulNonZero(P1, P2, newVariable(), newVariable()) + }, ecc.BLS12_381) + + registerSnippet("msm_bandersnatch_2", func(api frontend.API, newVariable func() frontend.Variable) { + curve, err := twistededwards.NewEdCurve(api, twistededwardsCryptoID.BLS12_381_BANDERSNATCH) + if err != nil { + panic(err) + } + var P1, P2 twistededwards.Point + P1.X = newVariable() + P1.Y = newVariable() + P2.X = newVariable() + P2.Y = newVariable() + _ = curve.DoubleBaseScalarMulNonZero(P1, P2, newVariable(), newVariable()) + }, ecc.BLS12_381) + // G2 scalar mul snippets — exercise the GLV+FakeGLV path to track its cost. registerSnippet("scalar_mul_G2_bls12381", func(api frontend.API, newVariable func() frontend.Variable) { bls12381fp, _ := emulated.NewField[emulated.BLS12381Fp](api) diff --git a/std/algebra/native/twistededwards/curve.go b/std/algebra/native/twistededwards/curve.go index 4dfcf10e09..1fab86ede1 100644 --- a/std/algebra/native/twistededwards/curve.go +++ b/std/algebra/native/twistededwards/curve.go @@ -10,6 +10,7 @@ type curve struct { api frontend.API id twistededwards.ID params *CurveParams + endo *EndoParams // non-nil iff the curve has a GLV endomorphism (Bandersnatch) } func (c *curve) Params() *CurveParams { @@ -44,8 +45,29 @@ func (c *curve) ScalarMul(p1 Point, scalar frontend.Variable) Point { p.scalarMul(c.api, &p1, scalar, c.params) return p } + +// DoubleBaseScalarMul computes s1*p1 + s2*p2. It is complete for all scalar +// inputs, including zero, and for identity points. func (c *curve) DoubleBaseScalarMul(p1, p2 Point, s1, s2 frontend.Variable) Point { var p Point p.doubleBaseScalarMul(c.api, &p1, &p2, s1, s2, c.params) return p } + +// DoubleBaseScalarMulNonZero computes s1*p1 + s2*p2 using the most efficient +// lattice-based MSM variant available for the curve: +// - GLV-equipped curves (Bandersnatch): 6-MSM with r^(1/3)-bounded sub-scalars. +// - non-GLV curves (Jubjub, BabyJubjub, edBLS12-377, edBW6-761): 3-MSM with +// r^(2/3)-bounded sub-scalars and LogUp lookups. +// +// The scalars s1, s2 must be nonzero and p1, p2 must not be the TE identity +// (0, 1). Use DoubleBaseScalarMul for complete edge-case handling. +func (c *curve) DoubleBaseScalarMulNonZero(p1, p2 Point, s1, s2 frontend.Variable) Point { + var p Point + if c.endo != nil { + p.doubleBaseScalarMul6MSMLogUp(c.api, &p1, &p2, s1, s2, c.params, c.endo) + } else { + p.doubleBaseScalarMul3MSMLogUp(c.api, &p1, &p2, s1, s2, c.params) + } + return p +} diff --git a/std/algebra/native/twistededwards/curve_test.go b/std/algebra/native/twistededwards/curve_test.go index 5afcc9a727..b979ad59f4 100644 --- a/std/algebra/native/twistededwards/curve_test.go +++ b/std/algebra/native/twistededwards/curve_test.go @@ -191,6 +191,22 @@ func (circuit *doubleBaseScalarMulCircuit) Define(api frontend.API) error { return nil } +type doubleBaseScalarMulNonZeroCircuit struct { + curveID twistededwards.ID + P1, P2 Point + S1, S2 frontend.Variable + Result Point +} + +func (circuit *doubleBaseScalarMulNonZeroCircuit) Define(api frontend.API) error { + curve, err := NewEdCurve(api, circuit.curveID) + if err != nil { + return err + } + assertPointEqual(api, curve.DoubleBaseScalarMulNonZero(circuit.P1, circuit.P2, circuit.S1, circuit.S2), circuit.Result) + return nil +} + func TestAdd(t *testing.T) { for _, curveID := range curves { params, err := GetCurveParams(curveID) @@ -296,6 +312,21 @@ func TestDoubleBaseScalarMul(t *testing.T) { } } +func TestDoubleBaseScalarMulNonZero(t *testing.T) { + for _, curveID := range curves { + params, err := GetCurveParams(curveID) + if err != nil { + t.Fatalf("%s: get curve params: %v", curveLabel(curveID), err) + } + data := randomTestData(params, curveID) + circuit := &doubleBaseScalarMulNonZeroCircuit{curveID: curveID} + witness := &doubleBaseScalarMulNonZeroCircuit{P1: data.P1, P2: data.P2, S1: data.S1, S2: data.S2, Result: data.DoubleScalarMulResult} + invalidWitness := *witness + invalidWitness.Result = offCurvePoint() + checkCircuitForCurve(t, curveID, circuit, witness, &invalidWitness) + } +} + func TestAddEdgeCases(t *testing.T) { for _, curveID := range curves { params, err := GetCurveParams(curveID) @@ -380,6 +411,9 @@ func TestFixedScalarMulEdgeCases(t *testing.T) { } } +// TestDoubleBaseScalarMulEdgeCases covers the complete public method, including +// zero scalars and identity points. The optimized NonZero variant is tested +// separately. func TestDoubleBaseScalarMulEdgeCases(t *testing.T) { for _, curveID := range curves { params, err := GetCurveParams(curveID) @@ -387,11 +421,14 @@ func TestDoubleBaseScalarMulEdgeCases(t *testing.T) { t.Fatalf("%s: get curve params: %v", curveLabel(curveID), err) } data := testDataForScalars(params, curveID, big.NewInt(1), big.NewInt(2)) + base := Point{X: params.Base[0], Y: params.Base[1]} t.Run(curveLabel(curveID), func(t *testing.T) { - assertSolvedForCurve(t, curveID, &doubleBaseScalarMulCircuit{curveID: curveID}, &doubleBaseScalarMulCircuit{P1: data.P1, P2: data.P2, S1: 0, S2: 0, Result: identityPoint()}) - assertSolvedForCurve(t, curveID, &doubleBaseScalarMulCircuit{curveID: curveID}, &doubleBaseScalarMulCircuit{P1: data.P1, P2: data.P2, S1: 1, S2: 0, Result: data.P1}) - assertSolvedForCurve(t, curveID, &doubleBaseScalarMulCircuit{curveID: curveID}, &doubleBaseScalarMulCircuit{P1: data.P1, P2: data.P2, S1: 0, S2: 1, Result: data.P2}) + circuit := &doubleBaseScalarMulCircuit{curveID: curveID} + assertSolvedForCurve(t, curveID, circuit, &doubleBaseScalarMulCircuit{P1: data.P1, P2: data.P2, S1: 0, S2: 0, Result: identityPoint()}) + assertSolvedForCurve(t, curveID, circuit, &doubleBaseScalarMulCircuit{P1: data.P1, P2: data.P2, S1: 1, S2: 0, Result: data.P1}) + assertSolvedForCurve(t, curveID, circuit, &doubleBaseScalarMulCircuit{P1: data.P1, P2: data.P2, S1: 0, S2: 1, Result: data.P2}) + assertSolvedForCurve(t, curveID, circuit, &doubleBaseScalarMulCircuit{P1: identityPoint(), P2: base, S1: 1, S2: 2, Result: data.P2}) }) } } @@ -630,8 +667,8 @@ func zeroRationalReconstructHint(_ *big.Int, inputs, outputs []*big.Int) error { if len(inputs) != 2 { return errors.New("expecting two inputs") } - if len(outputs) != 4 { - return errors.New("expecting four outputs") + if len(outputs) != 3 { + return errors.New("expecting three outputs") } for i := range outputs { outputs[i].SetUint64(0) @@ -659,3 +696,96 @@ func TestScalarMulFakeGLVRegressionTrivialDecomposition(t *testing.T) { ) assert.Error(err) } + +func forgedBN254DoubleBaseScalarMulHint(_ *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 7 { + return errors.New("expecting seven inputs") + } + if len(outputs) != 4 { + return errors.New("expecting four outputs") + } + var p1, p2, q1, q2 tbn254.PointAffine + p1.X.SetBigInt(inputs[0]) + p1.Y.SetBigInt(inputs[1]) + p2.X.SetBigInt(inputs[3]) + p2.Y.SetBigInt(inputs[4]) + q1.ScalarMultiplication(&p1, inputs[2]) + q2.ScalarMultiplication(&p2, inputs[5]) + + var delta tbn254.PointAffine + delta.Set(&p1) + + var q1Hint tbn254.PointAffine + q1Hint.Add(&q1, &delta) + + q1Hint.X.BigInt(outputs[0]) + q1Hint.Y.BigInt(outputs[1]) + q2.X.BigInt(outputs[2]) + q2.Y.BigInt(outputs[3]) + return nil +} + +func forgedBN254DoubleBaseResult(params *CurveParams, s1, s2 *big.Int) (Point, error) { + var p1, p2 tbn254.PointAffine + p1.X.SetBigInt(params.Base[0]) + p1.Y.SetBigInt(params.Base[1]) + p2.Set(&p1) + p1.ScalarMultiplication(&p1, s1) + p2.ScalarMultiplication(&p2, s2) + + p1X, p1Y := new(big.Int), new(big.Int) + p2X, p2Y := new(big.Int), new(big.Int) + p1.X.BigInt(p1X) + p1.Y.BigInt(p1Y) + p2.X.BigInt(p2X) + p2.Y.BigInt(p2Y) + + inputs := []*big.Int{ + p1X, + p1Y, + new(big.Int).Set(s1), + p2X, + p2Y, + new(big.Int).Set(s2), + new(big.Int).Set(params.Order), + } + outputs := []*big.Int{new(big.Int), new(big.Int), new(big.Int), new(big.Int)} + if err := forgedBN254DoubleBaseScalarMulHint(nil, inputs, outputs); err != nil { + return Point{}, err + } + var q1, q2, r tbn254.PointAffine + q1.X.SetBigInt(outputs[0]) + q1.Y.SetBigInt(outputs[1]) + q2.X.SetBigInt(outputs[2]) + q2.Y.SetBigInt(outputs[3]) + r.Add(&q1, &q2) + rX, rY := new(big.Int), new(big.Int) + r.X.BigInt(rX) + r.Y.BigInt(rY) + return Point{X: rX, Y: rY}, nil +} + +func TestDoubleBaseScalarMulNonZeroRejectsForgedPartialHints(t *testing.T) { + assert := require.New(t) + params, err := GetCurveParams(twistededwards.BN254) + assert.NoError(err) + + data := testDataForScalars(params, twistededwards.BN254, big.NewInt(5), big.NewInt(7)) + forged, err := forgedBN254DoubleBaseResult(params, data.S1, data.S2) + assert.NoError(err) + + witness := doubleBaseScalarMulNonZeroCircuit{ + P1: data.P1, + P2: data.P2, + S1: data.S1, + S2: data.S2, + Result: forged, + } + err = test.IsSolved( + &doubleBaseScalarMulNonZeroCircuit{curveID: twistededwards.BN254}, + &witness, + ecc.BN254.ScalarField(), + test.WithReplacementHint(solver.GetHintID(doubleBaseScalarMulHint), forgedBN254DoubleBaseScalarMulHint), + ) + assert.Error(err) +} diff --git a/std/algebra/native/twistededwards/emulatedparams.go b/std/algebra/native/twistededwards/emulatedparams.go new file mode 100644 index 0000000000..c24d016a89 --- /dev/null +++ b/std/algebra/native/twistededwards/emulatedparams.go @@ -0,0 +1,64 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +package twistededwards + +import "math/big" + +// Emulated field parameters for twisted Edwards curve orders. +// These are used for overflow-safe scalar decomposition verification. + +// edBN254Order is the BabyJubjub curve order (251 bits). +type edBN254Order struct{} + +func (edBN254Order) NbLimbs() uint { return 4 } +func (edBN254Order) BitsPerLimb() uint { return 64 } +func (edBN254Order) IsPrime() bool { return true } +func (edBN254Order) Modulus() *big.Int { + r, _ := new(big.Int).SetString("2736030358979909402780800718157159386076813972158567259200215660948447373041", 10) + return r +} + +// edBLS12381Order is the Jubjub curve order (252 bits). +type edBLS12381Order struct{} + +func (edBLS12381Order) NbLimbs() uint { return 4 } +func (edBLS12381Order) BitsPerLimb() uint { return 64 } +func (edBLS12381Order) IsPrime() bool { return true } +func (edBLS12381Order) Modulus() *big.Int { + r, _ := new(big.Int).SetString("6554484396890773809930967563523245729705921265872317281365359162392183254199", 10) + return r +} + +// edBandersnatchOrder is the Bandersnatch curve order (253 bits). +type edBandersnatchOrder struct{} + +func (edBandersnatchOrder) NbLimbs() uint { return 4 } +func (edBandersnatchOrder) BitsPerLimb() uint { return 64 } +func (edBandersnatchOrder) IsPrime() bool { return true } +func (edBandersnatchOrder) Modulus() *big.Int { + r, _ := new(big.Int).SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) + return r +} + +// edBLS12377Order is the BLS12-377 twisted Edwards curve order (251 bits). +type edBLS12377Order struct{} + +func (edBLS12377Order) NbLimbs() uint { return 4 } +func (edBLS12377Order) BitsPerLimb() uint { return 64 } +func (edBLS12377Order) IsPrime() bool { return true } +func (edBLS12377Order) Modulus() *big.Int { + r, _ := new(big.Int).SetString("2111115437357092606062206234695386632838870926408408195193685246394721360383", 10) + return r +} + +// edBW6761Order is the BW6-761 twisted Edwards curve order (374 bits). +type edBW6761Order struct{} + +func (edBW6761Order) NbLimbs() uint { return 6 } +func (edBW6761Order) BitsPerLimb() uint { return 64 } +func (edBW6761Order) IsPrime() bool { return true } +func (edBW6761Order) Modulus() *big.Int { + r, _ := new(big.Int).SetString("32333053251621136751331591711861691692049189094364332567435817881934511297123972799646723302813083835942624121493", 10) + return r +} diff --git a/std/algebra/native/twistededwards/hints.go b/std/algebra/native/twistededwards/hints.go index 0a52dba254..1bffcdb200 100644 --- a/std/algebra/native/twistededwards/hints.go +++ b/std/algebra/native/twistededwards/hints.go @@ -20,6 +20,9 @@ func GetHints() []solver.Hint { rationalReconstruct, scalarMulHint, decomposeScalar, + doubleBaseScalarMulHint, + multiRationalReconstructHint, + multiRationalReconstructExtHint, } } @@ -64,20 +67,20 @@ func decomposeScalar(scalarField *big.Int, inputs []*big.Int, res []*big.Int) er return nil } -// rationalReconstruct decomposes a scalar s ∈ Fr into (s1, s2, signBit, k) such -// that s1 + s2·s = k·r in the integers, with |s1|, |s2| < γ₂·√r ≈ 1.15·√r +// rationalReconstruct decomposes a scalar s ∈ Fr into (s1, s2, signBit) such +// that s1 + s2·s = 0 mod r, with |s1|, |s2| < γ₂·√r ≈ 1.15·√r // (proven LLL/Hermite bound). Replaces the older heuristic-bound HalfGCD. // // The bit-decomposition convention: s1 ≥ 0 always, s2 = ±|s2| with signBit = 1 -// iff the underlying signed s2 was negative. The integer k is signed. +// iff the underlying signed s2 was negative. func rationalReconstruct(_ *big.Int, inputs, outputs []*big.Int) error { if len(inputs) != 2 { return errors.New("expecting two inputs (s, r)") } - if len(outputs) != 4 { - return errors.New("expecting four outputs (s1, |s2|, signBit, k)") + if len(outputs) != 3 { + return errors.New("expecting three outputs (s1, |s2|, signBit)") } - // Zero scalar: trivial (s1=s2=k=0). The in-circuit IsZero(s2)=0 guard + // Zero scalar: trivial (s1=s2=0). The in-circuit IsZero(s2)=0 guard // rejects this; the caller must pre-route scalar=1 (mirrors the existing // scalarMulFakeGLV: checkedScalar = Select(isScalarZero, 1, scalar)). if inputs[0].Sign() == 0 { @@ -87,9 +90,8 @@ func rationalReconstruct(_ *big.Int, inputs, outputs []*big.Int) error { return nil } - // lattice.RationalReconstruct returns (x, z) with x ≡ z·s mod r, - // so x − z·s = m·r for some signed integer m, with |x|, |z| < γ₂·√r. - // Map onto our convention: s1 + s2·s = k·r ⇒ s1 = x, s2 = −z, k = m. + // lattice.RationalReconstruct returns (x, z) with x ≡ z·s mod r. + // Map onto our convention: s1 + s2·s = 0 mod r ⇒ s1 = x, s2 = −z. res := lattice.RationalReconstruct(inputs[0], inputs[1]) x, z := new(big.Int).Set(res[0]), new(big.Int).Set(res[1]) @@ -101,12 +103,6 @@ func rationalReconstruct(_ *big.Int, inputs, outputs []*big.Int) error { } outputs[0].Set(x) // s1 = x ≥ 0 - // k = (x − z·s) / r computed in signed integers. - k := new(big.Int).Mul(z, inputs[0]) - k.Sub(x, k) - k.Quo(k, inputs[1]) - outputs[3].Set(k) - // s2 = −z, encoded as |s2| + signBit. signBit = 1 iff −z < 0 iff z > 0. outputs[1].Abs(z) outputs[2].SetUint64(0) @@ -167,3 +163,191 @@ func scalarMulHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error } return nil } + +// doubleBaseScalarMulHint computes [s1]P1 and [s2]P2 separately and returns +// their (X, Y) coords. Inputs: P1.X, P1.Y, s1, P2.X, P2.Y, s2, order. +// Outputs: Q1.X, Q1.Y, Q2.X, Q2.Y where Q1=[s1]P1 and Q2=[s2]P2. +// +// Used by `doubleBaseScalarMul3MSMLogUp` and `doubleBaseScalarMul6MSMLogUp` to +// hint the result that the in-circuit MSM verifies. +func doubleBaseScalarMulHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 7 { + return errors.New("expecting seven inputs") + } + if len(outputs) != 4 { + return errors.New("expecting four outputs") + } + if field.Cmp(ecc.BLS12_381.ScalarField()) == 0 { + order, _ := new(big.Int).SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) + if inputs[6].Cmp(order) == 0 { + var P1, P2 bandersnatch.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } else { + var P1, P2 jubjub.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } + } else if field.Cmp(ecc.BN254.ScalarField()) == 0 { + var P1, P2 babyjubjub.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } else if field.Cmp(ecc.BLS12_377.ScalarField()) == 0 { + var P1, P2 edbls12377.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } else if field.Cmp(ecc.BW6_761.ScalarField()) == 0 { + var P1, P2 edbw6761.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } else { + return errors.New("doubleBaseScalarMulHint: unknown curve") + } + return nil +} + +// multiRationalReconstructHint decomposes (k1, k2) jointly via 3-D LLL +// reconstruction: finds (x1, x2, z) with a shared denominator z such that +// +// k1 ≡ x1 / z (mod r) +// k2 ≡ x2 / z (mod r) +// +// with each component bounded by ~r^(2/3). Used by the non-GLV +// `doubleBaseScalarMul3MSMLogUp` path. +// +// inputs: k1, k2, order +// outputs[0..2]: |x1|, |x2|, |z| +// outputs[3..5]: signX1, signX2, signZ +func multiRationalReconstructHint(_ *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 3 { + return errors.New("expecting three inputs: k1, k2, order") + } + if len(outputs) != 6 { + return errors.New("expecting six outputs") + } + k1, k2, order := inputs[0], inputs[1], inputs[2] + + if k1.Sign() == 0 && k2.Sign() == 0 { + for i := range outputs { + outputs[i].SetUint64(0) + } + return nil + } + + res := lattice.NewReconstructor(order).MultiRationalReconstruct(k1, k2) + x1, x2, z := res[0], res[1], res[2] + + outputs[0].Abs(x1) + outputs[1].Abs(x2) + outputs[2].Abs(z) + + setSign := func(out *big.Int, val *big.Int) { + if val.Sign() < 0 { + out.SetUint64(1) + } else { + out.SetUint64(0) + } + } + setSign(outputs[3], x1) + setSign(outputs[4], x2) + setSign(outputs[5], z) + + return nil +} + +// multiRationalReconstructExtHint decomposes (k1, k2) jointly via 6-D LLL +// reconstruction: finds (x1, y1, x2, y2, z, t) with shared denominator +// (z + λ·t) such that +// +// k1 ≡ (x1 + λ·y1) / (z + λ·t) (mod r) +// k2 ≡ (x2 + λ·y2) / (z + λ·t) (mod r) +// +// with each component bounded by ~r^(1/3). Used by the GLV-curve +// `doubleBaseScalarMul6MSMLogUp` path. +// +// inputs: k1, k2, order, lambda +// outputs[0..5]: |x1|, |y1|, |x2|, |y2|, |z|, |t| +// outputs[6..11]: signX1, signY1, signX2, signY2, signZ, signT +func multiRationalReconstructExtHint(_ *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 4 { + return errors.New("expecting four inputs: k1, k2, order, lambda") + } + if len(outputs) != 12 { + return errors.New("expecting 12 outputs") + } + k1, k2, order, lambda := inputs[0], inputs[1], inputs[2], inputs[3] + + if k1.Sign() == 0 && k2.Sign() == 0 { + for i := range outputs { + outputs[i].SetUint64(0) + } + return nil + } + + rc := lattice.NewReconstructor(order).SetLambda(lambda) + res := rc.MultiRationalReconstructExt(k1, k2) + x1, y1, x2, y2, z, t := res[0], res[1], res[2], res[3], res[4], res[5] + + outputs[0].Abs(x1) + outputs[1].Abs(y1) + outputs[2].Abs(x2) + outputs[3].Abs(y2) + outputs[4].Abs(z) + outputs[5].Abs(t) + + setSign := func(out *big.Int, val *big.Int) { + if val.Sign() < 0 { + out.SetUint64(1) + } else { + out.SetUint64(0) + } + } + setSign(outputs[6], x1) + setSign(outputs[7], y1) + setSign(outputs[8], x2) + setSign(outputs[9], y2) + setSign(outputs[10], z) + setSign(outputs[11], t) + + return nil +} diff --git a/std/algebra/native/twistededwards/point.go b/std/algebra/native/twistededwards/point.go index 1c93fe2650..762dc84ce9 100644 --- a/std/algebra/native/twistededwards/point.go +++ b/std/algebra/native/twistededwards/point.go @@ -3,7 +3,10 @@ package twistededwards -import "github.com/consensys/gnark/frontend" +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/lookup/logderivlookup" +) // neg computes the negative of a point in SNARK coordinates func (p *Point) neg(api frontend.API, p1 *Point) *Point { @@ -134,25 +137,15 @@ func (p *Point) scalarMulFakeGLV(api frontend.API, p1 *Point, scalar frontend.Va // s1 + s * s2 == 0 mod Order. Uses LLL-based lattice rational // reconstruction with proven Hermite bound |s1|, |s2| < γ₂·√r ≈ 1.15·√r // (see [EEMP25] / gnark-crypto/algebra/lattice). - s, err := api.NewHint(rationalReconstruct, 4, checkedScalar, curve.Order) + s, err := api.NewHint(rationalReconstruct, 3, checkedScalar, curve.Order) if err != nil { // err is non-nil only for invalid number of inputs panic(err) } - s1, s2, bit, k := s[0], s[1], s[2], s[3] + s1, s2, bit := s[0], s[1], s[2] - // check that s1 + s2 * s == k*Order - _s2 := api.Mul(s2, checkedScalar) - _k := api.Mul(k, curve.Order) - lhs := api.Select(bit, s1, api.Add(s1, _s2)) - rhs := api.Select(bit, api.Add(_k, _s2), _k) - api.AssertIsEqual(lhs, rhs) - // A malicious hint can provide s1=s2=0, which makes the relation vacuous. - api.AssertIsEqual(api.IsZero(s2), 0) - - n := (curve.Order.BitLen() + 1) / 2 - b1 := api.ToBinary(s1, n) - b2 := api.ToBinary(s2, n) + b1, b2 := verifyScalarDecomposition(api, s1, s2, bit, checkedScalar, curve) + n := len(b1) var res, p2, p3, tmp Point q, err := api.NewHint(scalarMulHint, 2, p1.X, p1.Y, checkedScalar, curve.Order) @@ -183,3 +176,321 @@ func (p *Point) scalarMulFakeGLV(api frontend.API, p1 *Point, scalar frontend.Va return p } + +// phi is the GLV endomorphism on Bandersnatch: (x, y) → ((1-y²)·E1/(x·y), +// (y²+E0)·E0/(y²-E0)) acts as scalar multiplication by Lambda on the prime- +// order subgroup. Used by `doubleBaseScalarMul6MSMLogUp` only. +func (p *Point) phi(api frontend.API, p1 *Point, curve *CurveParams, endo *EndoParams) *Point { + xy := api.Mul(p1.X, p1.Y) + yy := api.Mul(p1.Y, p1.Y) + f := api.Sub(1, yy) + f = api.Mul(f, endo.Endo[1]) + g := api.Add(yy, endo.Endo[0]) + g = api.Mul(g, endo.Endo[0]) + h := api.Sub(yy, endo.Endo[0]) + + p.X = api.DivUnchecked(f, xy) + p.Y = api.DivUnchecked(g, h) + return p +} + +// doubleBaseScalarMul3MSMLogUp computes s1*P1+s2*P2 using MultiRationalReconstruct. +// This decomposes both scalars with a shared denominator in Z, giving +// ~r^(2/3)-bit scalars. It verifies [x1]P1 + [x2]P2 - [z]R = O where +// R = [s1]P1 + [s2]P2 is hinted. +func (p *Point) doubleBaseScalarMul3MSMLogUp(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve *CurveParams) *Point { + // Get hinted results Q1 = [s1]P1 and Q2 = [s2]P2 + q, err := api.NewHint(doubleBaseScalarMulHint, 4, p1.X, p1.Y, s1, p2.X, p2.Y, s2, curve.Order) + if err != nil { + panic(err) + } + var Q1, Q2 Point + Q1.X, Q1.Y = q[0], q[1] + Q2.X, Q2.Y = q[2], q[3] + + var R Point + R.add(api, &Q1, &Q2, curve) + + // Decompose (s1, s2) into (x1, x2, z) such that + // s1*z ≡ x1 and s2*z ≡ x2 (mod Order). + h, err := api.NewHint(multiRationalReconstructHint, 6, s1, s2, curve.Order) + if err != nil { + panic(err) + } + absX1, absX2, absZ := h[0], h[1], h[2] + signX1, signX2, signZ := h[3], h[4], h[5] + + // Verify the decomposition using emulated arithmetic to avoid native field + // overflow. Also range-checks x1, x2, z and ensures z is non-zero. + bX1, bX2, bZ := verifyScalarDecomposition3D(api, s1, s2, absX1, absX2, absZ, signX1, signX2, signZ, curve) + + var sP1, sP2, sR Point + sP1.X = api.Select(signX1, api.Neg(p1.X), p1.X) + sP1.Y = p1.Y + sP2.X = api.Select(signX2, api.Neg(p2.X), p2.X) + sP2.Y = p2.Y + sR.X = api.Select(signZ, R.X, api.Neg(R.X)) + sR.Y = R.Y + + // Build the 8-entry table for 3-MSM: sP1, sP2, sR. + var table [8]Point + table[0] = Point{X: 0, Y: 1} + table[1] = sP1 + table[2] = sP2 + table[3].add(api, &sP1, &sP2, curve) + table[4] = sR + table[5].add(api, &sP1, &sR, curve) + table[6].add(api, &sP2, &sR, curve) + table[7].add(api, &table[3], &sR, curve) + + // Create LogDerivLookup tables + tableX := logderivlookup.New(api) + tableY := logderivlookup.New(api) + for i := 0; i < 8; i++ { + tableX.Insert(table[i].X) + tableY.Insert(table[i].Y) + } + + n := len(bX1) + + // Compute indices for lookups + indices := make([]frontend.Variable, n) + for i := 0; i < n; i++ { + // index = bX1[i] + 2*bX2[i] + 4*bZ[i] + indices[i] = api.Add(bX1[i], api.Mul(bX2[i], 2), api.Mul(bZ[i], 4)) + } + + // Batch lookup + resX := tableX.Lookup(indices...) + resY := tableY.Lookup(indices...) + + // Initialize accumulator with first entry + var res Point + res.X = resX[n-1] + res.Y = resY[n-1] + + for i := n - 2; i >= 0; i-- { + res.double(api, &res, curve) + var tmp Point + tmp.X = resX[i] + tmp.Y = resY[i] + res.add(api, &res, &tmp, curve) + } + + // Verify accumulator equals identity (0, 1) + api.AssertIsEqual(res.X, 0) + api.AssertIsEqual(res.Y, 1) + + p.X = R.X + p.Y = R.Y + + return p +} + +// doubleBaseScalarMul6MSMLogUp computes s1*P1+s2*P2 using MultiRationalReconstructExt (true 6-MSM). +// This decomposes both scalars with a shared denominator in Z[λ], giving ~r^(1/3)-bit scalars. +// Verifies: [x1]P + [y1]φ(P) + [x2]Q + [y2]φ(Q) = [z]R + [t]φ(R) +// where R = [s1]P + [s2]Q (hinted). +// Only works for curves with efficient endomorphism (e.g., Bandersnatch). +func (p *Point) doubleBaseScalarMul6MSMLogUp(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve *CurveParams, endo *EndoParams) *Point { + // Get hinted result R = [s1]P + [s2]Q + qHint, err := api.NewHint(doubleBaseScalarMulHint, 4, p1.X, p1.Y, s1, p2.X, p2.Y, s2, curve.Order) + if err != nil { + panic(err) + } + var R Point + // We need Q1 + Q2 = R + var Q1, Q2 Point + Q1.X, Q1.Y = qHint[0], qHint[1] + Q2.X, Q2.Y = qHint[2], qHint[3] + R.add(api, &Q1, &Q2, curve) + + // Decompose (s1, s2) using MultiRationalReconstructExt. Returns + // |x1|, |y1|, |x2|, |y2|, |z|, |t| and their signs. + h, err := api.NewHint(multiRationalReconstructExtHint, 12, s1, s2, curve.Order, endo.Lambda) + if err != nil { + panic(err) + } + absX1, absY1, absX2, absY2, absZ, absT := h[0], h[1], h[2], h[3], h[4], h[5] + signX1, signY1, signX2, signY2, signZ, signT := h[6], h[7], h[8], h[9], h[10], h[11] + + // Verify the decomposition using emulated arithmetic to avoid native field overflow. + // Checks: s_i * (z + λ*t) ≡ x_i + λ*y_i (mod r) for i=1,2 + // Also range-checks sub-scalars and ensures the shared denominator is non-zero. + bX1, bY1, bX2, bY2, bZ, bT := verifyScalarDecomposition6D(api, s1, s2, + absX1, absY1, absX2, absY2, absZ, absT, + signX1, signY1, signX2, signY2, signZ, signT, + curve, endo, + ) + + // Compute φ(P1), φ(P2), φ(R) + var phiP1, phiP2, phiR Point + phiP1.phi(api, p1, curve, endo) + phiP2.phi(api, p2, curve, endo) + phiR.phi(api, &R, curve, endo) + + // Apply signs to create signed points for the 6-MSM + // The verification is: [x1]P + [y1]φ(P) + [x2]Q + [y2]φ(Q) - [z]R - [t]φ(R) = O + // With signs: we negate the point when the sign is 1 + var sP1, sPhiP1, sP2, sPhiP2, sR, sPhiR Point + + // For P1: if signX1 == 1, use -P1, else use P1 + sP1.X = api.Select(signX1, api.Neg(p1.X), p1.X) + sP1.Y = p1.Y + + // For φ(P1): if signY1 == 1, use -φ(P1), else use φ(P1) + sPhiP1.X = api.Select(signY1, api.Neg(phiP1.X), phiP1.X) + sPhiP1.Y = phiP1.Y + + // For P2: if signX2 == 1, use -P2, else use P2 + sP2.X = api.Select(signX2, api.Neg(p2.X), p2.X) + sP2.Y = p2.Y + + // For φ(P2): if signY2 == 1, use -φ(P2), else use φ(P2) + sPhiP2.X = api.Select(signY2, api.Neg(phiP2.X), phiP2.X) + sPhiP2.Y = phiP2.Y + + // For R: we subtract [z]R, so if signZ == 0 (z positive), use -R; if signZ == 1 (z negative), use R + sR.X = api.Select(signZ, R.X, api.Neg(R.X)) + sR.Y = R.Y + + // For φ(R): similarly for t + sPhiR.X = api.Select(signT, phiR.X, api.Neg(phiR.X)) + sPhiR.Y = phiR.Y + + // Build 64-entry table for 6-MSM + // Index = b0 + 2*b1 + 4*b2 + 8*b3 + 16*b4 + 32*b5 + // Points: sP1, sPhiP1, sP2, sPhiP2, sR, sPhiR + var table [64]Point + + // Precompute all 64 combinations + // table[i] = (i&1)*sP1 + ((i>>1)&1)*sPhiP1 + ((i>>2)&1)*sP2 + ((i>>3)&1)*sPhiP2 + ((i>>4)&1)*sR + ((i>>5)&1)*sPhiR + + // Start with identity + table[0] = Point{X: 0, Y: 1} + + // Single points + table[1] = sP1 + table[2] = sPhiP1 + table[4] = sP2 + table[8] = sPhiP2 + table[16] = sR + table[32] = sPhiR + + // 2-combinations + table[3].add(api, &sP1, &sPhiP1, curve) + table[5].add(api, &sP1, &sP2, curve) + table[6].add(api, &sPhiP1, &sP2, curve) + table[9].add(api, &sP1, &sPhiP2, curve) + table[10].add(api, &sPhiP1, &sPhiP2, curve) + table[12].add(api, &sP2, &sPhiP2, curve) + table[17].add(api, &sP1, &sR, curve) + table[18].add(api, &sPhiP1, &sR, curve) + table[20].add(api, &sP2, &sR, curve) + table[24].add(api, &sPhiP2, &sR, curve) + table[33].add(api, &sP1, &sPhiR, curve) + table[34].add(api, &sPhiP1, &sPhiR, curve) + table[36].add(api, &sP2, &sPhiR, curve) + table[40].add(api, &sPhiP2, &sPhiR, curve) + table[48].add(api, &sR, &sPhiR, curve) + + // 3-combinations (build from 2-combinations) + table[7].add(api, &table[3], &sP2, curve) // sP1 + sPhiP1 + sP2 + table[11].add(api, &table[3], &sPhiP2, curve) // sP1 + sPhiP1 + sPhiP2 + table[13].add(api, &table[5], &sPhiP2, curve) // sP1 + sP2 + sPhiP2 + table[14].add(api, &table[6], &sPhiP2, curve) // sPhiP1 + sP2 + sPhiP2 + table[19].add(api, &table[3], &sR, curve) // sP1 + sPhiP1 + sR + table[21].add(api, &table[5], &sR, curve) // sP1 + sP2 + sR + table[22].add(api, &table[6], &sR, curve) // sPhiP1 + sP2 + sR + table[25].add(api, &table[9], &sR, curve) // sP1 + sPhiP2 + sR + table[26].add(api, &table[10], &sR, curve) // sPhiP1 + sPhiP2 + sR + table[28].add(api, &table[12], &sR, curve) // sP2 + sPhiP2 + sR + table[35].add(api, &table[3], &sPhiR, curve) // sP1 + sPhiP1 + sPhiR + table[37].add(api, &table[5], &sPhiR, curve) // sP1 + sP2 + sPhiR + table[38].add(api, &table[6], &sPhiR, curve) // sPhiP1 + sP2 + sPhiR + table[41].add(api, &table[9], &sPhiR, curve) // sP1 + sPhiP2 + sPhiR + table[42].add(api, &table[10], &sPhiR, curve) // sPhiP1 + sPhiP2 + sPhiR + table[44].add(api, &table[12], &sPhiR, curve) // sP2 + sPhiP2 + sPhiR + table[49].add(api, &table[17], &sPhiR, curve) // sP1 + sR + sPhiR + table[50].add(api, &table[18], &sPhiR, curve) // sPhiP1 + sR + sPhiR + table[52].add(api, &table[20], &sPhiR, curve) // sP2 + sR + sPhiR + table[56].add(api, &table[24], &sPhiR, curve) // sPhiP2 + sR + sPhiR + + // 4-combinations + table[15].add(api, &table[7], &sPhiP2, curve) // sP1 + sPhiP1 + sP2 + sPhiP2 + table[23].add(api, &table[7], &sR, curve) // sP1 + sPhiP1 + sP2 + sR + table[27].add(api, &table[11], &sR, curve) // sP1 + sPhiP1 + sPhiP2 + sR + table[29].add(api, &table[13], &sR, curve) // sP1 + sP2 + sPhiP2 + sR + table[30].add(api, &table[14], &sR, curve) // sPhiP1 + sP2 + sPhiP2 + sR + table[39].add(api, &table[7], &sPhiR, curve) // sP1 + sPhiP1 + sP2 + sPhiR + table[43].add(api, &table[11], &sPhiR, curve) // sP1 + sPhiP1 + sPhiP2 + sPhiR + table[45].add(api, &table[13], &sPhiR, curve) // sP1 + sP2 + sPhiP2 + sPhiR + table[46].add(api, &table[14], &sPhiR, curve) // sPhiP1 + sP2 + sPhiP2 + sPhiR + table[51].add(api, &table[19], &sPhiR, curve) // sP1 + sPhiP1 + sR + sPhiR + table[53].add(api, &table[21], &sPhiR, curve) // sP1 + sP2 + sR + sPhiR + table[54].add(api, &table[22], &sPhiR, curve) // sPhiP1 + sP2 + sR + sPhiR + table[57].add(api, &table[25], &sPhiR, curve) // sP1 + sPhiP2 + sR + sPhiR + table[58].add(api, &table[26], &sPhiR, curve) // sPhiP1 + sPhiP2 + sR + sPhiR + table[60].add(api, &table[28], &sPhiR, curve) // sP2 + sPhiP2 + sR + sPhiR + + // 5-combinations + table[31].add(api, &table[15], &sR, curve) // all except sPhiR + table[47].add(api, &table[15], &sPhiR, curve) // all except sR + table[55].add(api, &table[23], &sPhiR, curve) // sP1 + sPhiP1 + sP2 + sR + sPhiR + table[59].add(api, &table[27], &sPhiR, curve) // sP1 + sPhiP1 + sPhiP2 + sR + sPhiR + table[61].add(api, &table[29], &sPhiR, curve) // sP1 + sP2 + sPhiP2 + sR + sPhiR + table[62].add(api, &table[30], &sPhiR, curve) // sPhiP1 + sP2 + sPhiP2 + sR + sPhiR + + // 6-combination (all points) + table[63].add(api, &table[31], &sPhiR, curve) + + // Use LogDerivLookup for the 64-entry table + tableX := logderivlookup.New(api) + tableY := logderivlookup.New(api) + for i := 0; i < 64; i++ { + tableX.Insert(table[i].X) + tableY.Insert(table[i].Y) + } + + n := len(bX1) + + // Compute indices for lookups + indices := make([]frontend.Variable, n) + for i := 0; i < n; i++ { + indices[i] = api.Add( + bX1[i], + api.Mul(bY1[i], 2), + api.Mul(bX2[i], 4), + api.Mul(bY2[i], 8), + api.Mul(bZ[i], 16), + api.Mul(bT[i], 32), + ) + } + + // Batch lookup + lookupX := tableX.Lookup(indices...) + lookupY := tableY.Lookup(indices...) + + // Initialize accumulator with last entry + var acc Point + acc.X = lookupX[n-1] + acc.Y = lookupY[n-1] + + for i := n - 2; i >= 0; i-- { + acc.double(api, &acc, curve) + var tmp Point + tmp.X = lookupX[i] + tmp.Y = lookupY[i] + acc.add(api, &acc, &tmp, curve) + } + + // Verify accumulator equals identity (0, 1) + api.AssertIsEqual(acc.X, 0) + api.AssertIsEqual(acc.Y, 1) + + // Return R (the hinted result) + p.X = R.X + p.Y = R.Y + + return p +} diff --git a/std/algebra/native/twistededwards/scalar_decomp.go b/std/algebra/native/twistededwards/scalar_decomp.go new file mode 100644 index 0000000000..94cac78901 --- /dev/null +++ b/std/algebra/native/twistededwards/scalar_decomp.go @@ -0,0 +1,242 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +package twistededwards + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +// verifyScalarDecomposition checks s1 + s2*scalar ≡ 0 (mod r) using emulated +// arithmetic to avoid native field overflow. The sign bit controls whether +// the relation is s1 + s2*scalar or s1 - s2*scalar. +// +// s1 and s2 are range-checked to nBits via ToBinary inside this function. +// Returns the bit decompositions of s1 and s2. +func verifyScalarDecomposition( + api frontend.API, + s1, s2, bit, scalar frontend.Variable, + curve *CurveParams, +) (s1Bits, s2Bits []frontend.Variable) { + r := curve.Order + n := (r.BitLen() + 1) / 2 + + // Range-check s1, s2 via ToBinary + s1Bits = api.ToBinary(s1, n) + s2Bits = api.ToBinary(s2, n) + + // Dispatch to the correct emulated field based on the curve order + switch { + case r.BitLen() <= 253 && r.Cmp(edBN254Order{}.Modulus()) == 0: + verifyDecompEmulated[edBN254Order](api, s1, s2, bit, scalar, s1Bits, s2Bits, r) + case r.BitLen() <= 253 && r.Cmp(edBLS12381Order{}.Modulus()) == 0: + verifyDecompEmulated[edBLS12381Order](api, s1, s2, bit, scalar, s1Bits, s2Bits, r) + case r.BitLen() <= 253 && r.Cmp(edBandersnatchOrder{}.Modulus()) == 0: + verifyDecompEmulated[edBandersnatchOrder](api, s1, s2, bit, scalar, s1Bits, s2Bits, r) + case r.BitLen() <= 253 && r.Cmp(edBLS12377Order{}.Modulus()) == 0: + verifyDecompEmulated[edBLS12377Order](api, s1, s2, bit, scalar, s1Bits, s2Bits, r) + case r.Cmp(edBW6761Order{}.Modulus()) == 0: + verifyDecompEmulated[edBW6761Order](api, s1, s2, bit, scalar, s1Bits, s2Bits, r) + default: + panic(fmt.Sprintf("unsupported twisted Edwards curve order: %s", r.String())) + } + + return s1Bits, s2Bits +} + +func verifyDecompEmulated[T emulated.FieldParams]( + api frontend.API, + s1, s2, bit, scalar frontend.Variable, + s1Bits, s2Bits []frontend.Variable, + r *big.Int, +) { + f, err := emulated.NewField[T](api) + if err != nil { + panic(fmt.Sprintf("failed to create emulated field: %v", err)) + } + + scalarBits := api.ToBinary(scalar, api.Compiler().FieldBitLen()) + + s1Emu := f.FromBits(s1Bits...) + s2Emu := f.FromBits(s2Bits...) + scalarEmu := f.FromBits(scalarBits...) + zero := f.Zero() + + // Compute s2 * scalar mod r + s2s := f.Mul(s2Emu, scalarEmu) + + // Check: s1 ± s2*scalar ≡ 0 (mod r) + // When bit=0: s1 + s2*scalar ≡ 0 → s1 ≡ -s2*scalar + // When bit=1: s1 - s2*scalar ≡ 0 → s1 ≡ s2*scalar + // Equivalently: s1 + Select(bit, -s2s, s2s) ≡ 0 + negS2s := f.Neg(s2s) + term := f.Select(bit, negS2s, s2s) + sum := f.Add(s1Emu, term) + f.AssertIsEqual(sum, zero) + + // Ensure s2 is non-zero to prevent trivial decomposition. + // When scalar=0, s2=0 is legitimate. + scalarIsZero := api.IsZero(scalar) + s2Check := f.Select(scalarIsZero, f.One(), s2Emu) + f.AssertIsDifferent(s2Check, zero) +} + +// verifyScalarDecomposition3D checks a shared-denominator decomposition: +// s1*z ≡ x1 (mod r) and s2*z ≡ x2 (mod r). +// Used by doubleBaseScalarMul3MSMLogUp. +func verifyScalarDecomposition3D( + api frontend.API, + s1, s2 frontend.Variable, + absX1, absX2, absZ frontend.Variable, + signX1, signX2, signZ frontend.Variable, + curve *CurveParams, +) (x1Bits, x2Bits, zBits []frontend.Variable) { + r := curve.Order + n := (2*r.BitLen() + 2) / 3 + + x1Bits = api.ToBinary(absX1, n) + x2Bits = api.ToBinary(absX2, n) + zBits = api.ToBinary(absZ, n) + + switch { + case r.Cmp(edBN254Order{}.Modulus()) == 0: + verifyDecomp3DEmulated[edBN254Order](api, s1, s2, signX1, signX2, signZ, x1Bits, x2Bits, zBits) + case r.Cmp(edBLS12381Order{}.Modulus()) == 0: + verifyDecomp3DEmulated[edBLS12381Order](api, s1, s2, signX1, signX2, signZ, x1Bits, x2Bits, zBits) + case r.Cmp(edBandersnatchOrder{}.Modulus()) == 0: + verifyDecomp3DEmulated[edBandersnatchOrder](api, s1, s2, signX1, signX2, signZ, x1Bits, x2Bits, zBits) + case r.Cmp(edBLS12377Order{}.Modulus()) == 0: + verifyDecomp3DEmulated[edBLS12377Order](api, s1, s2, signX1, signX2, signZ, x1Bits, x2Bits, zBits) + case r.Cmp(edBW6761Order{}.Modulus()) == 0: + verifyDecomp3DEmulated[edBW6761Order](api, s1, s2, signX1, signX2, signZ, x1Bits, x2Bits, zBits) + default: + panic(fmt.Sprintf("unsupported twisted Edwards curve order: %s", r.String())) + } + + return +} + +func verifyDecomp3DEmulated[T emulated.FieldParams]( + api frontend.API, + s1, s2 frontend.Variable, + signX1, signX2, signZ frontend.Variable, + x1Bits, x2Bits, zBits []frontend.Variable, +) { + f, err := emulated.NewField[T](api) + if err != nil { + panic(fmt.Sprintf("failed to create emulated field: %v", err)) + } + + nativeBits := api.Compiler().FieldBitLen() + s1Bits := api.ToBinary(s1, nativeBits) + s2Bits := api.ToBinary(s2, nativeBits) + + x1Emu := f.FromBits(x1Bits...) + x2Emu := f.FromBits(x2Bits...) + zEmu := f.FromBits(zBits...) + s1Emu := f.FromBits(s1Bits...) + s2Emu := f.FromBits(s2Bits...) + zero := f.Zero() + + x1Signed := f.Select(signX1, f.Neg(x1Emu), x1Emu) + x2Signed := f.Select(signX2, f.Neg(x2Emu), x2Emu) + zSigned := f.Select(signZ, f.Neg(zEmu), zEmu) + + f.AssertIsEqual(f.Mul(s1Emu, zSigned), x1Signed) + f.AssertIsEqual(f.Mul(s2Emu, zSigned), x2Signed) + + f.AssertIsDifferent(zEmu, zero) +} + +// verifyScalarDecomposition6D checks the 6D decomposition for doubleBaseScalarMul6MSMLogUp. +// Verifies: s_i * (z + λ*t) ≡ x_i + λ*y_i (mod r) for i=1,2 +// All verification is done in emulated arithmetic over the curve order to avoid overflow. +func verifyScalarDecomposition6D( + api frontend.API, + s1, s2 frontend.Variable, + absX1, absY1, absX2, absY2, absZ, absT frontend.Variable, + signX1, signY1, signX2, signY2, signZ, signT frontend.Variable, + curve *CurveParams, + endo *EndoParams, +) (x1Bits, y1Bits, x2Bits, y2Bits, zBits, tBits []frontend.Variable) { + r := curve.Order + n := (r.BitLen() + 2) / 3 + + x1Bits = api.ToBinary(absX1, n) + y1Bits = api.ToBinary(absY1, n) + x2Bits = api.ToBinary(absX2, n) + y2Bits = api.ToBinary(absY2, n) + zBits = api.ToBinary(absZ, n) + tBits = api.ToBinary(absT, n) + + switch { + case r.Cmp(edBandersnatchOrder{}.Modulus()) == 0: + verify6DEmulated[edBandersnatchOrder](api, s1, s2, x1Bits, y1Bits, x2Bits, y2Bits, zBits, tBits, + signX1, signY1, signX2, signY2, signZ, signT, r, endo.Lambda) + default: + // Currently only Bandersnatch has an endomorphism. Add other cases as needed. + panic(fmt.Sprintf("unsupported twisted Edwards curve order for 6D decomposition: %s", r.String())) + } + + return +} + +func verify6DEmulated[T emulated.FieldParams]( + api frontend.API, + s1, s2 frontend.Variable, + absX1Bits, absY1Bits, absX2Bits, absY2Bits, absZBits, absTBits []frontend.Variable, + signX1, signY1, signX2, signY2, signZ, signT frontend.Variable, + r, lambda *big.Int, +) { + f, err := emulated.NewField[T](api) + if err != nil { + panic(fmt.Sprintf("failed to create emulated field: %v", err)) + } + + absX1Emu := f.FromBits(absX1Bits...) + absY1Emu := f.FromBits(absY1Bits...) + absX2Emu := f.FromBits(absX2Bits...) + absY2Emu := f.FromBits(absY2Bits...) + absZEmu := f.FromBits(absZBits...) + absTEmu := f.FromBits(absTBits...) + + lambdaEmu := f.NewElement(lambda) + zero := f.Zero() + + // Signed values in emulated field + x1Emu := f.Select(signX1, f.Neg(absX1Emu), absX1Emu) + y1Emu := f.Select(signY1, f.Neg(absY1Emu), absY1Emu) + x2Emu := f.Select(signX2, f.Neg(absX2Emu), absX2Emu) + y2Emu := f.Select(signY2, f.Neg(absY2Emu), absY2Emu) + zEmu := f.Select(signZ, f.Neg(absZEmu), absZEmu) + tEmu := f.Select(signT, f.Neg(absTEmu), absTEmu) + + // d = z + λ*t (mod r) + dComputed := f.Add(zEmu, f.Mul(lambdaEmu, tEmu)) + + // n1 = x1 + λ*y1 (mod r) + n1Computed := f.Add(x1Emu, f.Mul(lambdaEmu, y1Emu)) + + // n2 = x2 + λ*y2 (mod r) + n2Computed := f.Add(x2Emu, f.Mul(lambdaEmu, y2Emu)) + + // s1 * d ≡ n1 (mod r) + nativeBits := api.Compiler().FieldBitLen() + s1Bits := api.ToBinary(s1, nativeBits) + s1Emu := f.FromBits(s1Bits...) + f.AssertIsEqual(f.Mul(s1Emu, dComputed), n1Computed) + + // s2 * d ≡ n2 (mod r) + s2Bits := api.ToBinary(s2, nativeBits) + s2Emu := f.FromBits(s2Bits...) + f.AssertIsEqual(f.Mul(s2Emu, dComputed), n2Computed) + + // Ensure d non-zero (unless both scalars are zero) + bothZero := api.And(api.IsZero(s1), api.IsZero(s2)) + dCheck := f.Select(bothZero, f.One(), dComputed) + f.AssertIsDifferent(dCheck, zero) +} diff --git a/std/algebra/native/twistededwards/twistededwards.go b/std/algebra/native/twistededwards/twistededwards.go index 0e730e93d6..c20cd732fc 100644 --- a/std/algebra/native/twistededwards/twistededwards.go +++ b/std/algebra/native/twistededwards/twistededwards.go @@ -33,6 +33,10 @@ type Curve interface { ScalarMul(p1 Point, scalar frontend.Variable) Point // DoubleBaseScalarMul computes [s1]p1+[s2]p2 for points that lie on the curve. DoubleBaseScalarMul(p1, p2 Point, s1, s2 frontend.Variable) Point + // DoubleBaseScalarMulNonZero computes [s1]p1+[s2]p2 with the optimized + // lattice MSM path. It requires s1, s2 to be nonzero and p1, p2 to be + // non-identity points. + DoubleBaseScalarMulNonZero(p1, p2 Point, s1, s2 frontend.Variable) Point API() frontend.API } @@ -48,6 +52,15 @@ type CurveParams struct { Base [2]*big.Int // base point coordinates } +// EndoParams holds the GLV endomorphism parameters for curves that have one +// (Bandersnatch). The endomorphism Φ(x, y) = ((Endo[0]·(1−y²)/(x·y) + …) acts as +// scalar multiplication by Lambda on the prime-order subgroup. This is used +// only by the `doubleBaseScalarMul6MSMLogUp` MSM(6, n/3) variant. +type EndoParams struct { + Endo [2]*big.Int + Lambda *big.Int +} + // NewEdCurve returns a new Edwards curve func NewEdCurve(api frontend.API, id twistededwards.ID) (Curve, error) { snarkField, err := GetSnarkField(id) @@ -62,8 +75,18 @@ func NewEdCurve(api frontend.API, id twistededwards.ID) (Curve, error) { return nil, err } - // default - return &curve{api: api, params: params, id: id}, nil + var endo *EndoParams + if id == twistededwards.BLS12_381_BANDERSNATCH { + endo = &EndoParams{ + Endo: [2]*big.Int{new(big.Int), new(big.Int)}, + Lambda: new(big.Int), + } + endo.Endo[0].SetString("37446463827641770816307242315180085052603635617490163568005256780843403514036", 10) + endo.Endo[1].SetString("49199877423542878313146170939139662862850515542392585932876811575731455068989", 10) + endo.Lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10) + } + + return &curve{api: api, params: params, endo: endo, id: id}, nil } func GetCurveParams(id twistededwards.ID) (*CurveParams, error) {