py: str.split: handle non-default separator.

This commit is contained in:
Damien George 2014-04-06 11:11:15 +01:00
parent 36dd19ae27
commit deed087e2c
2 changed files with 77 additions and 24 deletions

View File

@ -33,6 +33,7 @@ const mp_obj_t mp_const_empty_bytes;
STATIC mp_obj_t mp_obj_new_str_iterator(mp_obj_t str);
STATIC mp_obj_t mp_obj_new_bytes_iterator(mp_obj_t str);
STATIC mp_obj_t str_new(const mp_obj_type_t *type, const byte* data, uint len);
STATIC void bad_implicit_conversion(mp_obj_t self_in) __attribute__((noreturn));
/******************************************************************************/
/* str */
@ -367,38 +368,71 @@ bad_arg:
#define is_ws(c) ((c) == ' ' || (c) == '\t')
STATIC mp_obj_t str_split(uint n_args, const mp_obj_t *args) {
int splits = -1;
machine_int_t splits = -1;
mp_obj_t sep = mp_const_none;
if (n_args > 1) {
sep = args[1];
if (n_args > 2) {
splits = MP_OBJ_SMALL_INT_VALUE(args[2]);
splits = mp_obj_get_int(args[2]);
}
}
assert(sep == mp_const_none);
(void)sep; // unused; to hush compiler warning
mp_obj_t res = mp_obj_new_list(0, NULL);
GET_STR_DATA_LEN(args[0], s, len);
const byte *top = s + len;
const byte *start;
// Initial whitespace is not counted as split, so we pre-do it
while (s < top && is_ws(*s)) s++;
while (s < top && splits != 0) {
start = s;
while (s < top && !is_ws(*s)) s++;
mp_obj_list_append(res, mp_obj_new_str(start, s - start, false));
if (s >= top) {
break;
}
if (sep == mp_const_none) {
// sep not given, so separate on whitespace
// Initial whitespace is not counted as split, so we pre-do it
while (s < top && is_ws(*s)) s++;
if (splits > 0) {
splits--;
while (s < top && splits != 0) {
const byte *start = s;
while (s < top && !is_ws(*s)) s++;
mp_obj_list_append(res, mp_obj_new_str(start, s - start, false));
if (s >= top) {
break;
}
while (s < top && is_ws(*s)) s++;
if (splits > 0) {
splits--;
}
}
}
if (s < top) {
mp_obj_list_append(res, mp_obj_new_str(s, top - s, false));
if (s < top) {
mp_obj_list_append(res, mp_obj_new_str(s, top - s, false));
}
} else {
// sep given
uint sep_len;
const char *sep_str = mp_obj_str_get_data(sep, &sep_len);
if (sep_len == 0) {
nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "empty separator"));
}
for (;;) {
const byte *start = s;
for (;;) {
if (splits == 0 || s + sep_len > top) {
s = top;
break;
} else if (memcmp(s, sep_str, sep_len) == 0) {
break;
}
s++;
}
mp_obj_list_append(res, mp_obj_new_str(start, s - start, false));
if (s >= top) {
break;
}
s += sep_len;
if (splits > 0) {
splits--;
}
}
}
return res;
@ -1052,7 +1086,7 @@ STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, uint n_args, const mp_obj_t
}
pfenv_print_int(&pfenv_vstr, arg_as_int(arg), 1, 16, 'A', flags, fill, width);
break;
default:
nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError,
"unsupported format character '%c' (0x%x) at index %d",
@ -1191,8 +1225,7 @@ STATIC mp_obj_t str_count(uint n_args, const mp_obj_t *args) {
STATIC mp_obj_t str_partitioner(mp_obj_t self_in, mp_obj_t arg, machine_int_t direction) {
assert(MP_OBJ_IS_STR(self_in));
if (!MP_OBJ_IS_STR(arg)) {
nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError,
"Can't convert '%s' object to str implicitly", mp_obj_get_type_str(arg)));
bad_implicit_conversion(arg);
}
GET_STR_DATA_LEN(self_in, str, str_len);
@ -1365,8 +1398,7 @@ bool mp_obj_str_equal(mp_obj_t s1, mp_obj_t s2) {
}
}
void bad_implicit_conversion(mp_obj_t self_in) __attribute__((noreturn));
void bad_implicit_conversion(mp_obj_t self_in) {
STATIC void bad_implicit_conversion(mp_obj_t self_in) {
nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError, "Can't convert '%s' object to str implicitly", mp_obj_get_type_str(self_in)));
}

View File

@ -1,3 +1,4 @@
# default separator (whitespace)
print("a b".split())
print(" a b ".split(None))
print(" a b ".split(None, 1))
@ -5,3 +6,23 @@ print(" a b ".split(None, 2))
print(" a b c ".split(None, 1))
print(" a b c ".split(None, 0))
print(" a b c ".split(None, -1))
# empty separator should fail
try:
"abc".split('')
except ValueError:
print("ValueError")
# non-empty separator
print("abc".split("a"))
print("abc".split("b"))
print("abc".split("c"))
print("abc".split("z"))
print("abc".split("ab"))
print("abc".split("bc"))
print("abc".split("abc"))
print("abc".split("abcd"))
print("abcabc".split("bc"))
print("abcabc".split("bc", 0))
print("abcabc".split("bc", 1))
print("abcabc".split("bc", 2))