diff --git a/py/builtinimport.c b/py/builtinimport.c index 50780c01d6..c6910138f0 100644 --- a/py/builtinimport.c +++ b/py/builtinimport.c @@ -163,16 +163,6 @@ void do_load(mp_obj_t module_obj, vstr_t *file) { mp_globals_set(old_globals); } -// TODO: Move to objdict? -STATIC inline mp_obj_t mp_obj_dict_get(mp_obj_t dict_in, mp_obj_t key) { - mp_obj_dict_t *dict = dict_in; - mp_map_elem_t *elem = mp_map_lookup(&dict->map, key, MP_MAP_LOOKUP); - if (elem == NULL) { - return elem; - } - return elem->value; -} - mp_obj_t mp_builtin___import__(uint n_args, mp_obj_t *args) { #if DEBUG_PRINT printf("__import__:\n"); diff --git a/py/obj.h b/py/obj.h index 7c83715111..aa78b2a22d 100644 --- a/py/obj.h +++ b/py/obj.h @@ -505,6 +505,7 @@ typedef struct _mp_obj_dict_t { } mp_obj_dict_t; void mp_obj_dict_init(mp_obj_dict_t *dict, int n_args); uint mp_obj_dict_len(mp_obj_t self_in); +mp_obj_t mp_obj_dict_get(mp_obj_t self_in, mp_obj_t index); mp_obj_t mp_obj_dict_store(mp_obj_t self_in, mp_obj_t key, mp_obj_t value); mp_obj_t mp_obj_dict_delete(mp_obj_t self_in, mp_obj_t key); mp_map_t *mp_obj_dict_get_map(mp_obj_t self_in); diff --git a/py/objdict.c b/py/objdict.c index 8a0a08772a..696aad80f5 100644 --- a/py/objdict.c +++ b/py/objdict.c @@ -117,6 +117,17 @@ STATIC mp_obj_t dict_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { } } +// TODO: Make sure this is inlined in dict_subscr() below. +mp_obj_t mp_obj_dict_get(mp_obj_t self_in, mp_obj_t index) { + mp_obj_dict_t *self = self_in; + mp_map_elem_t *elem = mp_map_lookup(&self->map, index, MP_MAP_LOOKUP); + if (elem == NULL) { + nlr_raise(mp_obj_new_exception_msg(&mp_type_KeyError, "")); + } else { + return elem->value; + } +} + STATIC mp_obj_t dict_subscr(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) { if (value == MP_OBJ_NULL) { // delete diff --git a/py/objstr.c b/py/objstr.c index 012d6404f7..d932824456 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -40,7 +40,7 @@ #include "objstr.h" #include "objlist.h" -STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, uint n_args, const mp_obj_t *args); +STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, uint n_args, const mp_obj_t *args, mp_obj_t dict); const mp_obj_t mp_const_empty_bytes; // use this macro to extract the string hash @@ -307,14 +307,19 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { case MP_BINARY_OP_MODULO: { mp_obj_t *args; uint n_args; + mp_obj_t dict = MP_OBJ_NULL; if (MP_OBJ_IS_TYPE(rhs_in, &mp_type_tuple)) { // TODO: Support tuple subclasses? mp_obj_tuple_get(rhs_in, &n_args, &args); + } else if (MP_OBJ_IS_TYPE(rhs_in, &mp_type_dict)) { + args = NULL; + n_args = 0; + dict = rhs_in; } else { args = &rhs_in; n_args = 1; } - return str_modulo_format(lhs_in, n_args, args); + return str_modulo_format(lhs_in, n_args, args, dict); } //case MP_BINARY_OP_NOT_EQUAL: // This is never passed here @@ -1125,7 +1130,7 @@ mp_obj_t mp_obj_str_format(uint n_args, const mp_obj_t *args) { return s; } -STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, uint n_args, const mp_obj_t *args) { +STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, uint n_args, const mp_obj_t *args, mp_obj_t dict) { assert(MP_OBJ_IS_STR(pattern)); GET_STR_DATA_LEN(pattern, str, len); @@ -1137,6 +1142,7 @@ STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, uint n_args, const mp_obj_t pfenv_vstr.print_strn = pfenv_vstr_add_strn; for (const byte *top = str + len; str < top; str++) { + mp_obj_t arg = MP_OBJ_NULL; if (*str != '%') { vstr_add_char(vstr, *str); continue; @@ -1148,9 +1154,21 @@ STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, uint n_args, const mp_obj_t vstr_add_char(vstr, '%'); continue; } - if (arg_i >= n_args) { - nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, "not enough arguments for format string")); + + // Dictionary value lookup + if (*str == '(') { + const byte *key = ++str; + while (*str != ')') { + if (str >= top) { + nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "incomplete format key")); + } + ++str; + } + mp_obj_t k_obj = mp_obj_new_str((const char*)key, str - key, true); + arg = mp_obj_dict_get(dict, k_obj); + str++; } + int flags = 0; char fill = ' '; bool alt = false; @@ -1169,6 +1187,9 @@ STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, uint n_args, const mp_obj_t int width = 0; if (str < top) { if (*str == '*') { + if (arg_i >= n_args) { + goto not_enough_args; + } width = mp_obj_get_int(args[arg_i++]); str++; } else { @@ -1181,6 +1202,9 @@ STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, uint n_args, const mp_obj_t if (str < top && *str == '.') { if (++str < top) { if (*str == '*') { + if (arg_i >= n_args) { + goto not_enough_args; + } prec = mp_obj_get_int(args[arg_i++]); str++; } else { @@ -1195,7 +1219,15 @@ STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, uint n_args, const mp_obj_t if (str >= top) { nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "incomplete format")); } - mp_obj_t arg = args[arg_i]; + + // Tuple value lookup + if (arg == MP_OBJ_NULL) { + if (arg_i >= n_args) { +not_enough_args: + nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, "not enough arguments for format string")); + } + arg = args[arg_i++]; + } switch (*str) { case 'c': if (MP_OBJ_IS_STR(arg)) { @@ -1284,7 +1316,6 @@ STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, uint n_args, const mp_obj_t "unsupported format character '%c' (0x%x) at index %d", *str, *str, str - start_str)); } - arg_i++; } if (arg_i != n_args) { diff --git a/tests/basics/string-format-modulo.py b/tests/basics/string-format-modulo.py index 0e2c1d1096..c8fdc06f68 100644 --- a/tests/basics/string-format-modulo.py +++ b/tests/basics/string-format-modulo.py @@ -48,3 +48,29 @@ print("%#X" % 18) print("%#6o" % 18) print("%#6x" % 18) print("%#06x" % 18) + +print("%*d" % (5, 10)) +print("%*.*d" % (2, 2, 20)) +# TODO: Formatted incorrectly +#print("%*.*d" % (5, 8, 20)) + +# Cases when "*" used and there's not enough values total +try: + print("%*s" % 5) +except TypeError: + print("TypeError") +try: + print("%*.*s" % (1, 15)) +except TypeError: + print("TypeError") + +print("%(foo)s" % {"foo": "bar", "baz": False}) +try: + print("%(foo)s" % {}) +except KeyError: + print("KeyError") +# Using in "*" with dict got to fail +try: + print("%(foo)*s" % {"foo": "bar"}) +except TypeError: + print("TypeError")