py: Implement bit-shift and not operations for mpz.

Implement not, shl and shr in mpz library.  Add function to create mpzs
on the stack, used for memory efficiency when rhs is a small int.
Factor out code to parse base-prefix of number into a dedicated function.
This commit is contained in:
Damien George 2014-03-01 19:50:50 +00:00
parent 793838a919
commit 06201ff3d6
9 changed files with 273 additions and 135 deletions

203
py/mpz.c
View File

@ -10,19 +10,27 @@
#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ #if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
#define DIG_SIZE (15) #define DIG_SIZE (MPZ_DIG_SIZE)
#define DIG_MASK ((1 << DIG_SIZE) - 1) #define DIG_MASK ((1 << DIG_SIZE) - 1)
/* /*
definition of normalise: mpz is an arbitrary precision integer type with a public API.
?
mpn functions act on non-negative integers represented by an array of generalised
digits (eg a word per digit). You also need to specify separately the length of the
array. There is no public API for mpn. Rather, the functions are used by mpz to
implement its features.
Integer values are stored little endian (first digit is first in memory).
Definition of normalise: ?
*/ */
/* compares i with j /* compares i with j
returns sign(i - j) returns sign(i - j)
assumes i, j are normalised assumes i, j are normalised
*/ */
int mpn_cmp(const mpz_dig_t *idig, uint ilen, const mpz_dig_t *jdig, uint jlen) { STATIC int mpn_cmp(const mpz_dig_t *idig, uint ilen, const mpz_dig_t *jdig, uint jlen) {
if (ilen < jlen) { return -1; } if (ilen < jlen) { return -1; }
if (ilen > jlen) { return 1; } if (ilen > jlen) { return 1; }
@ -37,39 +45,46 @@ int mpn_cmp(const mpz_dig_t *idig, uint ilen, const mpz_dig_t *jdig, uint jlen)
/* computes i = j << n /* computes i = j << n
returns number of digits in i returns number of digits in i
assumes enough memory in i; assumes normalised j assumes enough memory in i; assumes normalised j; assumes n > 0
can have i, j pointing to same memory can have i, j pointing to same memory
*/ */
/* unfinished STATIC uint mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) {
uint mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) { uint n_whole = (n + DIG_SIZE - 1) / DIG_SIZE;
uint n_whole = n / DIG_SIZE;
uint n_part = n % DIG_SIZE; uint n_part = n % DIG_SIZE;
idig += jlen + n_whole + 1; // start from the high end of the digit arrays
idig += jlen + n_whole - 1;
jdig += jlen - 1;
for (uint i = jlen; i > 0; --i, ++idig, ++jdig) { // shift the digits
mpz_dbl_dig_t d = *jdig; mpz_dbl_dig_t d = 0;
if (i > 1) { for (uint i = jlen; i > 0; i--, idig--, jdig--) {
d |= jdig[1] << DIG_SIZE; d |= *jdig;
} *idig = d >> (DIG_SIZE - n_part);
d <<= n_part; d <<= DIG_SIZE;
*idig = d & DIG_MASK;
} }
if (idig[-1] == 0) { // store remaining bits
--jlen; *idig = d >> (DIG_SIZE - n_part);
idig -= n_whole - 1;
memset(idig, 0, n_whole - 1);
// work out length of result
jlen += n_whole;
if (idig[jlen - 1] == 0) {
jlen--;
} }
// return length of result
return jlen; return jlen;
} }
*/
/* computes i = j >> n /* computes i = j >> n
returns number of digits in i returns number of digits in i
assumes enough memory in i; assumes normalised j assumes enough memory in i; assumes normalised j; assumes n > 0
can have i, j pointing to same memory can have i, j pointing to same memory
*/ */
uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) { STATIC uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) {
uint n_whole = n / DIG_SIZE; uint n_whole = n / DIG_SIZE;
uint n_part = n % DIG_SIZE; uint n_part = n % DIG_SIZE;
@ -80,7 +95,7 @@ uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) {
jdig += n_whole; jdig += n_whole;
jlen -= n_whole; jlen -= n_whole;
for (uint i = jlen; i > 0; --i, ++idig, ++jdig) { for (uint i = jlen; i > 0; i--, idig++, jdig++) {
mpz_dbl_dig_t d = *jdig; mpz_dbl_dig_t d = *jdig;
if (i > 1) { if (i > 1) {
d |= jdig[1] << DIG_SIZE; d |= jdig[1] << DIG_SIZE;
@ -90,7 +105,7 @@ uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) {
} }
if (idig[-1] == 0) { if (idig[-1] == 0) {
--jlen; jlen--;
} }
return jlen; return jlen;
@ -101,7 +116,7 @@ uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) {
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
can have i, j, k pointing to same memory can have i, j, k pointing to same memory
*/ */
uint mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) { STATIC uint mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) {
mpz_dig_t *oidig = idig; mpz_dig_t *oidig = idig;
mpz_dbl_dig_t carry = 0; mpz_dbl_dig_t carry = 0;
@ -131,7 +146,7 @@ uint mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t
assumes enough memory in i; assumes normalised j, k; assumes j >= k assumes enough memory in i; assumes normalised j, k; assumes j >= k
can have i, j, k pointing to same memory can have i, j, k pointing to same memory
*/ */
uint mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) { STATIC uint mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) {
mpz_dig_t *oidig = idig; mpz_dig_t *oidig = idig;
mpz_dbl_dig_signed_t borrow = 0; mpz_dbl_dig_signed_t borrow = 0;
@ -159,7 +174,7 @@ uint mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t
returns number of digits in i returns number of digits in i
assumes enough memory in i; assumes normalised i; assumes dmul != 0 assumes enough memory in i; assumes normalised i; assumes dmul != 0
*/ */
uint mpn_mul_dig_add_dig(mpz_dig_t *idig, uint ilen, mpz_dig_t dmul, mpz_dig_t dadd) { STATIC uint mpn_mul_dig_add_dig(mpz_dig_t *idig, uint ilen, mpz_dig_t dmul, mpz_dig_t dadd) {
mpz_dig_t *oidig = idig; mpz_dig_t *oidig = idig;
mpz_dbl_dig_t carry = dadd; mpz_dbl_dig_t carry = dadd;
@ -181,7 +196,7 @@ uint mpn_mul_dig_add_dig(mpz_dig_t *idig, uint ilen, mpz_dig_t dmul, mpz_dig_t d
assumes enough memory in i; assumes i is zeroed; assumes normalised j, k assumes enough memory in i; assumes i is zeroed; assumes normalised j, k
can have j, k point to same memory can have j, k point to same memory
*/ */
uint mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, mpz_dig_t *kdig, uint klen) { STATIC uint mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, mpz_dig_t *kdig, uint klen) {
mpz_dig_t *oidig = idig; mpz_dig_t *oidig = idig;
uint ilen = 0; uint ilen = 0;
@ -214,7 +229,7 @@ uint mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, mpz_dig_t *kdig, uint
modifies den_dig memory, but restors it to original state at end modifies den_dig memory, but restors it to original state at end
*/ */
void mpn_div(mpz_dig_t *num_dig, machine_uint_t *num_len, mpz_dig_t *den_dig, machine_uint_t den_len, mpz_dig_t *quo_dig, machine_uint_t *quo_len) { STATIC void mpn_div(mpz_dig_t *num_dig, machine_uint_t *num_len, mpz_dig_t *den_dig, machine_uint_t den_len, mpz_dig_t *quo_dig, machine_uint_t *quo_len) {
mpz_dig_t *orig_num_dig = num_dig; mpz_dig_t *orig_num_dig = num_dig;
mpz_dig_t *orig_quo_dig = quo_dig; mpz_dig_t *orig_quo_dig = quo_dig;
mpz_dig_t norm_shift = 0; mpz_dig_t norm_shift = 0;
@ -343,9 +358,7 @@ void mpn_div(mpz_dig_t *num_dig, machine_uint_t *num_len, mpz_dig_t *den_dig, ma
} }
} }
#define MIN_ALLOC (4) #define MIN_ALLOC (2)
#define ALIGN_ALLOC (2)
#define NUM_DIG_FOR_INT (sizeof(machine_int_t) * 8 / DIG_SIZE + 1)
static const uint log_base2_floor[] = { static const uint log_base2_floor[] = {
0, 0,
@ -359,13 +372,10 @@ static const uint log_base2_floor[] = {
4, 4, 4, 5 4, 4, 4, 5
}; };
bool mpz_int_is_sml_int(machine_int_t i) {
return -(1 << DIG_SIZE) < i && i < (1 << DIG_SIZE);
}
void mpz_init_zero(mpz_t *z) { void mpz_init_zero(mpz_t *z) {
z->alloc = 0;
z->neg = 0; z->neg = 0;
z->fixed_dig = 0;
z->alloc = 0;
z->len = 0; z->len = 0;
z->dig = NULL; z->dig = NULL;
} }
@ -375,8 +385,17 @@ void mpz_init_from_int(mpz_t *z, machine_int_t val) {
mpz_set_from_int(z, val); mpz_set_from_int(z, val);
} }
void mpz_init_fixed_from_int(mpz_t *z, mpz_dig_t *dig, uint alloc, machine_int_t val) {
z->neg = 0;
z->fixed_dig = 1;
z->alloc = alloc;
z->len = 0;
z->dig = dig;
mpz_set_from_int(z, val);
}
void mpz_deinit(mpz_t *z) { void mpz_deinit(mpz_t *z) {
if (z != NULL) { if (z != NULL && !z->fixed_dig) {
m_del(mpz_dig_t, z->dig, z->alloc); m_del(mpz_dig_t, z->dig, z->alloc);
} }
} }
@ -407,23 +426,26 @@ void mpz_free(mpz_t *z) {
} }
STATIC void mpz_need_dig(mpz_t *z, uint need) { STATIC void mpz_need_dig(mpz_t *z, uint need) {
uint alloc;
if (need < MIN_ALLOC) { if (need < MIN_ALLOC) {
alloc = MIN_ALLOC; need = MIN_ALLOC;
} else {
alloc = (need + ALIGN_ALLOC) & (~(ALIGN_ALLOC - 1));
} }
if (z->dig == NULL || z->alloc < alloc) { if (z->dig == NULL || z->alloc < need) {
z->dig = m_renew(mpz_dig_t, z->dig, z->alloc, alloc); if (z->fixed_dig) {
z->alloc = alloc; // cannot reallocate fixed buffers
assert(0);
return;
}
z->dig = m_renew(mpz_dig_t, z->dig, z->alloc, need);
z->alloc = need;
} }
} }
mpz_t *mpz_clone(const mpz_t *src) { mpz_t *mpz_clone(const mpz_t *src) {
mpz_t *z = m_new_obj(mpz_t); mpz_t *z = m_new_obj(mpz_t);
z->alloc = src->alloc;
z->neg = src->neg; z->neg = src->neg;
z->fixed_dig = 0;
z->alloc = src->alloc;
z->len = src->len; z->len = src->len;
if (src->dig == NULL) { if (src->dig == NULL) {
z->dig = NULL; z->dig = NULL;
@ -434,6 +456,9 @@ mpz_t *mpz_clone(const mpz_t *src) {
return z; return z;
} }
/* sets dest = src
can have dest, src the same
*/
void mpz_set(mpz_t *dest, const mpz_t *src) { void mpz_set(mpz_t *dest, const mpz_t *src) {
mpz_need_dig(dest, src->len); mpz_need_dig(dest, src->len);
dest->neg = src->neg; dest->neg = src->neg;
@ -442,7 +467,7 @@ void mpz_set(mpz_t *dest, const mpz_t *src) {
} }
void mpz_set_from_int(mpz_t *z, machine_int_t val) { void mpz_set_from_int(mpz_t *z, machine_int_t val) {
mpz_need_dig(z, NUM_DIG_FOR_INT); mpz_need_dig(z, MPZ_NUM_DIG_FOR_INT);
if (val < 0) { if (val < 0) {
z->neg = 1; z->neg = 1;
@ -527,6 +552,9 @@ int mpz_cmp(const mpz_t *z1, const mpz_t *z2) {
return cmp; return cmp;
} }
#if 0
// obsolete
// compares mpz with an integer that fits within DIG_SIZE bits
int mpz_cmp_sml_int(const mpz_t *z, machine_int_t sml_int) { int mpz_cmp_sml_int(const mpz_t *z, machine_int_t sml_int) {
int cmp; int cmp;
if (z->neg == 0) { if (z->neg == 0) {
@ -554,6 +582,7 @@ int mpz_cmp_sml_int(const mpz_t *z, machine_int_t sml_int) {
if (cmp > 0) return 1; if (cmp > 0) return 1;
return 0; return 0;
} }
#endif
#if 0 #if 0
these functions are unused these functions are unused
@ -631,50 +660,71 @@ void mpz_neg_inpl(mpz_t *dest, const mpz_t *z) {
dest->neg = 1 - dest->neg; dest->neg = 1 - dest->neg;
} }
#if 0 /* computes dest = ~z (= -z - 1)
not finished can have dest, z the same
*/
void mpz_not_inpl(mpz_t *dest, const mpz_t *z) {
if (dest != z) {
mpz_set(dest, z);
}
if (dest->neg) {
dest->neg = 0;
mpz_dig_t k = 1;
dest->len = mpn_sub(dest->dig, dest->dig, dest->len, &k, 1);
} else {
mpz_dig_t k = 1;
dest->len = mpn_add(dest->dig, dest->dig, dest->len, &k, 1);
dest->neg = 1;
}
}
/* computes dest = lhs << rhs /* computes dest = lhs << rhs
can have dest, lhs the same can have dest, lhs the same
*/ */
void mpz_shl_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs) { void mpz_shl_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs) {
if (dest != lhs) { if (lhs->len == 0 || rhs == 0) {
mpz_set(dest, lhs); mpz_set(dest, lhs);
} } else if (rhs < 0) {
mpz_shr_inpl(dest, lhs, -rhs);
if (dest.len == 0 || rhs == 0) {
return dest;
}
if (rhs < 0) {
dest->len = mpn_shr(dest->len, dest->dig, -rhs);
} else { } else {
dest->len = mpn_shl(dest->len, dest->dig, rhs); mpz_need_dig(dest, lhs->len + (rhs + DIG_SIZE - 1) / DIG_SIZE);
dest->len = mpn_shl(dest->dig, lhs->dig, lhs->len, rhs);
dest->neg = lhs->neg;
} }
return dest;
} }
/* computes dest = lhs >> rhs /* computes dest = lhs >> rhs
can have dest, lhs the same can have dest, lhs the same
*/ */
void mpz_shr_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs) { void mpz_shr_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs) {
if (dest != lhs) { if (lhs->len == 0 || rhs == 0) {
mpz_set(dest, lhs); mpz_set(dest, lhs);
} } else if (rhs < 0) {
mpz_shl_inpl(dest, lhs, -rhs);
if (dest.len == 0 || rhs == 0) {
return dest;
}
if (rhs < 0) {
dest->len = mpn_shl(dest->len, dest->dig, -rhs);
} else { } else {
dest->len = mpn_shr(dest->len, dest->dig, rhs); mpz_need_dig(dest, lhs->len);
dest->len = mpn_shr(dest->dig, lhs->dig, lhs->len, rhs);
dest->neg = lhs->neg;
if (dest->neg) {
// arithmetic shift right, rounding to negative infinity
uint n_whole = rhs / DIG_SIZE;
uint n_part = rhs % DIG_SIZE;
mpz_dig_t round_up = 0;
for (uint i = 0; i < lhs->len && i < n_whole; i++) {
if (lhs->dig[i] != 0) {
round_up = 1;
break;
}
}
if (n_whole < lhs->len && (lhs->dig[n_whole] & ((1 << n_part) - 1)) != 0) {
round_up = 1;
}
if (round_up) {
dest->len = mpn_add(dest->dig, dest->dig, dest->len, &round_up, 1);
}
}
} }
return dest;
} }
#endif
/* computes dest = lhs + rhs /* computes dest = lhs + rhs
can have dest, lhs, rhs the same can have dest, lhs, rhs the same
@ -931,12 +981,11 @@ machine_int_t mpz_as_int(const mpz_t *i) {
machine_int_t val = 0; machine_int_t val = 0;
mpz_dig_t *d = i->dig + i->len; mpz_dig_t *d = i->dig + i->len;
while (--d >= i->dig) while (--d >= i->dig) {
{
machine_int_t oldval = val; machine_int_t oldval = val;
val = (val << DIG_SIZE) | *d; val = (val << DIG_SIZE) | *d;
if (val < oldval) if (val < oldval) {
{ // TODO need better handling of conversion overflow
if (i->neg == 0) { if (i->neg == 0) {
return 0x7fffffff; return 0x7fffffff;
} else { } else {

View File

@ -4,15 +4,21 @@ typedef int32_t mpz_dbl_dig_signed_t;
typedef struct _mpz_t { typedef struct _mpz_t {
machine_uint_t neg : 1; machine_uint_t neg : 1;
machine_uint_t alloc : 31; machine_uint_t fixed_dig : 1;
machine_uint_t alloc : 30;
machine_uint_t len; machine_uint_t len;
mpz_dig_t *dig; mpz_dig_t *dig;
} mpz_t; } mpz_t;
bool mpz_int_is_sml_int(machine_int_t i); #define MPZ_DIG_SIZE (15) // see mpn_div for why this needs to be at most 15
#define MPZ_NUM_DIG_FOR_INT (sizeof(machine_int_t) * 8 / MPZ_DIG_SIZE + 1)
// convenience macro to declare an mpz with a digit array from the stack, initialised by an integer
#define MPZ_CONST_INT(z, val) mpz_t z; mpz_dig_t z ## _digits[MPZ_NUM_DIG_FOR_INT]; mpz_init_fixed_from_int(&z, z_digits, MPZ_NUM_DIG_FOR_INT, val);
void mpz_init_zero(mpz_t *z); void mpz_init_zero(mpz_t *z);
void mpz_init_from_int(mpz_t *z, machine_int_t val); void mpz_init_from_int(mpz_t *z, machine_int_t val);
void mpz_init_fixed_from_int(mpz_t *z, mpz_dig_t *dig, uint dig_alloc, machine_int_t val);
void mpz_deinit(mpz_t *z); void mpz_deinit(mpz_t *z);
mpz_t *mpz_zero(); mpz_t *mpz_zero();
@ -33,7 +39,6 @@ bool mpz_is_odd(const mpz_t *z);
bool mpz_is_even(const mpz_t *z); bool mpz_is_even(const mpz_t *z);
int mpz_cmp(const mpz_t *lhs, const mpz_t *rhs); int mpz_cmp(const mpz_t *lhs, const mpz_t *rhs);
int mpz_cmp_sml_int(const mpz_t *lhs, machine_int_t sml_int);
mpz_t *mpz_abs(const mpz_t *z); mpz_t *mpz_abs(const mpz_t *z);
mpz_t *mpz_neg(const mpz_t *z); mpz_t *mpz_neg(const mpz_t *z);
@ -44,8 +49,9 @@ mpz_t *mpz_pow(const mpz_t *lhs, const mpz_t *rhs);
void mpz_abs_inpl(mpz_t *dest, const mpz_t *z); void mpz_abs_inpl(mpz_t *dest, const mpz_t *z);
void mpz_neg_inpl(mpz_t *dest, const mpz_t *z); void mpz_neg_inpl(mpz_t *dest, const mpz_t *z);
//void mpz_shl_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs); void mpz_not_inpl(mpz_t *dest, const mpz_t *z);
//void mpz_shr_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs); void mpz_shl_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs);
void mpz_shr_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs);
void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
void mpz_mul_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); void mpz_mul_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);

View File

@ -7,6 +7,7 @@
#include "misc.h" #include "misc.h"
#include "mpconfig.h" #include "mpconfig.h"
#include "qstr.h" #include "qstr.h"
#include "parsenumbase.h"
#include "obj.h" #include "obj.h"
#include "mpz.h" #include "mpz.h"
#include "objint.h" #include "objint.h"
@ -39,17 +40,20 @@ mp_obj_t int_unary_op(int op, mp_obj_t o_in) {
case RT_UNARY_OP_BOOL: return MP_BOOL(!mpz_is_zero(&o->mpz)); case RT_UNARY_OP_BOOL: return MP_BOOL(!mpz_is_zero(&o->mpz));
case RT_UNARY_OP_POSITIVE: return o_in; case RT_UNARY_OP_POSITIVE: return o_in;
case RT_UNARY_OP_NEGATIVE: { mp_obj_int_t *o2 = mp_obj_int_new_mpz(); mpz_neg_inpl(&o2->mpz, &o->mpz); return o2; } case RT_UNARY_OP_NEGATIVE: { mp_obj_int_t *o2 = mp_obj_int_new_mpz(); mpz_neg_inpl(&o2->mpz, &o->mpz); return o2; }
//case RT_UNARY_OP_INVERT: ~ not implemented for mpz case RT_UNARY_OP_INVERT: { mp_obj_int_t *o2 = mp_obj_int_new_mpz(); mpz_not_inpl(&o2->mpz, &o->mpz); return o2; }
default: return NULL; // op not supported default: return NULL; // op not supported
} }
} }
mp_obj_t int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { mp_obj_t int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
mpz_t *zlhs = &((mp_obj_int_t*)lhs_in)->mpz; mpz_t *zlhs = &((mp_obj_int_t*)lhs_in)->mpz;
mpz_t *zrhs; const mpz_t *zrhs;
mpz_t z_int;
mpz_dig_t z_int_dig[MPZ_NUM_DIG_FOR_INT];
if (MP_OBJ_IS_SMALL_INT(rhs_in)) { if (MP_OBJ_IS_SMALL_INT(rhs_in)) {
zrhs = mpz_from_int(MP_OBJ_SMALL_INT_VALUE(rhs_in)); mpz_init_fixed_from_int(&z_int, z_int_dig, MPZ_NUM_DIG_FOR_INT, MP_OBJ_SMALL_INT_VALUE(rhs_in));
zrhs = &z_int;
} else if (MP_OBJ_IS_TYPE(rhs_in, &int_type)) { } else if (MP_OBJ_IS_TYPE(rhs_in, &int_type)) {
zrhs = &((mp_obj_int_t*)rhs_in)->mpz; zrhs = &((mp_obj_int_t*)rhs_in)->mpz;
} else { } else {
@ -95,10 +99,22 @@ mp_obj_t int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
//case RT_BINARY_OP_XOR: //case RT_BINARY_OP_XOR:
//case RT_BINARY_OP_INPLACE_XOR: //case RT_BINARY_OP_INPLACE_XOR:
//case RT_BINARY_OP_LSHIFT: case RT_BINARY_OP_LSHIFT:
//case RT_BINARY_OP_INPLACE_LSHIFT: case RT_BINARY_OP_INPLACE_LSHIFT:
//case RT_BINARY_OP_RSHIFT: case RT_BINARY_OP_RSHIFT:
//case RT_BINARY_OP_INPLACE_RSHIFT: case RT_BINARY_OP_INPLACE_RSHIFT: {
// TODO check conversion overflow
machine_int_t irhs = mpz_as_int(zrhs);
if (irhs < 0) {
nlr_jump(mp_obj_new_exception_msg(&mp_type_ValueError, "negative shift count"));
}
if (op == RT_BINARY_OP_LSHIFT || op == RT_BINARY_OP_INPLACE_LSHIFT) {
mpz_shl_inpl(&res->mpz, zlhs, irhs);
} else {
mpz_shr_inpl(&res->mpz, zlhs, irhs);
}
break;
}
case RT_BINARY_OP_POWER: case RT_BINARY_OP_POWER:
case RT_BINARY_OP_INPLACE_POWER: case RT_BINARY_OP_INPLACE_POWER:
@ -158,7 +174,11 @@ mp_obj_t mp_obj_new_int_from_uint(machine_uint_t value) {
mp_obj_t mp_obj_new_int_from_long_str(const char *str) { mp_obj_t mp_obj_new_int_from_long_str(const char *str) {
mp_obj_int_t *o = mp_obj_int_new_mpz(); mp_obj_int_t *o = mp_obj_int_new_mpz();
uint len = strlen(str); uint len = strlen(str);
uint n = mpz_set_from_str(&o->mpz, str, len, false, 10); int base = 0;
int skip = mp_parse_num_base(str, len, &base);
str += skip;
len -= skip;
uint n = mpz_set_from_str(&o->mpz, str, len, false, base);
if (n != len) { if (n != len) {
nlr_jump(mp_obj_new_exception_msg(&mp_type_SyntaxError, "invalid syntax for number")); nlr_jump(mp_obj_new_exception_msg(&mp_type_SyntaxError, "invalid syntax for number"));
} }

View File

@ -10,6 +10,7 @@
#include "mpconfig.h" #include "mpconfig.h"
#include "qstr.h" #include "qstr.h"
#include "lexer.h" #include "lexer.h"
#include "parsenumbase.h"
#include "parse.h" #include "parse.h"
#define RULE_ACT_KIND_MASK (0xf0) #define RULE_ACT_KIND_MASK (0xf0)
@ -241,23 +242,8 @@ STATIC void push_result_token(parser_t *parser, const mp_lexer_t *lex) {
machine_int_t int_val = 0; machine_int_t int_val = 0;
int len = tok->len; int len = tok->len;
const char *str = tok->str; const char *str = tok->str;
int base = 10; int base = 0;
int i = 0; int i = mp_parse_num_base(str, len, &base);
if (len >= 3 && str[0] == '0') {
if (str[1] == 'o' || str[1] == 'O') {
// octal
base = 8;
i = 2;
} else if (str[1] == 'x' || str[1] == 'X') {
// hexadecimal
base = 16;
i = 2;
} else if (str[1] == 'b' || str[1] == 'B') {
// binary
base = 2;
i = 2;
}
}
bool overflow = false; bool overflow = false;
for (; i < len; i++) { for (; i < len; i++) {
machine_int_t old_val = int_val; machine_int_t old_val = int_val;

View File

@ -5,6 +5,7 @@
#include "qstr.h" #include "qstr.h"
#include "nlr.h" #include "nlr.h"
#include "obj.h" #include "obj.h"
#include "parsenumbase.h"
#include "parsenum.h" #include "parsenum.h"
#if defined(UNIX) #if defined(UNIX)
@ -33,38 +34,15 @@ mp_obj_t mp_parse_num_integer(const char *restrict str, uint len, int base) {
// preced sign // preced sign
if (c == '+' || c == '-') { if (c == '+' || c == '-') {
neg = - (c == '-'); neg = - (c == '-');
c = *(p++);
}
// find real radix base, and strip preced '0x', '0o' and '0b'
// TODO somehow merge with similar code in parse.c
if ((base == 0 || base == 16) && c == '0') {
c = *(p++);
if ((c | 32) == 'x') {
base = 16;
} else if (base == 0 && (c | 32) == 'o') {
base = 8;
} else if (base == 0 && (c | 32) == 'b') {
base = 2;
} else { } else {
base = 10;
p -= 2;
}
} else if (base == 8 && c == '0') {
c = *(p++);
if ((c | 32) != 'o') {
p -= 2;
}
} else if (base == 2 && c == '0') {
c = *(p++);
if ((c | 32) != 'b') {
p -= 2;
}
} else {
if (base == 0) base = 10;
p--; p--;
} }
len -= p - str;
int skip = mp_parse_num_base(p, len, &base);
p += skip;
len -= skip;
errno = 0; errno = 0;
found = strtol(p, &num, base); found = strtol(p, &num, base);
if (errno) { if (errno) {

40
py/parsenumbase.c Normal file
View File

@ -0,0 +1,40 @@
#include "misc.h"
#include "mpconfig.h"
#include "parsenumbase.h"
// find real radix base, and strip preceding '0x', '0o' and '0b'
// puts base in *base, and returns number of bytes to skip the prefix
int mp_parse_num_base(const char *str, uint len, int *base) {
const char *p = str;
int c = *(p++);
if ((*base == 0 || *base == 16) && c == '0') {
c = *(p++);
if ((c | 32) == 'x') {
*base = 16;
} else if (*base == 0 && (c | 32) == 'o') {
*base = 8;
} else if (*base == 0 && (c | 32) == 'b') {
*base = 2;
} else {
*base = 10;
p -= 2;
}
} else if (*base == 8 && c == '0') {
c = *(p++);
if ((c | 32) != 'o') {
p -= 2;
}
} else if (*base == 2 && c == '0') {
c = *(p++);
if ((c | 32) != 'b') {
p -= 2;
}
} else {
if (*base == 0) {
*base = 10;
}
p--;
}
return p - str;
}

1
py/parsenumbase.h Normal file
View File

@ -0,0 +1 @@
int mp_parse_num_base(const char *str, uint len, int *base);

View File

@ -32,6 +32,7 @@ PY_O_BASENAME = \
asmthumb.o \ asmthumb.o \
emitnthumb.o \ emitnthumb.o \
emitinlinethumb.o \ emitinlinethumb.o \
parsenumbase.o \
parsenum.o \ parsenum.o \
runtime.o \ runtime.o \
map.o \ map.o \

57
tests/basics/int-mpz.py Normal file
View File

@ -0,0 +1,57 @@
# to test arbitrariy precision integers
x = 1000000000000000000000000000000
y = 2000000000000000000000000000000
# printing
print(x)
print(y)
# addition
print(x + 1)
print(x + y)
# subtraction
print(x - 1)
print(x - y)
print(y - x)
# multiplication
print(x * 2)
print(x * y)
# integer division
print(x // 2)
print(y // x)
# bit inversion
print(~x)
print(~(-x))
# left shift
x = 0x10000000000000000000000
for i in range(32):
x = x << 1
print(x)
# right shift
x = 0x10000000000000000000000
for i in range(32):
x = x >> 1
print(x)
# left shift of a negative number
for i in range(8):
print(-10000000000000000000000000 << i)
print(-10000000000000000000000001 << i)
print(-10000000000000000000000002 << i)
print(-10000000000000000000000003 << i)
print(-10000000000000000000000004 << i)
# right shift of a negative number
for i in range(8):
print(-10000000000000000000000000 >> i)
print(-10000000000000000000000001 >> i)
print(-10000000000000000000000002 >> i)
print(-10000000000000000000000003 >> i)
print(-10000000000000000000000004 >> i)