diff --git a/py/map.c b/py/map.c index e44cf33d26..524b6c8a5b 100644 --- a/py/map.c +++ b/py/map.c @@ -132,9 +132,29 @@ void mp_set_init(mp_set_t *set, int n) { set->table = m_new0(mp_obj_t, set->alloc); } +static void mp_set_rehash(mp_set_t *set) { + int old_alloc = set->alloc; + mp_obj_t *old_table = set->table; + set->alloc = get_doubling_prime_greater_or_equal_to(set->alloc + 1); + set->used = 0; + set->table = m_new0(mp_obj_t, set->alloc); + for (int i = 0; i < old_alloc; i++) { + if (old_table[i] != NULL) { + mp_set_lookup(set, old_table[i], true); + } + } + m_del(mp_obj_t, old_table, old_alloc); +} + mp_obj_t mp_set_lookup(mp_set_t *set, mp_obj_t index, bool add_if_not_found) { int hash = mp_obj_hash(index); - assert(set->alloc); /* FIXME: if alloc is ever 0 when doing a lookup, this'll fail: */ + if (set->alloc == 0) { + if (add_if_not_found) { + mp_set_rehash(set); + } else { + return NULL; + } + } int pos = hash % set->alloc; for (;;) { mp_obj_t elem = set->table[pos]; @@ -143,17 +163,7 @@ mp_obj_t mp_set_lookup(mp_set_t *set, mp_obj_t index, bool add_if_not_found) { if (add_if_not_found) { if (set->used + 1 >= set->alloc) { // not enough room in table, rehash it - int old_alloc = set->alloc; - mp_obj_t *old_table = set->table; - set->alloc = get_doubling_prime_greater_or_equal_to(set->alloc + 1); - set->used = 0; - set->table = m_new(mp_obj_t, set->alloc); - for (int i = 0; i < old_alloc; i++) { - if (old_table[i] != NULL) { - mp_set_lookup(set, old_table[i], true); - } - } - m_del(mp_obj_t, old_table, old_alloc); + mp_set_rehash(set); // restart the search for the new element pos = hash % set->alloc; } else { @@ -173,3 +183,13 @@ mp_obj_t mp_set_lookup(mp_set_t *set, mp_obj_t index, bool add_if_not_found) { } } } + +void mp_set_clear(mp_set_t *set) { + set->used = 0; + machine_uint_t a = set->alloc; + set->alloc = 0; + set->table = m_renew(mp_obj_t, set->table, a, set->alloc); + for (uint i=0; ialloc; i++) { + set->table[i] = NULL; + } +} diff --git a/py/map.h b/py/map.h index 5ce4e835b6..ba6bf9e6ee 100644 --- a/py/map.h +++ b/py/map.h @@ -32,3 +32,4 @@ void mp_map_clear(mp_map_t *map); void mp_set_init(mp_set_t *set, int n); mp_obj_t mp_set_lookup(mp_set_t *set, mp_obj_t index, bool add_if_not_found); +void mp_set_clear(mp_set_t *set); diff --git a/py/objset.c b/py/objset.c index a74d1eb6a3..8bd006a761 100644 --- a/py/objset.c +++ b/py/objset.c @@ -104,6 +104,16 @@ static mp_obj_t set_add(mp_obj_t self_in, mp_obj_t item) { } static MP_DEFINE_CONST_FUN_OBJ_2(set_add_obj, set_add); +static mp_obj_t set_clear(mp_obj_t self_in) { + assert(MP_OBJ_IS_TYPE(self_in, &set_type)); + mp_obj_set_t *self = self_in; + + mp_set_clear(&self->set); + + return mp_const_none; +} +static MP_DEFINE_CONST_FUN_OBJ_1(set_clear_obj, set_clear); + /******************************************************************************/ /* set constructors & public C API */ @@ -111,6 +121,7 @@ static MP_DEFINE_CONST_FUN_OBJ_2(set_add_obj, set_add); static const mp_method_t set_type_methods[] = { { "add", &set_add_obj }, + { "clear", &set_clear_obj }, { NULL, NULL }, // end-of-list sentinel }; diff --git a/tests/basics/tests/set_clear.py b/tests/basics/tests/set_clear.py new file mode 100644 index 0000000000..6fda93f0fb --- /dev/null +++ b/tests/basics/tests/set_clear.py @@ -0,0 +1,3 @@ +s = {1, 2, 3, 4} +print(s.clear()) +print(list(s))