From f6532bb9e054fd20af5063b1e77ef98863f36c84 Mon Sep 17 00:00:00 2001 From: Damien George Date: Sun, 15 Feb 2015 01:10:13 +0000 Subject: [PATCH] py: Simplify and remove redundant code for __iter__ method lookup. --- py/obj.h | 2 +- py/objtype.c | 14 +++---------- py/runtime.c | 46 +++++++++++++++++++------------------------ tests/basics/iter0.py | 6 ++++++ tests/basics/iter1.py | 9 +++++++++ 5 files changed, 39 insertions(+), 38 deletions(-) create mode 100644 tests/basics/iter0.py diff --git a/py/obj.h b/py/obj.h index 0f6179630c..b6f1b02731 100644 --- a/py/obj.h +++ b/py/obj.h @@ -285,7 +285,7 @@ struct _mp_obj_type_t { // value=MP_OBJ_NULL means delete, value=MP_OBJ_SENTINEL means load, else store // can return MP_OBJ_NULL if op not supported - mp_fun_1_t getiter; + mp_fun_1_t getiter; // corresponds to __iter__ special method mp_fun_1_t iternext; // may return MP_OBJ_STOP_ITERATION as an optimisation instead of raising StopIteration() (with no args) mp_buffer_p_t buffer_p; diff --git a/py/objtype.c b/py/objtype.c index 95a7e6b5f3..6accaa74ed 100644 --- a/py/objtype.c +++ b/py/objtype.c @@ -644,21 +644,13 @@ STATIC mp_obj_t instance_getiter(mp_obj_t self_in) { }; mp_obj_class_lookup(&lookup, self->base.type); if (member[0] == MP_OBJ_NULL) { - // This kinda duplicates code in mp_getiter() - lookup.attr = MP_QSTR___getitem__; - lookup.meth_offset = 0; // TODO - mp_obj_class_lookup(&lookup, self->base.type); - if (member[0] != MP_OBJ_NULL) { - // __getitem__ exists, create an iterator - return mp_obj_new_getitem_iter(member); - } return MP_OBJ_NULL; - } - if (member[0] == MP_OBJ_SENTINEL) { + } else if (member[0] == MP_OBJ_SENTINEL) { mp_obj_type_t *type = mp_obj_get_type(self->subobj[0]); return type->getiter(self->subobj[0]); + } else { + return mp_call_method_n_kw(0, 0, member); } - return mp_call_method_n_kw(0, 0, member); } STATIC mp_int_t instance_get_buffer(mp_obj_t self_in, mp_buffer_info_t *bufinfo, mp_uint_t flags) { diff --git a/py/runtime.c b/py/runtime.c index 75dd467507..c1ce9fb88c 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -957,37 +957,31 @@ void mp_store_attr(mp_obj_t base, qstr attr, mp_obj_t value) { mp_obj_t mp_getiter(mp_obj_t o_in) { assert(o_in); + + // check for native getiter (corresponds to __iter__) mp_obj_type_t *type = mp_obj_get_type(o_in); if (type->getiter != NULL) { mp_obj_t iter = type->getiter(o_in); - if (iter == MP_OBJ_NULL) { - goto not_iterable; + if (iter != MP_OBJ_NULL) { + return iter; } - return iter; + } + + // check for __getitem__ + mp_obj_t dest[2]; + mp_load_method_maybe(o_in, MP_QSTR___getitem__, dest); + if (dest[0] != MP_OBJ_NULL) { + // __getitem__ exists, create and return an iterator + return mp_obj_new_getitem_iter(dest); + } + + // object not iterable + if (MICROPY_ERROR_REPORTING == MICROPY_ERROR_REPORTING_TERSE) { + nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, + "object not iterable")); } else { - // check for __iter__ method - mp_obj_t dest[2]; - mp_load_method_maybe(o_in, MP_QSTR___iter__, dest); - if (dest[0] != MP_OBJ_NULL) { - // __iter__ exists, call it and return its result - return mp_call_method_n_kw(0, 0, dest); - } else { - mp_load_method_maybe(o_in, MP_QSTR___getitem__, dest); - if (dest[0] != MP_OBJ_NULL) { - // __getitem__ exists, create an iterator - return mp_obj_new_getitem_iter(dest); - } else { - // object not iterable -not_iterable: - if (MICROPY_ERROR_REPORTING == MICROPY_ERROR_REPORTING_TERSE) { - nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, - "object not iterable")); - } else { - nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError, - "'%s' object is not iterable", mp_obj_get_type_str(o_in))); - } - } - } + nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError, + "'%s' object is not iterable", mp_obj_get_type_str(o_in))); } } diff --git a/tests/basics/iter0.py b/tests/basics/iter0.py new file mode 100644 index 0000000000..6110e8fa58 --- /dev/null +++ b/tests/basics/iter0.py @@ -0,0 +1,6 @@ +# builtin type that is not iterable +try: + for i in 1: + pass +except TypeError: + print('TypeError') diff --git a/tests/basics/iter1.py b/tests/basics/iter1.py index c2ef86a635..5bd7f5090b 100644 --- a/tests/basics/iter1.py +++ b/tests/basics/iter1.py @@ -1,5 +1,14 @@ # test user defined iterators +# this class is not iterable +class NotIterable: + pass +try: + for i in NotIterable(): + pass +except TypeError: + print('TypeError') + class MyStopIteration(StopIteration): pass