diff --git a/internal/stats/latest_stats.csv b/internal/stats/latest_stats.csv index c0bb1990e7..b6e8158c49 100644 --- a/internal/stats/latest_stats.csv +++ b/internal/stats/latest_stats.csv @@ -95,24 +95,24 @@ pairing_bn254,bn254,groth16,506052,823961 pairing_bn254,bn254,plonk,1646819,1573151 pairing_bw6761,bn254,groth16,1589471,2646707 pairing_bw6761,bn254,plonk,5318762,5097941 -scalar_mul_G1_bn254,bn254,groth16,108168,163915 -scalar_mul_G1_bn254,bn254,plonk,355353,340385 -scalar_mul_G1_bn254_incomplete,bn254,groth16,51579,81902 -scalar_mul_G1_bn254_incomplete,bn254,plonk,185916,179316 -scalar_mul_G2_bls12381,bn254,groth16,303414,456759 -scalar_mul_G2_bls12381,bn254,plonk,989717,947171 -scalar_mul_G2_bn254,bn254,groth16,217155,326138 -scalar_mul_G2_bn254,bn254,plonk,721207,689834 -scalar_mul_G2_bw6761,bn254,groth16,389209,617382 -scalar_mul_G2_bw6761,bn254,plonk,1244592,1192510 -scalar_mul_P256,bn254,groth16,96724,151768 -scalar_mul_P256,bn254,plonk,328895,315729 -scalar_mul_P256_incomplete,bn254,groth16,75542,121798 -scalar_mul_P256_incomplete,bn254,plonk,263160,253523 -scalar_mul_secp256k1,bn254,groth16,108204,163975 -scalar_mul_secp256k1,bn254,plonk,355502,340530 -scalar_mul_secp256k1_incomplete,bn254,groth16,51619,81970 -scalar_mul_secp256k1_incomplete,bn254,plonk,186082,179475 +scalar_mul_G1_bn254,bn254,groth16,107499,163170 +scalar_mul_G1_bn254,bn254,plonk,353793,338853 +scalar_mul_G1_bn254_incomplete,bn254,groth16,50892,81121 +scalar_mul_G1_bn254_incomplete,bn254,plonk,184266,177694 +scalar_mul_G2_bls12381,bn254,groth16,302753,456022 +scalar_mul_G2_bls12381,bn254,plonk,988292,945775 +scalar_mul_G2_bn254,bn254,groth16,216495,325402 +scalar_mul_G2_bn254,bn254,plonk,719650,688306 +scalar_mul_G2_bw6761,bn254,groth16,387972,616049 +scalar_mul_G2_bw6761,bn254,plonk,1241833,1189810 +scalar_mul_P256,bn254,groth16,96434,151466 +scalar_mul_P256,bn254,plonk,328264,315107 +scalar_mul_P256_incomplete,bn254,groth16,75252,121496 +scalar_mul_P256_incomplete,bn254,plonk,262529,252901 +scalar_mul_secp256k1,bn254,groth16,107536,163231 +scalar_mul_secp256k1,bn254,plonk,353942,338998 +scalar_mul_secp256k1_incomplete,bn254,groth16,50932,81189 +scalar_mul_secp256k1_incomplete,bn254,plonk,184432,177853 selector/binaryMux_4,bn254,groth16,5,3 selector/binaryMux_4,bls12_377,groth16,5,3 selector/binaryMux_4,bls12_381,groth16,5,3 diff --git a/std/algebra/emulated/sw_bls12381/g2.go b/std/algebra/emulated/sw_bls12381/g2.go index 47b3a6aa55..a81fdc5ab2 100644 --- a/std/algebra/emulated/sw_bls12381/g2.go +++ b/std/algebra/emulated/sw_bls12381/g2.go @@ -693,6 +693,9 @@ func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.Alg if err != nil { panic(err) } + var st ScalarField + // LLL Hermite bound: u_i, v_i < γ₄·r^(1/4), fits in (BitLen+3)/4 + 2 bits. + nbits := (st.Modulus().BitLen()+3)/4 + 2 // handle 0-scalar and (-1)-scalar cases var isScalarZero, isScalarZeroOrMinusOne, isScalarOne, isScalarMinusOne frontend.Variable @@ -708,7 +711,8 @@ func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.Alg // Decompose s into (u1, u2, v1, v2) via LLL: s·(v1 + λ·v2) + u1 + λ·u2 ≡ 0 // (mod r), with each sub-scalar bounded by ~r^(1/4). - signs, sd, err := g2.fr.NewHintGeneric(rationalReconstructExtG2, 4, 4, nil, []*emulated.Element[ScalarField]{_s, g2.eigenvalue}) + signs, sd, err := g2.fr.NewHintGeneric(rationalReconstructExtG2, 4, 4, nil, []*emulated.Element[ScalarField]{_s, g2.eigenvalue}, + emulated.WithHintOutputRangeCheckBits(map[int]int{4: nbits, 5: nbits, 6: nbits, 7: nbits})) if err != nil { panic(fmt.Sprintf("rationalReconstructExtG2 hint: %v", err)) } @@ -716,7 +720,7 @@ func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.Alg isNegu1, isNegu2, isNegv1, isNegv2 := signs[0], signs[1], signs[2], signs[3] // Verify s·(v1 + λ·v2) + u1 + λ·u2 ≡ 0 (mod r). - var st ScalarField + sv1 := g2.fr.Mul(_s, v1) sλv2 := g2.fr.Mul(_s, g2.fr.Mul(g2.eigenvalue, v2)) λu2 := g2.fr.Mul(g2.eigenvalue, u2) @@ -834,8 +838,6 @@ func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.Alg g2GenPoint := &G2Affine{P: *g2.g2Gen} Acc = addFn(Acc, g2GenPoint) - // LLL Hermite bound: u_i, v_i < γ₄·r^(1/4), fits in (BitLen+3)/4 + 2 bits. - nbits := (st.Modulus().BitLen()+3)/4 + 2 u1bits := g2.fr.ToBits(u1) u2bits := g2.fr.ToBits(u2) v1bits := g2.fr.ToBits(v1) diff --git a/std/algebra/emulated/sw_bls12381/hints.go b/std/algebra/emulated/sw_bls12381/hints.go index 86141eee50..051f7d0d73 100644 --- a/std/algebra/emulated/sw_bls12381/hints.go +++ b/std/algebra/emulated/sw_bls12381/hints.go @@ -6,7 +6,6 @@ import ( "math/big" "github.com/consensys/gnark-crypto/algebra/lattice" - "github.com/consensys/gnark-crypto/ecc" bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" "github.com/consensys/gnark-crypto/ecc/bls12-381/hash_to_curve" @@ -24,7 +23,6 @@ func GetHints() []solver.Hint { finalExpHint, pairingCheckHint, millerLoopAndCheckFinalExpHint, - decomposeScalarG1, scalarMulG2Hint, rationalReconstructExtG2, g1SqrtRatioHint, @@ -287,49 +285,6 @@ func millerLoopAndCheckFinalExpHint(nativeMod *big.Int, nativeInputs, nativeOutp }) } -func decomposeScalarG1(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { - return emulated.UnwrapHintContext(mod, inputs, outputs, func(hc emulated.HintContext) error { - moduli := hc.EmulatedModuli() - if len(moduli) != 1 { - return fmt.Errorf("expecting one moduli, got %d", len(moduli)) - } - _, nativeOutputs := hc.NativeInputsOutputs() - if len(nativeOutputs) != 2 { - return fmt.Errorf("expecting two outputs, got %d", len(nativeOutputs)) - } - emuInputs, emuOutputs := hc.InputsOutputs(moduli[0]) - if len(emuInputs) != 2 { - return fmt.Errorf("expecting two inputs, got %d", len(emuInputs)) - } - if len(emuOutputs) != 2 { - return fmt.Errorf("expecting two outputs, got %d", len(emuOutputs)) - } - - glvBasis := new(ecc.Lattice) - ecc.PrecomputeLattice(moduli[0], emuInputs[1], glvBasis) - sp := ecc.SplitScalar(emuInputs[0], glvBasis) - emuOutputs[0].Set(&sp[0]) - emuOutputs[1].Set(&sp[1]) - nativeOutputs[0].SetUint64(0) - nativeOutputs[1].SetUint64(0) - // we need the absolute values for the in-circuit computations, - // otherwise the negative values will be reduced modulo the SNARK scalar - // field and not the emulated field. - // output0 = |s0| mod r - // output1 = |s1| mod r - if emuOutputs[0].Sign() == -1 { - emuOutputs[0].Neg(emuOutputs[0]) - nativeOutputs[0].SetUint64(1) - } - if emuOutputs[1].Sign() == -1 { - emuOutputs[1].Neg(emuOutputs[1]) - nativeOutputs[1].SetUint64(1) - } - - return nil - }) -} - // g1SqrtRatio computes the square root of u/v and returns 0 iff u/v was indeed a quadratic residue // if not, we get sqrt(Z * u / v). Recall that Z is non-residue // If v = 0, u/v is meaningless and the output is unspecified, without raising an error. diff --git a/std/algebra/emulated/sw_bn254/g2.go b/std/algebra/emulated/sw_bn254/g2.go index 6df014c75b..208d704098 100644 --- a/std/algebra/emulated/sw_bn254/g2.go +++ b/std/algebra/emulated/sw_bn254/g2.go @@ -478,6 +478,9 @@ func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.Alg if err != nil { panic(err) } + var st ScalarField + // u1, u2, v1, v2 < c*r^{1/4} where c ≈ 1.25 + nbits := (st.Modulus().BitLen()+3)/4 + 2 // handle 0-scalar and (-1)-scalar cases var isScalarZero, isScalarZeroOrMinusOne, isScalarOne, isScalarMinusOne frontend.Variable @@ -493,7 +496,8 @@ func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.Alg // Decompose s into (u1, u2, v1, v2) via LLL: s·(v1 + λ·v2) + u1 + λ·u2 ≡ 0 // (mod r), with each sub-scalar bounded by ~r^(1/4). - signs, sd, err := g2.fr.NewHintGeneric(rationalReconstructExtG2, 4, 4, nil, []*emulated.Element[ScalarField]{_s, g2.eigenvalue}) + signs, sd, err := g2.fr.NewHintGeneric(rationalReconstructExtG2, 4, 4, nil, []*emulated.Element[ScalarField]{_s, g2.eigenvalue}, + emulated.WithHintOutputRangeCheckBits(map[int]int{4: nbits, 5: nbits, 6: nbits, 7: nbits})) if err != nil { panic(fmt.Sprintf("rationalReconstructExtG2 hint: %v", err)) } @@ -501,7 +505,7 @@ func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.Alg isNegu1, isNegu2, isNegv1, isNegv2 := signs[0], signs[1], signs[2], signs[3] // Verify s·(v1 + λ·v2) + u1 + λ·u2 ≡ 0 (mod r). - var st ScalarField + sv1 := g2.fr.Mul(_s, v1) sλv2 := g2.fr.Mul(_s, g2.fr.Mul(g2.eigenvalue, v2)) λu2 := g2.fr.Mul(g2.eigenvalue, u2) @@ -629,8 +633,6 @@ func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.Alg g2GenPoint := &G2Affine{P: *g2.g2Gen} Acc = addFn(Acc, g2GenPoint) - // u1, u2, v1, v2 < c*r^{1/4} where c ≈ 1.25 - nbits := (st.Modulus().BitLen()+3)/4 + 2 u1bits := g2.fr.ToBits(u1) u2bits := g2.fr.ToBits(u2) v1bits := g2.fr.ToBits(v1) diff --git a/std/algebra/emulated/sw_bw6761/g2.go b/std/algebra/emulated/sw_bw6761/g2.go index 3b0bce4cfc..2b19a5e10f 100644 --- a/std/algebra/emulated/sw_bw6761/g2.go +++ b/std/algebra/emulated/sw_bw6761/g2.go @@ -408,6 +408,9 @@ func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.Alg if err != nil { panic(err) } + var st ScalarField + // u1, u2, v1, v2 < c*r^{1/4} where c ≈ 1.25 + nbits := (st.Modulus().BitLen()+3)/4 + 2 // handle 0-scalar and (-1)-scalar cases var isScalarZero, isScalarZeroOrMinusOne, isScalarOne, isScalarMinusOne frontend.Variable @@ -423,7 +426,8 @@ func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.Alg // Decompose s into (u1, u2, v1, v2) via LLL: s·(v1 + λ·v2) + u1 + λ·u2 ≡ 0 // (mod r), with each sub-scalar bounded by ~r^(1/4). - signs, sd, err := g2.fr.NewHintGeneric(rationalReconstructExtG2, 4, 4, nil, []*emulated.Element[ScalarField]{_s, g2.eigenvalue}) + signs, sd, err := g2.fr.NewHintGeneric(rationalReconstructExtG2, 4, 4, nil, []*emulated.Element[ScalarField]{_s, g2.eigenvalue}, + emulated.WithHintOutputRangeCheckBits(map[int]int{4: nbits, 5: nbits, 6: nbits, 7: nbits})) if err != nil { panic(fmt.Sprintf("rationalReconstructExtG2 hint: %v", err)) } @@ -431,7 +435,6 @@ func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.Alg isNegu1, isNegu2, isNegv1, isNegv2 := signs[0], signs[1], signs[2], signs[3] // Verify s·(v1 + λ·v2) + u1 + λ·u2 ≡ 0 (mod r). - var st ScalarField sv1 := g2.fr.Mul(_s, v1) sλv2 := g2.fr.Mul(_s, g2.fr.Mul(g2.eigenvalue, v2)) λu2 := g2.fr.Mul(g2.eigenvalue, u2) @@ -559,8 +562,6 @@ func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.Alg g2GenPoint := &G2Affine{P: *g2.g2Gen} Acc = addFn(Acc, g2GenPoint) - // u1, u2, v1, v2 < c*r^{1/4} where c ≈ 1.25 - nbits := (st.Modulus().BitLen()+3)/4 + 2 u1bits := g2.fr.ToBits(u1) u2bits := g2.fr.ToBits(u2) v1bits := g2.fr.ToBits(v1) diff --git a/std/algebra/emulated/sw_bw6761/hints.go b/std/algebra/emulated/sw_bw6761/hints.go index 9821a6b074..67d44d0ce9 100644 --- a/std/algebra/emulated/sw_bw6761/hints.go +++ b/std/algebra/emulated/sw_bw6761/hints.go @@ -5,7 +5,6 @@ import ( "math/big" "github.com/consensys/gnark-crypto/algebra/lattice" - "github.com/consensys/gnark-crypto/ecc" bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/std/math/emulated" @@ -20,7 +19,6 @@ func GetHints() []solver.Hint { return []solver.Hint{ finalExpHint, pairingCheckHint, - decomposeScalarG1, scalarMulG2Hint, rationalReconstructExtG2, } @@ -118,49 +116,6 @@ func finalExpWitness(millerLoop *bw6761.E6, mInv *big.Int) (residueWitness bw676 return residueWitness } -func decomposeScalarG1(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { - return emulated.UnwrapHintContext(mod, inputs, outputs, func(hc emulated.HintContext) error { - moduli := hc.EmulatedModuli() - if len(moduli) != 1 { - return fmt.Errorf("expecting one moduli, got %d", len(moduli)) - } - _, nativeOutputs := hc.NativeInputsOutputs() - if len(nativeOutputs) != 2 { - return fmt.Errorf("expecting two outputs, got %d", len(nativeOutputs)) - } - emuInputs, emuOutputs := hc.InputsOutputs(moduli[0]) - if len(emuInputs) != 2 { - return fmt.Errorf("expecting two inputs, got %d", len(emuInputs)) - } - if len(emuOutputs) != 2 { - return fmt.Errorf("expecting two outputs, got %d", len(emuOutputs)) - } - - glvBasis := new(ecc.Lattice) - ecc.PrecomputeLattice(moduli[0], emuInputs[1], glvBasis) - sp := ecc.SplitScalar(emuInputs[0], glvBasis) - emuOutputs[0].Set(&sp[0]) - emuOutputs[1].Set(&sp[1]) - nativeOutputs[0].SetUint64(0) - nativeOutputs[1].SetUint64(0) - // we need the absolute values for the in-circuit computations, - // otherwise the negative values will be reduced modulo the SNARK scalar - // field and not the emulated field. - // output0 = |s0| mod r - // output1 = |s1| mod r - if emuOutputs[0].Sign() == -1 { - emuOutputs[0].Neg(emuOutputs[0]) - nativeOutputs[0].SetUint64(1) - } - if emuOutputs[1].Sign() == -1 { - emuOutputs[1].Neg(emuOutputs[1]) - nativeOutputs[1].SetUint64(1) - } - - return nil - }) -} - func scalarMulG2Hint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { return emulated.UnwrapHintContext(field, inputs, outputs, func(hc emulated.HintContext) error { moduli := hc.EmulatedModuli() diff --git a/std/algebra/emulated/sw_emulated/point.go b/std/algebra/emulated/sw_emulated/point.go index 0404d1a385..9ed6f6b662 100644 --- a/std/algebra/emulated/sw_emulated/point.go +++ b/std/algebra/emulated/sw_emulated/point.go @@ -669,6 +669,8 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op if err != nil { panic(err) } + var st S + nbits := st.Modulus().BitLen()>>1 + 2 addFn := c.Add var isPointAtInfinity frontend.Variable if !cfg.IncompleteArithmetic { @@ -688,7 +690,9 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op // sub-scalars. // decompose s into s1 and s2 - sdBits, sd, err := c.scalarApi.NewHintGeneric(decomposeScalarG1, 2, 2, nil, []*emulated.Element[S]{s, c.eigenvalue}) + sdBits, sd, err := c.scalarApi.NewHintGeneric(decomposeScalarG1, 2, 2, nil, []*emulated.Element[S]{s, c.eigenvalue}, + emulated.WithHintOutputRangeCheckBits(map[int]int{2: nbits, 3: nbits}), + ) if err != nil { panic(fmt.Sprintf("compute GLV decomposition: %v", err)) } @@ -704,8 +708,6 @@ func (c *Curve[B, S]) scalarMulGLV(Q *AffinePoint[B], s *emulated.Element[S], op s1bits := c.scalarApi.ToBits(s1) s2bits := c.scalarApi.ToBits(s2) - var st S - nbits := st.Modulus().BitLen()>>1 + 2 // precompute -Q, Q, 3Q, -Φ(Q), Φ(Q), 3Φ(Q) var tableQ, tablePhiQ [3]*AffinePoint[B] @@ -1024,8 +1026,12 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat // and boolean flags sdBits and tdBits to negate the points Q, Φ(Q), R and // Φ(R) instead of the corresponding sub-scalars. + var st S + nbits := st.Modulus().BitLen()>>1 + 2 // decompose s into s1 and s2 - sdBits, sd, err := c.scalarApi.NewHintGeneric(decomposeScalarG1, 2, 2, nil, []*emulated.Element[S]{s, c.eigenvalue}) + sdBits, sd, err := c.scalarApi.NewHintGeneric(decomposeScalarG1, 2, 2, nil, []*emulated.Element[S]{s, c.eigenvalue}, + emulated.WithHintOutputRangeCheckBits(map[int]int{2: nbits, 3: nbits}), + ) if err != nil { panic(fmt.Sprintf("compute GLV decomposition s: %v", err)) } @@ -1040,7 +1046,9 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat ) // decompose t into t1 and t2 - tdBits, td, err := c.scalarApi.NewHintGeneric(decomposeScalarG1, 2, 2, nil, []*emulated.Element[S]{t, c.eigenvalue}) + tdBits, td, err := c.scalarApi.NewHintGeneric(decomposeScalarG1, 2, 2, nil, []*emulated.Element[S]{t, c.eigenvalue}, + emulated.WithHintOutputRangeCheckBits(map[int]int{2: nbits, 3: nbits}), + ) if err != nil { panic(fmt.Sprintf("compute GLV decomposition t: %v", err)) } @@ -1135,8 +1143,6 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat s2bits := c.scalarApi.ToBits(s2) t1bits := c.scalarApi.ToBits(t1) t2bits := c.scalarApi.ToBits(t2) - var st S - nbits := st.Modulus().BitLen()>>1 + 2 // At each iteration we look up the point Bi from: // B1 = +Q + R + Φ(Q) + Φ(R) @@ -1362,6 +1368,8 @@ func (c *Curve[B, S]) scalarMulFakeGLV(Q *AffinePoint[B], s *emulated.Element[S] if err != nil { panic(err) } + var st S + nbits := (st.Modulus().BitLen() + 1) / 2 var isScalarZero frontend.Variable _s := s @@ -1370,9 +1378,11 @@ func (c *Curve[B, S]) scalarMulFakeGLV(Q *AffinePoint[B], s *emulated.Element[S] _s = c.scalarApi.Select(isScalarZero, c.scalarApi.One(), s) } - // First we find the sub-salars s1, s2 s.t. s1 + s2*s = 0 mod r and s1, s2 < sqrt(r). + // First we find the sub-salars s1, s2 s.t. s1 + s2*s = 0 mod r and |s1|,|s2| < γ₂·√r ≈ 1.15·√r. // we also output the sign in case s2 is negative. In that case we compute _s2 = -s2 mod r. - sign, sd, err := c.scalarApi.NewHintGeneric(rationalReconstruct, 1, 2, nil, []*emulated.Element[S]{_s}) + sign, sd, err := c.scalarApi.NewHintGeneric(rationalReconstruct, 1, 2, nil, []*emulated.Element[S]{_s}, + // we know that the hint will return s1, s2 < sqrt(r) so we can set the hint output range check bits to nbits = ceil(log2(sqrt(r))) = ceil(log2(r)/2) + emulated.WithHintOutputRangeCheckBits(map[int]int{1: nbits, 2: nbits})) if err != nil { panic(fmt.Sprintf("rationalReconstruct hint: %v", err)) } @@ -1409,8 +1419,6 @@ func (c *Curve[B, S]) scalarMulFakeGLV(Q *AffinePoint[B], s *emulated.Element[S] r1 = c.baseApi.Select(isInputPointAtInfinity, &dummy.Y, r1) } - var st S - nbits := (st.Modulus().BitLen() + 1) / 2 s1bits := c.scalarApi.ToBits(s1) s2bits := c.scalarApi.ToBits(s2) @@ -1635,6 +1643,11 @@ func (c *Curve[B, S]) scalarMulGLVAndFakeGLV(P *AffinePoint[B], s *emulated.Elem if err != nil { panic(err) } + var st S + // LLL Hermite bound (gnark-crypto/algebra/lattice): u1, u2, v1, v2 are + // bounded by γ₄·r^(1/4) ≈ 1.25·r^(1/4), which fits in (BitLen+3)/4 + 2 bits. + // This is tighter than the previous heuristic BitLen/4 + 9 (saves ~7 iters). + nbits := (st.Modulus().BitLen()+3)/4 + 2 // handle 0-scalar and (-1)-scalar cases var isScalarZero, isScalarZeroOrMinusOne, isScalarOne, isScalarMinusOne frontend.Variable @@ -1676,7 +1689,10 @@ func (c *Curve[B, S]) scalarMulGLVAndFakeGLV(P *AffinePoint[B], s *emulated.Elem // Eisenstein integers real and imaginary parts can be negative. So we // return the absolute value in the hint and negate the corresponding // points here when needed. - signs, sd, err := c.scalarApi.NewHintGeneric(rationalReconstructExt, 4, 4, nil, []*emulated.Element[S]{_s, c.eigenvalue}) + signs, sd, err := c.scalarApi.NewHintGeneric(rationalReconstructExt, 4, 4, nil, []*emulated.Element[S]{_s, c.eigenvalue}, + // we later need to check that u1, u2, v1, v2 < c*r^(1/4) so we provide a hint output range check with nbits = (BitLen+3)/4 + 2 + emulated.WithHintOutputRangeCheckBits(map[int]int{4: nbits, 5: nbits, 6: nbits, 7: nbits}), + ) if err != nil { panic(fmt.Sprintf("rationalReconstructExt hint: %v", err)) } @@ -1685,7 +1701,6 @@ func (c *Curve[B, S]) scalarMulGLVAndFakeGLV(P *AffinePoint[B], s *emulated.Elem // We need to check that: // s*(v1 + λ*v2) + u1 + λ*u2 = 0 - var st S sv1 := c.scalarApi.Mul(_s, v1) sλv2 := c.scalarApi.Mul(_s, c.scalarApi.Mul(c.eigenvalue, v2)) λu2 := c.scalarApi.Mul(c.eigenvalue, u2) @@ -1800,10 +1815,6 @@ func (c *Curve[B, S]) scalarMulGLVAndFakeGLV(P *AffinePoint[B], s *emulated.Elem g := c.Generator() Acc = addFn(Acc, g) - // LLL Hermite bound (gnark-crypto/algebra/lattice): u1, u2, v1, v2 are - // bounded by γ₄·r^(1/4) ≈ 1.25·r^(1/4), which fits in (BitLen+3)/4 + 2 bits. - // This is tighter than the previous heuristic BitLen/4 + 9 (saves ~7 iters). - nbits := (st.Modulus().BitLen()+3)/4 + 2 u1bits := c.scalarApi.ToBits(u1) u2bits := c.scalarApi.ToBits(u2) v1bits := c.scalarApi.ToBits(v1) diff --git a/std/algebra/native/sw_bls12377/hints.go b/std/algebra/native/sw_bls12377/hints.go index 3302be200b..7e07a2961d 100644 --- a/std/algebra/native/sw_bls12377/hints.go +++ b/std/algebra/native/sw_bls12377/hints.go @@ -12,9 +12,7 @@ import ( func GetHints() []solver.Hint { return []solver.Hint{ - decomposeScalarG1, decomposeScalarG1Simple, - decomposeScalarG2, scalarMulGLVG1Hint, rationalReconstructExt, pairingCheckHint, @@ -234,60 +232,6 @@ func decomposeScalarG1Simple(scalarField *big.Int, inputs []*big.Int, outputs [] return nil } -func decomposeScalarG1(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error { - if len(inputs) != 1 { - return errors.New("expecting one input") - } - if len(outputs) != 3 { - return errors.New("expecting three outputs") - } - cc := getInnerCurveConfig(scalarField) - sp := ecc.SplitScalar(inputs[0], cc.glvBasis) - outputs[0].Set(&(sp[0])) - outputs[1].Set(&(sp[1])) - one := big.NewInt(1) - // add (lambda+1, lambda) until scalar compostion is over Fr to ensure that - // the high bits are set in decomposition. - for outputs[0].Cmp(cc.lambda) < 1 && outputs[1].Cmp(cc.lambda) < 1 { - outputs[0].Add(outputs[0], cc.lambda) - outputs[0].Add(outputs[0], one) - outputs[1].Add(outputs[1], cc.lambda) - } - // figure out how many times we have overflowed - outputs[2].Mul(outputs[1], cc.lambda).Add(outputs[2], outputs[0]) - outputs[2].Sub(outputs[2], inputs[0]) - outputs[2].Div(outputs[2], cc.fr) - - return nil -} - -func decomposeScalarG2(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error { - if len(inputs) != 1 { - return errors.New("expecting one input") - } - if len(outputs) != 3 { - return errors.New("expecting three outputs") - } - cc := getInnerCurveConfig(scalarField) - sp := ecc.SplitScalar(inputs[0], cc.glvBasis) - outputs[0].Set(&(sp[0])) - outputs[1].Set(&(sp[1])) - one := big.NewInt(1) - // add (lambda+1, lambda) until scalar compostion is over Fr to ensure that - // the high bits are set in decomposition. - for outputs[0].Cmp(cc.lambda) < 1 && outputs[1].Cmp(cc.lambda) < 1 { - outputs[0].Add(outputs[0], cc.lambda) - outputs[0].Add(outputs[0], one) - outputs[1].Add(outputs[1], cc.lambda) - } - // figure out how many times we have overflowed - outputs[2].Mul(outputs[1], cc.lambda).Add(outputs[2], outputs[0]) - outputs[2].Sub(outputs[2], inputs[0]) - outputs[2].Div(outputs[2], cc.fr) - - return nil -} - func scalarMulGLVG1Hint(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error { if len(inputs) != 3 { return errors.New("expecting three inputs") diff --git a/std/algebra/native/sw_grumpkin/g1.go b/std/algebra/native/sw_grumpkin/g1.go index 22feb69c53..978e340ffb 100644 --- a/std/algebra/native/sw_grumpkin/g1.go +++ b/std/algebra/native/sw_grumpkin/g1.go @@ -190,7 +190,7 @@ func (p *G1Affine) scalarMulGLV(api frontend.API, q G1Affine, s frontend.Variabl // curve. cc := getInnerCurveConfig(api.Compiler().Field()) - s1, s2 := callDecomposeScalar(api, s, true) + s1, s2 := callDecomposeScalar(api, s) nbits := 127 s1bits := api.ToBinary(s1, nbits) @@ -477,8 +477,8 @@ func (p *G1Affine) jointScalarMulUnsafe(api frontend.API, q, r G1Affine, s, t fr // DoubleAndAdd accumulator collisions. func (p *G1Affine) jointScalarMulGLVUnsafe(api frontend.API, q, r G1Affine, s, t frontend.Variable) *G1Affine { cc := getInnerCurveConfig(api.Compiler().Field()) - s1, s2 := callDecomposeScalar(api, s, false) - t1, t2 := callDecomposeScalar(api, t, false) + s1, s2 := callDecomposeScalar(api, s) + t1, t2 := callDecomposeScalar(api, t) nbits := cc.fr.BitLen()>>1 + 1 s1bits := api.ToBinary(s1, nbits) diff --git a/std/algebra/native/sw_grumpkin/hints.go b/std/algebra/native/sw_grumpkin/hints.go index 21ca653e88..0e4ffa9cce 100644 --- a/std/algebra/native/sw_grumpkin/hints.go +++ b/std/algebra/native/sw_grumpkin/hints.go @@ -23,15 +23,21 @@ func init() { } func decomposeScalar(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { - return emulated.UnwrapHintWithNativeInput(nativeInputs, nativeOutputs, func(nnMod *big.Int, nninputs, nnOutputs []*big.Int) error { - if len(nninputs) != 1 { - return errors.New("expecting one input") + return emulated.UnwrapHintContext(nativeMod, nativeInputs, nativeOutputs, func(hc emulated.HintContext) error { + moduli := hc.EmulatedModuli() + if len(moduli) != 1 { + return errors.New("expecting one modulus") } + nativeInputs, _ := hc.NativeInputsOutputs() + if len(nativeInputs) != 1 { + return errors.New("expecting one native input") + } + _, nnOutputs := hc.InputsOutputs(moduli[0]) if len(nnOutputs) != 2 { return errors.New("expecting two outputs") } cc := getInnerCurveConfig(nativeMod) - sp := ecc.SplitScalar(nninputs[0], cc.glvBasis) + sp := ecc.SplitScalar(nativeInputs[0], cc.glvBasis) nnOutputs[0].Set(&(sp[0])) nnOutputs[1].Neg(&(sp[1])) @@ -39,7 +45,7 @@ func decomposeScalar(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) }) } -func callDecomposeScalar(api frontend.API, s frontend.Variable, simple bool) (s1, s2 frontend.Variable) { +func callDecomposeScalar(api frontend.API, s frontend.Variable) (s1, s2 frontend.Variable) { cc := getInnerCurveConfig(api.Compiler().Field()) sapi, err := emulated.NewField[emparams.GrumpkinFr](api) if err != nil { @@ -51,7 +57,7 @@ func callDecomposeScalar(api frontend.API, s frontend.Variable, simple bool) (s1 // the hints allow to decompose the scalar s into s1 and s2 such that // s1 + λ * s2 == s mod r, // where λ is third root of one in 𝔽_r. - sd, err := sapi.NewHintWithNativeInput(decomposeScalar, 2, s) + _, sd, err := sapi.NewHintGeneric(decomposeScalar, 0, 2, []frontend.Variable{s}, nil, emulated.WithHintOutputRangeCheckBits(map[int]int{0: 127, 1: 127})) if err != nil { panic(err) } diff --git a/std/algebra/native/sw_grumpkin/wrapper.go b/std/algebra/native/sw_grumpkin/wrapper.go index 861a69be53..9670c3678b 100644 --- a/std/algebra/native/sw_grumpkin/wrapper.go +++ b/std/algebra/native/sw_grumpkin/wrapper.go @@ -208,7 +208,7 @@ func (c *Curve) MultiScalarMul(P []*G1Affine, scalars []*Scalar, opts ...algopts } gamma := c.packScalarToVar(scalars[0]) // decompose gamma in the endomorphism eigenvalue basis and bit-decompose the sub-scalars - gamma1, gamma2 := callDecomposeScalar(c.api, gamma, true) + gamma1, gamma2 := callDecomposeScalar(c.api, gamma) nbits := 127 gamma1Bits := c.api.ToBinary(gamma1, nbits) gamma2Bits := c.api.ToBinary(gamma2, nbits) diff --git a/std/algebra/native/twistededwards/hints.go b/std/algebra/native/twistededwards/hints.go index 0a52dba254..6e3a6ca6f5 100644 --- a/std/algebra/native/twistededwards/hints.go +++ b/std/algebra/native/twistededwards/hints.go @@ -3,7 +3,6 @@ package twistededwards import ( "errors" "math/big" - "sync" "github.com/consensys/gnark-crypto/algebra/lattice" "github.com/consensys/gnark-crypto/ecc" @@ -19,7 +18,6 @@ func GetHints() []solver.Hint { return []solver.Hint{ rationalReconstruct, scalarMulHint, - decomposeScalar, } } @@ -27,43 +25,6 @@ func init() { solver.RegisterHint(GetHints()...) } -type glvParams struct { - lambda, order big.Int - glvBasis ecc.Lattice -} - -func decomposeScalar(scalarField *big.Int, inputs []*big.Int, res []*big.Int) error { - // the efficient endomorphism exists on Bandersnatch only - if scalarField.Cmp(ecc.BLS12_381.ScalarField()) != 0 { - return errors.New("no efficient endomorphism is available on this curve") - } - var glv glvParams - var init sync.Once - init.Do(func() { - glv.lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10) - glv.order.SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) - ecc.PrecomputeLattice(&glv.order, &glv.lambda, &glv.glvBasis) - }) - - // sp[0] is always negative because, in SplitScalar(), we always round above - // the determinant/2 computed in PrecomputeLattice() which is negative for Bandersnatch. - // Thus taking -sp[0] here and negating the point in ScalarMul(). - // If we keep -sp[0] it will be reduced mod r (the BLS12-381 prime order) - // and not the Bandersnatch prime order (Order) and the result will be incorrect. - // Also, if we reduce it mod Order here, we can't use api.ToBinary(sp[0], 129) - // and hence we can't reduce optimally the number of constraints. - sp := ecc.SplitScalar(inputs[0], &glv.glvBasis) - res[0].Neg(&(sp[0])) - res[1].Set(&(sp[1])) - - // figure out how many times we have overflowed - res[2].Mul(res[1], &glv.lambda).Sub(res[2], res[0]) - res[2].Sub(res[2], inputs[0]) - res[2].Div(res[2], &glv.order) - - return nil -} - // rationalReconstruct decomposes a scalar s ∈ Fr into (s1, s2, signBit, k) such // that s1 + s2·s = k·r in the integers, with |s1|, |s2| < γ₂·√r ≈ 1.15·√r // (proven LLL/Hermite bound). Replaces the older heuristic-bound HalfGCD. diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index 9f308df1e5..e23a29ca1f 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -211,13 +211,24 @@ func (f *Field[T]) modulusPrev() *Element[T] { // less constraints will be generated. // If strict is false, each limbs is constrained to have width as defined by field parameter. func (f *Field[T]) packLimbs(limbs []frontend.Variable, strict bool) *Element[T] { + nbBits := len(limbs) * int(f.fParams.BitsPerLimb()) + if strict { + nbBits = f.fParams.Modulus().BitLen() + } + return f.packLimbsWithWidth(limbs, nbBits) +} + +// packLimbsWithWidth returns an element from the given limbs with range check to nbBits. +// The number of limbs must be exactly ceil(nbBits / BitsPerLimb()), and each limb is +// range-checked accordingly. +func (f *Field[T]) packLimbsWithWidth(limbs []frontend.Variable, nbBits int) *Element[T] { if !f.useSmallFieldOptimization() { e := f.newInternalElement(limbs, 0) - f.enforceWidth(e, strict) + f.enforceWidth(e, nbBits) return e } else { e := f.newInternalElement(limbs, uint(f.smallAdditionalOverflow())) - f.smallEnforceWidth(e, strict) + f.smallEnforceWidth(e, nbBits) return e } } @@ -279,9 +290,9 @@ func (f *Field[T]) enforceWidthConditional(a *Element[T]) (didConstrain bool) { } if didConstrain { if !f.useSmallFieldOptimization() { - f.enforceWidth(a, true) + f.enforceWidth(a, f.fParams.Modulus().BitLen()) } else { - f.smallEnforceWidth(a, true) + f.smallEnforceWidth(a, f.fParams.Modulus().BitLen()) } } return diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index 2d0bb885a4..450d50281f 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -7,37 +7,49 @@ import ( "github.com/consensys/gnark/profile" ) -// enforceWidth enforces the width of the limbs. When modWidth is true, then the -// limbs are asserted to be the width of the modulus (highest limb may be less -// than full limb width). Otherwise, every limb is assumed to have same width -// (defined by the field parameter). -func (f *Field[T]) enforceWidth(a *Element[T], modWidth bool) { +// enforceWidth enforces the width of the limbs to nbBits. The number of limbs +// must be exactly ceil(nbBits / BitsPerLimb()); it panics otherwise. All limbs +// except the last are range-checked to BitsPerLimb() bits; the last limb is +// range-checked to ((nbBits-1) % BitsPerLimb()) + 1 bits. +func (f *Field[T]) enforceWidth(a *Element[T], nbBits int) { + bitsPerLimb := int(f.fParams.BitsPerLimb()) + expectedNbLimbs := (nbBits + bitsPerLimb - 1) / bitsPerLimb if _, aConst := f.constantValue(a); aConst { - if modWidth && len(a.Limbs) != int(f.fParams.NbLimbs()) { + if len(a.Limbs) != expectedNbLimbs { panic("constant limb width doesn't match parametrized field") } } - if modWidth && len(a.Limbs) != int(f.fParams.NbLimbs()) { - panic("enforcing modulus width element with inexact number of limbs") + if len(a.Limbs) != expectedNbLimbs { + panic("enforcing width element with inexact number of limbs") } for i := range a.Limbs { - limbNbBits := int(f.fParams.BitsPerLimb()) - if modWidth && i == len(a.Limbs)-1 { + limbNbBits := bitsPerLimb + if i == len(a.Limbs)-1 { // take only required bits from the most significant limb - limbNbBits = ((f.fParams.Modulus().BitLen() - 1) % int(f.fParams.BitsPerLimb())) + 1 + limbNbBits = ((nbBits - 1) % bitsPerLimb) + 1 } f.rangeCheck(a.Limbs[i], limbNbBits) } } -func (f *Field[T]) smallEnforceWidth(a *Element[T], modWidth bool) { - if modWidth && len(a.Limbs) != int(f.fParams.NbLimbs()) { - panic("enforcing modulus width element with inexact number of limbs") +// smallEnforceWidth enforces a custom bit width on an element via the small +// field optimization path. Unlike enforceWidth, each limb is range-checked +// to nbBits+overflow. The limb count must equal ceil(nbBits / BitsPerLimb()). +func (f *Field[T]) smallEnforceWidth(a *Element[T], nbBits int) { + bitsPerLimb := int(f.fParams.BitsPerLimb()) + expectedNbLimbs := (nbBits + bitsPerLimb - 1) / bitsPerLimb + if len(a.Limbs) != expectedNbLimbs { + panic("enforcing width element with inexact number of limbs") } for i := range a.Limbs { - f.rangeCheck(a.Limbs[i], f.fParams.Modulus().BitLen()+int(a.overflow)) + limbNbBits := bitsPerLimb + if i == len(a.Limbs)-1 { + // take only required bits from the most significant limb + limbNbBits = ((nbBits - 1) % bitsPerLimb) + 1 + } + f.rangeCheck(a.Limbs[i], limbNbBits+int(a.overflow)) } } diff --git a/std/math/emulated/field_hint.go b/std/math/emulated/field_hint.go index 9570cedce8..be3be4117f 100644 --- a/std/math/emulated/field_hint.go +++ b/std/math/emulated/field_hint.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" limbs "github.com/consensys/gnark/std/internal/limbcomposition" + "github.com/consensys/gnark/std/rangecheck" ) // UnwrapHint unwraps the native inputs into nonnative inputs. Then it calls @@ -306,10 +307,12 @@ func wrapGenericHintInputs[T1, T2 FieldParams]( } // unwrapGenericHintOutputs unwraps the wrapped outputs from the hint function -// into elements of different fields. -func unwrapGenericHintOutputs[T1, T2 FieldParams](field *big.Int, fp1 *Field[T1], fp2 *Field[T2], +// into elements of different fields. If cfg is non-nil, per-output range checks +// from [WithHintOutputRangeCheckBits] are applied: native outputs (indices +// 0..nbNativeOutputs-1), then emulated1 outputs, then emulated2 outputs. +func unwrapGenericHintOutputs[T1, T2 FieldParams](api frontend.API, field *big.Int, fp1 *Field[T1], fp2 *Field[T2], nbNativeOutputs, nbEmulated1Outputs, nbEmulated2Outputs int, - hintOutputs []frontend.Variable, + hintOutputs []frontend.Variable, cfg *hintConfig, ) (nativeOutputs []frontend.Variable, emulated1Outputs []*Element[T1], emulated2Outputs []*Element[T2], err error) { effNbLimbs1, _ := GetEffectiveFieldParams[T1](field) effNbLimbs2, _ := GetEffectiveFieldParams[T2](field) @@ -318,14 +321,32 @@ func unwrapGenericHintOutputs[T1, T2 FieldParams](field *big.Int, fp1 *Field[T1] return nil, nil, nil, fmt.Errorf("hint outputs length mismatch: expected %d, got %d", nbExpectedOutputs, len(hintOutputs)) } nativeOutputs = hintOutputs[:nbNativeOutputs] + // apply range checks on native outputs when requested + rchecker := rangecheck.New(api) + for i := range nbNativeOutputs { + bits, ok := cfg.outputRangeCheckBits[i] + if !ok { + continue + } + if bits > 0 { + if bits > field.BitLen() { + return nil, nil, nil, fmt.Errorf("range check bits for native output %d exceed native field modulus bit size %d", i, field.BitLen()) + } + rchecker.Check(nativeOutputs[i], bits) + } + // bits <= 0: no check + } if nbEmulated1Outputs > 0 { if fp1 == nil { return nil, nil, nil, errors.New("nil emulated1 field") } emulated1Outputs = make([]*Element[T1], nbEmulated1Outputs) for i := range nbEmulated1Outputs { - limbs := hintOutputs[nbNativeOutputs+i*int(effNbLimbs1) : nbNativeOutputs+(i+1)*int(effNbLimbs1)] - emulated1Outputs[i] = fp1.packLimbs(limbs, true) + allLimbs := hintOutputs[nbNativeOutputs+i*int(effNbLimbs1) : nbNativeOutputs+(i+1)*int(effNbLimbs1)] + // apply range checks on emulated outputs when requested. + if err := unwrapOutputRangeCheck(cfg, nbNativeOutputs, i, fp1, allLimbs, emulated1Outputs); err != nil { + return nil, nil, nil, err + } } } if nbEmulated2Outputs > 0 { @@ -334,13 +355,48 @@ func unwrapGenericHintOutputs[T1, T2 FieldParams](field *big.Int, fp1 *Field[T1] } emulated2Outputs = make([]*Element[T2], nbEmulated2Outputs) for i := range nbEmulated2Outputs { - limbs := hintOutputs[nbNativeOutputs+nbEmulated1Outputs*int(effNbLimbs1)+i*int(effNbLimbs2) : nbNativeOutputs+nbEmulated1Outputs*int(effNbLimbs1)+(i+1)*int(effNbLimbs2)] - emulated2Outputs[i] = fp2.packLimbs(limbs, true) + allLimbs := hintOutputs[nbNativeOutputs+nbEmulated1Outputs*int(effNbLimbs1)+i*int(effNbLimbs2) : nbNativeOutputs+nbEmulated1Outputs*int(effNbLimbs1)+(i+1)*int(effNbLimbs2)] + // apply range checks on emulated outputs when requested. + if err := unwrapOutputRangeCheck(cfg, nbNativeOutputs+nbEmulated1Outputs, i, fp2, allLimbs, emulated2Outputs); err != nil { + return nil, nil, nil, err + } } } return nativeOutputs, emulated1Outputs, emulated2Outputs, nil } +func unwrapOutputRangeCheck[T FieldParams](cfg *hintConfig, startIdx, idx int, fp *Field[T], limbs []frontend.Variable, outputs []*Element[T]) error { + bits, ok := cfg.outputRangeCheckBits[startIdx+idx] + if !ok { + // there is no override + outputs[idx] = fp.packLimbs(limbs, true) + return nil + } + if bits > 0 { + // if there is an override with positive bits, then we need to apply + // range check to the limbs and pack with custom width + + // sanity check that the given bound is not larger than the field modulus bit size + if bits > int(fp.fParams.Modulus().BitLen()) { + return fmt.Errorf("range check output %d bits override %d exceed field modulus bit size %d", startIdx+idx, bits, fp.fParams.Modulus().BitLen()) + } + + // lets compute the expected number of limbs we have based on the bits + nbCustomLimbs := (bits + int(fp.fParams.BitsPerLimb()) - 1) / int(fp.fParams.BitsPerLimb()) + // we don't need to assert that the limbs beyond nbCustomLimbs are zero, + // we only return the element with custom number of limbs so the caller + // can never access the potentially non-zero limbs beyond nbCustomLimbs, + // and thus there is no soundness issue. + outputs[idx] = fp.packLimbsWithWidth(limbs[:nbCustomLimbs], bits) + } else { + // bits <= 0: pack with default width but without range check. + // however, return it as non-internal so that when it is used as input to another operation, + // it will be properly range checked. + outputs[idx] = &Element[T]{Limbs: limbs, overflow: 0, internal: false} + } + return nil +} + // Hint is a non-native hint function which takes a [HintContext] as an argument // which allows to access inputs and outputs over different fields. // @@ -391,7 +447,11 @@ type Hint func(HintContext) error // // here we can use hc to access inputs and outputs for the given field modulus // }) // } -func (f *Field[T]) NewHintGeneric(hf solver.Hint, nbNativeOutputs, nbEmulatedOutputs int, nativeInputs []frontend.Variable, nonNativeInputs []*Element[T]) ([]frontend.Variable, []*Element[T], error) { +func (f *Field[T]) NewHintGeneric(hf solver.Hint, nbNativeOutputs, nbEmulatedOutputs int, nativeInputs []frontend.Variable, nonNativeInputs []*Element[T], opts ...HintOption) ([]frontend.Variable, []*Element[T], error) { + cfg, err := applyHintOptions(opts, []int{nbNativeOutputs, nbEmulatedOutputs}) + if err != nil { + return nil, nil, fmt.Errorf("apply hint options: %w", err) + } for i := range nonNativeInputs { nonNativeInputs[i].Initialize(f.api.Compiler().Field()) } @@ -403,7 +463,7 @@ func (f *Field[T]) NewHintGeneric(hf solver.Hint, nbNativeOutputs, nbEmulatedOut if err != nil { return nil, nil, fmt.Errorf("call hint: %w", err) } - nres, em1res, em2res, err := unwrapGenericHintOutputs[T, T](f.api.Compiler().Field(), f, nil, nbNativeOutputs, nbEmulatedOutputs, 0, outputs) + nres, em1res, em2res, err := unwrapGenericHintOutputs[T, T](f.api, f.api.Compiler().Field(), f, nil, nbNativeOutputs, nbEmulatedOutputs, 0, outputs, cfg) if err != nil { return nil, nil, fmt.Errorf("unwrap generic hint context: %w", err) } @@ -594,7 +654,12 @@ func NewVarGenericHint[T1, T2 FieldParams]( emulated1Inputs []*Element[T1], emulated2Inputs []*Element[T2], hf solver.Hint, + opts ...HintOption, ) (nativeOutputs []frontend.Variable, emulated1Outputs []*Element[T1], emulated2Outputs []*Element[T2], err error) { + cfg, err := applyHintOptions(opts, []int{nbNativeOutputs, nbEmulated1Outputs, nbEmulated2Outputs}) + if err != nil { + return nil, nil, nil, fmt.Errorf("apply hint options: %w", err) + } fp1, err := NewField[T1](api) if err != nil { return nil, nil, nil, fmt.Errorf("create emulated1 field: %w", err) @@ -619,7 +684,129 @@ func NewVarGenericHint[T1, T2 FieldParams]( if err != nil { return nil, nil, nil, fmt.Errorf("call hint: %w", err) } - return unwrapGenericHintOutputs(nativeField, fp1, fp2, + return unwrapGenericHintOutputs(api, nativeField, fp1, fp2, nbNativeOutputs, nbEmulated1Outputs, nbEmulated2Outputs, - outputs) + outputs, cfg) +} + +// hintConfig holds the configuration for a generic hint. +type hintConfig struct { + // outputRangeCheckBits holds the number of bits being range checked for + // each output of the hint. The indexing starts from native outputs and then + // goes through emulated outputs in order. If not set, then we range check + // the default number of bits for the field (i.e. the bit size of the + // modulus), except for the native field where we do not perform range + // checks by default (as it is natively enforced by the field). This is used + // to optimize the hint when we know that the outputs will be in a certain + // range, so we can perform range checks on the outputs and avoid some + // constraints in the hint logic. + outputRangeCheckBits map[int]int +} + +// HintOption is a functional option for configuring the generic hint. +// The generic hint functions [Field.NewHintGeneric] and [NewVarGenericHint] accept a +// variable number of these options to customize the behavior of the hint. +type HintOption func(*hintConfig) error + +// applyHintOptions applies the given options to a new genericHintConfig. +func applyHintOptions(opts []HintOption, nbOutputs []int) (*hintConfig, error) { + cfg := &hintConfig{} + for _, opt := range opts { + if err := opt(cfg); err != nil { + return nil, err + } + } + // validate that the output range check bits do not have indices larger than the total number of outputs + if cfg.outputRangeCheckBits != nil { + for idx := range cfg.outputRangeCheckBits { + if idx < 0 { + return nil, fmt.Errorf("output range check bits index cannot be negative, got %d", idx) + } + totalOutputs := 0 + for _, n := range nbOutputs { + totalOutputs += n + } + if idx >= totalOutputs { + return nil, fmt.Errorf("output range check bits index %d is out of range for total outputs %d", idx, totalOutputs) + } + } + } + return cfg, nil +} + +// WithHintOutputRangeCheckBits allows to set the number of bits being range +// checked for each output of the hint. The indexing starts from native outputs +// and then goes through emulated outputs in order. +// +// When not set for any specific output, then we perform the default range checking: +// +// - for native outputs, we do not perform range checks by default (as it is natively +// enforced by the field) +// +// - for emulated outputs, we range check the default number of bits for the field +// (i.e. the bit size of the modulus). But we don't perform strict range checking +// against the modulus, only for the bit length. In case the modulus is not a power +// of 2, then it means that we can have some values which are in the range of the +// bit length but are still larger than the modulus. This is to avoid some +// constraints in the hint logic and it is still safe as long as the hint logic +// does not rely on strict range checks against the modulus. +// +// There is a method [Field.AssertIsInRange] if such strict range checks are needed. +// +// When set, then we perform range checking as follows: +// - when the bits > 0, then we range check the output to be less than 2^bits. +// This is useful when we know that the output will be in a certain range which is +// smaller than the field modulus, so we can perform range checks on the outputs +// and avoid some constraints in the hint logic. +// - when the bits <= 0, then we do not perform range checks on the outputs, not even +// the default ones. This is useful in case where we want to enforce range checking +// logic in the circuit itself. By default it is unsafe to use. +// +// This also affects how many limbs the returned outputs have. When the bits > +// 0, then we ensure that the number of limbs is sufficient to represent values +// up to 2^bits. When the bits <= 0, then we do not perform any checks on the +// number of limbs and it is default, given by the emulation parameters. +// +// All checks are done in-circuit, the solver does not enforce any checks on the +// hint outputs, so it is the responsibility of the user to ensure that the hint +// logic is consistent with the range checks being performed. +// +// For example when we work over native field BN254, emulated fields BLS12381Fp +// and BLS12381Fr, having correpondingly 2, 3, 4 outputs, then the indexing of +// the outputs is as follows: +// - output 0 and 1 are native outputs +// - output 2, 3, 4 are emulated outputs for BLS12381Fp +// - output 5, 6, 7, 8 are emulated outputs for BLS12381Fr +// +// We would call the NewVarGenericHint function as follows: +// +// nativeOutputs, emulatedFpOutputs, emulatedFrOutputs, err := emulated.NewVarGenericHint( +// api, +// 3, +// 4, +// 5, +// nativeInputs, +// emulatedFpInputs, +// emulatedFrInputs, +// MyHintFn, +// emulated.WithHintOutputRangeCheckBits(map[int]int{ +// 0: 4, // range check the first native output to be less than 16 +// 1: 8, // range check the second native output to be less than 256 +// 2: 10, // range check the first emulated output for BLS12381Fp to be less than 1024 +// // no range check for the second and third emulated outputs for BLS12381Fp +// // no range check for the emulated outputs for BLS12381Fr +// }), +// +// ) +func WithHintOutputRangeCheckBits(outputRangeCheckBits map[int]int) func(*hintConfig) error { + return func(cfg *hintConfig) error { + if outputRangeCheckBits == nil { + return errors.New("output range check bits map cannot be nil") + } + if cfg.outputRangeCheckBits != nil { + return errors.New("output range check bits map is already set") + } + cfg.outputRangeCheckBits = outputRangeCheckBits + return nil + } } diff --git a/std/math/emulated/field_hint_test.go b/std/math/emulated/field_hint_test.go index f1ad583ecc..2fda161cbf 100644 --- a/std/math/emulated/field_hint_test.go +++ b/std/math/emulated/field_hint_test.go @@ -10,6 +10,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/test" ) @@ -662,3 +663,245 @@ func TestMatchingFieldHint(t *testing.T) { testMatchingFieldHint[BN254Fr](t) testMatchingFieldHint[BLS12381Fr](t) } + +func hintSquareAllInputs(mod *big.Int, inputs, outputs []*big.Int) error { + return UnwrapHintContext(mod, inputs, outputs, func(ctx HintContext) error { + nativeInputs, nativeOutputs := ctx.NativeInputsOutputs() + if len(nativeInputs) != len(nativeOutputs) { + return fmt.Errorf("expected same number of native inputs and outputs, got %d inputs and %d outputs", len(nativeInputs), len(nativeOutputs)) + } + for i := range nativeOutputs { + nativeOutputs[i].Mul(nativeInputs[i], nativeInputs[i]) + nativeOutputs[i].Mod(nativeOutputs[i], ctx.NativeModulus()) + } + for _, m := range ctx.EmulatedModuli() { + emulatedInputs, emulatedOutputs := ctx.InputsOutputs(m) + if len(emulatedInputs) != len(emulatedOutputs) { + return fmt.Errorf("expected same number of emulated inputs and outputs for modulus %s, got %d inputs and %d outputs", m.String(), len(emulatedInputs), len(emulatedOutputs)) + } + for i := range emulatedOutputs { + emulatedOutputs[i].Mul(emulatedInputs[i], emulatedInputs[i]) + emulatedOutputs[i].Mod(emulatedOutputs[i], m) + } + } + + return nil + }) +} + +type customRangeCheckHintCircuit1[T1 FieldParams] struct { + NativeIn [2]frontend.Variable + Emulated1In [2]Element[T1] + rangeCheckAmount map[int]int +} + +func (c *customRangeCheckHintCircuit1[T1]) Define(api frontend.API) error { + f, err := NewField[T1](api) + if err != nil { + return fmt.Errorf("new field: %w", err) + } + // request 2 native and 2 emulated outputs, with a range check option + outNat, outEm, err := f.NewHintGeneric(hintSquareAllInputs, 2, 2, c.NativeIn[:], []*Element[T1]{&c.Emulated1In[0], &c.Emulated1In[1]}, + WithHintOutputRangeCheckBits(c.rangeCheckAmount)) + if err != nil { + return fmt.Errorf("new hint: %w", err) + } + for i := range outNat { + api.AssertIsDifferent(outNat[i], c.NativeIn[i]) + } + for i := range outEm { + f.AssertIsDifferent(outEm[i], &c.Emulated1In[i]) + } + return nil +} + +func TestCustomRangeCheckHint(t *testing.T) { + assert := test.NewAssert(t) + circuit := customRangeCheckHintCircuit1[Secp256k1Fp]{ + rangeCheckAmount: map[int]int{1: 8, 3: 8}, + } + witness := customRangeCheckHintCircuit1[Secp256k1Fp]{ + NativeIn: [2]frontend.Variable{3, 4}, + Emulated1In: [2]Element[Secp256k1Fp]{ValueOf[Secp256k1Fp](5), ValueOf[Secp256k1Fp](6)}, + } + invalidWitness := customRangeCheckHintCircuit1[Secp256k1Fp]{ + NativeIn: [2]frontend.Variable{1 << 7, 1 << 6}, + Emulated1In: [2]Element[Secp256k1Fp]{ValueOf[Secp256k1Fp](1 << 7), ValueOf[Secp256k1Fp](1 << 6)}, + } + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithInvalidAssignment(&invalidWitness), + test.WithCurves(ecc.BN254), + test.WithSolverOpts(solver.WithHints(hintSquareAllInputs))) +} + +// varGenericHintCircuit exercises [NewVarGenericHint] which operates over the +// native field and two distinct emulated fields. We reuse hintSquareAllInputs +// which iterates over the native field and all emulated moduli. Note that the +// two emulated fields must use distinct moduli (and distinct from the native +// modulus), as the hint context looks up inputs/outputs by modulus value. +type varGenericHintCircuit[T1, T2 FieldParams] struct { + NativeIn [2]frontend.Variable + Emulated1In [2]Element[T1] + Emulated2In [2]Element[T2] + rangeCheckAmount map[int]int +} + +func (c *varGenericHintCircuit[T1, T2]) Define(api frontend.API) error { + f1, err := NewField[T1](api) + if err != nil { + return fmt.Errorf("new field 1: %w", err) + } + f2, err := NewField[T2](api) + if err != nil { + return fmt.Errorf("new field 2: %w", err) + } + outNat, outEm1, outEm2, err := NewVarGenericHint(api, 2, 2, 2, + c.NativeIn[:], + []*Element[T1]{&c.Emulated1In[0], &c.Emulated1In[1]}, + []*Element[T2]{&c.Emulated2In[0], &c.Emulated2In[1]}, + hintSquareAllInputs, + WithHintOutputRangeCheckBits(c.rangeCheckAmount), + ) + if err != nil { + return fmt.Errorf("new var generic hint: %w", err) + } + for i := range outNat { + api.AssertIsDifferent(outNat[i], c.NativeIn[i]) + } + for i := range outEm1 { + f1.AssertIsDifferent(outEm1[i], &c.Emulated1In[i]) + } + for i := range outEm2 { + f2.AssertIsDifferent(outEm2[i], &c.Emulated2In[i]) + } + return nil +} + +func TestVarGenericHint(t *testing.T) { + assert := test.NewAssert(t) + circuit := varGenericHintCircuit[Secp256k1Fp, BLS12381Fr]{ + rangeCheckAmount: map[int]int{0: 8, 2: 8, 4: 8}, + } + witness := varGenericHintCircuit[Secp256k1Fp, BLS12381Fr]{ + NativeIn: [2]frontend.Variable{3, 4}, + Emulated1In: [2]Element[Secp256k1Fp]{ValueOf[Secp256k1Fp](5), ValueOf[Secp256k1Fp](6)}, + Emulated2In: [2]Element[BLS12381Fr]{ValueOf[BLS12381Fr](7), ValueOf[BLS12381Fr](8)}, + } + invalidWitness := varGenericHintCircuit[Secp256k1Fp, BLS12381Fr]{ + NativeIn: [2]frontend.Variable{1 << 32, 1 << 8}, + Emulated1In: [2]Element[Secp256k1Fp]{ValueOf[Secp256k1Fp](1 << 32), ValueOf[Secp256k1Fp](1 << 8)}, + Emulated2In: [2]Element[BLS12381Fr]{ValueOf[BLS12381Fr](1 << 32), ValueOf[BLS12381Fr](1 << 8)}, + } + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithInvalidAssignment(&invalidWitness), + test.WithCurves(ecc.BN254), + test.WithSolverOpts(solver.WithHints(hintSquareAllInputs))) +} + +// zeroBitsHintCircuit exercises the bits==0 range check option, which disables +// range checking (same as bits<0). The hint outputs are returned as-is, not +// forced to zero, and no range check constraints are emitted. +type zeroBitsHintCircuit[T1 FieldParams] struct { + NativeIn frontend.Variable + Emulated1In Element[T1] +} + +func (c *zeroBitsHintCircuit[T1]) Define(api frontend.API) error { + f, err := NewField[T1](api) + if err != nil { + return fmt.Errorf("new field: %w", err) + } + // output index 0 is the native output, index 1 is the emulated output; both + // with bits==0, i.e. no range check. + outNat, outEm, err := f.NewHintGeneric(hintSquareAllInputs, 1, 1, []frontend.Variable{c.NativeIn}, []*Element[T1]{&c.Emulated1In}, + WithHintOutputRangeCheckBits(map[int]int{0: 0, 1: 0})) + if err != nil { + return fmt.Errorf("new hint: %w", err) + } + // the actual hint outputs are returned (not forced to zero) and are usable. + nRes := api.Mul(c.NativeIn, c.NativeIn) + api.AssertIsEqual(outNat[0], nRes) + emRes := f.Mul(&c.Emulated1In, &c.Emulated1In) + f.AssertIsEqual(outEm[0], emRes) + return nil +} + +func TestZeroBitsHint(t *testing.T) { + assert := test.NewAssert(t) + circuit := zeroBitsHintCircuit[Secp256k1Fp]{} + witness := zeroBitsHintCircuit[Secp256k1Fp]{ + NativeIn: 5, + Emulated1In: ValueOf[Secp256k1Fp](6), + } + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), + test.WithCurves(ecc.BN254), + test.WithSolverOpts(solver.WithHints(hintSquareAllInputs))) +} + +// multiLimbRangeCheckHintCircuit exercises a positive bits range check that +// spans multiple limbs. Secp256k1Fp uses 64-bit limbs, so 130 bits requires 3 +// limbs and exercises packLimbsWithWidth across a limb boundary (unlike the +// 8-bit single-limb case in TestCustomRangeCheckHint). +type multiLimbRangeCheckHintCircuit[T1 FieldParams] struct { + Emulated1In [1]Element[T1] +} + +func (c *multiLimbRangeCheckHintCircuit[T1]) Define(api frontend.API) error { + f, err := NewField[T1](api) + if err != nil { + return fmt.Errorf("new field: %w", err) + } + _, outEm, err := f.NewHintGeneric(hintSquareAllInputs, 0, 1, nil, []*Element[T1]{&c.Emulated1In[0]}, + WithHintOutputRangeCheckBits(map[int]int{0: 130})) + if err != nil { + return fmt.Errorf("new hint: %w", err) + } + f.AssertIsDifferent(outEm[0], &c.Emulated1In[0]) + return nil +} + +func TestMultiLimbRangeCheckHint(t *testing.T) { + assert := test.NewAssert(t) + circuit := multiLimbRangeCheckHintCircuit[Secp256k1Fp]{} + // valid: (2^60)^2 = 2^120 < 2^130. + validIn := new(big.Int).Lsh(big.NewInt(1), 60) + // invalid: (2^66)^2 = 2^132 >= 2^130, so the range check fails. + invalidIn := new(big.Int).Lsh(big.NewInt(1), 66) + witness := multiLimbRangeCheckHintCircuit[Secp256k1Fp]{ + Emulated1In: [1]Element[Secp256k1Fp]{ValueOf[Secp256k1Fp](validIn)}, + } + invalidWitness := multiLimbRangeCheckHintCircuit[Secp256k1Fp]{ + Emulated1In: [1]Element[Secp256k1Fp]{ValueOf[Secp256k1Fp](invalidIn)}, + } + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithInvalidAssignment(&invalidWitness), + test.WithCurves(ecc.BN254), + test.WithSolverOpts(solver.WithHints(hintSquareAllInputs))) +} + +// badIndexHintCircuit requests a single emulated output (output index 0) but +// allows the test to supply an arbitrary range check index map to exercise the +// validation in applyHintOptions. +type badIndexHintCircuit[T1 FieldParams] struct { + In Element[T1] + idx map[int]int +} + +func (c *badIndexHintCircuit[T1]) Define(api frontend.API) error { + f, err := NewField[T1](api) + if err != nil { + return fmt.Errorf("new field: %w", err) + } + _, _, err = f.NewHintGeneric(hintSquareAllInputs, 0, 1, nil, []*Element[T1]{&c.In}, + WithHintOutputRangeCheckBits(c.idx)) + return err +} + +func TestHintRangeCheckOptionValidation(t *testing.T) { + assert := test.NewAssert(t) + // out-of-range index: only one output exists (index 0), index 5 is invalid. + _, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, + &badIndexHintCircuit[Secp256k1Fp]{idx: map[int]int{5: 8}}) + assert.Error(err, "expected error for out-of-range output range check index") + // negative index is invalid. + _, err = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, + &badIndexHintCircuit[Secp256k1Fp]{idx: map[int]int{-1: 8}}) + assert.Error(err, "expected error for negative output range check index") +}