py/objarray: Implement more/less comparisons for array.
This commit is contained in:
parent
57365d8557
commit
09be0c083c
|
@ -258,12 +258,13 @@ STATIC mp_obj_t array_unary_op(mp_unary_op_t op, mp_obj_t o_in) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
STATIC int typecode_for_comparison(int typecode) {
|
STATIC int typecode_for_comparison(int typecode, bool *is_unsigned) {
|
||||||
if (typecode == BYTEARRAY_TYPECODE) {
|
if (typecode == BYTEARRAY_TYPECODE) {
|
||||||
typecode = 'B';
|
typecode = 'B';
|
||||||
}
|
}
|
||||||
if (typecode <= 'Z') {
|
if (typecode <= 'Z') {
|
||||||
typecode += 32; // to lowercase
|
typecode += 32; // to lowercase
|
||||||
|
*is_unsigned = true;
|
||||||
}
|
}
|
||||||
return typecode;
|
return typecode;
|
||||||
}
|
}
|
||||||
|
@ -322,7 +323,11 @@ STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs
|
||||||
return mp_const_false;
|
return mp_const_false;
|
||||||
}
|
}
|
||||||
|
|
||||||
case MP_BINARY_OP_EQUAL: {
|
case MP_BINARY_OP_EQUAL:
|
||||||
|
case MP_BINARY_OP_LESS:
|
||||||
|
case MP_BINARY_OP_LESS_EQUAL:
|
||||||
|
case MP_BINARY_OP_MORE:
|
||||||
|
case MP_BINARY_OP_MORE_EQUAL: {
|
||||||
mp_buffer_info_t lhs_bufinfo;
|
mp_buffer_info_t lhs_bufinfo;
|
||||||
mp_buffer_info_t rhs_bufinfo;
|
mp_buffer_info_t rhs_bufinfo;
|
||||||
array_get_buffer(lhs_in, &lhs_bufinfo, MP_BUFFER_READ);
|
array_get_buffer(lhs_in, &lhs_bufinfo, MP_BUFFER_READ);
|
||||||
|
@ -333,11 +338,13 @@ STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs
|
||||||
// The type doesn't matter: array/bytearray/str/bytes all have the same buffer layout, so
|
// The type doesn't matter: array/bytearray/str/bytes all have the same buffer layout, so
|
||||||
// just check if the typecodes are compatible; for testing equality the types should have the
|
// just check if the typecodes are compatible; for testing equality the types should have the
|
||||||
// same code except for signedness, and not be floating point because nan never equals nan.
|
// same code except for signedness, and not be floating point because nan never equals nan.
|
||||||
|
// For > and < the types should be the same and unsigned.
|
||||||
// Note that typecode_for_comparison always returns lowercase letters to save code size.
|
// Note that typecode_for_comparison always returns lowercase letters to save code size.
|
||||||
// No need for (& TYPECODE_MASK) here: xxx_get_buffer already takes care of that.
|
// No need for (& TYPECODE_MASK) here: xxx_get_buffer already takes care of that.
|
||||||
const int lhs_code = typecode_for_comparison(lhs_bufinfo.typecode);
|
bool is_unsigned = false;
|
||||||
const int rhs_code = typecode_for_comparison(rhs_bufinfo.typecode);
|
const int lhs_code = typecode_for_comparison(lhs_bufinfo.typecode, &is_unsigned);
|
||||||
if (lhs_code == rhs_code && lhs_code != 'f' && lhs_code != 'd') {
|
const int rhs_code = typecode_for_comparison(rhs_bufinfo.typecode, &is_unsigned);
|
||||||
|
if (lhs_code == rhs_code && lhs_code != 'f' && lhs_code != 'd' && (op == MP_BINARY_OP_EQUAL || is_unsigned)) {
|
||||||
return mp_obj_new_bool(mp_seq_cmp_bytes(op, lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len));
|
return mp_obj_new_bool(mp_seq_cmp_bytes(op, lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len));
|
||||||
}
|
}
|
||||||
// mp_obj_equal_not_equal treats returning MP_OBJ_NULL as 'fall back to pointer comparison'
|
// mp_obj_equal_not_equal treats returning MP_OBJ_NULL as 'fall back to pointer comparison'
|
||||||
|
|
|
@ -66,3 +66,24 @@ print(X('b', [0x61, 0x62, 0x63]) == b'abc')
|
||||||
print(X('b', [0x61, 0x62, 0x63]) != b'abc')
|
print(X('b', [0x61, 0x62, 0x63]) != b'abc')
|
||||||
print(X('b', [0x61, 0x62, 0x63]) == array.array('b', [0x61, 0x62, 0x63]))
|
print(X('b', [0x61, 0x62, 0x63]) == array.array('b', [0x61, 0x62, 0x63]))
|
||||||
print(X('b', [0x61, 0x62, 0x63]) != array.array('b', [0x61, 0x62, 0x63]))
|
print(X('b', [0x61, 0x62, 0x63]) != array.array('b', [0x61, 0x62, 0x63]))
|
||||||
|
|
||||||
|
# other comparisons
|
||||||
|
for typecode in ["B", "H", "I", "L", "Q"]:
|
||||||
|
a = array.array(typecode, [1, 1])
|
||||||
|
print(a < a)
|
||||||
|
print(a <= a)
|
||||||
|
print(a > a)
|
||||||
|
print(a >= a)
|
||||||
|
|
||||||
|
al = array.array(typecode, [1, 0])
|
||||||
|
ab = array.array(typecode, [1, 2])
|
||||||
|
|
||||||
|
print(a < al)
|
||||||
|
print(a <= al)
|
||||||
|
print(a > al)
|
||||||
|
print(a >= al)
|
||||||
|
|
||||||
|
print(a < ab)
|
||||||
|
print(a <= ab)
|
||||||
|
print(a > ab)
|
||||||
|
print(a >= ab)
|
||||||
|
|
|
@ -27,6 +27,26 @@ print(bytearray([1]) == b"1")
|
||||||
print(b"1" == bytearray([1]))
|
print(b"1" == bytearray([1]))
|
||||||
print(bytearray() == bytearray())
|
print(bytearray() == bytearray())
|
||||||
|
|
||||||
|
b1 = bytearray([1, 2, 3])
|
||||||
|
b2 = bytearray([1, 2, 3])
|
||||||
|
b3 = bytearray([1, 3])
|
||||||
|
print(b1 == b2)
|
||||||
|
print(b2 != b3)
|
||||||
|
print(b1 <= b2)
|
||||||
|
print(b1 <= b3)
|
||||||
|
print(b1 < b3)
|
||||||
|
print(b1 >= b2)
|
||||||
|
print(b3 >= b2)
|
||||||
|
print(b3 > b2)
|
||||||
|
print(b1 != b2)
|
||||||
|
print(b2 == b3)
|
||||||
|
print(b1 > b2)
|
||||||
|
print(b1 > b3)
|
||||||
|
print(b1 >= b3)
|
||||||
|
print(b1 < b2)
|
||||||
|
print(b3 < b2)
|
||||||
|
print(b3 <= b2)
|
||||||
|
|
||||||
# comparison with other type should return False
|
# comparison with other type should return False
|
||||||
print(bytearray() == 1)
|
print(bytearray() == 1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue