diff --git a/ports/esp32/modsocket.c b/ports/esp32/modsocket.c index 94cf3cc57c..0268baa18c 100644 --- a/ports/esp32/modsocket.c +++ b/ports/esp32/modsocket.c @@ -63,6 +63,7 @@ typedef struct _socket_obj_t { uint8_t domain; uint8_t type; uint8_t proto; + bool peer_closed; unsigned int retries; #if MICROPY_PY_USOCKET_EVENTS mp_obj_t events_callback; @@ -233,6 +234,7 @@ STATIC mp_obj_t socket_accept(const mp_obj_t arg0) { sock->domain = self->domain; sock->type = self->type; sock->proto = self->proto; + sock->peer_closed = false; _socket_settimeout(sock, UINT64_MAX); // make the return value @@ -354,23 +356,57 @@ STATIC mp_obj_t socket_setblocking(const mp_obj_t arg0, const mp_obj_t arg1) { } STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_setblocking_obj, socket_setblocking); +// XXX this can end up waiting a very long time if the content is dribbled in one character +// at a time, as the timeout resets each time a recvfrom succeeds ... this is probably not +// good behaviour. +STATIC mp_uint_t _socket_read_data(mp_obj_t self_in, void *buf, size_t size, + struct sockaddr *from, socklen_t *from_len, int *errcode) { + socket_obj_t *sock = MP_OBJ_TO_PTR(self_in); + + // If the peer closed the connection then the lwIP socket API will only return "0" once + // from lwip_recvfrom_r and then block on subsequent calls. To emulate POSIX behaviour, + // which continues to return "0" for each call on a closed socket, we set a flag when + // the peer closed the socket. + if (sock->peer_closed) { + return 0; + } + + // XXX Would be nicer to use RTC to handle timeouts + for (int i = 0; i <= sock->retries; ++i) { + MP_THREAD_GIL_EXIT(); + int r = lwip_recvfrom_r(sock->fd, buf, size, 0, from, from_len); + MP_THREAD_GIL_ENTER(); + if (r == 0) { + sock->peer_closed = true; + } + if (r >= 0) { + return r; + } + if (errno != EWOULDBLOCK) { + *errcode = errno; + return MP_STREAM_ERROR; + } + check_for_exceptions(); + } + + *errcode = sock->retries == 0 ? MP_EWOULDBLOCK : MP_ETIMEDOUT; + return MP_STREAM_ERROR; +} + mp_obj_t _socket_recvfrom(mp_obj_t self_in, mp_obj_t len_in, struct sockaddr *from, socklen_t *from_len) { - socket_obj_t *sock = MP_OBJ_TO_PTR(self_in); size_t len = mp_obj_get_int(len_in); vstr_t vstr; vstr_init_len(&vstr, len); - // XXX Would be nicer to use RTC to handle timeouts - for (int i=0; i<=sock->retries; i++) { - MP_THREAD_GIL_EXIT(); - int r = lwip_recvfrom_r(sock->fd, vstr.buf, len, 0, from, from_len); - MP_THREAD_GIL_ENTER(); - if (r >= 0) { vstr.len = r; return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr); } - if (errno != EWOULDBLOCK) exception_from_errno(errno); - check_for_exceptions(); + int errcode; + mp_uint_t ret = _socket_read_data(self_in, vstr.buf, len, from, from_len, &errcode); + if (ret == MP_STREAM_ERROR) { + exception_from_errno(errcode); } - mp_raise_OSError(MP_ETIMEDOUT); + + vstr.len = ret; + return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr); } STATIC mp_obj_t socket_recv(mp_obj_t self_in, mp_obj_t len_in) { @@ -468,25 +504,8 @@ STATIC mp_obj_t socket_makefile(size_t n_args, const mp_obj_t *args) { } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(socket_makefile_obj, 1, 3, socket_makefile); - -// XXX this can end up waiting a very long time if the content is dribbled in one character -// at a time, as the timeout resets each time a recvfrom succeeds ... this is probably not -// good behaviour. - STATIC mp_uint_t socket_stream_read(mp_obj_t self_in, void *buf, mp_uint_t size, int *errcode) { - socket_obj_t *sock = self_in; - - // XXX Would be nicer to use RTC to handle timeouts - for (int i=0; i<=sock->retries; i++) { - MP_THREAD_GIL_EXIT(); - int r = lwip_recvfrom_r(sock->fd, buf, size, 0, NULL, NULL); - MP_THREAD_GIL_ENTER(); - if (r >= 0) return r; - if (r < 0 && errno != EWOULDBLOCK) { *errcode = errno; return MP_STREAM_ERROR; } - check_for_exceptions(); - } - *errcode = sock->retries == 0 ? MP_EWOULDBLOCK : MP_ETIMEDOUT; - return MP_STREAM_ERROR; + return _socket_read_data(self_in, buf, size, NULL, NULL, errcode); } STATIC mp_uint_t socket_stream_write(mp_obj_t self_in, const void *buf, mp_uint_t size, int *errcode) { @@ -592,6 +611,7 @@ STATIC mp_obj_t get_socket(size_t n_args, const mp_obj_t *args) { sock->domain = AF_INET; sock->type = SOCK_STREAM; sock->proto = 0; + sock->peer_closed = false; if (n_args > 0) { sock->domain = mp_obj_get_int(args[0]); if (n_args > 1) {