diff --git a/py/repl.c b/py/repl.c index bca1be584e..4cafb88e2d 100644 --- a/py/repl.c +++ b/py/repl.c @@ -14,43 +14,66 @@ bool str_startswith_word(const char *str, const char *head) { return head[i] == '\0' && (str[i] == '\0' || !unichar_isalpha(str[i])); } -bool mp_repl_is_compound_stmt(const char *line) { - // compound if line starts with a certain keyword - if ( - str_startswith_word(line, "if") - || str_startswith_word(line, "while") - || str_startswith_word(line, "for") - || str_startswith_word(line, "try") - || str_startswith_word(line, "with") - || str_startswith_word(line, "def") - || str_startswith_word(line, "class") - || str_startswith_word(line, "@") - ) { - return true; +bool mp_repl_continue_with_input(const char *input) { + // check for blank input + if (input[0] == '\0') { + return false; } - // also "compound" if unmatched open bracket or triple quote + // check if input starts with a certain keyword + bool starts_with_compound_keyword = + input[0] == '@' + || str_startswith_word(input, "if") + || str_startswith_word(input, "while") + || str_startswith_word(input, "for") + || str_startswith_word(input, "try") + || str_startswith_word(input, "with") + || str_startswith_word(input, "def") + || str_startswith_word(input, "class") + ; + + // check for unmatched open bracket or triple quote + // TODO don't look at triple quotes inside single quotes int n_paren = 0; int n_brack = 0; int n_brace = 0; int in_triple_quote = 0; - for (const char *l = line; *l; l++) { - switch (*l) { + const char *i; + for (i = input; *i; i++) { + switch (*i) { case '(': n_paren += 1; break; case ')': n_paren -= 1; break; case '[': n_brack += 1; break; case ']': n_brack -= 1; break; case '{': n_brace += 1; break; case '}': n_brace -= 1; break; + case '\'': + if (in_triple_quote != '"' && i[1] == '\'' && i[2] == '\'') { + i += 2; + in_triple_quote = '\'' - in_triple_quote; + } + break; case '"': - if (l[1] == '"' && l[2] == '"') { - l += 2; - in_triple_quote = 1 - in_triple_quote; + if (in_triple_quote != '\'' && i[1] == '"' && i[2] == '"') { + i += 2; + in_triple_quote = '"' - in_triple_quote; } break; } } - return n_paren > 0 || n_brack > 0 || n_brace > 0 || in_triple_quote != 0; + + // continue if unmatched brackets or quotes + if (n_paren > 0 || n_brack > 0 || n_brace > 0 || in_triple_quote != 0) { + return true; + } + + // continue if compound keyword and last line was not empty + if (starts_with_compound_keyword && i[-1] != '\n') { + return true; + } + + // otherwise, don't continue + return false; } #endif // MICROPY_ENABLE_REPL_HELPERS diff --git a/py/repl.h b/py/repl.h index 23259fa90d..cba77aad0b 100644 --- a/py/repl.h +++ b/py/repl.h @@ -1,3 +1,3 @@ #if MICROPY_ENABLE_REPL_HELPERS -bool mp_repl_is_compound_stmt(const char *line); +bool mp_repl_continue_with_input(const char *input); #endif diff --git a/stm/pyexec.c b/stm/pyexec.c index f3dfd70aab..52a436218e 100644 --- a/stm/pyexec.c +++ b/stm/pyexec.c @@ -283,15 +283,12 @@ void pyexec_repl(void) { continue; } - if (mp_repl_is_compound_stmt(vstr_str(&line))) { - for (;;) { - vstr_add_char(&line, '\n'); - int len = vstr_len(&line); - int ret = readline(&line, "... "); - if (ret == VCP_CHAR_CTRL_D || vstr_len(&line) == len) { - // done entering compound statement - break; - } + while (mp_repl_continue_with_input(vstr_str(&line))) { + vstr_add_char(&line, '\n'); + int ret = readline(&line, "... "); + if (ret == VCP_CHAR_CTRL_D) { + // stop entering compound statement + break; } } diff --git a/stmhal/pyexec.c b/stmhal/pyexec.c index b960198ec7..298e58a5fd 100644 --- a/stmhal/pyexec.c +++ b/stmhal/pyexec.c @@ -204,15 +204,12 @@ friendly_repl_reset: continue; } - if (mp_repl_is_compound_stmt(vstr_str(&line))) { - for (;;) { - vstr_add_char(&line, '\n'); - int len = vstr_len(&line); - int ret = readline(&line, "... "); - if (ret == VCP_CHAR_CTRL_D || vstr_len(&line) == len) { - // done entering compound statement - break; - } + while (mp_repl_continue_with_input(vstr_str(&line))) { + vstr_add_char(&line, '\n'); + int ret = readline(&line, "... "); + if (ret == VCP_CHAR_CTRL_D) { + // stop entering compound statement + break; } } diff --git a/teensy/main.c b/teensy/main.c index bfb7413e7f..eb153c245d 100644 --- a/teensy/main.c +++ b/teensy/main.c @@ -399,15 +399,12 @@ void do_repl(void) { continue; } - if (mp_repl_is_compound_stmt(vstr_str(&line))) { - for (;;) { - vstr_add_char(&line, '\n'); - int len = vstr_len(&line); - int ret = readline(&line, "... "); - if (ret == 0 || vstr_len(&line) == len) { - // done entering compound statement - break; - } + while (mp_repl_continue_with_input(vstr_str(&line))) { + vstr_add_char(&line, '\n'); + int ret = readline(&line, "... "); + if (ret == 0) { + // stop entering compound statement + break; } } diff --git a/unix/main.c b/unix/main.c index 1549054f04..11df4cadf2 100644 --- a/unix/main.c +++ b/unix/main.c @@ -146,17 +146,15 @@ STATIC void do_repl(void) { // EOF return; } - if (mp_repl_is_compound_stmt(line)) { - for (;;) { - char *line2 = prompt("... "); - if (line2 == NULL || strlen(line2) == 0) { - break; - } - char *line3 = strjoin(line, '\n', line2); - free(line); - free(line2); - line = line3; + while (mp_repl_continue_with_input(line)) { + char *line2 = prompt("... "); + if (line2 == NULL) { + break; } + char *line3 = strjoin(line, '\n', line2); + free(line); + free(line2); + line = line3; } mp_lexer_t *lex = mp_lexer_new_from_str_len(MP_QSTR__lt_stdin_gt_, line, strlen(line), false); diff --git a/windows/main.c b/windows/main.c index 36d98f73d0..5ba21eef30 100644 --- a/windows/main.c +++ b/windows/main.c @@ -126,17 +126,15 @@ static void do_repl(void) { // EOF return; } - if (mp_repl_is_compound_stmt(line)) { - for (;;) { - char *line2 = prompt("... "); - if (line2 == NULL || strlen(line2) == 0) { - break; - } - char *line3 = str_join(line, '\n', line2); - free(line); - free(line2); - line = line3; + while (mp_repl_continue_with_input(line)) { + char *line2 = prompt("... "); + if (line2 == NULL) { + break; } + char *line3 = str_join(line, '\n', line2); + free(line); + free(line2); + line = line3; } mp_lexer_t *lex = mp_lexer_new_from_str_len(MP_QSTR__lt_stdin_gt_, line, strlen(line), false);