diff --git a/py/formatfloat.c b/py/formatfloat.c index 21ed2d5508..58a423e38c 100644 --- a/py/formatfloat.c +++ b/py/formatfloat.c @@ -170,10 +170,20 @@ int mp_format_float(FPTYPE f, char *buf, size_t buf_size, char fmt, int prec, ch if (fp_iszero(f)) { e = 0; - if (fmt == 'e') { - e_sign = '+'; - } else if (fmt == 'f') { + if (fmt == 'f') { + // Truncate precision to prevent buffer overflow + if (prec + 2 > buf_remaining) { + prec = buf_remaining - 2; + } num_digits = prec + 1; + } else { + // Truncate precision to prevent buffer overflow + if (prec + 6 > buf_remaining) { + prec = buf_remaining - 6; + } + if (fmt == 'e') { + e_sign = '+'; + } } } else if (fp_isless1(f)) { // We need to figure out what an integer digit will be used @@ -275,6 +285,12 @@ int mp_format_float(FPTYPE f, char *buf, size_t buf_size, char fmt, int prec, ch if (fmt == 'e' && prec > (buf_remaining - FPMIN_BUF_SIZE)) { prec = buf_remaining - FPMIN_BUF_SIZE; } + if (fmt == 'g'){ + // Truncate precision to prevent buffer overflow + if (prec + (FPMIN_BUF_SIZE - 1) > buf_remaining) { + prec = buf_remaining - (FPMIN_BUF_SIZE - 1); + } + } // If the user specified 'g' format, and e is < prec, then we'll switch // to the fixed format. @@ -378,6 +394,9 @@ int mp_format_float(FPTYPE f, char *buf, size_t buf_size, char fmt, int prec, ch } } + // verify that we did not overrun the input buffer so far + assert((size_t)(s + 1 - buf) <= buf_size); + if (org_fmt == 'g' && prec > 0) { // Remove trailing zeros and a trailing decimal point while (s[-1] == '0') { diff --git a/tests/float/string_format_modulo2.py b/tests/float/string_format_modulo2.py new file mode 100644 index 0000000000..d35f2a2c02 --- /dev/null +++ b/tests/float/string_format_modulo2.py @@ -0,0 +1,24 @@ +# test formatting floats with large precision, that it doesn't overflow the buffer + +def test(num, num_str): + if num == float('inf') or num == 0.0 and num_str != '0.0': + # skip numbers that overflow or underflow the FP precision + return + for kind in ('e', 'f', 'g'): + # check precision either side of the size of the buffer (32 bytes) + for prec in range(23, 36, 2): + fmt = '%.' + '%d' % prec + kind + s = fmt % num + check = abs(float(s) - num) + if num > 1: + check /= num + if check > 1e-6: + print('FAIL', num_str, fmt, s, len(s), check) + +# check pure zero +test(0.0, '0.0') + +# check most powers of 10, making sure to include exponents with 3 digits +for e in range(-101, 102): + num = pow(10, e) + test(num, '1e%d' % e)