From cdb83b18ece63b13d6796e315f0737f833c1ee16 Mon Sep 17 00:00:00 2001 From: Jeff Epler Date: Thu, 22 Mar 2018 21:52:25 -0500 Subject: [PATCH] Implement * and *= for array.array --- py/objarray.c | 33 +++++++++++++++++++++++++++++++++ tests/basics/array_mul.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 tests/basics/array_mul.py diff --git a/py/objarray.c b/py/objarray.c index a1a979b56f..eb053bd8a9 100644 --- a/py/objarray.c +++ b/py/objarray.c @@ -241,6 +241,39 @@ STATIC mp_obj_t array_unary_op(mp_unary_op_t op, mp_obj_t o_in) { STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) { mp_obj_array_t *lhs = MP_OBJ_TO_PTR(lhs_in); switch (op) { + case MP_BINARY_OP_MULTIPLY: + case MP_BINARY_OP_INPLACE_MULTIPLY: { + if (!MP_OBJ_IS_INT(rhs_in)) { + return MP_OBJ_NULL; // op not supported + } + mp_uint_t repeat = mp_obj_get_int(rhs_in); + bool inplace = (op == MP_BINARY_OP_INPLACE_MULTIPLY); + mp_buffer_info_t lhs_bufinfo; + array_get_buffer(lhs_in, &lhs_bufinfo, MP_BUFFER_READ); + mp_obj_array_t *res; + byte *ptr; + size_t orig_lhs_bufinfo_len = lhs_bufinfo.len; + if(inplace) { + res = lhs; + size_t item_sz = mp_binary_get_size('@', lhs->typecode, NULL); + lhs->items = m_renew(byte, lhs->items, (lhs->len + lhs->free) * item_sz, lhs->len * repeat * item_sz); + lhs->len = lhs->len * repeat; + lhs->free = 0; + if (!repeat) + return MP_OBJ_FROM_PTR(res); + repeat--; + ptr = (byte*)res->items + orig_lhs_bufinfo_len; + } else { + res = array_new(lhs_bufinfo.typecode, lhs->len * repeat); + ptr = (byte*)res->items; + } + if(orig_lhs_bufinfo_len) { + for(;repeat--; ptr += orig_lhs_bufinfo_len) { + memcpy(ptr, lhs_bufinfo.buf, orig_lhs_bufinfo_len); + } + } + return MP_OBJ_FROM_PTR(res); + } case MP_BINARY_OP_ADD: { // allow to add anything that has the buffer protocol (extension to CPython) mp_buffer_info_t lhs_bufinfo; diff --git a/tests/basics/array_mul.py b/tests/basics/array_mul.py new file mode 100644 index 0000000000..bb5f3aa6b1 --- /dev/null +++ b/tests/basics/array_mul.py @@ -0,0 +1,28 @@ +try: + import array +except ImportError: + print("SKIP") + raise SystemExit + +a1 = array.array('I', [1]) +a2 = array.array('I', [2]) * 2 +a3 = (a1 + a2) +print(a3) + +a3 *= 5 +print(a3) + +a3 *= 0 +print(a3) + +a4 = a2 * 0 +print(a4) + +a4 *= 0 +print(a4) + +a4 = a4 * 2 +print(a4) + +a4 *= 2 +print(a4)