diff --git a/py/modbuiltins.c b/py/modbuiltins.c index 5461816af0..5d4ec8ffe4 100644 --- a/py/modbuiltins.c +++ b/py/modbuiltins.c @@ -173,46 +173,24 @@ STATIC mp_obj_t mp_builtin_chr(mp_obj_t o_in) { MP_DEFINE_CONST_FUN_OBJ_1(mp_builtin_chr_obj, mp_builtin_chr); STATIC mp_obj_t mp_builtin_dir(size_t n_args, const mp_obj_t *args) { - // TODO make this function more general and less of a hack - - mp_obj_dict_t *dict = NULL; - mp_map_t *members = NULL; - if (n_args == 0) { - // make a list of names in the local name space - dict = mp_locals_get(); - } else { // n_args == 1 - // make a list of names in the given object - if (MP_OBJ_IS_TYPE(args[0], &mp_type_module)) { - dict = mp_obj_module_get_globals(args[0]); - } else { - mp_obj_type_t *type; - if (MP_OBJ_IS_TYPE(args[0], &mp_type_type)) { - type = MP_OBJ_TO_PTR(args[0]); - } else { - type = mp_obj_get_type(args[0]); - } - if (type->locals_dict != NULL && type->locals_dict->base.type == &mp_type_dict) { - dict = type->locals_dict; - } - } - if (mp_obj_is_instance_type(mp_obj_get_type(args[0]))) { - mp_obj_instance_t *inst = MP_OBJ_TO_PTR(args[0]); - members = &inst->members; - } - } - mp_obj_t dir = mp_obj_new_list(0, NULL); - if (dict != NULL) { + if (n_args == 0) { + // Make a list of names in the local namespace + mp_obj_dict_t *dict = mp_locals_get(); for (size_t i = 0; i < dict->map.alloc; i++) { if (MP_MAP_SLOT_IS_FILLED(&dict->map, i)) { mp_obj_list_append(dir, dict->map.table[i].key); } } - } - if (members != NULL) { - for (size_t i = 0; i < members->alloc; i++) { - if (MP_MAP_SLOT_IS_FILLED(members, i)) { - mp_obj_list_append(dir, members->table[i].key); + } else { // n_args == 1 + // Make a list of names in the given object + // Implemented by probing all possible qstrs with mp_load_method_maybe + size_t nqstr = QSTR_TOTAL(); + for (size_t i = 1; i < nqstr; ++i) { + mp_obj_t dest[2]; + mp_load_method_maybe(args[0], i, dest); + if (dest[0] != MP_OBJ_NULL) { + mp_obj_list_append(dir, MP_OBJ_NEW_QSTR(i)); } } } diff --git a/tests/basics/builtin_dir.py b/tests/basics/builtin_dir.py index 16e7e669e4..c15a76f4c6 100644 --- a/tests/basics/builtin_dir.py +++ b/tests/basics/builtin_dir.py @@ -17,3 +17,22 @@ foo = Foo() print('__init__' in dir(foo)) print('x' in dir(foo)) +# dir of subclass +class A: + def a(): + pass +class B(A): + def b(): + pass +d = dir(B()) +print(d.count('a'), d.count('b')) + +# dir of class with multiple bases and a common parent +class C(A): + def c(): + pass +class D(B, C): + def d(): + pass +d = dir(D()) +print(d.count('a'), d.count('b'), d.count('c'), d.count('d'))