128bit 整数型を使わない 64bit modint

Miller-Rabin素数判定法を使いたい時などに、64bit整数型でmodをとりたい場合があります。
128bit整数が使えない場合、 a \cdot b \bmod m を計算しようとすると、 a \cdot b を計算する時点でオーバーフローしてしまいうまく計算できません。
かといって、このために多倍長整数型を持ち出すと定数倍が悪くなってしまいます。
そこで、モンゴメリ乗算をうまくやることで解決します。

モンゴメリ乗算

モンゴメリ乗算自体についてはここでは深く書かないので、他の記事*1*2を参照してください。
モンゴメリ乗算に使う値をそれぞれ次のように定義します。
 R = 2^{64}
 R_2 = R^{2} \bmod N
 N' = -N^{-1} \bmod R

 R_2, N' はそれぞれ、繰り返し二倍法、ニュートン法で計算できます。
 xモンゴメリ表現を求める関数を  \mathrm{mr}(x) とすると、
 x \cdot y \bmod N \mathrm{mr}(\mathrm{mr}(x \cdot R_2) \cdot y) と計算できます。
もちろんここでも  x \cdot R_2 などを計算するタイミングでオーバーフローが発生する可能性があるため、代わりに  x \cdot yモンゴメリ表現を( x \cdot y を陽に計算せずに)求める関数  \mathrm{mul\_mr}(x, y) を考え、 \mathrm{mul\_mr}(\mathrm{mul\_mr}(x, R_2), y) と計算することにします。

前準備

 \mathrm{low}(x) := x\text{の下位64ビット}
 \mathrm{high}(x) := x\text{の上位64ビット}
と定義します。

先にこれらを計算する方法について考えます。

 \mathrm{low}(x) はオーバーフローを無視してxを計算すれば勝手に求まります。
また、 \mathrm{high}(x \cdot y) も、 x \cdot y を陽に求めず64bit整数のみを使って計算できます。
 x = x_h \cdot 2^{32} + x_l, \quad y = y_h \cdot 2^{32} + y_l というように分解すると、
 x y = x_h \cdot y_h \cdot 2^{64} + (x_h \cdot y_l + x_l \cdot y_h) \cdot 2^{32} + x_l \cdot y_l であることから、オーバーフローを避けて計算すると
 \mathrm{high}(x \cdot y) = x_h \cdot y_h + (x_h \cdot y_l \gg 32) + (x_l \cdot y_h \gg 32) + ( (x_h \cdot y_l \And (2^{32}-1)) + (x_l \cdot y_h \And (2^{32}-1)) + (x_l \cdot y_l \gg 32) ) \gg 32 です。

ちなみに、Java 9 以降であれば Math.multiplyHigh(x, y) + (x >> 63 & y) + (y >> 63 & x)
更に Java 18 以降であれば Math.unsignedMultiplyHigh(x, y) と書けます。

本題

さて、 xモンゴメリ表現は  (x + (x \cdot N' \bmod R) \cdot N) / R で求められますが、
 R=2^{64} のとき、これは  \mathrm{high}(x + \mathrm{low}(x \cdot N') \cdot N) と書けます。
各項の上位64ビットと下位64ビットからの繰り上がりを分けて考えると、
 \mathrm{high}(x + \mathrm{low}(x \cdot N') \cdot N) = \mathrm{high}(x) + \mathrm{high}(\mathrm{low}(x \cdot N') \cdot N) + \mathrm{high}(\mathrm{low}(x) + \mathrm{low}(\mathrm{low}(x \cdot N') \cdot N)) です。
前二項は前の議論から求めることができます。
 \mathrm{high}(\mathrm{low}(x) + \mathrm{low}(\mathrm{low}(x \cdot N') \cdot N)) について考えます。

 N' \cdot N = -1 \bmod R であることを思い出すと、
 \begin{align}
\mathrm{low}(x + \mathrm{low}(\mathrm{low}(x \cdot N') \cdot N)) &= \mathrm{low}(x + x \cdot N' \cdot N) \\
&= \mathrm{low}(x + (R-1) \cdot x) \\
&= 0
\end{align}
となることから、 x + \mathrm{low}(x \cdot N') \cdot N R の倍数です。これと
 0 \leq \mathrm{low}(x), \mathrm{low}(\mathrm{low}(x \cdot N') \cdot N) \lt R を考えると、

 \mathrm{high}(\mathrm{low}(x) + \mathrm{low}(\mathrm{low}(x \cdot N') \cdot N)) は、 \mathrm{low}(x) 0 のときは  0、そうでないときは  1 となります。
ここまでから、  \mathrm{mul\_mr}(x, y) = \mathrm{high}(x \cdot y) + \mathrm{high}(\mathrm{low}(x \cdot y \cdot N') \cdot N) + (\mathrm{low}(x \cdot y) \neq 0) と計算できます。

コード

この実装では  n < 2^{62} まで対応していますが、がんばると  n < 2^{63} までなら対応できると思います。

judge.yosupo.jp

class Montgomery {
    private final long n, r2, nInv;

    public Montgomery(long n) {
        long r2 = (1L << 62) % n;
        for (int i = 0; i < 66; i++) {
            r2 <<= 1;
            if (r2 >= n) {
                r2 -= n;
            }
        }
        long nInv = n;
        for (int i = 0; i < 5; ++i) {
            nInv *= 2 - n * nInv;
        }
        this.n = n;
        this.r2 = r2;
        this.nInv = nInv;
    }

    private static long high(long x, long y) {
        long xh = x >>> 32;
        long yh = y >>> 32;
        long xl = x & 0xFFFFFFFFL;
        long yl = y & 0xFFFFFFFFL;
        return xh * yh + (xh * yl >>> 32) + (xl * yh >>> 32) + ((((xh * yl & 0xFFFFFFFFL) + (xl * yh & 0xFFFFFFFFL)) + (xl * yl >>> 32)) >>> 32);
    }

    private long mulMr(long x, long y) {
        return high(x, y) + high(-nInv * x * y, n) + (x * y == 0 ? 0 : 1);
    }

    public long mod(long x) {
        return x < n ? x : x - n;
    }

    public long mul(long x, long y) {
        return mod(mulMr(mulMr(x, r2), y));
    }

    public long pow(long x, long y) {
        long z = mulMr(x, r2);
        long r = 1;
        while (y > 0) {
            if ((y & 1) == 1) {
                r = mulMr(r, z);
            }
            z = mulMr(z, z);
            y >>= 1;
        }
        return mod(r);
    }
}