py/mpz: Fix mpn_div so that it doesn't modify memory of denominator.

Previous to this patch bignum division and modulo would temporarily
modify the RHS argument to the operation (eg x/y would modify y), but on
return the RHS would be restored to its original value.  This is not
allowed because arguments to binary operations are const, and in
particular might live in ROM.  The modification was to normalise the arg
(and then unnormalise before returning), and this patch makes it so the
normalisation is done on the fly and the arg is now accessed as read-only.

This change doesn't increase the order complexity of the operation, and
actually reduces code size.
This commit is contained in:
Damien George 2016-05-09 17:21:42 +01:00
parent de5e0ed2e0
commit 460b086333

View File

@ -454,10 +454,8 @@ STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_d
assumes num_dig has enough memory to be extended by 1 digit
assumes quo_dig has enough memory (as many digits as num)
assumes quo_dig is filled with zeros
modifies den_dig memory, but restors it to original state at end
*/
STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, mp_uint_t den_len, mpz_dig_t *quo_dig, mp_uint_t *quo_len) {
STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, const mpz_dig_t *den_dig, mp_uint_t den_len, mpz_dig_t *quo_dig, mp_uint_t *quo_len) {
mpz_dig_t *orig_num_dig = num_dig;
mpz_dig_t *orig_quo_dig = quo_dig;
mpz_dig_t norm_shift = 0;
@ -478,6 +476,11 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
}
}
// We need to normalise the denominator (leading bit of leading digit is 1)
// so that the division routine works. Since the denominator memory is
// read-only we do the normalisation on the fly, each time a digit of the
// denominator is needed. We need to know is how many bits to shift by.
// count number of leading zeros in leading digit of denominator
{
mpz_dig_t d = den_dig[den_len - 1];
@ -487,13 +490,6 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
}
}
// normalise denomenator (leading bit of leading digit is 1)
for (mpz_dig_t *den = den_dig, carry = 0; den < den_dig + den_len; ++den) {
mpz_dig_t d = *den;
*den = ((d << norm_shift) | carry) & DIG_MASK;
carry = (mpz_dbl_dig_t)d >> (DIG_SIZE - norm_shift);
}
// now need to shift numerator by same amount as denominator
// first, increase length of numerator in case we need more room to shift
num_dig[*num_len] = 0;
@ -505,7 +501,10 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
}
// cache the leading digit of the denominator
lead_den_digit = den_dig[den_len - 1];
lead_den_digit = (mpz_dbl_dig_t)den_dig[den_len - 1] << norm_shift;
if (den_len >= 2) {
lead_den_digit |= (mpz_dbl_dig_t)den_dig[den_len - 2] >> (DIG_SIZE - norm_shift);
}
// point num_dig to last digit in numerator
num_dig += *num_len - 1;
@ -540,10 +539,13 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
// round up).
if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) {
const mpz_dig_t *d = den_dig;
mpz_dbl_dig_t d_norm = 0;
mpz_dbl_dig_signed_t borrow = 0;
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)*d; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
*n = borrow & DIG_MASK;
borrow >>= DIG_SIZE;
}
@ -553,9 +555,12 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
// adjust quotient if it is too big
for (; borrow != 0; --quo) {
d = den_dig;
d_norm = 0;
mpz_dbl_dig_t carry = 0;
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
*n = carry & DIG_MASK;
carry >>= DIG_SIZE;
}
@ -566,10 +571,13 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
borrow += carry;
}
} else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2
const mpz_dig_t *d = den_dig;
mpz_dbl_dig_t d_norm = 0;
mpz_dbl_dig_t borrow = 0;
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)(*d);
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
if (x >= *n || *n - x <= borrow) {
borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n;
*n = (-borrow) & DIG_MASK;
@ -590,9 +598,12 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
// adjust quotient if it is too big
for (; borrow != 0; --quo) {
d = den_dig;
d_norm = 0;
mpz_dbl_dig_t carry = 0;
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
*n = carry & DIG_MASK;
carry >>= DIG_SIZE;
}
@ -614,13 +625,6 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
--(*num_len);
}
// unnormalise denomenator
for (mpz_dig_t *den = den_dig + den_len - 1, carry = 0; den >= den_dig; --den) {
mpz_dig_t d = *den;
*den = ((d >> norm_shift) | carry) & DIG_MASK;
carry = (mpz_dbl_dig_t)d << (DIG_SIZE - norm_shift);
}
// unnormalise numerator (remainder now)
for (mpz_dig_t *num = orig_num_dig + *num_len - 1, carry = 0; num >= orig_num_dig; --num) {
mpz_dig_t n = *num;
@ -1506,7 +1510,6 @@ void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const m
dest_quo->len = 0;
mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary?
mpz_set(dest_rem, lhs);
//rhs->dig[rhs->len] = 0;
mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len);
// check signs and do Python style modulo