Montgomery Multiplication - Algorithmica
Montgomery Multiplication

Montgomery Multiplication

Unsurprisingly, a large fraction of computation in modular arithmetic is often spent on calculating the modulo operation, which is as slow as general integer division and typically takes 15-20 cycles, depending on the operand size.

The best way to deal this nuisance is to avoid modulo operation altogether, delaying or replacing it with predication, which can be done, for example, when calculating modular sums:

const int M = 1e9 + 7;

// input: array of n integers in the [0, M) range
// output: sum modulo M
int slow_sum(int *a, int n) {
    int s = 0;
    for (int i = 0; i < n; i++)
        s = (s + a[i]) % M;
    return s;
}

int fast_sum(int *a, int n) {
    int s = 0;
    for (int i = 0; i < n; i++) {
        s += a[i]; // s < 2 * M
        s = (s >= M ? s - M : s); // will be replaced with cmov
    }
    return s;
}

int faster_sum(int *a, int n) {
    long long s = 0; // 64-bit integer to handle overflow
    for (int i = 0; i < n; i++)
        s += a[i]; // will be vectorized
    return s % M;
}

However, sometimes you only have a chain of modular multiplications, and there is no good way to eel out of computing the remainder of the division — other than with the integer division tricks requiring a constant modulo and some precomputation.

But there is another technique designed specifically for modular arithmetic, called Montgomery multiplication.

#Montgomery Space

Montgomery multiplication works by first transforming the multipliers into Montgomery space, where modular multiplication can be performed cheaply, and then transforming them back when their actual values are needed. Unlike general integer division methods, Montgomery multiplication is not efficient for performing just one modular reduction and only becomes worthwhile when there is a chain of modular operations.

The space is defined by the modulo nn and a positive integer rnr \ge n coprime to nn. The algorithm involves modulo and division by rr, so in practice, rr is chosen to be 2322^{32} or 2642^{64}, so that these operations can be done with a right-shift and a bitwise AND respectively.

Definition. The representative xˉ\bar x of a number xx in the Montgomery space is defined as

xˉ=xrmodn \bar{x} = x \cdot r \bmod n

Computing this transformation involves a multiplication and a modulo — an expensive operation that we wanted to optimize away in the first place — which is why we only use this method when the overhead of transforming numbers to and from the Montgomery space is worth it and not for general modular multiplication.

Inside the Montgomery space, addition, substraction, and checking for equality is performed as usual:

xr+yr(x+y)rmodn x \cdot r + y \cdot r \equiv (x + y) \cdot r \bmod n However, this is not the case for multiplication. Denoting multiplication in the Montgomery space as * and the “normal” multiplication as \cdot, we expect the result to be: xˉyˉ=xy=(xy)rmodn \bar{x} * \bar{y} = \overline{x \cdot y} = (x \cdot y) \cdot r \bmod n But the normal multiplication in the Montgomery space yields: xˉyˉ=(xy)rrmodn \bar{x} \cdot \bar{y} = (x \cdot y) \cdot r \cdot r \bmod n Therefore, the multiplication in the Montgomery space is defined as xˉyˉ=xˉyˉr1modn \bar{x} * \bar{y} = \bar{x} \cdot \bar{y} \cdot r^{-1} \bmod n

This means that, after we normally multiply two numbers in the Montgomery space, we need to reduce the result by multiplying it by r1r^{-1} and taking the modulo — and there is an efficent way to do this particular operation.

#Montgomery reduction

Assume that r=232r=2^{32}, the modulo nn is 32-bit, and the number xx we need to reduce is 64-bit (the product of two 32-bit numbers). Our goal is to calculate y=xr1modny = x \cdot r^{-1} \bmod n.

Since rr is coprime with nn, we know that there are two numbers r1r^{-1} and nn^\prime in the [0,n)[0, n) range such that

rr1+nn=1 r \cdot r^{-1} + n \cdot n^\prime = 1

and both r1r^{-1} and nn^\prime can be computed, e.g., using the extended Euclidean algorithm.

Using this identity, we can express rr1r \cdot r^{-1} as (1nn)(1 - n \cdot n^\prime) and write xr1x \cdot r^{-1} as

xr1=xrr1/r=x(1nn)/r=(xxnn)/r(xxnn+krn)/r(modn)    (for any integer k)(x(xnkr)n)/r(modn) \begin{aligned} x \cdot r^{-1} &= x \cdot r \cdot r^{-1} / r \\ &= x \cdot (1 - n \cdot n^{\prime}) / r \\ &= (x - x \cdot n \cdot n^{\prime} ) / r \\ &\equiv (x - x \cdot n \cdot n^{\prime} + k \cdot r \cdot n) / r &\pmod n &\;\;\text{(for any integer $k$)} \\ &\equiv (x - (x \cdot n^{\prime} - k \cdot r) \cdot n) / r &\pmod n \end{aligned} Now, if we choose kk to be xn/r\lfloor x \cdot n^\prime / r \rfloor (the upper 64 bits of the xnx \cdot n^\prime product), it will cancel out, and (krxn)(k \cdot r - x \cdot n^{\prime}) will simply be equal to xnmodrx \cdot n^{\prime} \bmod r (the lower 32 bits of xnx \cdot n^\prime), implying: xr1(xxnmodrn)/r x \cdot r^{-1} \equiv (x - x \cdot n^{\prime} \bmod r \cdot n) / r

The algorithm itself just evaluates this formula, performing two multiplications to calculate q=xnmodrq = x \cdot n^{\prime} \bmod r and m=qnm = q \cdot n, and then subtracts it from xx and right-shifts the result to divide it by rr.

The only remaining thing to handle is that the result may not be in the [0,n)[0, n) range; but since

x<nn<rn    x/r<n x < n \cdot n < r \cdot n \implies x / r < n and m=qn<rn    m/r<n m = q \cdot n < r \cdot n \implies m / r < n it is guaranteed that n<(xm)/r<n -n < (x - m) / r < n

Therefore, we can simply check if the result is negative and in that case, add nn to it, giving the following algorithm:

typedef __uint32_t u32;
typedef __uint64_t u64;

const u32 n = 1e9 + 7, nr = inverse(n, 1ull << 32);

u32 reduce(u64 x) {
    u32 q = u32(x) * nr;      // q = x * n' mod r
    u64 m = (u64) q * n;      // m = q * n
    u32 y = (x - m) >> 32;    // y = (x - m) / r
    return x < m ? y + n : y; // if y < 0, add n to make it be in the [0, n) range
}

This last check is relatively cheap, but it is still on the critical path. If we are fine with the result being in the [0,2n2][0, 2 \cdot n - 2] range instead of [0,n)[0, n), we can remove it and add nn to the result unconditionally:

u32 reduce(u64 x) {
    u32 q = u32(x) * nr;
    u64 m = (u64) q * n;
    u32 y = (x - m) >> 32;
    return y + n
}

We can also move the >> 32 operation one step earlier in the computation graph and compute x/rm/r\lfloor x / r \rfloor - \lfloor m / r \rfloor instead of (xm)/r(x - m) / r. This is correct because the lower 32 bits of xx and mm are equal anyway since

m=xnnx(modr) m = x \cdot n^\prime \cdot n \equiv x \pmod r

But why would we voluntarily choose to perfom two right-shifts instead of just one? This is beneficial because for ((u64) q * n) >> 32 we need to do a 32-by-32 multiplication and take the upper 32 bits of the result (which the x86 mul instruction already writes in a separate register, so it doesn’t cost anything), and the other right-shift x >> 32 is not on the critical path.

u32 reduce(u64 x) {
    u32 q = u32(x) * nr;
    u32 m = ((u64) q * n) >> 32;
    return (x >> 32) + n - m;
}

One of the main advantages of Montgomery multiplication over other modular reduction methods is that it doesn’t require very large data types: it only needs a r×rr \times r multiplication that extracts the lower and higher rr bits of the result, which has special support on most hardware also makes it easily generalizable to SIMD and larger data types:

typedef __uint128_t u128;

u64 reduce(u128 x) const {
    u64 q = u64(x) * nr;
    u64 m = ((u128) q * n) >> 64;
    return (x >> 64) + n - m;
}

Note that a 128-by-64 modulo is not possible with general integer division tricks: the compiler falls back to calling a slow long arithmetic library function to support it.

#Faster Inverse and Transform

Montgomery multiplication itself is fast, but it requires some precomputation:

  • inverting nn modulo rr to compute nn^\prime,
  • transforming a number to the Montgomery space,
  • transforming a number from the Montgomery space.

The last operation is already efficiently performed with the reduce procedure we just implemented, but the first two can be slightly optimized.

Computing the inverse n=n1modrn^\prime = n^{-1} \bmod r can be done faster than with the extended Euclidean algorithm by taking advantage of the fact that rr is a power of two and using the following identity:

ax1mod2k    ax(2ax)1mod22k a \cdot x \equiv 1 \bmod 2^k \implies a \cdot x \cdot (2 - a \cdot x) \equiv 1 \bmod 2^{2k} Proof: ax(2ax)=2ax(ax)2=2(1+m2k)(1+m2k)2=2+2m2k12m2km222k=1m222k1mod22k. \begin{aligned} a \cdot x \cdot (2 - a \cdot x) &= 2 \cdot a \cdot x - (a \cdot x)^2 \\ &= 2 \cdot (1 + m \cdot 2^k) - (1 + m \cdot 2^k)^2 \\ &= 2 + 2 \cdot m \cdot 2^k - 1 - 2 \cdot m \cdot 2^k - m^2 \cdot 2^{2k} \\ &= 1 - m^2 \cdot 2^{2k} \\ &\equiv 1 \bmod 2^{2k}. \end{aligned}

We can start with x=1x = 1 as the inverse of aa modulo 212^1 and apply this identity exactly log2r\log_2 r times, each time doubling the number of bits in the inverse — somewhat reminiscent of the Newton’s method.

Transforming a number into the Montgomery space can be done by multiplying it by rr and computing modulo the usual way, but we can also take advantage of this relation:

xˉ=xrmodn=xr2 \bar{x} = x \cdot r \bmod n = x * r^2

Transforming a number into the space is just a multiplication by r2r^2. Therefore, we can precompute r2modnr^2 \bmod n and perform a multiplication and reduction instead — which may or may not be actually faster because multiplying a number by r=2kr=2^{k} can be implemented with a left-shift, while multiplication by r2modnr^2 \bmod n can not.

#Complete Implementation

It is convenient to wrap everything into a single constexpr structure:

struct Montgomery {
    u32 n, nr;
    
    constexpr Montgomery(u32 n) : n(n), nr(1) {
        // log(2^32) = 5
        for (int i = 0; i < 5; i++)
            nr *= 2 - n * nr;
    }

    u32 reduce(u64 x) const {
        u32 q = u32(x) * nr;
        u32 m = ((u64) q * n) >> 32;
        return (x >> 32) + n - m;
        // returns a number in the [0, 2 * n - 2] range
        // (add a "x < n ? x : x - n" type of check if you need a proper modulo)
    }

    u32 multiply(u32 x, u32 y) const {
        return reduce((u64) x * y);
    }

    u32 transform(u32 x) const {
        return (u64(x) << 32) % n;
        // can also be implemented as multiply(x, r^2 mod n)
    }
};

To test its performance, we can plug Montgomery multiplication into the binary exponentiation:

constexpr Montgomery space(M);

int inverse(int _a) {
    u64 a = space.transform(_a);
    u64 r = space.transform(1);
    
    #pragma GCC unroll(30)
    for (int l = 0; l < 30; l++) {
        if ( (M - 2) >> l & 1 )
            r = space.multiply(r, a);
        a = space.multiply(a, a);
    }

    return space.reduce(r);
}

While vanilla binary exponentiation with a compiler-generated fast modulo trick requires ~170ns per inverse call, this implementation takes ~166ns, going down to ~158ns we omit transform and reduce (a reasonable use case is for inverse to be used as a subprocedure in a bigger modular computation). This is a small improvement, but Montgomery multiplication becomes much more advantageous for SIMD applications and larger data types.

Exercise. Implement efficient modular matix multiplication.