objstr: Make .join() support bytes.

This commit is contained in:
Paul Sokolovsky 2014-05-11 21:13:01 +03:00
parent 7e7940c39d
commit 5e5d69b35e
2 changed files with 18 additions and 4 deletions

View File

@ -357,7 +357,8 @@ STATIC mp_obj_t str_subscr(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) {
} }
STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) { STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) {
assert(MP_OBJ_IS_STR(self_in)); assert(is_str_or_bytes(self_in));
const mp_obj_type_t *self_type = mp_obj_get_type(self_in);
// get separation string // get separation string
GET_STR_DATA_LEN(self_in, sep_str, sep_len); GET_STR_DATA_LEN(self_in, sep_str, sep_len);
@ -379,8 +380,9 @@ STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) {
// count required length // count required length
int required_len = 0; int required_len = 0;
for (int i = 0; i < seq_len; i++) { for (int i = 0; i < seq_len; i++) {
if (!MP_OBJ_IS_STR(seq_items[i])) { if (mp_obj_get_type(seq_items[i]) != self_type) {
nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, "join expected a list of str's")); nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError,
"join expects a list of str/bytes objects consistent with self object"));
} }
if (i > 0) { if (i > 0) {
required_len += sep_len; required_len += sep_len;
@ -391,7 +393,7 @@ STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) {
// make joined string // make joined string
byte *data; byte *data;
mp_obj_t joined_str = mp_obj_str_builder_start(mp_obj_get_type(self_in), required_len, &data); mp_obj_t joined_str = mp_obj_str_builder_start(self_type, required_len, &data);
for (int i = 0; i < seq_len; i++) { for (int i = 0; i < seq_len; i++) {
if (i > 0) { if (i > 0) {
memcpy(data, sep_str, sep_len); memcpy(data, sep_str, sep_len);

View File

@ -10,3 +10,15 @@ print(''.join(''))
print(''.join('abc')) print(''.join('abc'))
print(','.join('abc')) print(','.join('abc'))
print(','.join('abc' for i in range(5))) print(','.join('abc' for i in range(5)))
print(b','.join([b'abc', b'123']))
try:
print(b','.join(['abc', b'123']))
except TypeError:
print("TypeError")
try:
print(','.join([b'abc', b'123']))
except TypeError:
print("TypeError")