In modular arithmetic (and computational algebra in general), you often need to raise a number to the -th power — to do modular division, perform primality tests, or compute some combinatorial values — and you usually want to spend fewer than operations calculating it.
Binary exponentiation, also known as exponentiation by squaring, is a method that allows for computation of the -th power using multiplications, relying on the following observation:
To compute , we can recursively compute , square it, and then optionally multiply by if is odd, corresponding to the following recurrence:Since is at least halved every two recursive transitions, the depth of this recurrence and the total number of multiplications will be at most .
#Recursive Implementation
As we already have a recurrence, it is natural to implement the algorithm as a case matching recursive function:
const int M = 1e9 + 7; // modulo
typedef unsigned long long u64;
u64 binpow(u64 a, u64 n) {
if (n == 0)
return 1;
if (n % 2 == 1)
return binpow(a, n - 1) * a % M;
else {
u64 b = binpow(a, n / 2);
return b * b % M;
}
}
In our benchmark, we use so that we compute the multiplicative inverse of modulo :
u64 inverse(u64 a) {
return binpow(a, M - 2);
}
We use , which is a modulo value commonly used in competitive programming to calculate checksums in combinatorial problems — because it is prime (allowing inverse via binary exponentiation), sufficiently large, not overflowing int
in addition, not overflowing long long
in multiplication, and easy to type as 1e9 + 7
.
Since we use it as compile-time constant in the code, the compiler can optimize the modulo by replacing it with multiplication (even if it is not a compile-time constant, it is still cheaper to compute the magic constants by hand once and use them for fast reduction).
The execution path — and consequently the running time — depends on the value of . For this particular , the baseline implementation takes around 330ns per call. As recursion introduces some overhead, it makes sense to unroll the implementation into an iterative procedure.
#Iterative Implementation
The result of can be represented as the product of to some powers of two — those that correspond to 1s in the binary representation of . For example, if , then
To calculate this product, we can iterate over the bits of maintaining two variables: the value of and the current product after considering lowest bits of . On each step, we multiply the current product by if the -th bit of is set, and, in either case, square to get that will be used on the next iteration.
u64 binpow(u64 a, u64 n) {
u64 r = 1;
while (n) {
if (n & 1)
r = res * a % M;
a = a * a % M;
n >>= 1;
}
return r;
}
The iterative implementation takes about 180ns per call. The heavy calculations are the same; the improvement mainly comes from the reduced dependency chain: a = a * a % M
needs to finish before the loop can proceed, and it can now execute concurrently with r = res * a % M
.
The performance also benefits from being a constant, making all branches predictable and letting the scheduler know what needs to be executed in advance. The compiler, however, does not take advantage of it and does not unroll the while(n) n >>= 1
loop. We can rewrite it as a for
loop that performs constant 30 iterations:
u64 inverse(u64 a) {
u64 r = 1;
#pragma GCC unroll(30)
for (int l = 0; l < 30; l++) {
if ( (M - 2) >> l & 1 )
r = r * a % M;
a = a * a % M;
}
return r;
}
This forces the compiler to generate only the instructions we need, shaving off another 10ns and making the total running time ~170ns.
Note that the performance depends not only on the binary length of , but also on the number of binary 1s. If is , it takes around 20ns less as we don’t have to to perform any off-path multiplications.