py/objarray: Prohibit comparison of mismatching types.
Array equality is defined as each element being equal but to keep code size down MicroPython implements a binary comparison. This can only be used correctly for elements with the same binary layout though so turn it into an NotImplementedError when comparing types for which the binary comparison yielded incorrect results: types with different sizes, and floating point numbers because nan != nan.
This commit is contained in:
parent
6affcb0104
commit
57365d8557
@ -258,6 +258,16 @@ STATIC mp_obj_t array_unary_op(mp_unary_op_t op, mp_obj_t o_in) {
|
||||
}
|
||||
}
|
||||
|
||||
STATIC int typecode_for_comparison(int typecode) {
|
||||
if (typecode == BYTEARRAY_TYPECODE) {
|
||||
typecode = 'B';
|
||||
}
|
||||
if (typecode <= 'Z') {
|
||||
typecode += 32; // to lowercase
|
||||
}
|
||||
return typecode;
|
||||
}
|
||||
|
||||
STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
|
||||
mp_obj_array_t *lhs = MP_OBJ_TO_PTR(lhs_in);
|
||||
switch (op) {
|
||||
@ -319,7 +329,20 @@ STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs
|
||||
if (!mp_get_buffer(rhs_in, &rhs_bufinfo, MP_BUFFER_READ)) {
|
||||
return mp_const_false;
|
||||
}
|
||||
return mp_obj_new_bool(mp_seq_cmp_bytes(op, lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len));
|
||||
// mp_seq_cmp_bytes is used so only compatible representations can be correctly compared.
|
||||
// The type doesn't matter: array/bytearray/str/bytes all have the same buffer layout, so
|
||||
// just check if the typecodes are compatible; for testing equality the types should have the
|
||||
// same code except for signedness, and not be floating point because nan never equals nan.
|
||||
// Note that typecode_for_comparison always returns lowercase letters to save code size.
|
||||
// No need for (& TYPECODE_MASK) here: xxx_get_buffer already takes care of that.
|
||||
const int lhs_code = typecode_for_comparison(lhs_bufinfo.typecode);
|
||||
const int rhs_code = typecode_for_comparison(rhs_bufinfo.typecode);
|
||||
if (lhs_code == rhs_code && lhs_code != 'f' && lhs_code != 'd') {
|
||||
return mp_obj_new_bool(mp_seq_cmp_bytes(op, lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len));
|
||||
}
|
||||
// mp_obj_equal_not_equal treats returning MP_OBJ_NULL as 'fall back to pointer comparison'
|
||||
// for MP_BINARY_OP_EQUAL but that is incompatible with CPython.
|
||||
mp_raise_NotImplementedError(NULL);
|
||||
}
|
||||
|
||||
default:
|
||||
|
@ -41,14 +41,23 @@ except ValueError:
|
||||
# equality (CPython requires both sides are array)
|
||||
print(bytes(array.array('b', [0x61, 0x62, 0x63])) == b'abc')
|
||||
print(array.array('b', [0x61, 0x62, 0x63]) == b'abc')
|
||||
print(array.array('B', [0x61, 0x62, 0x63]) == b'abc')
|
||||
print(array.array('b', [0x61, 0x62, 0x63]) != b'abc')
|
||||
print(array.array('b', [0x61, 0x62, 0x63]) == b'xyz')
|
||||
print(array.array('b', [0x61, 0x62, 0x63]) != b'xyz')
|
||||
print(b'abc' == array.array('b', [0x61, 0x62, 0x63]))
|
||||
print(b'abc' == array.array('B', [0x61, 0x62, 0x63]))
|
||||
print(b'abc' != array.array('b', [0x61, 0x62, 0x63]))
|
||||
print(b'xyz' == array.array('b', [0x61, 0x62, 0x63]))
|
||||
print(b'xyz' != array.array('b', [0x61, 0x62, 0x63]))
|
||||
|
||||
compatible_typecodes = []
|
||||
for t in ["b", "h", "i", "l", "q"]:
|
||||
compatible_typecodes.append((t, t))
|
||||
compatible_typecodes.append((t, t.upper()))
|
||||
for a, b in compatible_typecodes:
|
||||
print(array.array(a, [1, 2]) == array.array(b, [1, 2]))
|
||||
|
||||
class X(array.array):
|
||||
pass
|
||||
|
||||
|
@ -17,3 +17,15 @@ print(a[0])
|
||||
a = array.array('P')
|
||||
a.append(1)
|
||||
print(a[0])
|
||||
|
||||
# comparison between mismatching binary layouts is not implemented
|
||||
typecodes = ["b", "h", "i", "l", "q", "P", "O", "S", "f", "d"]
|
||||
for a in typecodes:
|
||||
for b in typecodes:
|
||||
if a == b and a not in ["f", "d"]:
|
||||
continue
|
||||
try:
|
||||
array.array(a) == array.array(b)
|
||||
print('FAIL')
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
9
tests/cpydiff/module_array_comparison.py
Normal file
9
tests/cpydiff/module_array_comparison.py
Normal file
@ -0,0 +1,9 @@
|
||||
"""
|
||||
categories: Modules,array
|
||||
description: Comparison between different typecodes not supported
|
||||
cause: Code size
|
||||
workaround: Compare individual elements
|
||||
"""
|
||||
import array
|
||||
|
||||
array.array("b", [1, 2]) == array.array("i", [1, 2])
|
Loading…
Reference in New Issue
Block a user