// schnorrRecover recovers a public key using a signature, hash function, // and message. It also attempts to verify the signature against the // regenerated public key. func schnorrRecover(curve *secp256k1.KoblitzCurve, sig, msg []byte, hashFunc func([]byte) []byte) (*secp256k1.PublicKey, bool, error) { if len(msg) != scalarSize { str := fmt.Sprintf("wrong size for message (got %v, want %v)", len(msg), scalarSize) return nil, false, schnorrError(ErrBadInputSize, str) } if len(sig) != SignatureSize { str := fmt.Sprintf("wrong size for signature (got %v, want %v)", len(sig), SignatureSize) return nil, false, schnorrError(ErrBadInputSize, str) } sigR := sig[:32] sigS := sig[32:] sigRCopy := make([]byte, scalarSize, scalarSize) copy(sigRCopy, sigR) toHash := append(sigRCopy, msg...) h := hashFunc(toHash) hBig := new(big.Int).SetBytes(h) // If the hash ends up larger than the order of the curve, abort. // Same thing for hash == 0 (as unlikely as that is...). if hBig.Cmp(curve.N) >= 0 { str := fmt.Sprintf("hash of (R || m) too big") return nil, false, schnorrError(ErrSchnorrHashValue, str) } if hBig.Cmp(bigZero) == 0 { str := fmt.Sprintf("hash of (R || m) is zero value") return nil, false, schnorrError(ErrSchnorrHashValue, str) } // Convert s to big int. sBig := EncodedBytesToBigInt(copyBytes(sigS)) // We also can't have s greater than the order of the curve. if sBig.Cmp(curve.N) >= 0 { str := fmt.Sprintf("s value is too big") return nil, false, schnorrError(ErrInputValue, str) } // r can't be larger than the curve prime. rBig := EncodedBytesToBigInt(copyBytes(sigR)) if rBig.Cmp(curve.P) == 1 { str := fmt.Sprintf("given R was greater than curve prime") return nil, false, schnorrError(ErrBadSigRNotOnCurve, str) } // Decompress the Y value. We know that the first bit must // be even. Use the PublicKey struct to make it easier. compressedPoint := make([]byte, PubKeyBytesLen, PubKeyBytesLen) compressedPoint[0] = pubkeyCompressed copy(compressedPoint[1:], sigR) rPoint, err := secp256k1.ParsePubKey(compressedPoint, curve) if err != nil { str := fmt.Sprintf("bad r point") return nil, false, schnorrError(ErrRegenerateRPoint, str) } // Get the inverse of the hash. hInv := new(big.Int).ModInverse(hBig, curve.N) hInv.Mod(hInv, curve.N) // Negate s. sBig.Sub(curve.N, sBig) sBig.Mod(sBig, curve.N) // s' = -s * inverse(h). sBig.Mul(sBig, hInv) sBig.Mod(sBig, curve.N) // Q = h^(-1)R + s'G lx, ly := curve.ScalarMult(rPoint.GetX(), rPoint.GetY(), hInv.Bytes()) rx, ry := curve.ScalarBaseMult(sBig.Bytes()) pkx, pky := curve.Add(lx, ly, rx, ry) // Check if the public key is on the curve. if !curve.IsOnCurve(pkx, pky) { str := fmt.Sprintf("pubkey not on curve") return nil, false, schnorrError(ErrPubKeyOffCurve, str) } pubkey := secp256k1.NewPublicKey(curve, pkx, pky) // Verify this signature. Slow, lots of double checks, could be more // cheaply implemented as // hQ + sG - R == 0 // which this function checks. // This will sometimes pass even for corrupted signatures, but // this shouldn't be a concern because whoever is using the // results should be checking the returned public key against // some known one anyway. In the case of these Schnorr signatures, // relatively high numbers of corrupted signatures (50-70%) // seem to produce valid pubkeys and valid signatures. _, err = schnorrVerify(curve, sig, pubkey, msg, hashFunc) if err != nil { str := fmt.Sprintf("pubkey/sig pair could not be validated") return nil, false, schnorrError(ErrRegenSig, str) } return pubkey, true, nil }
// schnorrVerify is the internal function for verification of a secp256k1 // Schnorr signature. A secure hash function may be passed for the calculation // of r. // This is identical to the Schnorr verification function found in libsecp256k1: // https://github.com/bitcoin/secp256k1/tree/master/src/modules/schnorr func schnorrVerify(curve *secp256k1.KoblitzCurve, sig []byte, pubkey *secp256k1.PublicKey, msg []byte, hashFunc func([]byte) []byte) (bool, error) { if len(msg) != scalarSize { str := fmt.Sprintf("wrong size for message (got %v, want %v)", len(msg), scalarSize) return false, schnorrError(ErrBadInputSize, str) } if len(sig) != SignatureSize { str := fmt.Sprintf("wrong size for signature (got %v, want %v)", len(sig), SignatureSize) return false, schnorrError(ErrBadInputSize, str) } if pubkey == nil { str := fmt.Sprintf("nil pubkey") return false, schnorrError(ErrInputValue, str) } if !curve.IsOnCurve(pubkey.GetX(), pubkey.GetY()) { str := fmt.Sprintf("pubkey point is not on curve") return false, schnorrError(ErrPointNotOnCurve, str) } sigR := sig[:32] sigS := sig[32:] sigRCopy := make([]byte, scalarSize, scalarSize) copy(sigRCopy, sigR) toHash := append(sigRCopy, msg...) h := hashFunc(toHash) hBig := new(big.Int).SetBytes(h) // If the hash ends up larger than the order of the curve, abort. // Same thing for hash == 0 (as unlikely as that is...). if hBig.Cmp(curve.N) >= 0 { str := fmt.Sprintf("hash of (R || m) too big") return false, schnorrError(ErrSchnorrHashValue, str) } if hBig.Cmp(bigZero) == 0 { str := fmt.Sprintf("hash of (R || m) is zero value") return false, schnorrError(ErrSchnorrHashValue, str) } // Convert s to big int. sBig := EncodedBytesToBigInt(copyBytes(sigS)) // We also can't have s greater than the order of the curve. if sBig.Cmp(curve.N) >= 0 { str := fmt.Sprintf("s value is too big") return false, schnorrError(ErrInputValue, str) } // r can't be larger than the curve prime. rBig := EncodedBytesToBigInt(copyBytes(sigR)) if rBig.Cmp(curve.P) == 1 { str := fmt.Sprintf("given R was greater than curve prime") return false, schnorrError(ErrBadSigRNotOnCurve, str) } // r' = hQ + sG lx, ly := curve.ScalarMult(pubkey.GetX(), pubkey.GetY(), h) rx, ry := curve.ScalarBaseMult(sigS) rlx, rly := curve.Add(lx, ly, rx, ry) if rly.Bit(0) == 1 { str := fmt.Sprintf("calculated R y-value was odd") return false, schnorrError(ErrBadSigRYValue, str) } if !curve.IsOnCurve(rlx, rly) { str := fmt.Sprintf("calculated R point was not on curve") return false, schnorrError(ErrBadSigRNotOnCurve, str) } rlxB := BigIntToEncodedBytes(rlx) // r == r' --> valid signature if !bytes.Equal(sigR, rlxB[:]) { str := fmt.Sprintf("calculated R point was not given R") return false, schnorrError(ErrUnequalRValues, str) } return true, nil }