diff --git a/ports/espressif/common-hal/socketpool/Socket.c b/ports/espressif/common-hal/socketpool/Socket.c index 5f34d674b2..edaed97598 100644 --- a/ports/espressif/common-hal/socketpool/Socket.c +++ b/ports/espressif/common-hal/socketpool/Socket.c @@ -31,6 +31,8 @@ #include "py/mperrno.h" #include "py/runtime.h" #include "shared-bindings/socketpool/SocketPool.h" +#include "shared-bindings/ssl/SSLSocket.h" +#include "common-hal/ssl/SSLSocket.h" #include "supervisor/port.h" #include "supervisor/shared/tick.h" #include "supervisor/workflow.h" @@ -44,7 +46,7 @@ StackType_t socket_select_stack[2 * configMINIMAL_STACK_SIZE]; STATIC int open_socket_fds[CONFIG_LWIP_MAX_SOCKETS]; -STATIC bool user_socket[CONFIG_LWIP_MAX_SOCKETS]; +STATIC socketpool_socket_obj_t *user_socket[CONFIG_LWIP_MAX_SOCKETS]; StaticTask_t socket_select_task_handle; STATIC int socket_change_fd = -1; @@ -117,7 +119,7 @@ void socket_user_reset(void) { for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_fds); i++) { open_socket_fds[i] = -1; - user_socket[i] = false; + user_socket[i] = NULL; } socket_change_fd = eventfd(0, 0); // Run this at the same priority as CP so that the web workflow background task can be @@ -134,12 +136,13 @@ void socket_user_reset(void) { for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_fds); i++) { if (open_socket_fds[i] >= 0 && user_socket[i]) { + common_hal_socketpool_socket_close(user_socket[i]); int num = open_socket_fds[i]; // Close automatically clears socket handle lwip_shutdown(num, SHUT_RDWR); lwip_close(num); open_socket_fds[i] = -1; - user_socket[i] = false; + user_socket[i] = NULL; } } } @@ -171,10 +174,10 @@ STATIC void unregister_open_socket(int fd) { } } -STATIC void mark_user_socket(int fd) { +STATIC void mark_user_socket(int fd, socketpool_socket_obj_t *obj) { for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_fds); i++) { if (open_socket_fds[i] == fd) { - user_socket[i] = true; + user_socket[i] = obj; return; } } @@ -236,7 +239,7 @@ socketpool_socket_obj_t *common_hal_socketpool_socket(socketpool_socketpool_obj_ if (!socketpool_socket(self, family, type, sock)) { mp_raise_RuntimeError(translate("Out of sockets")); } - mark_user_socket(sock->num); + mark_user_socket(sock->num, sock); return sock; } @@ -292,12 +295,12 @@ int socketpool_socket_accept(socketpool_socket_obj_t *self, uint8_t *ip, uint32_ socketpool_socket_obj_t *common_hal_socketpool_socket_accept(socketpool_socket_obj_t *self, uint8_t *ip, uint32_t *port) { + socketpool_socket_obj_t *sock = m_new_obj_with_finaliser(socketpool_socket_obj_t); int newsoc = socketpool_socket_accept(self, ip, port, NULL); if (newsoc > 0) { - mark_user_socket(newsoc); // Create the socket - socketpool_socket_obj_t *sock = m_new_obj_with_finaliser(socketpool_socket_obj_t); + mark_user_socket(newsoc, sock); sock->base.type = &socketpool_socket_type; sock->num = newsoc; sock->pool = self->pool; @@ -338,6 +341,12 @@ bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self, } void socketpool_socket_close(socketpool_socket_obj_t *self) { + if (self->ssl_socket) { + ssl_sslsocket_obj_t *ssl_socket = self->ssl_socket; + self->ssl_socket = NULL; + common_hal_ssl_sslsocket_close(ssl_socket); + return; + } self->connected = false; if (self->num >= 0) { lwip_shutdown(self->num, SHUT_RDWR); diff --git a/ports/espressif/common-hal/socketpool/Socket.h b/ports/espressif/common-hal/socketpool/Socket.h index b91419807c..b189889cfd 100644 --- a/ports/espressif/common-hal/socketpool/Socket.h +++ b/ports/espressif/common-hal/socketpool/Socket.h @@ -42,6 +42,7 @@ typedef struct { int ipproto; bool connected; socketpool_socketpool_obj_t *pool; + struct ssl_sslsocket_obj *ssl_socket; mp_uint_t timeout_ms; } socketpool_socket_obj_t; diff --git a/ports/espressif/common-hal/ssl/SSLContext.c b/ports/espressif/common-hal/ssl/SSLContext.c index 386986e6be..1923a9c2a9 100644 --- a/ports/espressif/common-hal/ssl/SSLContext.c +++ b/ports/espressif/common-hal/ssl/SSLContext.c @@ -48,6 +48,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t sock->base.type = &ssl_sslsocket_type; sock->ssl_context = self; sock->sock = socket; + socket->ssl_socket = sock; // Create a copy of the ESP-TLS config object and store the server hostname // Note that ESP-TLS will use common_name for both SNI and verification diff --git a/ports/espressif/common-hal/ssl/SSLSocket.h b/ports/espressif/common-hal/ssl/SSLSocket.h index 097f19857b..6b65a56223 100644 --- a/ports/espressif/common-hal/ssl/SSLSocket.h +++ b/ports/espressif/common-hal/ssl/SSLSocket.h @@ -34,7 +34,7 @@ #include "components/esp-tls/esp_tls.h" -typedef struct { +typedef struct ssl_sslsocket_obj { mp_obj_base_t base; socketpool_socket_obj_t *sock; esp_tls_t *tls;