py/mpz: Fix mpn_div so that it doesn't modify memory of denominator.
Previous to this patch bignum division and modulo would temporarily modify the RHS argument to the operation (eg x/y would modify y), but on return the RHS would be restored to its original value. This is not allowed because arguments to binary operations are const, and in particular might live in ROM. The modification was to normalise the arg (and then unnormalise before returning), and this patch makes it so the normalisation is done on the fly and the arg is now accessed as read-only. This change doesn't increase the order complexity of the operation, and actually reduces code size.
This commit is contained in:
parent
de5e0ed2e0
commit
460b086333
57
py/mpz.c
57
py/mpz.c
@ -454,10 +454,8 @@ STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_d
|
||||
assumes num_dig has enough memory to be extended by 1 digit
|
||||
assumes quo_dig has enough memory (as many digits as num)
|
||||
assumes quo_dig is filled with zeros
|
||||
modifies den_dig memory, but restors it to original state at end
|
||||
*/
|
||||
|
||||
STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, mp_uint_t den_len, mpz_dig_t *quo_dig, mp_uint_t *quo_len) {
|
||||
STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, const mpz_dig_t *den_dig, mp_uint_t den_len, mpz_dig_t *quo_dig, mp_uint_t *quo_len) {
|
||||
mpz_dig_t *orig_num_dig = num_dig;
|
||||
mpz_dig_t *orig_quo_dig = quo_dig;
|
||||
mpz_dig_t norm_shift = 0;
|
||||
@ -478,6 +476,11 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
|
||||
}
|
||||
}
|
||||
|
||||
// We need to normalise the denominator (leading bit of leading digit is 1)
|
||||
// so that the division routine works. Since the denominator memory is
|
||||
// read-only we do the normalisation on the fly, each time a digit of the
|
||||
// denominator is needed. We need to know is how many bits to shift by.
|
||||
|
||||
// count number of leading zeros in leading digit of denominator
|
||||
{
|
||||
mpz_dig_t d = den_dig[den_len - 1];
|
||||
@ -487,13 +490,6 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
|
||||
}
|
||||
}
|
||||
|
||||
// normalise denomenator (leading bit of leading digit is 1)
|
||||
for (mpz_dig_t *den = den_dig, carry = 0; den < den_dig + den_len; ++den) {
|
||||
mpz_dig_t d = *den;
|
||||
*den = ((d << norm_shift) | carry) & DIG_MASK;
|
||||
carry = (mpz_dbl_dig_t)d >> (DIG_SIZE - norm_shift);
|
||||
}
|
||||
|
||||
// now need to shift numerator by same amount as denominator
|
||||
// first, increase length of numerator in case we need more room to shift
|
||||
num_dig[*num_len] = 0;
|
||||
@ -505,7 +501,10 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
|
||||
}
|
||||
|
||||
// cache the leading digit of the denominator
|
||||
lead_den_digit = den_dig[den_len - 1];
|
||||
lead_den_digit = (mpz_dbl_dig_t)den_dig[den_len - 1] << norm_shift;
|
||||
if (den_len >= 2) {
|
||||
lead_den_digit |= (mpz_dbl_dig_t)den_dig[den_len - 2] >> (DIG_SIZE - norm_shift);
|
||||
}
|
||||
|
||||
// point num_dig to last digit in numerator
|
||||
num_dig += *num_len - 1;
|
||||
@ -540,10 +539,13 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
|
||||
// round up).
|
||||
|
||||
if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) {
|
||||
const mpz_dig_t *d = den_dig;
|
||||
mpz_dbl_dig_t d_norm = 0;
|
||||
mpz_dbl_dig_signed_t borrow = 0;
|
||||
|
||||
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
|
||||
borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)*d; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
|
||||
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
|
||||
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
|
||||
borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
|
||||
*n = borrow & DIG_MASK;
|
||||
borrow >>= DIG_SIZE;
|
||||
}
|
||||
@ -553,9 +555,12 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
|
||||
|
||||
// adjust quotient if it is too big
|
||||
for (; borrow != 0; --quo) {
|
||||
d = den_dig;
|
||||
d_norm = 0;
|
||||
mpz_dbl_dig_t carry = 0;
|
||||
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
|
||||
carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
|
||||
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
|
||||
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
|
||||
carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
|
||||
*n = carry & DIG_MASK;
|
||||
carry >>= DIG_SIZE;
|
||||
}
|
||||
@ -566,10 +571,13 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
|
||||
borrow += carry;
|
||||
}
|
||||
} else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2
|
||||
const mpz_dig_t *d = den_dig;
|
||||
mpz_dbl_dig_t d_norm = 0;
|
||||
mpz_dbl_dig_t borrow = 0;
|
||||
|
||||
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
|
||||
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)(*d);
|
||||
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
|
||||
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
|
||||
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
|
||||
if (x >= *n || *n - x <= borrow) {
|
||||
borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n;
|
||||
*n = (-borrow) & DIG_MASK;
|
||||
@ -590,9 +598,12 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
|
||||
|
||||
// adjust quotient if it is too big
|
||||
for (; borrow != 0; --quo) {
|
||||
d = den_dig;
|
||||
d_norm = 0;
|
||||
mpz_dbl_dig_t carry = 0;
|
||||
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
|
||||
carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
|
||||
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
|
||||
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
|
||||
carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
|
||||
*n = carry & DIG_MASK;
|
||||
carry >>= DIG_SIZE;
|
||||
}
|
||||
@ -614,13 +625,6 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
|
||||
--(*num_len);
|
||||
}
|
||||
|
||||
// unnormalise denomenator
|
||||
for (mpz_dig_t *den = den_dig + den_len - 1, carry = 0; den >= den_dig; --den) {
|
||||
mpz_dig_t d = *den;
|
||||
*den = ((d >> norm_shift) | carry) & DIG_MASK;
|
||||
carry = (mpz_dbl_dig_t)d << (DIG_SIZE - norm_shift);
|
||||
}
|
||||
|
||||
// unnormalise numerator (remainder now)
|
||||
for (mpz_dig_t *num = orig_num_dig + *num_len - 1, carry = 0; num >= orig_num_dig; --num) {
|
||||
mpz_dig_t n = *num;
|
||||
@ -1506,7 +1510,6 @@ void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const m
|
||||
dest_quo->len = 0;
|
||||
mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary?
|
||||
mpz_set(dest_rem, lhs);
|
||||
//rhs->dig[rhs->len] = 0;
|
||||
mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len);
|
||||
|
||||
// check signs and do Python style modulo
|
||||
|
Loading…
Reference in New Issue
Block a user