diff --git a/py/mpz.c b/py/mpz.c index 3b96c574f3..24c83d9133 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -1213,6 +1213,22 @@ mpz_t *mpz_mod(const mpz_t *lhs, const mpz_t *rhs) { } #endif +// must return actual int value if it fits in mp_int_t +mp_int_t mpz_hash(const mpz_t *z) { + mp_int_t val = 0; + mpz_dig_t *d = z->dig + z->len; + + while (--d >= z->dig) { + val = (val << DIG_SIZE) | *d; + } + + if (z->neg != 0) { + val = -val; + } + + return val; +} + // TODO check that this correctly handles overflow in all cases mp_int_t mpz_as_int(const mpz_t *i) { mp_int_t val = 0; diff --git a/py/mpz.h b/py/mpz.h index 76c308285e..c0cf60f4b7 100644 --- a/py/mpz.h +++ b/py/mpz.h @@ -96,6 +96,7 @@ void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const m mpz_t *mpz_div(const mpz_t *lhs, const mpz_t *rhs); mpz_t *mpz_mod(const mpz_t *lhs, const mpz_t *rhs); +mp_int_t mpz_hash(const mpz_t *z); mp_int_t mpz_as_int(const mpz_t *z); bool mpz_as_int_checked(const mpz_t *z, mp_int_t *value); #if MICROPY_PY_BUILTINS_FLOAT diff --git a/py/obj.c b/py/obj.c index cfebbf3466..755da02750 100644 --- a/py/obj.c +++ b/py/obj.c @@ -24,6 +24,7 @@ * THE SOFTWARE. */ +#include #include #include #include @@ -33,6 +34,8 @@ #include "misc.h" #include "qstr.h" #include "obj.h" +#include "mpz.h" +#include "objint.h" #include "runtime0.h" #include "runtime.h" #include "stackctrl.h" @@ -152,6 +155,8 @@ mp_int_t mp_obj_hash(mp_obj_t o_in) { return 1; // needs to hash to same as the integer 1, since True==1 } else if (MP_OBJ_IS_SMALL_INT(o_in)) { return MP_OBJ_SMALL_INT_VALUE(o_in); + } else if (MP_OBJ_IS_TYPE(o_in, &mp_type_int)) { + return mp_obj_int_hash(o_in); } else if (MP_OBJ_IS_STR(o_in) || MP_OBJ_IS_TYPE(o_in, &mp_type_bytes)) { return mp_obj_str_get_hash(o_in); } else if (MP_OBJ_IS_TYPE(o_in, &mp_type_NoneType)) { diff --git a/py/objint.c b/py/objint.c index bca32d13fa..e351c3f394 100644 --- a/py/objint.c +++ b/py/objint.c @@ -215,6 +215,10 @@ char *mp_obj_int_formatted(char **buf, int *buf_size, int *fmt_size, mp_const_ob #if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_NONE +mp_int_t mp_obj_int_hash(mp_obj_t self_in) { + return MP_OBJ_SMALL_INT_VALUE(self_in); +} + bool mp_obj_int_is_positive(mp_obj_t self_in) { return mp_obj_get_int(self_in) >= 0; } diff --git a/py/objint.h b/py/objint.h index 1d12cffae0..b4f8c9ae50 100644 --- a/py/objint.h +++ b/py/objint.h @@ -38,6 +38,7 @@ char *mp_obj_int_formatted(char **buf, int *buf_size, int *fmt_size, mp_const_ob int base, const char *prefix, char base_char, char comma); char *mp_obj_int_formatted_impl(char **buf, int *buf_size, int *fmt_size, mp_const_obj_t self_in, int base, const char *prefix, char base_char, char comma); +mp_int_t mp_obj_int_hash(mp_obj_t self_in); bool mp_obj_int_is_positive(mp_obj_t self_in); mp_obj_t mp_obj_int_unary_op(int op, mp_obj_t o_in); mp_obj_t mp_obj_int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in); diff --git a/py/objint_longlong.c b/py/objint_longlong.c index 637d9c32c9..466b45c51f 100644 --- a/py/objint_longlong.c +++ b/py/objint_longlong.c @@ -55,6 +55,16 @@ const mp_obj_int_t mp_maxsize_obj = {{&mp_type_int}, INT_MAX}; #endif +mp_int_t mp_obj_int_hash(mp_obj_t self_in) { + if (MP_OBJ_IS_SMALL_INT(self_in)) { + return MP_OBJ_SMALL_INT_VALUE(self_in); + } + mp_obj_int_t *self = self_in; + // truncate value to fit in mp_int_t, which gives the same hash as + // small int if the value fits without truncation + return self->val; +} + bool mp_obj_int_is_positive(mp_obj_t self_in) { if (MP_OBJ_IS_SMALL_INT(self_in)) { return MP_OBJ_SMALL_INT_VALUE(self_in) >= 0; diff --git a/py/objint_mpz.c b/py/objint_mpz.c index 6e1c3c5a84..c60e5c2b83 100644 --- a/py/objint_mpz.c +++ b/py/objint_mpz.c @@ -96,6 +96,14 @@ char *mp_obj_int_formatted_impl(char **buf, int *buf_size, int *fmt_size, mp_con return str; } +mp_int_t mp_obj_int_hash(mp_obj_t self_in) { + if (MP_OBJ_IS_SMALL_INT(self_in)) { + return MP_OBJ_SMALL_INT_VALUE(self_in); + } + mp_obj_int_t *self = self_in; + return mpz_hash(&self->mpz); +} + bool mp_obj_int_is_positive(mp_obj_t self_in) { if (MP_OBJ_IS_SMALL_INT(self_in)) { return MP_OBJ_SMALL_INT_VALUE(self_in) >= 0;