Merge pull request #693 from jepler/issue236
Implement * and *= for array.array
This commit is contained in:
commit
8d376a3efb
@ -241,6 +241,39 @@ STATIC mp_obj_t array_unary_op(mp_unary_op_t op, mp_obj_t o_in) {
|
||||
STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
|
||||
mp_obj_array_t *lhs = MP_OBJ_TO_PTR(lhs_in);
|
||||
switch (op) {
|
||||
case MP_BINARY_OP_MULTIPLY:
|
||||
case MP_BINARY_OP_INPLACE_MULTIPLY: {
|
||||
if (!MP_OBJ_IS_INT(rhs_in)) {
|
||||
return MP_OBJ_NULL; // op not supported
|
||||
}
|
||||
mp_uint_t repeat = mp_obj_get_int(rhs_in);
|
||||
bool inplace = (op == MP_BINARY_OP_INPLACE_MULTIPLY);
|
||||
mp_buffer_info_t lhs_bufinfo;
|
||||
array_get_buffer(lhs_in, &lhs_bufinfo, MP_BUFFER_READ);
|
||||
mp_obj_array_t *res;
|
||||
byte *ptr;
|
||||
size_t orig_lhs_bufinfo_len = lhs_bufinfo.len;
|
||||
if(inplace) {
|
||||
res = lhs;
|
||||
size_t item_sz = mp_binary_get_size('@', lhs->typecode, NULL);
|
||||
lhs->items = m_renew(byte, lhs->items, (lhs->len + lhs->free) * item_sz, lhs->len * repeat * item_sz);
|
||||
lhs->len = lhs->len * repeat;
|
||||
lhs->free = 0;
|
||||
if (!repeat)
|
||||
return MP_OBJ_FROM_PTR(res);
|
||||
repeat--;
|
||||
ptr = (byte*)res->items + orig_lhs_bufinfo_len;
|
||||
} else {
|
||||
res = array_new(lhs_bufinfo.typecode, lhs->len * repeat);
|
||||
ptr = (byte*)res->items;
|
||||
}
|
||||
if(orig_lhs_bufinfo_len) {
|
||||
for(;repeat--; ptr += orig_lhs_bufinfo_len) {
|
||||
memcpy(ptr, lhs_bufinfo.buf, orig_lhs_bufinfo_len);
|
||||
}
|
||||
}
|
||||
return MP_OBJ_FROM_PTR(res);
|
||||
}
|
||||
case MP_BINARY_OP_ADD: {
|
||||
// allow to add anything that has the buffer protocol (extension to CPython)
|
||||
mp_buffer_info_t lhs_bufinfo;
|
||||
|
28
tests/basics/array_mul.py
Normal file
28
tests/basics/array_mul.py
Normal file
@ -0,0 +1,28 @@
|
||||
try:
|
||||
import array
|
||||
except ImportError:
|
||||
print("SKIP")
|
||||
raise SystemExit
|
||||
|
||||
a1 = array.array('I', [1])
|
||||
a2 = array.array('I', [2]) * 2
|
||||
a3 = (a1 + a2)
|
||||
print(a3)
|
||||
|
||||
a3 *= 5
|
||||
print(a3)
|
||||
|
||||
a3 *= 0
|
||||
print(a3)
|
||||
|
||||
a4 = a2 * 0
|
||||
print(a4)
|
||||
|
||||
a4 *= 0
|
||||
print(a4)
|
||||
|
||||
a4 = a4 * 2
|
||||
print(a4)
|
||||
|
||||
a4 *= 2
|
||||
print(a4)
|
Loading…
x
Reference in New Issue
Block a user