Binary Exponentiation - Algorithmica
Binary Exponentiation

Binary Exponentiation

In modular arithmetic (and computational algebra in general), you often need to raise a number to the nn-th power — to do modular division, perform primality tests, or compute some combinatorial values — ­and you usually want to spend fewer than Θ(n)\Theta(n) operations calculating it.

Binary exponentiation, also known as exponentiation by squaring, is a method that allows for computation of the nn-th power using O(logn)O(\log n) multiplications, relying on the following observation:

a2k=(ak)2a2k+1=(ak)2a \begin{aligned} a^{2k} &= (a^k)^2 \\ a^{2k + 1} &= (a^k)^2 \cdot a \end{aligned} To compute ana^n, we can recursively compute an/2a^{\lfloor n / 2 \rfloor}, square it, and then optionally multiply by aa if nn is odd, corresponding to the following recurrence: an=f(a,n)={1,n=0f(a,n2)2,2nf(a,n1)a,2n a^n = f(a, n) = \begin{cases} 1, && n = 0 \\ f(a, \frac{n}{2})^2, && 2 \mid n \\ f(a, n - 1) \cdot a, && 2 \nmid n \end{cases}

Since nn is at least halved every two recursive transitions, the depth of this recurrence and the total number of multiplications will be at most O(logn)O(\log n).

#Recursive Implementation

As we already have a recurrence, it is natural to implement the algorithm as a case matching recursive function:

const int M = 1e9 + 7; // modulo
typedef unsigned long long u64;

u64 binpow(u64 a, u64 n) {
    if (n == 0)
        return 1;
    if (n % 2 == 1)
        return binpow(a, n - 1) * a % M;
    else {
        u64 b = binpow(a, n / 2);
        return b * b % M;
    }
}

In our benchmark, we use n=m2n = m - 2 so that we compute the multiplicative inverse of aa modulo mm:

u64 inverse(u64 a) {
    return binpow(a, M - 2);
}

We use m=109+7m = 10^9+7, which is a modulo value commonly used in competitive programming to calculate checksums in combinatorial problems — because it is prime (allowing inverse via binary exponentiation), sufficiently large, not overflowing int in addition, not overflowing long long in multiplication, and easy to type as 1e9 + 7.

Since we use it as compile-time constant in the code, the compiler can optimize the modulo by replacing it with multiplication (even if it is not a compile-time constant, it is still cheaper to compute the magic constants by hand once and use them for fast reduction).

The execution path — and consequently the running time — depends on the value of nn. For this particular nn, the baseline implementation takes around 330ns per call. As recursion introduces some overhead, it makes sense to unroll the implementation into an iterative procedure.

#Iterative Implementation

The result of ana^n can be represented as the product of aa to some powers of two — those that correspond to 1s in the binary representation of nn. For example, if n=42=32+8+2n = 42 = 32 + 8 + 2, then

a42=a32+8+2=a32a8a2 a^{42} = a^{32+8+2} = a^{32} \cdot a^8 \cdot a^2

To calculate this product, we can iterate over the bits of nn maintaining two variables: the value of a2ka^{2^k} and the current product after considering kk lowest bits of nn. On each step, we multiply the current product by a2ka^{2^k} if the kk-th bit of nn is set, and, in either case, square aka^k to get a2k2=a2k+1a^{2^k \cdot 2} = a^{2^{k+1}} that will be used on the next iteration.

u64 binpow(u64 a, u64 n) {
    u64 r = 1;
    
    while (n) {
        if (n & 1)
            r = res * a % M;
        a = a * a % M;
        n >>= 1;
    }
    
    return r;
}

The iterative implementation takes about 180ns per call. The heavy calculations are the same; the improvement mainly comes from the reduced dependency chain: a = a * a % M needs to finish before the loop can proceed, and it can now execute concurrently with r = res * a % M.

The performance also benefits from nn being a constant, making all branches predictable and letting the scheduler know what needs to be executed in advance. The compiler, however, does not take advantage of it and does not unroll the while(n) n >>= 1 loop. We can rewrite it as a for loop that performs constant 30 iterations:

u64 inverse(u64 a) {
    u64 r = 1;
    
    #pragma GCC unroll(30)
    for (int l = 0; l < 30; l++) {
        if ( (M - 2) >> l & 1 )
            r = r * a % M;
        a = a * a % M;
    }

    return r;
}

This forces the compiler to generate only the instructions we need, shaving off another 10ns and making the total running time ~170ns.

Note that the performance depends not only on the binary length of nn, but also on the number of binary 1s. If nn is 2302^{30}, it takes around 20ns less as we don’t have to to perform any off-path multiplications.