From 37a8c1c57518ac35724ee026cdf6f2f2871aaea2 Mon Sep 17 00:00:00 2001 From: Lucian Copeland Date: Wed, 13 Jan 2021 19:05:07 -0500 Subject: [PATCH] Complete non-blocking implementations, add socket close checking --- ports/esp32s2/common-hal/socketpool/Socket.c | 162 +++++++++++++++---- ports/esp32s2/common-hal/socketpool/Socket.h | 3 + ports/esp32s2/supervisor/port.c | 5 + shared-bindings/socketpool/Socket.c | 9 +- shared-bindings/socketpool/Socket.h | 2 +- 5 files changed, 143 insertions(+), 38 deletions(-) diff --git a/ports/esp32s2/common-hal/socketpool/Socket.c b/ports/esp32s2/common-hal/socketpool/Socket.c index 757156e08d..743414ae72 100644 --- a/ports/esp32s2/common-hal/socketpool/Socket.c +++ b/ports/esp32s2/common-hal/socketpool/Socket.c @@ -38,6 +38,32 @@ #include "components/lwip/lwip/src/include/lwip/sys.h" #include "components/lwip/lwip/src/include/lwip/netdb.h" +STATIC socketpool_socket_obj_t * open_socket_handles[CONFIG_LWIP_MAX_SOCKETS]; // 4 on the wrover/wroom + +void socket_reset(void) { + for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_handles); i++) { + if (open_socket_handles[i]) { + if (open_socket_handles[i]->num > 0) { + common_hal_socketpool_socket_close(open_socket_handles[i]); + open_socket_handles[i] = NULL; + } else { + // accidentally got a TCP socket in here, or something. + open_socket_handles[i] = NULL; + } + } + } +} + +bool register_open_socket(socketpool_socket_obj_t* self) { + for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_handles); i++) { + if (open_socket_handles[i] == NULL) { + open_socket_handles[i] = self; + return true; + } + } + return false; +} + STATIC void _lazy_init_LWIP(socketpool_socket_obj_t* self) { if (self->num != -1) { return; //safe to call on existing socket @@ -47,7 +73,7 @@ STATIC void _lazy_init_LWIP(socketpool_socket_obj_t* self) { } int socknum = -1; socknum = lwip_socket(self->family, self->type, self->ipproto); - if (socknum < 0) { + if (socknum < 0 || !register_open_socket(self)) { mp_raise_RuntimeError(translate("Out of sockets")); } self->num = socknum; @@ -78,34 +104,74 @@ bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t* self, bind_addr.sin_family = AF_INET; bind_addr.sin_port = htons(port); - return lwip_bind(self->num, (struct sockaddr *)&bind_addr, sizeof(bind_addr)) == 0; + int opt = 1; + int err = lwip_setsockopt(self->num, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + if (err != 0) { + mp_raise_RuntimeError(translate("Issue setting SO_REUSEADDR")); + } + int result = lwip_bind(self->num, (struct sockaddr *)&bind_addr, sizeof(bind_addr)) == 0; + return result; } bool common_hal_socketpool_socket_listen(socketpool_socket_obj_t* self, int backlog) { return lwip_listen(self->num, backlog) == 0; } -int common_hal_socketpool_socket_accept(socketpool_socket_obj_t* self, +socketpool_socket_obj_t* common_hal_socketpool_socket_accept(socketpool_socket_obj_t* self, uint8_t* ip, uint *port) { struct sockaddr_in accept_addr; socklen_t socklen = sizeof(accept_addr); - int newsoc = -1; - //(self->timeout_ms == 0 || supervisor_ticks_ms64() - start_ticks <= self->timeout_ms) - while ((newsoc == -1) && !mp_hal_is_interrupted() ) { + bool timed_out = false; + uint64_t start_ticks = supervisor_ticks_ms64(); + + if (self->timeout_ms != (uint)-1) { + mp_printf(&mp_plat_print, "will timeout"); + } else { + mp_printf(&mp_plat_print, "won't timeout"); + } + + // Allow timeouts and interrupts + while (newsoc == -1 && + !timed_out && + !mp_hal_is_interrupted()) { + if (self->timeout_ms != (uint)-1) { + timed_out = supervisor_ticks_ms64() - start_ticks >= self->timeout_ms; + } RUN_BACKGROUND_TASKS; newsoc = lwip_accept(self->num, (struct sockaddr *)&accept_addr, &socklen); + // In non-blocking mode, fail instead of looping + if (newsoc == -1 && self->timeout_ms == 0) { + mp_raise_OSError(MP_EAGAIN); + } } - mp_printf(&mp_plat_print, "oldsoc:%d newsoc:%d\n",self->num, newsoc); - memcpy((void *)ip, (void*)&accept_addr.sin_addr.s_addr, sizeof(accept_addr.sin_addr.s_addr)); - *port = accept_addr.sin_port; + if (!timed_out) { + // harmless on failure but avoiding memcpy is faster + memcpy((void *)ip, (void*)&accept_addr.sin_addr.s_addr, sizeof(accept_addr.sin_addr.s_addr)); + *port = accept_addr.sin_port; + } else { + mp_raise_OSError(ETIMEDOUT); + } if (newsoc > 0) { + // Create the socket + socketpool_socket_obj_t *sock = m_new_obj_with_finaliser(socketpool_socket_obj_t); + sock->base.type = &socketpool_socket_type; + sock->num = newsoc; + sock->tls = NULL; + sock->ssl_context = NULL; + sock->pool = self->pool; + + if (!register_open_socket(sock)) { + mp_raise_OSError(MP_EBADF); + } + lwip_fcntl(newsoc, F_SETFL, O_NONBLOCK); - return newsoc; + return sock; } else { - return 0; + mp_raise_OSError(MP_EBADF); + return NULL; } } @@ -158,9 +224,10 @@ bool common_hal_socketpool_socket_get_connected(socketpool_socket_obj_t* self) { } mp_uint_t common_hal_socketpool_socket_send(socketpool_socket_obj_t* self, const uint8_t* buf, mp_uint_t len) { - size_t sent = -1; + int sent = -1; if (self->num != -1) { // LWIP Socket + // TODO: deal with potential failure/add timeout? sent = lwip_send(self->num, buf, len, 0); } else if (self->tls != NULL) { // TLS Socket @@ -174,15 +241,27 @@ mp_uint_t common_hal_socketpool_socket_send(socketpool_socket_obj_t* self, const } mp_uint_t common_hal_socketpool_socket_recv_into(socketpool_socket_obj_t* self, const uint8_t* buf, mp_uint_t len) { - size_t received = 0; + int received = 0; + bool timed_out = false; if (self->num != -1) { // LWIP Socket - mp_printf(&mp_plat_print, "lwip_recv:\n"); - - received = lwip_recv(self->num, (void*) buf, len - 1, 0); - mp_printf(&mp_plat_print, "received:%d\n",received); + uint64_t start_ticks = supervisor_ticks_ms64(); + received = -1; + while (received == -1 && + !timed_out && + !mp_hal_is_interrupted()) { + if (self->timeout_ms != (uint)-1) { + timed_out = supervisor_ticks_ms64() - start_ticks >= self->timeout_ms; + } + RUN_BACKGROUND_TASKS; + received = lwip_recv(self->num, (void*) buf, len - 1, 0); + // In non-blocking mode, fail instead of looping + if (received == -1 && self->timeout_ms == 0) { + mp_raise_OSError(MP_EAGAIN); + } + } } else if (self->tls != NULL) { // TLS Socket int status = 0; @@ -194,8 +273,11 @@ mp_uint_t common_hal_socketpool_socket_recv_into(socketpool_socket_obj_t* self, } while (received == 0 && status >= 0 && - (self->timeout_ms == 0 || supervisor_ticks_ms64() - start_ticks <= self->timeout_ms) && + !timed_out && !mp_hal_is_interrupted()) { + if (self->timeout_ms != (uint)-1) { + timed_out = self->timeout_ms == 0 || supervisor_ticks_ms64() - start_ticks >= self->timeout_ms; + } RUN_BACKGROUND_TASKS; size_t available = esp_tls_get_bytes_avail(self->tls); if (available == 0) { @@ -219,11 +301,13 @@ mp_uint_t common_hal_socketpool_socket_recv_into(socketpool_socket_obj_t* self, } } } + } else { + // Socket does not have a valid descriptor of either type + mp_raise_OSError(MP_EBADF); } - if (received == 0) { - // socket closed - mp_raise_OSError(0); + if (timed_out) { + mp_raise_OSError(ETIMEDOUT); } return received; } @@ -270,19 +354,39 @@ mp_uint_t common_hal_socketpool_socket_recvfrom_into(socketpool_socket_obj_t* se struct sockaddr_in source_addr; socklen_t socklen = sizeof(source_addr); - mp_printf(&mp_plat_print, "recvfrom_into\n"); - int bytes_received = lwip_recvfrom(self->num, buf, len - 1, 0, (struct sockaddr *)&source_addr, &socklen); - mp_printf(&mp_plat_print, "received:%d\n",bytes_received); - memcpy((void *)ip, (void*)&source_addr.sin_addr.s_addr, sizeof(source_addr.sin_addr.s_addr)); - *port = source_addr.sin_port; + // LWIP Socket + uint64_t start_ticks = supervisor_ticks_ms64(); + int received = -1; + bool timed_out = false; + while (received == -1 && + !timed_out && + !mp_hal_is_interrupted()) { + if (self->timeout_ms != (uint)-1) { + timed_out = supervisor_ticks_ms64() - start_ticks >= self->timeout_ms; + } + RUN_BACKGROUND_TASKS; + received = lwip_recvfrom(self->num, buf, len - 1, 0, (struct sockaddr *)&source_addr, &socklen); - if (bytes_received < 0) { + // In non-blocking mode, fail instead of looping + if (received == -1 && self->timeout_ms == 0) { + mp_raise_OSError(MP_EAGAIN); + } + } + + if (!timed_out) { + memcpy((void *)ip, (void*)&source_addr.sin_addr.s_addr, sizeof(source_addr.sin_addr.s_addr)); + *port = source_addr.sin_port; + } else { + mp_raise_OSError(ETIMEDOUT); + } + + if (received < 0) { mp_raise_BrokenPipeError(); return 0; } else { - buf[bytes_received] = 0; // Null-terminate whatever we received - return bytes_received; + buf[received] = 0; // Null-terminate whatever we received + return received; } } diff --git a/ports/esp32s2/common-hal/socketpool/Socket.h b/ports/esp32s2/common-hal/socketpool/Socket.h index 3cffeeb6a1..4e6cfa5ef6 100644 --- a/ports/esp32s2/common-hal/socketpool/Socket.h +++ b/ports/esp32s2/common-hal/socketpool/Socket.h @@ -47,4 +47,7 @@ typedef struct { mp_uint_t timeout_ms; } socketpool_socket_obj_t; +void socket_reset(void); +bool register_open_socket(socketpool_socket_obj_t* self); + #endif // MICROPY_INCLUDED_ESP32S2_COMMON_HAL_SOCKETPOOL_SOCKET_H diff --git a/ports/esp32s2/supervisor/port.c b/ports/esp32s2/supervisor/port.c index 7037b4f051..1b123d19d1 100644 --- a/ports/esp32s2/supervisor/port.c +++ b/ports/esp32s2/supervisor/port.c @@ -46,6 +46,7 @@ #include "common-hal/pwmio/PWMOut.h" #include "common-hal/touchio/TouchIn.h" #include "common-hal/watchdog/WatchDogTimer.h" +#include "common-hal/socketpool/Socket.h" #include "common-hal/wifi/__init__.h" #include "supervisor/memory.h" #include "supervisor/shared/tick.h" @@ -174,6 +175,10 @@ void reset_port(void) { #if CIRCUITPY_WIFI wifi_reset(); #endif + +#if CIRCUITPY_SOCKETPOOL + socket_reset(); +#endif } void reset_to_bootloader(void) { diff --git a/shared-bindings/socketpool/Socket.c b/shared-bindings/socketpool/Socket.c index a92e508b61..0074173405 100644 --- a/shared-bindings/socketpool/Socket.c +++ b/shared-bindings/socketpool/Socket.c @@ -115,14 +115,7 @@ STATIC mp_obj_t socketpool_socket_accept(mp_obj_t self_in) { uint8_t ip[4]; uint port; - int socknum = common_hal_socketpool_socket_accept(self, ip, &port); - - socketpool_socket_obj_t *sock = m_new_obj_with_finaliser(socketpool_socket_obj_t); - sock->base.type = &socketpool_socket_type; - sock->num = socknum; - sock->tls = NULL; - sock->ssl_context = NULL; - sock->pool = self->pool; + socketpool_socket_obj_t * sock = common_hal_socketpool_socket_accept(self, ip, &port); mp_obj_t tuple_contents[2]; tuple_contents[0] = MP_OBJ_FROM_PTR(sock); diff --git a/shared-bindings/socketpool/Socket.h b/shared-bindings/socketpool/Socket.h index e2ea32d392..b5dceb50f4 100644 --- a/shared-bindings/socketpool/Socket.h +++ b/shared-bindings/socketpool/Socket.h @@ -35,7 +35,7 @@ void common_hal_socketpool_socket_settimeout(socketpool_socket_obj_t* self, mp_u bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t* self, const char* host, size_t hostlen, uint8_t port); bool common_hal_socketpool_socket_listen(socketpool_socket_obj_t* self, int backlog); -int common_hal_socketpool_socket_accept(socketpool_socket_obj_t* self, uint8_t* ip, uint *port); +socketpool_socket_obj_t * common_hal_socketpool_socket_accept(socketpool_socket_obj_t* self, uint8_t* ip, uint *port); bool common_hal_socketpool_socket_connect(socketpool_socket_obj_t* self, const char* host, size_t hostlen, mp_int_t port); mp_uint_t common_hal_socketpool_socket_send(socketpool_socket_obj_t* self, const uint8_t* buf, mp_uint_t len);