Skip to content

mul_mod for 2^31 < m < 2^32 #111

@mizar

Description

@mizar

mul_mod seems easily be adapted to the case 2^31 < m < 2^32 by simply improving the last subtraction borrow check.

(current code: ac-library)

https://github.com/atcoder/ac-library/blob/6c88a70c8f95fef575af354900d107fbd0db1a12/atcoder/internal_math.hpp#L22-L62

(current code: ac-library-rs)

/// Fast modular by barrett reduction
/// Reference: https://en.wikipedia.org/wiki/Barrett_reduction
/// NOTE: reconsider after Ice Lake
pub(crate) struct Barrett {
pub(crate) _m: u32,
pub(crate) im: u64,
}
impl Barrett {
/// # Arguments
/// * `m` `1 <= m`
/// (Note: `m <= 2^31` should also hold, which is undocumented in the original library.
/// See the [pull reqeust commment](https://github.com/rust-lang-ja/ac-library-rs/pull/3#discussion_r484661007)
/// for more details.)
pub(crate) fn new(m: u32) -> Barrett {
Barrett {
_m: m,
im: (-1i64 as u64 / m as u64).wrapping_add(1),
}
}
/// # Returns
/// `m`
pub(crate) fn umod(&self) -> u32 {
self._m
}
/// # Parameters
/// * `a` `0 <= a < m`
/// * `b` `0 <= b < m`
///
/// # Returns
/// a * b % m
#[allow(clippy::many_single_char_names)]
pub(crate) fn mul(&self, a: u32, b: u32) -> u32 {
mul_mod(a, b, self._m, self.im)
}
}
/// Calculates `a * b % m`.
///
/// * `a` `0 <= a < m`
/// * `b` `0 <= b < m`
/// * `m` `1 <= m <= 2^31`
/// * `im` = ceil(2^64 / `m`)
#[allow(clippy::many_single_char_names)]
pub(crate) fn mul_mod(a: u32, b: u32, m: u32, im: u64) -> u32 {
// [1] m = 1
// a = b = im = 0, so okay
// [2] m >= 2
// im = ceil(2^64 / m)
// -> im * m = 2^64 + r (0 <= r < m)
// let z = a*b = c*m + d (0 <= c, d < m)
// a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
// c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
// ((ab * im) >> 64) == c or c + 1
let mut z = a as u64;
z *= b as u64;
let x = (((z as u128) * (im as u128)) >> 64) as u64;
let mut v = z.wrapping_sub(x.wrapping_mul(m as u64)) as u32;
if m <= v {
v = v.wrapping_add(m);
}
v
}

  • $m\in\text{自然数(natural number)}\mathbb{N},\quad 1\le m\lt 2^{32}$
  • $\lbrace a,b\rbrace\in\text{整数(integer)}\mathbb{Z},\quad 0\le \lbrace a, b\rbrace\lt m$
  • $\displaystyle\bar{m'} = \left\lfloor\frac{2^{64}-1}{m}\right\rfloor+1\mod2^{64}=\left\lceil\frac{2^{64}}{m}\right\rceil\mod 2^{64}$
  • $\displaystyle x=\left\lfloor\frac{ab\bar{m'}}{2^{64}}\right\rfloor$
  • $ab\mod m=ab-xm\quad(ab\ge xm)$
  • $ab\mod m=ab-xm+m\quad(ab\lt xm)$

(proof)

  1. when $m=1$, $a=b=\bar{m'}=0$, so okey
  2. when $2\le m\lt 2^{32}$,
    • $2^{32}+2=\left\lceil\frac{2^{64}}{2^{32}-1}\right\rceil\le\bar{m'}=\left\lceil\frac{2^{64}}{m}\right\rceil\le \left\lceil\frac{2^{64}}{2}\right\rceil=2^{63}$
    • $\bar{m'}\hspace{.1em}m=2^{64}+r\quad(0\le r\lt m)$
    • $z = ab = cm + d\quad(0\le\lbrace c,d\rbrace\lt m)$
    • $z\hspace{.1em}\bar{m'}=ab\hspace{.1em}\bar{m'}=(cm+d)\hspace{.1em}\bar{m'}=c(\bar{m'}\hspace{.1em}m)+d\hspace{.1em}\bar{m'}=2^{64}c+c\hspace{.1em}r+d\hspace{.1em}\bar{m'}$
    • $2^{64}c\le z\hspace{.1em}\bar{m'}\lt 2^{64}(c+2)$
      • $z\hspace{.1em}\bar{m'}=2^{64}c+c\hspace{.1em}r+d\hspace{.1em}\bar{m'}$
      • $0\le c\hspace{.1em}r\le (m-1)^2\le(2^{32}-2)^2=2^{64}-2^{34}+4$
      • $0\le d\hspace{.1em}\bar{m'}\le\bar{m'}\hspace{.1em}(m-1)=2^{64}+r-\bar{m'}\le 2^{64}+(2^{32}-2)-(2^{32}+2)=2^{64}-4$
    • $x=\left\lfloor\frac{ab\hspace{.1em}\bar{m'}}{2^{64}}\right\rfloor=\lbrace c$ or $(c+1)\rbrace$
    • $z-xm=ab-\left\lfloor\frac{ab\hspace{.1em}\bar{m'}}{2^{64}}\right\rfloor m=\lbrace d$ or $(d-m)\rbrace$

(C++: $1\le m\lt 2^{32}$ draft code)

https://godbolt.org/z/9Gz1oGrTa

#ifdef _MSC_VER #include <intrin.h> #endif // @param a `0 <= a < m` // @param b `0 <= b < m` // @return `a * b % m` unsigned int barrett_mul_before(unsigned int a, unsigned int b, unsigned int _m, unsigned long long im) { // [1] m = 1 // a = b = im = 0, so okay // [2] m >= 2 // im = ceil(2^64 / m) // -> im * m = 2^64 + r (0 <= r < m) // let z = a*b = c*m + d (0 <= c, d < m) // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2 // ((ab * im) >> 64) == c or c + 1 unsigned long long z = a; z *= b; #ifdef _MSC_VER unsigned long long x; _umul128(z, im, &x); #else unsigned long long x = (unsigned long long)(((unsigned __int128)(z)*im) >> 64); #endif unsigned int v = (unsigned int)(z - x * _m); if (_m <= v) v += _m; return v; } // @param a `0 <= a < m` // @param b `0 <= b < m` // @return `a * b % m` unsigned int barrett_mul_after(unsigned int a, unsigned int b, unsigned int _m, unsigned long long im) { // [1] m = 1 // a = b = im = 0, so okay // [2] m >= 2 // im = ceil(2^64 / m) // -> im * m = 2^64 + r (0 <= r < m) // let z = a*b = c*m + d (0 <= c, d < m) // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2 // ((ab * im) >> 64) == c or c + 1 unsigned long long z = a; z *= b; #ifdef _MSC_VER unsigned long long x; _umul128(z, im, &x); #else unsigned long long x = (unsigned long long)(((unsigned __int128)(z)*im) >> 64); #endif unsigned long long y = x * _m; return (unsigned int)(z - y + (z < y ? _m : 0)); }

(Rust: $1\le m\lt 2^{32}$ draft code)

https://rust.godbolt.org/z/7P5rjahMn

/// Calculates `a * b % m`. /// /// * `a` `0 <= a < m` /// * `b` `0 <= b < m` /// * `m` `1 <= m <= 2^31` /// * `im` = ceil(2^64 / `m`) #[allow(clippy::many_single_char_names)] pub fn mul_mod_before(a: u32, b: u32, m: u32, im: u64) -> u32 { // [1] m = 1 // a = b = im = 0, so okay // [2] m >= 2 // im = ceil(2^64 / m) // -> im * m = 2^64 + r (0 <= r < m) // let z = a*b = c*m + d (0 <= c, d < m) // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2 // ((ab * im) >> 64) == c or c + 1 let mut z = a as u64; z *= b as u64; let x = (((z as u128) * (im as u128)) >> 64) as u64; let mut v = z.wrapping_sub(x.wrapping_mul(m as u64)) as u32; if m <= v { v = v.wrapping_add(m); } v } /// Calculates `a * b % m`. /// /// * `a` `0 <= a < m` /// * `b` `0 <= b < m` /// * `m` `1 <= m < 2^32` /// * `im` = ceil(2^64 / `m`) = floor((2^64 - 1) / `m`) + 1 #[allow(clippy::many_single_char_names)] pub fn mul_mod_after(a: u32, b: u32, m: u32, im: u64) -> u32 { // [1] m = 1 // a = b = im = 0, so okay // [2] m >= 2 // im = ceil(2^64 / m) // -> im * m = 2^64 + r (0 <= r < m) // let z = a*b = c*m + d (0 <= c, d < m) // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2 // ((ab * im) >> 64) == c or c + 1 let z = (a as u64) * (b as u64); let x = (((z as u128) * (im as u128)) >> 64) as u64; match z.overflowing_sub(x.wrapping_mul(m as u64)) { (v, true) => (v as u32).wrapping_add(m), (v, false) => v as u32, } }

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions