diff --git a/py/objdict.c b/py/objdict.c index 02aedacdd6..b63ea89137 100644 --- a/py/objdict.c +++ b/py/objdict.c @@ -110,8 +110,10 @@ STATIC void dict_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_ } } -mp_obj_t mp_obj_dict_make_new(const mp_obj_type_t *type, size_t n_args, size_t n_kw, const mp_obj_t *args) { - mp_obj_t dict_out = mp_obj_new_dict(0); +// This is a helper function to initialize an empty, but typed dictionary with +// a given number of slots. +STATIC mp_obj_t dict_new_typed(const mp_obj_type_t *type, const size_t n) { + mp_obj_t dict_out = mp_obj_new_dict(n); mp_obj_dict_t *dict = MP_OBJ_TO_PTR(dict_out); dict->base.type = type; #if MICROPY_PY_COLLECTIONS_ORDEREDDICT @@ -119,6 +121,11 @@ mp_obj_t mp_obj_dict_make_new(const mp_obj_type_t *type, size_t n_args, size_t n dict->map.is_ordered = 1; } #endif + return dict_out; +} + +mp_obj_t mp_obj_dict_make_new(const mp_obj_type_t *type, size_t n_args, size_t n_kw, const mp_obj_t *args) { + mp_obj_t dict_out = dict_new_typed(type, 0); if (n_args > 0 || n_kw > 0) { mp_obj_t args2[2] = {dict_out, args[0]}; // args[0] is always valid, even if it's not a positional arg mp_map_t kwargs; @@ -264,6 +271,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_1(dict_copy_obj, mp_obj_dict_copy); #if MICROPY_PY_BUILTINS_DICT_FROMKEYS // this is a classmethod STATIC mp_obj_t dict_fromkeys(size_t n_args, const mp_obj_t *args) { + mp_obj_type_t *type = MP_OBJ_TO_PTR(args[0]); mp_obj_t iter = mp_getiter(args[1], NULL); mp_obj_t value = mp_const_none; mp_obj_t next = MP_OBJ_NULL; @@ -277,9 +285,9 @@ STATIC mp_obj_t dict_fromkeys(size_t n_args, const mp_obj_t *args) { mp_obj_t len = mp_obj_len_maybe(args[1]); if (len == MP_OBJ_NULL) { /* object's type doesn't have a __len__ slot */ - self_out = mp_obj_new_dict(0); + self_out = dict_new_typed(type, 0); } else { - self_out = mp_obj_new_dict(MP_OBJ_SMALL_INT_VALUE(len)); + self_out = dict_new_typed(type, MP_OBJ_SMALL_INT_VALUE(len)); } mp_obj_dict_t *self = MP_OBJ_TO_PTR(self_out); diff --git a/tests/basics/ordereddict1.py b/tests/basics/ordereddict1.py index a6f305ff78..b70d7ff5d1 100644 --- a/tests/basics/ordereddict1.py +++ b/tests/basics/ordereddict1.py @@ -41,3 +41,14 @@ try: d.popitem() except: print('empty') + +# fromkeys returns the correct type and order +d = dict.fromkeys('abcdefghij') +print(type(d) == dict) +d = OrderedDict.fromkeys('abcdefghij') +print(type(d) == OrderedDict) +print(''.join(d)) + +# fromkey handles ordering with duplicates +d = OrderedDict.fromkeys('abcdefghijjihgfedcba') +print(''.join(d))