Merge pull request #7291 from jepler/issue6502

Ensure orderly shutdown of ssl socket
This commit is contained in:
Dan Halbert 2022-12-07 19:12:12 -05:00 committed by GitHub
commit 44af05283a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 13 deletions

View File

@ -31,6 +31,8 @@
#include "py/mperrno.h"
#include "py/runtime.h"
#include "shared-bindings/socketpool/SocketPool.h"
#include "shared-bindings/ssl/SSLSocket.h"
#include "common-hal/ssl/SSLSocket.h"
#include "supervisor/port.h"
#include "supervisor/shared/tick.h"
#include "supervisor/workflow.h"
@ -44,7 +46,7 @@
StackType_t socket_select_stack[2 * configMINIMAL_STACK_SIZE];
STATIC int open_socket_fds[CONFIG_LWIP_MAX_SOCKETS];
STATIC bool user_socket[CONFIG_LWIP_MAX_SOCKETS];
STATIC socketpool_socket_obj_t *user_socket[CONFIG_LWIP_MAX_SOCKETS];
StaticTask_t socket_select_task_handle;
STATIC int socket_change_fd = -1;
@ -117,7 +119,7 @@ void socket_user_reset(void) {
for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_fds); i++) {
open_socket_fds[i] = -1;
user_socket[i] = false;
user_socket[i] = NULL;
}
socket_change_fd = eventfd(0, 0);
// Run this at the same priority as CP so that the web workflow background task can be
@ -134,12 +136,13 @@ void socket_user_reset(void) {
for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_fds); i++) {
if (open_socket_fds[i] >= 0 && user_socket[i]) {
common_hal_socketpool_socket_close(user_socket[i]);
int num = open_socket_fds[i];
// Close automatically clears socket handle
lwip_shutdown(num, SHUT_RDWR);
lwip_close(num);
open_socket_fds[i] = -1;
user_socket[i] = false;
user_socket[i] = NULL;
}
}
}
@ -171,10 +174,10 @@ STATIC void unregister_open_socket(int fd) {
}
}
STATIC void mark_user_socket(int fd) {
STATIC void mark_user_socket(int fd, socketpool_socket_obj_t *obj) {
for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_fds); i++) {
if (open_socket_fds[i] == fd) {
user_socket[i] = true;
user_socket[i] = obj;
return;
}
}
@ -236,7 +239,7 @@ socketpool_socket_obj_t *common_hal_socketpool_socket(socketpool_socketpool_obj_
if (!socketpool_socket(self, family, type, sock)) {
mp_raise_RuntimeError(translate("Out of sockets"));
}
mark_user_socket(sock->num);
mark_user_socket(sock->num, sock);
return sock;
}
@ -293,12 +296,12 @@ int socketpool_socket_accept(socketpool_socket_obj_t *self, uint8_t *ip, uint32_
socketpool_socket_obj_t *common_hal_socketpool_socket_accept(socketpool_socket_obj_t *self,
uint8_t *ip, uint32_t *port) {
socketpool_socket_obj_t *sock = m_new_obj_with_finaliser(socketpool_socket_obj_t);
int newsoc = socketpool_socket_accept(self, ip, port, NULL);
if (newsoc > 0) {
mark_user_socket(newsoc);
// Create the socket
socketpool_socket_obj_t *sock = m_new_obj_with_finaliser(socketpool_socket_obj_t);
mark_user_socket(newsoc, sock);
sock->base.type = &socketpool_socket_type;
sock->num = newsoc;
sock->pool = self->pool;
@ -338,6 +341,12 @@ bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self,
}
void socketpool_socket_close(socketpool_socket_obj_t *self) {
if (self->ssl_socket) {
ssl_sslsocket_obj_t *ssl_socket = self->ssl_socket;
self->ssl_socket = NULL;
common_hal_ssl_sslsocket_close(ssl_socket);
return;
}
self->connected = false;
if (self->num >= 0) {
lwip_shutdown(self->num, SHUT_RDWR);

View File

@ -24,8 +24,7 @@
* THE SOFTWARE.
*/
#ifndef MICROPY_INCLUDED_ESPRESSIF_COMMON_HAL_SOCKETPOOL_SOCKET_H
#define MICROPY_INCLUDED_ESPRESSIF_COMMON_HAL_SOCKETPOOL_SOCKET_H
#pragma once
#include "py/obj.h"
@ -34,6 +33,8 @@
#include "components/esp-tls/esp_tls.h"
typedef struct ssl_sslsocket_obj ssl_sslsocket_obj_t;
typedef struct {
mp_obj_base_t base;
int num;
@ -42,9 +43,8 @@ typedef struct {
int ipproto;
bool connected;
socketpool_socketpool_obj_t *pool;
ssl_sslsocket_obj_t *ssl_socket;
mp_uint_t timeout_ms;
} socketpool_socket_obj_t;
void socket_user_reset(void);
#endif // MICROPY_INCLUDED_ESPRESSIF_COMMON_HAL_SOCKETPOOL_SOCKET_H

View File

@ -48,6 +48,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
sock->base.type = &ssl_sslsocket_type;
sock->ssl_context = self;
sock->sock = socket;
socket->ssl_socket = sock;
// Create a copy of the ESP-TLS config object and store the server hostname
// Note that ESP-TLS will use common_name for both SNI and verification

View File

@ -34,7 +34,7 @@
#include "components/esp-tls/esp_tls.h"
typedef struct {
typedef struct ssl_sslsocket_obj {
mp_obj_base_t base;
socketpool_socket_obj_t *sock;
esp_tls_t *tls;