diff --git a/py/objdict.c b/py/objdict.c index b3e21aedd2..4cd6363796 100644 --- a/py/objdict.c +++ b/py/objdict.c @@ -17,20 +17,23 @@ typedef struct _mp_obj_dict_t { mp_map_t map; } mp_obj_dict_t; +static mp_obj_t mp_obj_new_dict_iterator(mp_obj_dict_t *dict, int cur); +static mp_map_elem_t *dict_it_iternext_elem(mp_obj_t self_in); + static void dict_print(void (*print)(void *env, const char *fmt, ...), void *env, mp_obj_t self_in) { mp_obj_dict_t *self = self_in; bool first = true; print(env, "{"); - for (int i = 0; i < self->map.alloc; i++) { - if (self->map.table[i].key != NULL) { - if (!first) { - print(env, ", "); - } - first = false; - mp_obj_print_helper(print, env, self->map.table[i].key); - print(env, ": "); - mp_obj_print_helper(print, env, self->map.table[i].value); + mp_obj_t *dict_iter = mp_obj_new_dict_iterator(self, 0); + mp_map_elem_t *next = NULL; + while ((next = dict_it_iternext_elem(dict_iter)) != NULL) { + if (!first) { + print(env, ", "); } + first = false; + mp_obj_print_helper(print, env, next->key); + print(env, ": "); + mp_obj_print_helper(print, env, next->value); } print(env, "}"); } @@ -60,13 +63,73 @@ static mp_obj_t dict_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { } } + +/******************************************************************************/ +/* dict iterator */ + +typedef struct _mp_obj_dict_it_t { + mp_obj_base_t base; + mp_obj_dict_t *dict; + machine_uint_t cur; +} mp_obj_dict_it_t; + +static mp_map_elem_t *dict_it_iternext_elem(mp_obj_t self_in) { + mp_obj_dict_it_t *self = self_in; + machine_uint_t max = self->dict->map.alloc; + mp_map_elem_t *table = self->dict->map.table; + + for (int i = self->cur; i < max; i++) { + if (table[i].key != NULL) { + self->cur = i + 1; + return &(table[i]); + } + } + + return NULL; +} + +mp_obj_t dict_it_iternext(mp_obj_t self_in) { + mp_map_elem_t *next = dict_it_iternext_elem(self_in); + + if (next != NULL) { + return next->key; + } else { + return mp_const_stop_iteration; + } +} + +static const mp_obj_type_t dict_it_type = { + { &mp_const_type }, + "dict_iterator", + .iternext = dict_it_iternext, + .methods = { { NULL, NULL }, }, +}; + +static mp_obj_t mp_obj_new_dict_iterator(mp_obj_dict_t *dict, int cur) { + mp_obj_dict_it_t *o = m_new_obj(mp_obj_dict_it_t); + o->base.type = &dict_it_type; + o->dict = dict; + o->cur = cur; + return o; +} + +static mp_obj_t dict_getiter(mp_obj_t o_in) { + return mp_obj_new_dict_iterator(o_in, 0); +} + +/******************************************************************************/ +/* dict methods */ + +/******************************************************************************/ +/* dict constructors & etc */ + const mp_obj_type_t dict_type = { { &mp_const_type }, "dict", .print = dict_print, .make_new = dict_make_new, .binary_op = dict_binary_op, - .getiter = NULL, + .getiter = dict_getiter, .methods = {{NULL, NULL},}, }; diff --git a/tests/basics/tests/dict_iterator.py b/tests/basics/tests/dict_iterator.py new file mode 100644 index 0000000000..f190e32ffd --- /dev/null +++ b/tests/basics/tests/dict_iterator.py @@ -0,0 +1,3 @@ +d = {1: 2, 3: 4} +for i in d: + print(i, d[i])