Implement * and *= for array.array

This commit is contained in:
Jeff Epler 2018-03-22 21:52:25 -05:00
parent 9ab39eb2d2
commit cdb83b18ec
2 changed files with 61 additions and 0 deletions

View File

@ -241,6 +241,39 @@ STATIC mp_obj_t array_unary_op(mp_unary_op_t op, mp_obj_t o_in) {
STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) { STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
mp_obj_array_t *lhs = MP_OBJ_TO_PTR(lhs_in); mp_obj_array_t *lhs = MP_OBJ_TO_PTR(lhs_in);
switch (op) { switch (op) {
case MP_BINARY_OP_MULTIPLY:
case MP_BINARY_OP_INPLACE_MULTIPLY: {
if (!MP_OBJ_IS_INT(rhs_in)) {
return MP_OBJ_NULL; // op not supported
}
mp_uint_t repeat = mp_obj_get_int(rhs_in);
bool inplace = (op == MP_BINARY_OP_INPLACE_MULTIPLY);
mp_buffer_info_t lhs_bufinfo;
array_get_buffer(lhs_in, &lhs_bufinfo, MP_BUFFER_READ);
mp_obj_array_t *res;
byte *ptr;
size_t orig_lhs_bufinfo_len = lhs_bufinfo.len;
if(inplace) {
res = lhs;
size_t item_sz = mp_binary_get_size('@', lhs->typecode, NULL);
lhs->items = m_renew(byte, lhs->items, (lhs->len + lhs->free) * item_sz, lhs->len * repeat * item_sz);
lhs->len = lhs->len * repeat;
lhs->free = 0;
if (!repeat)
return MP_OBJ_FROM_PTR(res);
repeat--;
ptr = (byte*)res->items + orig_lhs_bufinfo_len;
} else {
res = array_new(lhs_bufinfo.typecode, lhs->len * repeat);
ptr = (byte*)res->items;
}
if(orig_lhs_bufinfo_len) {
for(;repeat--; ptr += orig_lhs_bufinfo_len) {
memcpy(ptr, lhs_bufinfo.buf, orig_lhs_bufinfo_len);
}
}
return MP_OBJ_FROM_PTR(res);
}
case MP_BINARY_OP_ADD: { case MP_BINARY_OP_ADD: {
// allow to add anything that has the buffer protocol (extension to CPython) // allow to add anything that has the buffer protocol (extension to CPython)
mp_buffer_info_t lhs_bufinfo; mp_buffer_info_t lhs_bufinfo;

28
tests/basics/array_mul.py Normal file
View File

@ -0,0 +1,28 @@
try:
import array
except ImportError:
print("SKIP")
raise SystemExit
a1 = array.array('I', [1])
a2 = array.array('I', [2]) * 2
a3 = (a1 + a2)
print(a3)
a3 *= 5
print(a3)
a3 *= 0
print(a3)
a4 = a2 * 0
print(a4)
a4 *= 0
print(a4)
a4 = a4 * 2
print(a4)
a4 *= 2
print(a4)