From 06201ff3d6d9485b2657fc9ac4aa8a306884322f Mon Sep 17 00:00:00 2001 From: Damien George Date: Sat, 1 Mar 2014 19:50:50 +0000 Subject: [PATCH] py: Implement bit-shift and not operations for mpz. Implement not, shl and shr in mpz library. Add function to create mpzs on the stack, used for memory efficiency when rhs is a small int. Factor out code to parse base-prefix of number into a dedicated function. --- py/mpz.c | 203 +++++++++++++++++++++++++--------------- py/mpz.h | 16 +++- py/objint_mpz.c | 36 +++++-- py/parse.c | 20 +--- py/parsenum.c | 34 ++----- py/parsenumbase.c | 40 ++++++++ py/parsenumbase.h | 1 + py/py.mk | 1 + tests/basics/int-mpz.py | 57 +++++++++++ 9 files changed, 273 insertions(+), 135 deletions(-) create mode 100644 py/parsenumbase.c create mode 100644 py/parsenumbase.h create mode 100644 tests/basics/int-mpz.py diff --git a/py/mpz.c b/py/mpz.c index 8dc6b37b9b..6e54d154fe 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -10,19 +10,27 @@ #if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ -#define DIG_SIZE (15) +#define DIG_SIZE (MPZ_DIG_SIZE) #define DIG_MASK ((1 << DIG_SIZE) - 1) /* - definition of normalise: - ? + mpz is an arbitrary precision integer type with a public API. + + mpn functions act on non-negative integers represented by an array of generalised + digits (eg a word per digit). You also need to specify separately the length of the + array. There is no public API for mpn. Rather, the functions are used by mpz to + implement its features. + + Integer values are stored little endian (first digit is first in memory). + + Definition of normalise: ? */ /* compares i with j returns sign(i - j) assumes i, j are normalised */ -int mpn_cmp(const mpz_dig_t *idig, uint ilen, const mpz_dig_t *jdig, uint jlen) { +STATIC int mpn_cmp(const mpz_dig_t *idig, uint ilen, const mpz_dig_t *jdig, uint jlen) { if (ilen < jlen) { return -1; } if (ilen > jlen) { return 1; } @@ -37,39 +45,46 @@ int mpn_cmp(const mpz_dig_t *idig, uint ilen, const mpz_dig_t *jdig, uint jlen) /* computes i = j << n returns number of digits in i - assumes enough memory in i; assumes normalised j + assumes enough memory in i; assumes normalised j; assumes n > 0 can have i, j pointing to same memory */ -/* unfinished -uint mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) { - uint n_whole = n / DIG_SIZE; +STATIC uint mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) { + uint n_whole = (n + DIG_SIZE - 1) / DIG_SIZE; uint n_part = n % DIG_SIZE; - idig += jlen + n_whole + 1; + // start from the high end of the digit arrays + idig += jlen + n_whole - 1; + jdig += jlen - 1; - for (uint i = jlen; i > 0; --i, ++idig, ++jdig) { - mpz_dbl_dig_t d = *jdig; - if (i > 1) { - d |= jdig[1] << DIG_SIZE; - } - d <<= n_part; - *idig = d & DIG_MASK; + // shift the digits + mpz_dbl_dig_t d = 0; + for (uint i = jlen; i > 0; i--, idig--, jdig--) { + d |= *jdig; + *idig = d >> (DIG_SIZE - n_part); + d <<= DIG_SIZE; } - if (idig[-1] == 0) { - --jlen; + // store remaining bits + *idig = d >> (DIG_SIZE - n_part); + idig -= n_whole - 1; + memset(idig, 0, n_whole - 1); + + // work out length of result + jlen += n_whole; + if (idig[jlen - 1] == 0) { + jlen--; } + // return length of result return jlen; } -*/ /* computes i = j >> n returns number of digits in i - assumes enough memory in i; assumes normalised j + assumes enough memory in i; assumes normalised j; assumes n > 0 can have i, j pointing to same memory */ -uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) { +STATIC uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) { uint n_whole = n / DIG_SIZE; uint n_part = n % DIG_SIZE; @@ -80,7 +95,7 @@ uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) { jdig += n_whole; jlen -= n_whole; - for (uint i = jlen; i > 0; --i, ++idig, ++jdig) { + for (uint i = jlen; i > 0; i--, idig++, jdig++) { mpz_dbl_dig_t d = *jdig; if (i > 1) { d |= jdig[1] << DIG_SIZE; @@ -90,7 +105,7 @@ uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) { } if (idig[-1] == 0) { - --jlen; + jlen--; } return jlen; @@ -101,7 +116,7 @@ uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) { assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen can have i, j, k pointing to same memory */ -uint mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) { +STATIC uint mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) { mpz_dig_t *oidig = idig; mpz_dbl_dig_t carry = 0; @@ -131,7 +146,7 @@ uint mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t assumes enough memory in i; assumes normalised j, k; assumes j >= k can have i, j, k pointing to same memory */ -uint mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) { +STATIC uint mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) { mpz_dig_t *oidig = idig; mpz_dbl_dig_signed_t borrow = 0; @@ -159,7 +174,7 @@ uint mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t returns number of digits in i assumes enough memory in i; assumes normalised i; assumes dmul != 0 */ -uint mpn_mul_dig_add_dig(mpz_dig_t *idig, uint ilen, mpz_dig_t dmul, mpz_dig_t dadd) { +STATIC uint mpn_mul_dig_add_dig(mpz_dig_t *idig, uint ilen, mpz_dig_t dmul, mpz_dig_t dadd) { mpz_dig_t *oidig = idig; mpz_dbl_dig_t carry = dadd; @@ -181,7 +196,7 @@ uint mpn_mul_dig_add_dig(mpz_dig_t *idig, uint ilen, mpz_dig_t dmul, mpz_dig_t d assumes enough memory in i; assumes i is zeroed; assumes normalised j, k can have j, k point to same memory */ -uint mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, mpz_dig_t *kdig, uint klen) { +STATIC uint mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, mpz_dig_t *kdig, uint klen) { mpz_dig_t *oidig = idig; uint ilen = 0; @@ -214,7 +229,7 @@ uint mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, mpz_dig_t *kdig, uint modifies den_dig memory, but restors it to original state at end */ -void mpn_div(mpz_dig_t *num_dig, machine_uint_t *num_len, mpz_dig_t *den_dig, machine_uint_t den_len, mpz_dig_t *quo_dig, machine_uint_t *quo_len) { +STATIC void mpn_div(mpz_dig_t *num_dig, machine_uint_t *num_len, mpz_dig_t *den_dig, machine_uint_t den_len, mpz_dig_t *quo_dig, machine_uint_t *quo_len) { mpz_dig_t *orig_num_dig = num_dig; mpz_dig_t *orig_quo_dig = quo_dig; mpz_dig_t norm_shift = 0; @@ -343,9 +358,7 @@ void mpn_div(mpz_dig_t *num_dig, machine_uint_t *num_len, mpz_dig_t *den_dig, ma } } -#define MIN_ALLOC (4) -#define ALIGN_ALLOC (2) -#define NUM_DIG_FOR_INT (sizeof(machine_int_t) * 8 / DIG_SIZE + 1) +#define MIN_ALLOC (2) static const uint log_base2_floor[] = { 0, @@ -359,13 +372,10 @@ static const uint log_base2_floor[] = { 4, 4, 4, 5 }; -bool mpz_int_is_sml_int(machine_int_t i) { - return -(1 << DIG_SIZE) < i && i < (1 << DIG_SIZE); -} - void mpz_init_zero(mpz_t *z) { - z->alloc = 0; z->neg = 0; + z->fixed_dig = 0; + z->alloc = 0; z->len = 0; z->dig = NULL; } @@ -375,8 +385,17 @@ void mpz_init_from_int(mpz_t *z, machine_int_t val) { mpz_set_from_int(z, val); } +void mpz_init_fixed_from_int(mpz_t *z, mpz_dig_t *dig, uint alloc, machine_int_t val) { + z->neg = 0; + z->fixed_dig = 1; + z->alloc = alloc; + z->len = 0; + z->dig = dig; + mpz_set_from_int(z, val); +} + void mpz_deinit(mpz_t *z) { - if (z != NULL) { + if (z != NULL && !z->fixed_dig) { m_del(mpz_dig_t, z->dig, z->alloc); } } @@ -407,23 +426,26 @@ void mpz_free(mpz_t *z) { } STATIC void mpz_need_dig(mpz_t *z, uint need) { - uint alloc; if (need < MIN_ALLOC) { - alloc = MIN_ALLOC; - } else { - alloc = (need + ALIGN_ALLOC) & (~(ALIGN_ALLOC - 1)); + need = MIN_ALLOC; } - if (z->dig == NULL || z->alloc < alloc) { - z->dig = m_renew(mpz_dig_t, z->dig, z->alloc, alloc); - z->alloc = alloc; + if (z->dig == NULL || z->alloc < need) { + if (z->fixed_dig) { + // cannot reallocate fixed buffers + assert(0); + return; + } + z->dig = m_renew(mpz_dig_t, z->dig, z->alloc, need); + z->alloc = need; } } mpz_t *mpz_clone(const mpz_t *src) { mpz_t *z = m_new_obj(mpz_t); - z->alloc = src->alloc; z->neg = src->neg; + z->fixed_dig = 0; + z->alloc = src->alloc; z->len = src->len; if (src->dig == NULL) { z->dig = NULL; @@ -434,6 +456,9 @@ mpz_t *mpz_clone(const mpz_t *src) { return z; } +/* sets dest = src + can have dest, src the same +*/ void mpz_set(mpz_t *dest, const mpz_t *src) { mpz_need_dig(dest, src->len); dest->neg = src->neg; @@ -442,7 +467,7 @@ void mpz_set(mpz_t *dest, const mpz_t *src) { } void mpz_set_from_int(mpz_t *z, machine_int_t val) { - mpz_need_dig(z, NUM_DIG_FOR_INT); + mpz_need_dig(z, MPZ_NUM_DIG_FOR_INT); if (val < 0) { z->neg = 1; @@ -527,6 +552,9 @@ int mpz_cmp(const mpz_t *z1, const mpz_t *z2) { return cmp; } +#if 0 +// obsolete +// compares mpz with an integer that fits within DIG_SIZE bits int mpz_cmp_sml_int(const mpz_t *z, machine_int_t sml_int) { int cmp; if (z->neg == 0) { @@ -554,6 +582,7 @@ int mpz_cmp_sml_int(const mpz_t *z, machine_int_t sml_int) { if (cmp > 0) return 1; return 0; } +#endif #if 0 these functions are unused @@ -631,50 +660,71 @@ void mpz_neg_inpl(mpz_t *dest, const mpz_t *z) { dest->neg = 1 - dest->neg; } -#if 0 -not finished +/* computes dest = ~z (= -z - 1) + can have dest, z the same +*/ +void mpz_not_inpl(mpz_t *dest, const mpz_t *z) { + if (dest != z) { + mpz_set(dest, z); + } + if (dest->neg) { + dest->neg = 0; + mpz_dig_t k = 1; + dest->len = mpn_sub(dest->dig, dest->dig, dest->len, &k, 1); + } else { + mpz_dig_t k = 1; + dest->len = mpn_add(dest->dig, dest->dig, dest->len, &k, 1); + dest->neg = 1; + } +} + /* computes dest = lhs << rhs can have dest, lhs the same */ void mpz_shl_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs) { - if (dest != lhs) { + if (lhs->len == 0 || rhs == 0) { mpz_set(dest, lhs); - } - - if (dest.len == 0 || rhs == 0) { - return dest; - } - - if (rhs < 0) { - dest->len = mpn_shr(dest->len, dest->dig, -rhs); + } else if (rhs < 0) { + mpz_shr_inpl(dest, lhs, -rhs); } else { - dest->len = mpn_shl(dest->len, dest->dig, rhs); + mpz_need_dig(dest, lhs->len + (rhs + DIG_SIZE - 1) / DIG_SIZE); + dest->len = mpn_shl(dest->dig, lhs->dig, lhs->len, rhs); + dest->neg = lhs->neg; } - - return dest; } /* computes dest = lhs >> rhs can have dest, lhs the same */ void mpz_shr_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs) { - if (dest != lhs) { + if (lhs->len == 0 || rhs == 0) { mpz_set(dest, lhs); - } - - if (dest.len == 0 || rhs == 0) { - return dest; - } - - if (rhs < 0) { - dest->len = mpn_shl(dest->len, dest->dig, -rhs); + } else if (rhs < 0) { + mpz_shl_inpl(dest, lhs, -rhs); } else { - dest->len = mpn_shr(dest->len, dest->dig, rhs); + mpz_need_dig(dest, lhs->len); + dest->len = mpn_shr(dest->dig, lhs->dig, lhs->len, rhs); + dest->neg = lhs->neg; + if (dest->neg) { + // arithmetic shift right, rounding to negative infinity + uint n_whole = rhs / DIG_SIZE; + uint n_part = rhs % DIG_SIZE; + mpz_dig_t round_up = 0; + for (uint i = 0; i < lhs->len && i < n_whole; i++) { + if (lhs->dig[i] != 0) { + round_up = 1; + break; + } + } + if (n_whole < lhs->len && (lhs->dig[n_whole] & ((1 << n_part) - 1)) != 0) { + round_up = 1; + } + if (round_up) { + dest->len = mpn_add(dest->dig, dest->dig, dest->len, &round_up, 1); + } + } } - - return dest; } -#endif /* computes dest = lhs + rhs can have dest, lhs, rhs the same @@ -931,12 +981,11 @@ machine_int_t mpz_as_int(const mpz_t *i) { machine_int_t val = 0; mpz_dig_t *d = i->dig + i->len; - while (--d >= i->dig) - { + while (--d >= i->dig) { machine_int_t oldval = val; val = (val << DIG_SIZE) | *d; - if (val < oldval) - { + if (val < oldval) { + // TODO need better handling of conversion overflow if (i->neg == 0) { return 0x7fffffff; } else { diff --git a/py/mpz.h b/py/mpz.h index 13a96fd797..8f5fc3720e 100644 --- a/py/mpz.h +++ b/py/mpz.h @@ -4,15 +4,21 @@ typedef int32_t mpz_dbl_dig_signed_t; typedef struct _mpz_t { machine_uint_t neg : 1; - machine_uint_t alloc : 31; + machine_uint_t fixed_dig : 1; + machine_uint_t alloc : 30; machine_uint_t len; mpz_dig_t *dig; } mpz_t; -bool mpz_int_is_sml_int(machine_int_t i); +#define MPZ_DIG_SIZE (15) // see mpn_div for why this needs to be at most 15 +#define MPZ_NUM_DIG_FOR_INT (sizeof(machine_int_t) * 8 / MPZ_DIG_SIZE + 1) + +// convenience macro to declare an mpz with a digit array from the stack, initialised by an integer +#define MPZ_CONST_INT(z, val) mpz_t z; mpz_dig_t z ## _digits[MPZ_NUM_DIG_FOR_INT]; mpz_init_fixed_from_int(&z, z_digits, MPZ_NUM_DIG_FOR_INT, val); void mpz_init_zero(mpz_t *z); void mpz_init_from_int(mpz_t *z, machine_int_t val); +void mpz_init_fixed_from_int(mpz_t *z, mpz_dig_t *dig, uint dig_alloc, machine_int_t val); void mpz_deinit(mpz_t *z); mpz_t *mpz_zero(); @@ -33,7 +39,6 @@ bool mpz_is_odd(const mpz_t *z); bool mpz_is_even(const mpz_t *z); int mpz_cmp(const mpz_t *lhs, const mpz_t *rhs); -int mpz_cmp_sml_int(const mpz_t *lhs, machine_int_t sml_int); mpz_t *mpz_abs(const mpz_t *z); mpz_t *mpz_neg(const mpz_t *z); @@ -44,8 +49,9 @@ mpz_t *mpz_pow(const mpz_t *lhs, const mpz_t *rhs); void mpz_abs_inpl(mpz_t *dest, const mpz_t *z); void mpz_neg_inpl(mpz_t *dest, const mpz_t *z); -//void mpz_shl_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs); -//void mpz_shr_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs); +void mpz_not_inpl(mpz_t *dest, const mpz_t *z); +void mpz_shl_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs); +void mpz_shr_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs); void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); void mpz_mul_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); diff --git a/py/objint_mpz.c b/py/objint_mpz.c index a0889da9e4..6e1ee1a999 100644 --- a/py/objint_mpz.c +++ b/py/objint_mpz.c @@ -7,6 +7,7 @@ #include "misc.h" #include "mpconfig.h" #include "qstr.h" +#include "parsenumbase.h" #include "obj.h" #include "mpz.h" #include "objint.h" @@ -39,17 +40,20 @@ mp_obj_t int_unary_op(int op, mp_obj_t o_in) { case RT_UNARY_OP_BOOL: return MP_BOOL(!mpz_is_zero(&o->mpz)); case RT_UNARY_OP_POSITIVE: return o_in; case RT_UNARY_OP_NEGATIVE: { mp_obj_int_t *o2 = mp_obj_int_new_mpz(); mpz_neg_inpl(&o2->mpz, &o->mpz); return o2; } - //case RT_UNARY_OP_INVERT: ~ not implemented for mpz + case RT_UNARY_OP_INVERT: { mp_obj_int_t *o2 = mp_obj_int_new_mpz(); mpz_not_inpl(&o2->mpz, &o->mpz); return o2; } default: return NULL; // op not supported } } mp_obj_t int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { mpz_t *zlhs = &((mp_obj_int_t*)lhs_in)->mpz; - mpz_t *zrhs; + const mpz_t *zrhs; + mpz_t z_int; + mpz_dig_t z_int_dig[MPZ_NUM_DIG_FOR_INT]; if (MP_OBJ_IS_SMALL_INT(rhs_in)) { - zrhs = mpz_from_int(MP_OBJ_SMALL_INT_VALUE(rhs_in)); + mpz_init_fixed_from_int(&z_int, z_int_dig, MPZ_NUM_DIG_FOR_INT, MP_OBJ_SMALL_INT_VALUE(rhs_in)); + zrhs = &z_int; } else if (MP_OBJ_IS_TYPE(rhs_in, &int_type)) { zrhs = &((mp_obj_int_t*)rhs_in)->mpz; } else { @@ -95,10 +99,22 @@ mp_obj_t int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { //case RT_BINARY_OP_XOR: //case RT_BINARY_OP_INPLACE_XOR: - //case RT_BINARY_OP_LSHIFT: - //case RT_BINARY_OP_INPLACE_LSHIFT: - //case RT_BINARY_OP_RSHIFT: - //case RT_BINARY_OP_INPLACE_RSHIFT: + case RT_BINARY_OP_LSHIFT: + case RT_BINARY_OP_INPLACE_LSHIFT: + case RT_BINARY_OP_RSHIFT: + case RT_BINARY_OP_INPLACE_RSHIFT: { + // TODO check conversion overflow + machine_int_t irhs = mpz_as_int(zrhs); + if (irhs < 0) { + nlr_jump(mp_obj_new_exception_msg(&mp_type_ValueError, "negative shift count")); + } + if (op == RT_BINARY_OP_LSHIFT || op == RT_BINARY_OP_INPLACE_LSHIFT) { + mpz_shl_inpl(&res->mpz, zlhs, irhs); + } else { + mpz_shr_inpl(&res->mpz, zlhs, irhs); + } + break; + } case RT_BINARY_OP_POWER: case RT_BINARY_OP_INPLACE_POWER: @@ -158,7 +174,11 @@ mp_obj_t mp_obj_new_int_from_uint(machine_uint_t value) { mp_obj_t mp_obj_new_int_from_long_str(const char *str) { mp_obj_int_t *o = mp_obj_int_new_mpz(); uint len = strlen(str); - uint n = mpz_set_from_str(&o->mpz, str, len, false, 10); + int base = 0; + int skip = mp_parse_num_base(str, len, &base); + str += skip; + len -= skip; + uint n = mpz_set_from_str(&o->mpz, str, len, false, base); if (n != len) { nlr_jump(mp_obj_new_exception_msg(&mp_type_SyntaxError, "invalid syntax for number")); } diff --git a/py/parse.c b/py/parse.c index e70456e814..a7b73a5673 100644 --- a/py/parse.c +++ b/py/parse.c @@ -10,6 +10,7 @@ #include "mpconfig.h" #include "qstr.h" #include "lexer.h" +#include "parsenumbase.h" #include "parse.h" #define RULE_ACT_KIND_MASK (0xf0) @@ -241,23 +242,8 @@ STATIC void push_result_token(parser_t *parser, const mp_lexer_t *lex) { machine_int_t int_val = 0; int len = tok->len; const char *str = tok->str; - int base = 10; - int i = 0; - if (len >= 3 && str[0] == '0') { - if (str[1] == 'o' || str[1] == 'O') { - // octal - base = 8; - i = 2; - } else if (str[1] == 'x' || str[1] == 'X') { - // hexadecimal - base = 16; - i = 2; - } else if (str[1] == 'b' || str[1] == 'B') { - // binary - base = 2; - i = 2; - } - } + int base = 0; + int i = mp_parse_num_base(str, len, &base); bool overflow = false; for (; i < len; i++) { machine_int_t old_val = int_val; diff --git a/py/parsenum.c b/py/parsenum.c index 64594cd1b4..8e290da338 100644 --- a/py/parsenum.c +++ b/py/parsenum.c @@ -5,6 +5,7 @@ #include "qstr.h" #include "nlr.h" #include "obj.h" +#include "parsenumbase.h" #include "parsenum.h" #if defined(UNIX) @@ -33,38 +34,15 @@ mp_obj_t mp_parse_num_integer(const char *restrict str, uint len, int base) { // preced sign if (c == '+' || c == '-') { neg = - (c == '-'); - c = *(p++); - } - - // find real radix base, and strip preced '0x', '0o' and '0b' - // TODO somehow merge with similar code in parse.c - if ((base == 0 || base == 16) && c == '0') { - c = *(p++); - if ((c | 32) == 'x') { - base = 16; - } else if (base == 0 && (c | 32) == 'o') { - base = 8; - } else if (base == 0 && (c | 32) == 'b') { - base = 2; - } else { - base = 10; - p -= 2; - } - } else if (base == 8 && c == '0') { - c = *(p++); - if ((c | 32) != 'o') { - p -= 2; - } - } else if (base == 2 && c == '0') { - c = *(p++); - if ((c | 32) != 'b') { - p -= 2; - } } else { - if (base == 0) base = 10; p--; } + len -= p - str; + int skip = mp_parse_num_base(p, len, &base); + p += skip; + len -= skip; + errno = 0; found = strtol(p, &num, base); if (errno) { diff --git a/py/parsenumbase.c b/py/parsenumbase.c new file mode 100644 index 0000000000..ad24cc678b --- /dev/null +++ b/py/parsenumbase.c @@ -0,0 +1,40 @@ +#include "misc.h" +#include "mpconfig.h" +#include "parsenumbase.h" + +// find real radix base, and strip preceding '0x', '0o' and '0b' +// puts base in *base, and returns number of bytes to skip the prefix +int mp_parse_num_base(const char *str, uint len, int *base) { + const char *p = str; + int c = *(p++); + if ((*base == 0 || *base == 16) && c == '0') { + c = *(p++); + if ((c | 32) == 'x') { + *base = 16; + } else if (*base == 0 && (c | 32) == 'o') { + *base = 8; + } else if (*base == 0 && (c | 32) == 'b') { + *base = 2; + } else { + *base = 10; + p -= 2; + } + } else if (*base == 8 && c == '0') { + c = *(p++); + if ((c | 32) != 'o') { + p -= 2; + } + } else if (*base == 2 && c == '0') { + c = *(p++); + if ((c | 32) != 'b') { + p -= 2; + } + } else { + if (*base == 0) { + *base = 10; + } + p--; + } + return p - str; +} + diff --git a/py/parsenumbase.h b/py/parsenumbase.h new file mode 100644 index 0000000000..483596e329 --- /dev/null +++ b/py/parsenumbase.h @@ -0,0 +1 @@ +int mp_parse_num_base(const char *str, uint len, int *base); diff --git a/py/py.mk b/py/py.mk index 0285bc05fa..a12e44d4d5 100644 --- a/py/py.mk +++ b/py/py.mk @@ -32,6 +32,7 @@ PY_O_BASENAME = \ asmthumb.o \ emitnthumb.o \ emitinlinethumb.o \ + parsenumbase.o \ parsenum.o \ runtime.o \ map.o \ diff --git a/tests/basics/int-mpz.py b/tests/basics/int-mpz.py new file mode 100644 index 0000000000..0500d794cf --- /dev/null +++ b/tests/basics/int-mpz.py @@ -0,0 +1,57 @@ +# to test arbitrariy precision integers + +x = 1000000000000000000000000000000 +y = 2000000000000000000000000000000 + +# printing +print(x) +print(y) + +# addition +print(x + 1) +print(x + y) + +# subtraction +print(x - 1) +print(x - y) +print(y - x) + +# multiplication +print(x * 2) +print(x * y) + +# integer division +print(x // 2) +print(y // x) + +# bit inversion +print(~x) +print(~(-x)) + +# left shift +x = 0x10000000000000000000000 +for i in range(32): + x = x << 1 + print(x) + +# right shift +x = 0x10000000000000000000000 +for i in range(32): + x = x >> 1 + print(x) + +# left shift of a negative number +for i in range(8): + print(-10000000000000000000000000 << i) + print(-10000000000000000000000001 << i) + print(-10000000000000000000000002 << i) + print(-10000000000000000000000003 << i) + print(-10000000000000000000000004 << i) + +# right shift of a negative number +for i in range(8): + print(-10000000000000000000000000 >> i) + print(-10000000000000000000000001 >> i) + print(-10000000000000000000000002 >> i) + print(-10000000000000000000000003 >> i) + print(-10000000000000000000000004 >> i)