bytes: Implement comparison and other binary operations.

Should support everything supported by strings.
This commit is contained in:
Paul Sokolovsky 2014-05-10 04:26:10 +03:00
parent 070c78af5d
commit 7b0f9a7d9b
3 changed files with 65 additions and 9 deletions

View File

@ -251,11 +251,13 @@ STATIC const byte *find_subbytes(const byte *haystack, machine_uint_t hlen, cons
STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
GET_STR_DATA_LEN(lhs_in, lhs_data, lhs_len); GET_STR_DATA_LEN(lhs_in, lhs_data, lhs_len);
mp_obj_type_t *lhs_type = mp_obj_get_type(lhs_in);
mp_obj_type_t *rhs_type = mp_obj_get_type(rhs_in);
switch (op) { switch (op) {
case MP_BINARY_OP_ADD: case MP_BINARY_OP_ADD:
case MP_BINARY_OP_INPLACE_ADD: case MP_BINARY_OP_INPLACE_ADD:
if (MP_OBJ_IS_STR(rhs_in)) { if (lhs_type == rhs_type) {
// add 2 strings // add 2 strings or bytes
GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len); GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
int alloc_len = lhs_len + rhs_len; int alloc_len = lhs_len + rhs_len;
@ -270,7 +272,7 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
// code for non-qstr // code for non-qstr
byte *data; byte *data;
mp_obj_t s = mp_obj_str_builder_start(mp_obj_get_type(lhs_in), alloc_len, &data); mp_obj_t s = mp_obj_str_builder_start(lhs_type, alloc_len, &data);
memcpy(data, lhs_data, lhs_len); memcpy(data, lhs_data, lhs_len);
memcpy(data + lhs_len, rhs_data, rhs_len); memcpy(data + lhs_len, rhs_data, rhs_len);
return mp_obj_str_builder_end(s); return mp_obj_str_builder_end(s);
@ -279,7 +281,7 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
case MP_BINARY_OP_IN: case MP_BINARY_OP_IN:
/* NOTE `a in b` is `b.__contains__(a)` */ /* NOTE `a in b` is `b.__contains__(a)` */
if (MP_OBJ_IS_STR(rhs_in)) { if (lhs_type == rhs_type) {
GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len); GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len, 1) != NULL); return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len, 1) != NULL);
} }
@ -292,7 +294,7 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
} }
int n = MP_OBJ_SMALL_INT_VALUE(rhs_in); int n = MP_OBJ_SMALL_INT_VALUE(rhs_in);
byte *data; byte *data;
mp_obj_t s = mp_obj_str_builder_start(mp_obj_get_type(lhs_in), lhs_len * n, &data); mp_obj_t s = mp_obj_str_builder_start(lhs_type, lhs_len * n, &data);
mp_seq_multiply(lhs_data, sizeof(*lhs_data), lhs_len, n, data); mp_seq_multiply(lhs_data, sizeof(*lhs_data), lhs_len, n, data);
return mp_obj_str_builder_end(s); return mp_obj_str_builder_end(s);
} }
@ -310,14 +312,13 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
return str_modulo_format(lhs_in, n_args, args); return str_modulo_format(lhs_in, n_args, args);
} }
// These 2 are never passed here, dealt with as a special case in mp_binary_op(). //case MP_BINARY_OP_NOT_EQUAL: // This is never passed here
//case MP_BINARY_OP_EQUAL: case MP_BINARY_OP_EQUAL: // This will be passed only for bytes, str is dealt with in mp_obj_equal()
//case MP_BINARY_OP_NOT_EQUAL:
case MP_BINARY_OP_LESS: case MP_BINARY_OP_LESS:
case MP_BINARY_OP_LESS_EQUAL: case MP_BINARY_OP_LESS_EQUAL:
case MP_BINARY_OP_MORE: case MP_BINARY_OP_MORE:
case MP_BINARY_OP_MORE_EQUAL: case MP_BINARY_OP_MORE_EQUAL:
if (MP_OBJ_IS_STR(rhs_in)) { if (lhs_type == rhs_type) {
GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len); GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, rhs_data, rhs_len)); return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, rhs_data, rhs_len));
} }

View File

@ -83,6 +83,10 @@ bool m_seq_get_fast_slice_indexes(machine_uint_t len, mp_obj_t slice, machine_ui
// Special-case comparison function for sequences of bytes // Special-case comparison function for sequences of bytes
// Don't pass MP_BINARY_OP_NOT_EQUAL here // Don't pass MP_BINARY_OP_NOT_EQUAL here
bool mp_seq_cmp_bytes(int op, const byte *data1, uint len1, const byte *data2, uint len2) { bool mp_seq_cmp_bytes(int op, const byte *data1, uint len1, const byte *data2, uint len2) {
if (op == MP_BINARY_OP_EQUAL && len1 != len2) {
return false;
}
// Let's deal only with > & >= // Let's deal only with > & >=
if (op == MP_BINARY_OP_LESS || op == MP_BINARY_OP_LESS_EQUAL) { if (op == MP_BINARY_OP_LESS || op == MP_BINARY_OP_LESS_EQUAL) {
SWAP(const byte*, data1, data2); SWAP(const byte*, data1, data2);

View File

@ -0,0 +1,51 @@
print(b"" == b"")
print(b"" > b"")
print(b"" < b"")
print(b"" == b"1")
print(b"1" == b"")
print("==")
print(b"" > b"1")
print(b"1" > b"")
print(b"" < b"1")
print(b"1" < b"")
print(b"" >= b"1")
print(b"1" >= b"")
print(b"" <= b"1")
print(b"1" <= b"")
print(b"1" == b"1")
print(b"1" != b"1")
print(b"1" == b"2")
print(b"1" == b"10")
print(b"1" > b"1")
print(b"1" > b"2")
print(b"2" > b"1")
print(b"10" > b"1")
print(b"1/" > b"1")
print(b"1" > b"10")
print(b"1" > b"1/")
print(b"1" < b"1")
print(b"2" < b"1")
print(b"1" < b"2")
print(b"1" < b"10")
print(b"1" < b"1/")
print(b"10" < b"1")
print(b"1/" < b"1")
print(b"1" >= b"1")
print(b"1" >= b"2")
print(b"2" >= b"1")
print(b"10" >= b"1")
print(b"1/" >= b"1")
print(b"1" >= b"10")
print(b"1" >= b"1/")
print(b"1" <= b"1")
print(b"2" <= b"1")
print(b"1" <= b"2")
print(b"1" <= b"10")
print(b"1" <= b"1/")
print(b"10" <= b"1")
print(b"1/" <= b"1")