How to calculate A * B mod N, where numbers are 64 bit integers

I made a program to calculate ‘A * B mod N’, where A, B and N are 64 bit integers.

Main problem is to avoid overflow when calculate ‘A * B’.

I searched a while and finally decided to use the interleaved modular multiplication method, which is used in the feild of electric circuit design.

I found some bugs when I tested corner cases, so I have to fix them [2018.9.2]. I think I’ve fixed them [2018.9.3]. I’ve just improved it! [2018.9.5]

The program is shown below:

$ cat modular_multiplicate.c
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <assert.h>

#define Assert assert
#define uint64 uint64_t

/*
 * Calculate (x * y) % m, where x and y in [0, 2^64), m in [1, 2^64).
 *
 * If x or y is greater than 2^32, improved interleaved modular
 * multiplication algorithm is used to avoid overflow.
 */
static uint64 modular_multiplicate(uint64 x, uint64 y, const uint64 m)
{
	int		i, bits;
	uint64		r = 0;

	Assert(1 <= m);

	/* Because of (x * y) % m = (x % m * y % m) % m */
	if (x >= m)
		x %= m;
	if (y >= m)
		y %= m;

	/* Return the trivial result. */
	if (x == 0 || y == 0 || m == 1)
		return 0;

	/* Return the result if (x * y) can be multiplicated without overflow. */
	if ((x | y) < (0xffffffff))
		return (x * y) % m;

	/* To reduce the for loop in the algorithm below. */
	if (x < y)
	{
		uint64 tmp = x;
		x = y;
		y = tmp;
	}

	/* Interleaved modular multiplication algorithm
	 *
	 *   D.N. Amanor, et al, "Efficient hardware architecture for
	 *    modular multiplication on FPGAs", in Field Programmable
	 *    Logic and Apllications, 2005. International Conference on,
	 *    Aug 2005, pp. 539-542.
	 *
	 * This algorithm is usually used in the field of digital circuit
	 * design.
	 *
	 * Input: X, Y, M; 0 <= X, Y <= M;
	 * Output: R = X *  Y mod M;
	 * bits: number of bits of Y
	 * Y[i]: i th bit of Y
	 *
	 * 1. R = 0;
	 * 2. for (i = bits - 1; i >= 0; i--) {
	 * 3. 	R = 2 * R;
	 * 4. 	if (Y[i] == 0x1)
	 * 5. 		R += X;
	 * 6. 	if (R >= M) R -= M;
	 * 7.	if (R >= M) R -= M;
	 *   }
	 *
	 * In Steps 3 and 5, overflow should be avoided.
	 * Steps 6 and 7 can be instead of a modular operation (R %= M).
	 */

	bits = 64;

	for (i = bits - 1; i >= 0; i--)
	{
		if (r > 0x7fffffffffffffff)
			/* To avoid overflow, transform from (2 * r) to
			 * (2 * r) % m, and further transform to
			 * mathematically equivalent form shown below:
			 */
			r = m - ((m - r) << 1);
		else
			r <<= 1;

		if ((y >> i) & 0x1)
		{
			if (r > UINT64CONST(0xffffffffffffffff) - x)
			      /* To calculate (r + x) without overflow, transform to (r + x) % m,
	                       * and transform to mathematically equivalent form (r + x - m).
			       */
			      r += x - m;
			else
			      r += x;
		}

		r %= m;
	}

	return r;
}

int main(int argc, char **argv) {

	if (argc != 4) {
		printf("Syntax Error:\n\tUsage:%s A B N\n", argv[0]);
		return -1;
	}

	uint64_t a  = strtouq(argv[1], NULL, 10);
	uint64_t b  = strtouq(argv[2], NULL, 10);
	uint64_t n  = strtouq(argv[3], NULL, 10);

	uint64_t r =  modular_multiplicate(a, b, n);
	printf("(%llu * %llu) %% %llu = %llu\n", a, b, n, modular_multiplicate(a, b, n));
       	return 0;
}

You can compile and run it.

$ gcc modular_multiplicate.c -o modular_multiplicate

$ ./modular_multiplicate 9223372036854775806 9223372036854775806 9223372036854775807
(9223372036854775806 * 9223372036854775806) % 9223372036854775807 = 1

$ ./modular_multiplicate 9223372036854775807 9223372036854775806 9223372036854775807
(9223372036854775807 * 9223372036854775806) % 9223372036854775807 = 0

$ ./modular_multiplicate 922337203685 9223372036854775806 9223372036854775807
(922337203685 * 9223372036854775806) % 9223372036854775807 = 9223371114517572122

$ ./modular_multiplicate 922337203685477580 9223372036854775806 9223372036854775807
(922337203685477580 * 9223372036854775806) % 9223372036854775807 = 8301034833169298227

$ ./modular_multiplicate 9223372036854775807 9223372036854775807 9223372036854775806
(9223372036854775807 * 9223372036854775807) % 9223372036854775806 = 1

$ ./modular_multiplicate 18446744073709551615 18446744073709551615 18446744073709551615
(18446744073709551615 * 18446744073709551615) % 18446744073709551615 = 0

$ ./modular_multiplicate 18446744073709551614 18446744073709551614 18446744073709551615
(18446744073709551614 * 18446744073709551614) % 18446744073709551615 = 1

$ ./modular_multiplicate 18446744073709551615 18446744073709551615 18446744073709551614
(18446744073709551615 * 18446744073709551615) % 18446744073709551614 = 1

$ ./modular_multiplicate 184467440737095516 18446744073709551 18446744073709551615
(184467440737095516 * 18446744073709551) % 18446744073709551615 = 9405072465980814891

$ ./modular_multiplicate 18446744073709551613 18446744073709551614 18446744073709551615
(18446744073709551613 * 18446744073709551614) % 18446744073709551615 = 2

$ ./modular_multiplicate 18446744073709551612 18446744073709551614 18446744073709551615
(18446744073709551612 * 18446744073709551614) % 18446744073709551615 = 3

I provide a simple test suite.

$ cat ./modular_multiplicate_test.sh
#!/bin/bash

##
## Usage: modular_multiplicate.sh < modular_multiplicate.dat
##

PROG="./modular_multiplicate"

while read line ; do
    LINE=`echo $line | grep "^TEST"`
    if [ -n "${LINE}" ]; then
	set $LINE
	STMT="${PROG} ${3} ${5} ${7}"
	RET=`${STMT} | awk '{print $7}'`
	if [ $RET != ${9} ]; then
	    echo "ERROR:${2}"
	else
	    echo ${2} ${STMT}
	    RET=`${STMT}`
	    echo "$RET"
	fi
    fi
done

exit 0

$ cat modular_multiplicate.dat
##
## Usage: modular_multiplicate.sh < modular_multiplicate.dat
##

TEST [1] 92233720 x 854775806 % 5807 = 2875
TEST [2] 92233720 x 854775806 % 4775807 = 3519799

##  9223372036854775807 = (0x7fffffffffffffff)
TEST [3] 9223372036854775806 x 9223372036854775806 % 9223372036854775807 = 1
TEST [4] 9223372036854775805 x 9223372036854775806 % 9223372036854775807 = 2
TEST [5] 9223372036854775804 x 9223372036854775806 % 9223372036854775807  = 3
TEST [6] 11 x 9223372036854775806 % 9223372036854775807 = 9223372036854775796
TEST [7] 2 x 9223372036854775806 % 9223372036854775807  = 9223372036854775805
TEST [8] 9223372036854775807 x 9223372036854775806 % 9223372036854775807 = 0
TEST [9] 922337203685 x 9223372036854775806 % 9223372036854775807 = 9223371114517572122
TEST [10] 922337203685477580 x 9223372036854775806 % 9223372036854775807 = 8301034833169298227

TEST [11] 9223372036854775807 x 9223372036854775807 % 9223372036854775806 = 1

## 18446744073709551615 = (0xffffffffffffffff)
TEST [12] 18446744073709551615 x 18446744073709551615 % 18446744073709551615 = 0
TEST [13] 18446744073709551614 x 18446744073709551614 % 18446744073709551615 = 1
TEST [14] 18446744073709551613 x 18446744073709551614 % 18446744073709551615  = 2
TEST [15] 18446744073709551612 x 18446744073709551614 % 18446744073709551615 = 3
TEST [16] 11 x 18446744073709551614 % 18446744073709551615 = 18446744073709551604
TEST [17] 2 x 18446744073709551614 % 18446744073709551615  = 18446744073709551613
TEST [18] 184467440737095516 x 18446744073709551 % 18446744073709551615  = 9405072465980814891
TEST [19] 18446744073709551613 x 18446744073709551614 % 18446744073709551615  = 2
TEST [20] 18446744073709551612 x 18446744073709551614 % 18446744073709551615  = 3

TEST [21] 18446744073709551615 x 18446744073709551615 % 18446744073709551614  = 1