crypto.ff: allow seamless chaining regardless of representation (#30913)

Finite field elements can be in regular or Montgomery form, and
chaining different operations use to require manual and error-prone
conversions.

Now:

- `add`, `sub` and `mul` convert the second operand to match the
first operand's form
- `sq` and `pow` preserve the input's Montgomery form
- `toPrimitive` and `toBytes` return `UnexpectedRepresentation` if
the element is in Montgomery form, preventing incorrect serialization

This is fully backwards compatible and allows seamless chaining of
operations regardless of their representation.
This commit is contained in:
Frank Denis 2026-01-25 17:42:01 +01:00
parent 99ec1ee353
commit 8709f53d44

View file

@ -329,7 +329,11 @@ fn Fe_(comptime bits: comptime_int) type {
/// Converts the field element to a primitive.
/// This function may not run in constant time.
pub fn toPrimitive(self: Self, comptime T: type) OverflowError!T {
/// Returns an error if the element is in Montgomery form.
pub fn toPrimitive(self: Self, comptime T: type) (OverflowError || RepresentationError)!T {
if (self.montgomery) {
return error.UnexpectedRepresentation;
}
return self.v.toPrimitive(T);
}
@ -343,7 +347,11 @@ fn Fe_(comptime bits: comptime_int) type {
}
/// Converts the field element to a byte string.
pub fn toBytes(self: Self, bytes: []u8, comptime endian: Endian) OverflowError!void {
/// Returns an error if the element is in Montgomery form.
pub fn toBytes(self: Self, bytes: []u8, comptime endian: Endian) (OverflowError || RepresentationError)!void {
if (self.montgomery) {
return error.UnexpectedRepresentation;
}
return self.v.toBytes(bytes, endian);
}
@ -530,19 +538,46 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
/// Adds two field elements (mod m).
pub fn add(self: Self, x: Fe, y: Fe) Fe {
var out = x;
const overflow = out.v.addWithOverflow(y.v);
const underflow: u1 = @bitCast(ct.limbsCmpLt(out.v, self.v));
const need_sub = ct.eql(overflow, underflow);
_ = out.v.conditionalSubWithOverflow(need_sub, self.v);
return out;
if (x.montgomery == y.montgomery) {
@branchHint(.likely);
const overflow = out.v.addWithOverflow(y.v);
const underflow: u1 = @bitCast(ct.limbsCmpLt(out.v, self.v));
const need_sub = ct.eql(overflow, underflow);
_ = out.v.conditionalSubWithOverflow(need_sub, self.v);
return out;
} else {
var y_ = y;
if (y.montgomery) {
self.fromMontgomery(&y_) catch unreachable;
} else {
self.toMontgomery(&y_) catch unreachable;
}
const overflow = out.v.addWithOverflow(y_.v);
const underflow: u1 = @bitCast(ct.limbsCmpLt(out.v, self.v));
const need_sub = ct.eql(overflow, underflow);
_ = out.v.conditionalSubWithOverflow(need_sub, self.v);
return out;
}
}
/// Subtracts two field elements (mod m).
pub fn sub(self: Self, x: Fe, y: Fe) Fe {
var out = x;
const underflow: bool = @bitCast(out.v.subWithOverflow(y.v));
_ = out.v.conditionalAddWithOverflow(underflow, self.v);
return out;
if (x.montgomery == y.montgomery) {
const underflow: bool = @bitCast(out.v.subWithOverflow(y.v));
_ = out.v.conditionalAddWithOverflow(underflow, self.v);
return out;
} else {
var y_ = y;
if (y.montgomery) {
self.fromMontgomery(&y_) catch unreachable;
} else {
self.toMontgomery(&y_) catch unreachable;
}
const underflow: bool = @bitCast(out.v.subWithOverflow(y_.v));
_ = out.v.conditionalAddWithOverflow(underflow, self.v);
return out;
}
}
/// Converts a field element to the Montgomery form.
@ -663,13 +698,15 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
for (e) |b| acc |= b;
if (acc == 0) return error.NullExponent;
const was_montgomery = x.montgomery;
var out = self.one();
self.toMontgomery(&out) catch unreachable;
if (public and e.len < 3 or (e.len == 3 and e[if (endian == .big) 0 else 2] <= 0b1111)) {
// Do not use a precomputation table for short, public exponents
var x_m = x;
if (x.montgomery == false) {
if (!x.montgomery) {
self.toMontgomery(&x_m) catch unreachable;
}
var s = switch (endian) {
@ -702,7 +739,7 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
} else {
// Use a precomputation table for large exponents
var pc = [1]Fe{x} ++ [_]Fe{self.zero} ** 14;
if (x.montgomery == false) {
if (!x.montgomery) {
self.toMontgomery(&pc[0]) catch unreachable;
}
for (1..pc.len) |i| {
@ -747,38 +784,55 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
}
}
}
self.fromMontgomery(&out) catch unreachable;
if (!was_montgomery) {
self.fromMontgomery(&out) catch unreachable;
}
return out;
}
/// Multiplies two field elements.
/// Result preserves the first operand's form.
pub fn mul(self: Self, x: Fe, y: Fe) Fe {
if (x.montgomery != y.montgomery) {
return self.montgomeryMul(x, y);
}
var a_ = x;
if (x.montgomery == false) {
self.toMontgomery(&a_) catch unreachable;
if (x.montgomery) {
const y_ = if (!y.montgomery) blk: {
var yy = y;
self.toMontgomery(&yy) catch unreachable;
break :blk yy;
} else y;
return self.montgomeryMul(x, y_);
} else {
self.fromMontgomery(&a_) catch unreachable;
var x_m = x;
var y_m = if (y.montgomery) blk: {
var yy = y;
self.fromMontgomery(&yy) catch unreachable;
break :blk yy;
} else y;
self.toMontgomery(&x_m) catch unreachable;
self.toMontgomery(&y_m) catch unreachable;
var out = self.montgomeryMul(x_m, y_m);
self.fromMontgomery(&out) catch unreachable;
return out;
}
return self.montgomeryMul(a_, y);
}
/// Squares a field element.
pub fn sq(self: Self, x: Fe) Fe {
var out = x;
if (x.montgomery == true) {
if (x.montgomery) {
return self.montgomerySq(x);
} else {
var out = x;
self.toMontgomery(&out) catch unreachable;
out = self.montgomerySq(out);
self.fromMontgomery(&out) catch unreachable;
return out;
}
out = self.montgomerySq(out);
out.montgomery = false;
self.toMontgomery(&out) catch unreachable;
return out;
}
/// Returns x^e (mod m) in constant time.
pub fn pow(self: Self, x: Fe, e: Fe) NullExponentError!Fe {
pub fn pow(self: Self, x: Fe, e: Fe) (NullExponentError || RepresentationError)!Fe {
if (e.montgomery) {
return error.UnexpectedRepresentation;
}
var buf: [Fe.encoded_bytes]u8 = undefined;
e.toBytes(&buf, native_endian) catch unreachable;
return self.powWithEncodedExponent(x, &buf, native_endian);
@ -786,7 +840,10 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
/// Returns x^e (mod m), assuming that the exponent is public.
/// The function remains constant time with respect to `x`.
pub fn powPublic(self: Self, x: Fe, e: Fe) NullExponentError!Fe {
pub fn powPublic(self: Self, x: Fe, e: Fe) (NullExponentError || RepresentationError)!Fe {
if (e.montgomery) {
return error.UnexpectedRepresentation;
}
var e_normalized = Fe{ .v = e.v.normalize() };
var buf_: [Fe.encoded_bytes]u8 = undefined;
var buf = buf_[0 .. math.divCeil(usize, e_normalized.v.limbs_len * t_bits, 8) catch unreachable];
@ -927,6 +984,8 @@ test "finite field arithmetic" {
try m.toMontgomery(&x);
x_y = m.mul(x, y);
try testing.expect(x_y.montgomery); // result preserves first operand's form
try m.fromMontgomery(&x_y);
try testing.expectEqual(x_y.toPrimitive(u256), 1666576607955767413750776202132407807424848069716933450241);
try m.fromMontgomery(&x);
@ -941,8 +1000,11 @@ test "finite field arithmetic" {
const x_pow_y = try m.powPublic(x, y);
try testing.expectEqual(x_pow_y.toPrimitive(u256), 1631933139300737762906024873185789093007782131928298618473);
try testing.expect(!x_pow_y.montgomery);
try m.toMontgomery(&x);
const x_pow_y2 = try m.powPublic(x, y);
var x_pow_y2 = try m.powPublic(x, y);
try testing.expect(x_pow_y2.montgomery);
try m.fromMontgomery(&x_pow_y2);
try m.fromMontgomery(&x);
try testing.expect(x_pow_y2.eql(x_pow_y));
try testing.expectError(error.NullExponent, m.powPublic(x, m.zero));
@ -953,13 +1015,53 @@ test "finite field arithmetic" {
const x_sq = m.sq(x);
const x_sq2 = m.mul(x, x);
try testing.expect(!x_sq.montgomery);
try testing.expect(!x_sq2.montgomery);
try testing.expect(x_sq.eql(x_sq2));
try m.toMontgomery(&x);
const x_sq3 = m.sq(x);
const x_sq4 = m.mul(x, x);
var x_sq3 = m.sq(x);
var x_sq4 = m.mul(x, x);
try testing.expect(x_sq3.montgomery);
try testing.expect(x_sq4.montgomery);
try m.fromMontgomery(&x_sq3);
try m.fromMontgomery(&x_sq4);
try testing.expect(x_sq.eql(x_sq3));
try testing.expect(x_sq3.eql(x_sq4));
try m.fromMontgomery(&x);
var x_mont = x;
try m.toMontgomery(&x_mont);
// Non-montgomery + montgomery
const add_nm_m = m.add(x, x_mont);
try testing.expect(!add_nm_m.montgomery);
var add_m_nm = m.add(x_mont, x);
try testing.expect(add_m_nm.montgomery);
try m.fromMontgomery(&add_m_nm);
try testing.expect(add_nm_m.eql(add_m_nm));
// Non-montgomery - montgomery
const sub_nm_m = m.sub(x, y);
try testing.expect(!sub_nm_m.montgomery);
var y_mont = y;
try m.toMontgomery(&y_mont);
var sub_m_nm = m.sub(x_mont, y);
try testing.expect(sub_m_nm.montgomery);
try m.fromMontgomery(&sub_m_nm);
try testing.expect(sub_nm_m.eql(sub_m_nm));
// mul: preserves first operand's form
const mul_nm_m = m.mul(x, x_mont);
try testing.expect(!mul_nm_m.montgomery);
const mul_nm_nm = m.mul(x, x);
try testing.expect(mul_nm_m.eql(mul_nm_nm));
var mul_m_nm = m.mul(x_mont, x);
try testing.expect(mul_m_nm.montgomery);
try m.fromMontgomery(&mul_m_nm);
try testing.expect(mul_m_nm.eql(mul_nm_nm));
try testing.expectEqual(x.toPrimitive(u256), 80169837251094269539116136208111827396136208141182357733);
try testing.expectError(error.UnexpectedRepresentation, x_mont.toPrimitive(u256));
}
fn testCt(ct_: anytype) !void {