From 9f34ec78c44b6037e2f05bf4ebb44130b35e1425 Mon Sep 17 00:00:00 2001 From: Lucian Copeland Date: Fri, 15 Jan 2021 12:01:15 -0500 Subject: [PATCH 1/4] Separate Socket and SSLSocket, add LWIP connect --- locale/circuitpython.pot | 24 +- ports/esp32s2/common-hal/socketpool/Socket.c | 380 +++++++----------- ports/esp32s2/common-hal/socketpool/Socket.h | 2 - .../common-hal/socketpool/SocketPool.c | 22 +- ports/esp32s2/common-hal/ssl/SSLContext.c | 24 +- ports/esp32s2/common-hal/ssl/SSLSocket.c | 169 ++++++++ ports/esp32s2/common-hal/ssl/SSLSocket.h | 44 ++ py/circuitpy_defns.mk | 1 + shared-bindings/socketpool/Socket.c | 207 +++++----- shared-bindings/socketpool/Socket.h | 23 +- shared-bindings/socketpool/SocketPool.c | 3 +- shared-bindings/ssl/SSLContext.h | 3 +- shared-bindings/ssl/SSLSocket.c | 320 +++++++++++++++ shared-bindings/ssl/SSLSocket.h | 46 +++ 14 files changed, 886 insertions(+), 382 deletions(-) create mode 100644 ports/esp32s2/common-hal/ssl/SSLSocket.c create mode 100644 ports/esp32s2/common-hal/ssl/SSLSocket.h create mode 100644 shared-bindings/ssl/SSLSocket.c create mode 100644 shared-bindings/ssl/SSLSocket.h diff --git a/locale/circuitpython.pot b/locale/circuitpython.pot index 3ffc31cc6c..c88a32a051 100644 --- a/locale/circuitpython.pot +++ b/locale/circuitpython.pot @@ -622,6 +622,10 @@ msgstr "" msgid "Cannot reset into bootloader because no bootloader is present." msgstr "" +#: ports/esp32s2/common-hal/socketpool/Socket.c +msgid "Cannot set socket options" +msgstr "" + #: shared-bindings/digitalio/DigitalInOut.c msgid "Cannot set value when direction is input." msgstr "" @@ -854,7 +858,7 @@ msgstr "" msgid "Error in regex" msgstr "" -#: shared-bindings/socketpool/Socket.c +#: shared-bindings/socketpool/Socket.c shared-bindings/ssl/SSLSocket.c msgid "Error: Failure to bind" msgstr "" @@ -912,7 +916,7 @@ msgstr "" msgid "FFT is implemented for linear arrays only" msgstr "" -#: ports/esp32s2/common-hal/socketpool/Socket.c +#: ports/esp32s2/common-hal/ssl/SSLSocket.c msgid "Failed SSL handshake" msgstr "" @@ -1248,7 +1252,7 @@ msgstr "" msgid "Invalid size" msgstr "" -#: ports/esp32s2/common-hal/socketpool/Socket.c +#: ports/esp32s2/common-hal/ssl/SSLContext.c msgid "Invalid socket for TLS" msgstr "" @@ -1256,10 +1260,6 @@ msgstr "" msgid "Invalid state" msgstr "" -#: ports/esp32s2/common-hal/socketpool/Socket.c -msgid "Invalid use of TLS Socket" -msgstr "" - #: shared-bindings/audiomixer/Mixer.c msgid "Invalid voice" msgstr "" @@ -1276,10 +1276,6 @@ msgstr "" msgid "Invalid word/bit length" msgstr "" -#: ports/esp32s2/common-hal/socketpool/Socket.c -msgid "Issue setting SO_REUSEADDR" -msgstr "" - #: shared-bindings/aesio/aes.c msgid "Key must be 16, 24, or 32 bytes long" msgstr "" @@ -1562,7 +1558,7 @@ msgstr "" msgid "Out of memory" msgstr "" -#: ports/esp32s2/common-hal/socketpool/Socket.c +#: ports/esp32s2/common-hal/socketpool/SocketPool.c msgid "Out of sockets" msgstr "" @@ -2027,7 +2023,7 @@ msgstr "" msgid "Unexpected nrfx uuid type" msgstr "" -#: ports/esp32s2/common-hal/socketpool/Socket.c +#: ports/esp32s2/common-hal/ssl/SSLSocket.c #, c-format msgid "Unhandled ESP TLS error %d %d %x %d" msgstr "" @@ -2324,7 +2320,7 @@ msgstr "" msgid "buffer too small" msgstr "" -#: shared-bindings/socketpool/Socket.c +#: shared-bindings/socketpool/Socket.c shared-bindings/ssl/SSLSocket.c msgid "buffer too small for requested bytes" msgstr "" diff --git a/ports/esp32s2/common-hal/socketpool/Socket.c b/ports/esp32s2/common-hal/socketpool/Socket.c index 69ef41c6ec..4cbf4cff26 100644 --- a/ports/esp32s2/common-hal/socketpool/Socket.c +++ b/ports/esp32s2/common-hal/socketpool/Socket.c @@ -38,7 +38,7 @@ #include "components/lwip/lwip/src/include/lwip/sys.h" #include "components/lwip/lwip/src/include/lwip/netdb.h" -STATIC socketpool_socket_obj_t * open_socket_handles[CONFIG_LWIP_MAX_SOCKETS]; // 4 on the wrover/wroom +STATIC socketpool_socket_obj_t * open_socket_handles[CONFIG_LWIP_MAX_SOCKETS]; void socket_reset(void) { for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_handles); i++) { @@ -47,7 +47,6 @@ void socket_reset(void) { common_hal_socketpool_socket_close(open_socket_handles[i]); open_socket_handles[i] = NULL; } else { - // accidentally got a TCP socket in here, or something. open_socket_handles[i] = NULL; } } @@ -64,59 +63,6 @@ bool register_open_socket(socketpool_socket_obj_t* self) { return false; } -STATIC void _lazy_init_LWIP(socketpool_socket_obj_t* self) { - if (self->num != -1) { - return; //safe to call on existing socket - } - if (self->tls != NULL) { - mp_raise_RuntimeError(translate("Invalid use of TLS Socket")); - } - int socknum = -1; - socknum = lwip_socket(self->family, self->type, self->ipproto); - if (socknum < 0 || !register_open_socket(self)) { - mp_raise_RuntimeError(translate("Out of sockets")); - } - self->num = socknum; - lwip_fcntl(socknum, F_SETFL, O_NONBLOCK); -} - -STATIC void _lazy_init_TLS(socketpool_socket_obj_t* self) { - if (self->type != SOCK_STREAM || self->num != -1) { - mp_raise_RuntimeError(translate("Invalid socket for TLS")); - } - esp_tls_t* tls_handle = esp_tls_init(); - if (tls_handle == NULL) { - mp_raise_espidf_MemoryError(); - } - self->tls = tls_handle; -} - -void common_hal_socketpool_socket_settimeout(socketpool_socket_obj_t* self, mp_uint_t timeout_ms) { - self->timeout_ms = timeout_ms; -} - -bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t* self, - const char* host, size_t hostlen, uint8_t port) { - _lazy_init_LWIP(self); - - struct sockaddr_in bind_addr; - bind_addr.sin_addr.s_addr = inet_addr(host); - bind_addr.sin_family = AF_INET; - bind_addr.sin_port = htons(port); - - int opt = 1; - int err = lwip_setsockopt(self->num, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); - if (err != 0) { - mp_raise_RuntimeError(translate("Issue setting SO_REUSEADDR")); - } - int result = lwip_bind(self->num, (struct sockaddr *)&bind_addr, sizeof(bind_addr)) == 0; - return result; -} - -bool common_hal_socketpool_socket_listen(socketpool_socket_obj_t* self, int backlog) { - return lwip_listen(self->num, backlog) == 0; -} - socketpool_socket_obj_t* common_hal_socketpool_socket_accept(socketpool_socket_obj_t* self, uint8_t* ip, uint *port) { struct sockaddr_in accept_addr; @@ -125,22 +71,16 @@ socketpool_socket_obj_t* common_hal_socketpool_socket_accept(socketpool_socket_o bool timed_out = false; uint64_t start_ticks = supervisor_ticks_ms64(); - if (self->timeout_ms != (uint)-1) { - mp_printf(&mp_plat_print, "will timeout"); - } else { - mp_printf(&mp_plat_print, "won't timeout"); - } - // Allow timeouts and interrupts while (newsoc == -1 && !timed_out && !mp_hal_is_interrupted()) { - if (self->timeout_ms != (uint)-1) { + if (self->timeout_ms != (uint)-1 && self->timeout_ms != 0) { timed_out = supervisor_ticks_ms64() - start_ticks >= self->timeout_ms; } RUN_BACKGROUND_TASKS; newsoc = lwip_accept(self->num, (struct sockaddr *)&accept_addr, &socklen); - // In non-blocking mode, fail instead of looping + // In non-blocking mode, fail instead of timing out if (newsoc == -1 && self->timeout_ms == 0) { mp_raise_OSError(MP_EAGAIN); } @@ -159,8 +99,6 @@ socketpool_socket_obj_t* common_hal_socketpool_socket_accept(socketpool_socket_o socketpool_socket_obj_t *sock = m_new_obj_with_finaliser(socketpool_socket_obj_t); sock->base.type = &socketpool_socket_type; sock->num = newsoc; - sock->tls = NULL; - sock->ssl_context = NULL; sock->pool = self->pool; if (!register_open_socket(sock)) { @@ -175,183 +113,96 @@ socketpool_socket_obj_t* common_hal_socketpool_socket_accept(socketpool_socket_o } } +bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t* self, + const char* host, size_t hostlen, uint8_t port) { + struct sockaddr_in bind_addr; + bind_addr.sin_addr.s_addr = inet_addr(host); + bind_addr.sin_family = AF_INET; + bind_addr.sin_port = htons(port); + + int opt = 1; + int err = lwip_setsockopt(self->num, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + if (err != 0) { + mp_raise_RuntimeError(translate("Cannot set socket options")); + } + int result = lwip_bind(self->num, (struct sockaddr *)&bind_addr, sizeof(bind_addr)) == 0; + return result; +} + +void common_hal_socketpool_socket_close(socketpool_socket_obj_t* self) { + self->connected = false; + if (self->num >= 0) { + lwip_shutdown(self->num, 0); + lwip_close(self->num); + self->num = -1; + } +} + bool common_hal_socketpool_socket_connect(socketpool_socket_obj_t* self, const char* host, mp_uint_t hostlen, mp_int_t port) { - // For simplicity we use esp_tls for all TCP connections. If it's not SSL, ssl_context will be - // NULL and should still work. This makes regular TCP connections more memory expensive but TLS - // should become more and more common. Therefore, we optimize for the TLS case. - - // Todo: move to SSL Wrapper and add lwip_connect() - _lazy_init_TLS(self); - - esp_tls_cfg_t* tls_config = NULL; - if (self->ssl_context != NULL) { - tls_config = &self->ssl_context->ssl_config; - } - int result = esp_tls_conn_new_sync(host, hostlen, port, tls_config, self->tls); - self->connected = result >= 0; - if (result < 0) { - int esp_tls_code; - int flags; - esp_err_t err = esp_tls_get_and_clear_last_error(self->tls->error_handle, &esp_tls_code, &flags); - - if (err == ESP_ERR_MBEDTLS_SSL_SETUP_FAILED) { - mp_raise_espidf_MemoryError(); - } else if (ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED) { - mp_raise_OSError_msg_varg(translate("Failed SSL handshake")); - } else { - mp_raise_OSError_msg_varg(translate("Unhandled ESP TLS error %d %d %x %d"), esp_tls_code, flags, err, result); - } - } else { - // Connection successful, set the timeout on the underlying socket. We can't rely on the IDF - // to do it because the config structure is only used for TLS connections. Generally, we - // shouldn't hit this timeout because we try to only read available data. However, there is - // always a chance that we try to read something that is used internally. - int fd; - esp_tls_get_conn_sockfd(self->tls, &fd); - struct timeval tv; - tv.tv_sec = 2 * 60; // Two minutes - tv.tv_usec = 0; - setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); - setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); - } - - return self->connected; -} - -bool common_hal_socketpool_socket_get_connected(socketpool_socket_obj_t* self) { - return self->connected; -} - -mp_uint_t common_hal_socketpool_socket_send(socketpool_socket_obj_t* self, const uint8_t* buf, mp_uint_t len) { - int sent = -1; - if (self->num != -1) { - // LWIP Socket - // TODO: deal with potential failure/add timeout? - sent = lwip_send(self->num, buf, len, 0); - } else if (self->tls != NULL) { - // TLS Socket - sent = esp_tls_conn_write(self->tls, buf, len); - } - - if (sent < 0) { - mp_raise_OSError(MP_ENOTCONN); - } - return sent; -} - -mp_uint_t common_hal_socketpool_socket_recv_into(socketpool_socket_obj_t* self, const uint8_t* buf, mp_uint_t len) { - int received = 0; - bool timed_out = false; - - if (self->num != -1) { - // LWIP Socket - uint64_t start_ticks = supervisor_ticks_ms64(); - received = -1; - while (received == -1 && - !timed_out && - !mp_hal_is_interrupted()) { - if (self->timeout_ms != (uint)-1) { - timed_out = supervisor_ticks_ms64() - start_ticks >= self->timeout_ms; - } - RUN_BACKGROUND_TASKS; - received = lwip_recv(self->num, (void*) buf, len - 1, 0); - - // In non-blocking mode, fail instead of looping - if (received == -1 && self->timeout_ms == 0) { - mp_raise_OSError(MP_EAGAIN); - } - } - } else if (self->tls != NULL) { - // TLS Socket - int status = 0; - uint64_t start_ticks = supervisor_ticks_ms64(); - int sockfd; - esp_err_t err = esp_tls_get_conn_sockfd(self->tls, &sockfd); - if (err != ESP_OK) { - mp_raise_OSError(MP_EBADF); - } - while (received == 0 && - status >= 0 && - !timed_out && - !mp_hal_is_interrupted()) { - if (self->timeout_ms != (uint)-1) { - timed_out = self->timeout_ms == 0 || supervisor_ticks_ms64() - start_ticks >= self->timeout_ms; - } - RUN_BACKGROUND_TASKS; - size_t available = esp_tls_get_bytes_avail(self->tls); - if (available == 0) { - // This reads the raw socket buffer and is used for non-TLS connections - // and between encrypted TLS blocks. - status = lwip_ioctl(sockfd, FIONREAD, &available); - } - size_t remaining = len - received; - if (available > remaining) { - available = remaining; - } - if (available > 0) { - status = esp_tls_conn_read(self->tls, (void*) buf + received, available); - if (status == 0) { - // Reading zero when something is available indicates a closed - // connection. (The available bytes could have been TLS internal.) - break; - } - if (status > 0) { - received += status; - } - } - } - } else { - // Socket does not have a valid descriptor of either type - mp_raise_OSError(MP_EBADF); - } - - if (timed_out) { - mp_raise_OSError(ETIMEDOUT); - } - return received; -} - -mp_uint_t common_hal_socketpool_socket_sendto(socketpool_socket_obj_t* self, - const char* host, size_t hostlen, uint8_t port, const uint8_t* buf, mp_uint_t len) { - - _lazy_init_LWIP(self); - - // Get the IP address string const struct addrinfo hints = { .ai_family = AF_INET, .ai_socktype = SOCK_STREAM, }; - struct addrinfo *result; - int error = lwip_getaddrinfo(host, NULL, &hints, &result); - if (error != 0 || result == NULL) { - return 0; + struct addrinfo *result_i; + int error = lwip_getaddrinfo(host, NULL, &hints, &result_i); + if (error != 0 || result_i == NULL) { + mp_raise_OSError(EHOSTUNREACH); } // Set parameters struct sockaddr_in dest_addr; #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wcast-align" - dest_addr.sin_addr.s_addr = ((struct sockaddr_in *)result->ai_addr)->sin_addr.s_addr; + dest_addr.sin_addr.s_addr = ((struct sockaddr_in *)result_i->ai_addr)->sin_addr.s_addr; #pragma GCC diagnostic pop - freeaddrinfo(result); + freeaddrinfo(result_i); dest_addr.sin_family = AF_INET; dest_addr.sin_port = htons(port); - int bytes_sent = lwip_sendto(self->num, buf, len, 0, (struct sockaddr *)&dest_addr, sizeof(dest_addr)); - if (bytes_sent < 0) { - mp_raise_BrokenPipeError(); - return 0; + // Replace above with function call ----- + + // Switch to blocking mode for this one call + int opts; + opts = lwip_fcntl(self->num,F_GETFL,0); + opts = opts & (~O_NONBLOCK); + lwip_fcntl(self->num, F_SETFL, opts); + + int result = -1; + result = lwip_connect(self->num, (struct sockaddr *)&dest_addr, sizeof(struct sockaddr_in)); + + // Switch back once complete + opts = opts | O_NONBLOCK; + lwip_fcntl(self->num, F_SETFL, opts); + + if (result) { + self->connected = true; + return true; + } else { + mp_raise_OSError(errno); } - return bytes_sent; +} + +bool common_hal_socketpool_socket_get_closed(socketpool_socket_obj_t* self) { + return self->num < 0; +} + +bool common_hal_socketpool_socket_get_connected(socketpool_socket_obj_t* self) { + return self->connected; +} + +mp_uint_t common_hal_socketpool_socket_get_hash(socketpool_socket_obj_t* self) { + return self->num; +} + +bool common_hal_socketpool_socket_listen(socketpool_socket_obj_t* self, int backlog) { + return lwip_listen(self->num, backlog) == 0; } mp_uint_t common_hal_socketpool_socket_recvfrom_into(socketpool_socket_obj_t* self, uint8_t* buf, mp_uint_t len, uint8_t* ip, uint *port) { - _lazy_init_LWIP(self); - struct sockaddr_in source_addr; socklen_t socklen = sizeof(source_addr); @@ -362,7 +213,7 @@ mp_uint_t common_hal_socketpool_socket_recvfrom_into(socketpool_socket_obj_t* se while (received == -1 && !timed_out && !mp_hal_is_interrupted()) { - if (self->timeout_ms != (uint)-1) { + if (self->timeout_ms != (uint)-1 && self->timeout_ms != 0) { timed_out = supervisor_ticks_ms64() - start_ticks >= self->timeout_ms; } RUN_BACKGROUND_TASKS; @@ -389,24 +240,87 @@ mp_uint_t common_hal_socketpool_socket_recvfrom_into(socketpool_socket_obj_t* se return received; } -void common_hal_socketpool_socket_close(socketpool_socket_obj_t* self) { - self->connected = false; - if (self->tls != NULL) { - esp_tls_conn_destroy(self->tls); - self->tls = NULL; +mp_uint_t common_hal_socketpool_socket_recv_into(socketpool_socket_obj_t* self, const uint8_t* buf, mp_uint_t len) { + int received = 0; + bool timed_out = false; + + if (self->num != -1) { + // LWIP Socket + uint64_t start_ticks = supervisor_ticks_ms64(); + received = -1; + while (received == -1 && + !timed_out && + !mp_hal_is_interrupted()) { + if (self->timeout_ms != (uint)-1 && self->timeout_ms != 0) { + timed_out = supervisor_ticks_ms64() - start_ticks >= self->timeout_ms; + } + RUN_BACKGROUND_TASKS; + received = lwip_recv(self->num, (void*) buf, len - 1, 0); + + // In non-blocking mode, fail instead of looping + if (received == -1 && self->timeout_ms == 0) { + mp_raise_OSError(MP_EAGAIN); + } + } + } else { + mp_raise_OSError(MP_EBADF); } - if (self->num >= 0) { - lwip_shutdown(self->num, 0); - lwip_close(self->num); - self->num = -1; + + if (timed_out) { + mp_raise_OSError(ETIMEDOUT); } + return received; } -bool common_hal_socketpool_socket_get_closed(socketpool_socket_obj_t* self) { - return self->tls == NULL && self->num < 0; +mp_uint_t common_hal_socketpool_socket_send(socketpool_socket_obj_t* self, const uint8_t* buf, mp_uint_t len) { + int sent = -1; + if (self->num != -1) { + // LWIP Socket + // TODO: deal with potential failure/add timeout? + sent = lwip_send(self->num, buf, len, 0); + } else { + mp_raise_OSError(MP_EBADF); + } + + if (sent < 0) { + mp_raise_OSError(MP_ENOTCONN); + } + return sent; } +mp_uint_t common_hal_socketpool_socket_sendto(socketpool_socket_obj_t* self, + const char* host, size_t hostlen, uint8_t port, const uint8_t* buf, mp_uint_t len) { -mp_uint_t common_hal_socketpool_socket_get_hash(socketpool_socket_obj_t* self) { - return self->num; + // Set parameters + const struct addrinfo hints = { + .ai_family = AF_INET, + .ai_socktype = SOCK_STREAM, + }; + struct addrinfo *result_i; + int error = lwip_getaddrinfo(host, NULL, &hints, &result_i); + if (error != 0 || result_i == NULL) { + mp_raise_OSError(EHOSTUNREACH); + } + + // Set parameters + struct sockaddr_in dest_addr; + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wcast-align" + dest_addr.sin_addr.s_addr = ((struct sockaddr_in *)result_i->ai_addr)->sin_addr.s_addr; + #pragma GCC diagnostic pop + freeaddrinfo(result_i); + + dest_addr.sin_family = AF_INET; + dest_addr.sin_port = htons(port); + + int bytes_sent = lwip_sendto(self->num, buf, len, 0, (struct sockaddr *)&dest_addr, sizeof(dest_addr)); + if (bytes_sent < 0) { + mp_raise_BrokenPipeError(); + return 0; + } + return bytes_sent; +} + +void common_hal_socketpool_socket_settimeout(socketpool_socket_obj_t* self, mp_uint_t timeout_ms) { + self->timeout_ms = timeout_ms; } diff --git a/ports/esp32s2/common-hal/socketpool/Socket.h b/ports/esp32s2/common-hal/socketpool/Socket.h index 4e6cfa5ef6..b86f5597c4 100644 --- a/ports/esp32s2/common-hal/socketpool/Socket.h +++ b/ports/esp32s2/common-hal/socketpool/Socket.h @@ -41,8 +41,6 @@ typedef struct { int family; int ipproto; bool connected; - esp_tls_t* tls; - ssl_sslcontext_obj_t* ssl_context; socketpool_socketpool_obj_t* pool; mp_uint_t timeout_ms; } socketpool_socket_obj_t; diff --git a/ports/esp32s2/common-hal/socketpool/SocketPool.c b/ports/esp32s2/common-hal/socketpool/SocketPool.c index 5821728ce5..fbd6dca7af 100644 --- a/ports/esp32s2/common-hal/socketpool/SocketPool.c +++ b/ports/esp32s2/common-hal/socketpool/SocketPool.c @@ -25,6 +25,7 @@ */ #include "shared-bindings/socketpool/SocketPool.h" +#include "common-hal/socketpool/Socket.h" #include "py/runtime.h" #include "shared-bindings/wifi/__init__.h" @@ -65,22 +66,23 @@ socketpool_socket_obj_t* common_hal_socketpool_socket(socketpool_socketpool_obj_ mp_raise_NotImplementedError(translate("Only IPv4 sockets supported")); } - // Consider LWIP and MbedTLS "variant" sockets to be incompatible (for now) - // The variant of the socket is determined by whether the socket is wrapped - // by SSL. If no TLS handle is set in sslcontext_wrap_socket, the first call - // of bind() or connect() will create a LWIP socket with a corresponding - // socketnum. - // TODO: move MbedTLS to its own duplicate Socket or Server API, maybe? socketpool_socket_obj_t *sock = m_new_obj_with_finaliser(socketpool_socket_obj_t); sock->base.type = &socketpool_socket_type; - sock->num = -1; sock->type = socket_type; sock->family = addr_family; sock->ipproto = ipproto; - - sock->tls = NULL; - sock->ssl_context = NULL; sock->pool = self; + sock->timeout_ms = (uint)-1; + + // Create LWIP socket + int socknum = -1; + socknum = lwip_socket(sock->family, sock->type, sock->ipproto); + if (socknum < 0 || !register_open_socket(sock)) { + mp_raise_RuntimeError(translate("Out of sockets")); + } + sock->num = socknum; + // Sockets should be nonblocking in most cases + lwip_fcntl(socknum, F_SETFL, O_NONBLOCK); return sock; } diff --git a/ports/esp32s2/common-hal/ssl/SSLContext.c b/ports/esp32s2/common-hal/ssl/SSLContext.c index e24fd338b6..c0179399de 100644 --- a/ports/esp32s2/common-hal/ssl/SSLContext.c +++ b/ports/esp32s2/common-hal/ssl/SSLContext.c @@ -25,6 +25,9 @@ */ #include "shared-bindings/ssl/SSLContext.h" +#include "shared-bindings/ssl/SSLSocket.h" + +#include "bindings/espidf/__init__.h" #include "py/runtime.h" @@ -32,10 +35,25 @@ void common_hal_ssl_sslcontext_construct(ssl_sslcontext_obj_t* self) { } -socketpool_socket_obj_t* common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t* self, +ssl_sslsocket_obj_t* common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t* self, socketpool_socket_obj_t* socket, bool server_side, const char* server_hostname) { - socket->ssl_context = self; + ssl_sslsocket_obj_t *sock = m_new_obj_with_finaliser(ssl_sslsocket_obj_t); + sock->base.type = &ssl_sslsocket_type; + sock->ssl_context = self; + sock->sock = socket; + + if (socket->type != SOCK_STREAM || socket->num != -1) { + mp_raise_RuntimeError(translate("Invalid socket for TLS")); + } + esp_tls_t* tls_handle = esp_tls_init(); + if (tls_handle == NULL) { + mp_raise_espidf_MemoryError(); + } + sock->tls = tls_handle; + + // TODO: do something with the original socket? Don't call a close on the internal LWIP. + // Should we store server hostname on the socket in case connect is called with an ip? - return socket; + return sock; } diff --git a/ports/esp32s2/common-hal/ssl/SSLSocket.c b/ports/esp32s2/common-hal/ssl/SSLSocket.c new file mode 100644 index 0000000000..d8e48c3d59 --- /dev/null +++ b/ports/esp32s2/common-hal/ssl/SSLSocket.c @@ -0,0 +1,169 @@ +/* + * This file is part of the MicroPython project, http://micropython.org/ + * + * The MIT License (MIT) + * + * Copyright (c) 2021 Lucian Copeland for Adafruit Industries + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "shared-bindings/ssl/SSLSocket.h" +#include "shared-bindings/socketpool/Socket.h" +#include "shared-bindings/ssl/SSLContext.h" + +#include "bindings/espidf/__init__.h" +#include "lib/utils/interrupt_char.h" +#include "py/mperrno.h" +#include "py/runtime.h" +#include "supervisor/shared/tick.h" + +void common_hal_ssl_sslsocket_settimeout(ssl_sslsocket_obj_t* self, mp_uint_t timeout_ms) { + self->sock->timeout_ms = timeout_ms; +} + +ssl_sslsocket_obj_t* common_hal_ssl_sslsocket_accept(ssl_sslsocket_obj_t* self, + uint8_t* ip, uint *port) { + socketpool_socket_obj_t * sock = common_hal_socketpool_socket_accept(self->sock, ip, port); + ssl_sslsocket_obj_t * sslsock = common_hal_ssl_sslcontext_wrap_socket(self->ssl_context, sock, false, NULL); + return sslsock; +} + +bool common_hal_ssl_sslsocket_bind(ssl_sslsocket_obj_t* self, + const char* host, size_t hostlen, uint8_t port) { + return common_hal_socketpool_socket_bind(self->sock, host, hostlen, port); +} + +void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t* self) { + self->sock->connected = false; + esp_tls_conn_destroy(self->tls); + self->tls = NULL; +} + +bool common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t* self, + const char* host, mp_uint_t hostlen, mp_int_t port) { + esp_tls_cfg_t* tls_config = NULL; + tls_config = &self->ssl_context->ssl_config; + int result = esp_tls_conn_new_sync(host, hostlen, port, tls_config, self->tls); + self->sock->connected = result >= 0; + if (result < 0) { + int esp_tls_code; + int flags; + esp_err_t err = esp_tls_get_and_clear_last_error(self->tls->error_handle, &esp_tls_code, &flags); + + if (err == ESP_ERR_MBEDTLS_SSL_SETUP_FAILED) { + mp_raise_espidf_MemoryError(); + } else if (ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED) { + mp_raise_OSError_msg_varg(translate("Failed SSL handshake")); + } else { + mp_raise_OSError_msg_varg(translate("Unhandled ESP TLS error %d %d %x %d"), esp_tls_code, flags, err, result); + } + } else { + // Connection successful, set the timeout on the underlying socket. We can't rely on the IDF + // to do it because the config structure is only used for TLS connections. Generally, we + // shouldn't hit this timeout because we try to only read available data. However, there is + // always a chance that we try to read something that is used internally. + int fd; + esp_tls_get_conn_sockfd(self->tls, &fd); + struct timeval tv; + tv.tv_sec = 2 * 60; // Two minutes + tv.tv_usec = 0; + setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); + setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); + } + + return self->sock->connected; +} + +bool common_hal_ssl_sslsocket_get_closed(ssl_sslsocket_obj_t* self) { + return self->tls == NULL && self->sock->num < 0; +} + +bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t* self) { + return self->sock->connected; +} + +mp_uint_t common_hal_ssl_sslsocket_get_hash(ssl_sslsocket_obj_t* self) { + return self->sock->num; +} + +bool common_hal_ssl_sslsocket_listen(ssl_sslsocket_obj_t* self, int backlog) { + return common_hal_socketpool_socket_listen(self->sock, backlog); +} + +mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t* self, const uint8_t* buf, mp_uint_t len) { + int received = 0; + bool timed_out = false; + int status = 0; + uint64_t start_ticks = supervisor_ticks_ms64(); + int sockfd; + esp_err_t err = esp_tls_get_conn_sockfd(self->tls, &sockfd); + if (err != ESP_OK) { + mp_raise_OSError(MP_EBADF); + } + while (received == 0 && + status >= 0 && + !timed_out && + !mp_hal_is_interrupted()) { + if (self->sock->timeout_ms != (uint)-1 && self->sock->timeout_ms != 0) { + timed_out = self->sock->timeout_ms == 0 || supervisor_ticks_ms64() - start_ticks >= self->sock->timeout_ms; + } + RUN_BACKGROUND_TASKS; + size_t available = esp_tls_get_bytes_avail(self->tls); + if (available == 0) { + // This reads the raw socket buffer and is used for non-TLS connections + // and between encrypted TLS blocks. + status = lwip_ioctl(sockfd, FIONREAD, &available); + } + size_t remaining = len - received; + if (available > remaining) { + available = remaining; + } + if (available > 0) { + status = esp_tls_conn_read(self->tls, (void*) buf + received, available); + if (status == 0) { + // Reading zero when something is available indicates a closed + // connection. (The available bytes could have been TLS internal.) + break; + } + if (status > 0) { + received += status; + } + } + // In non-blocking mode, fail instead of timing out + if (received==0 && self->sock->timeout_ms == 0) { + mp_raise_OSError(MP_EAGAIN); + } + } + + if (timed_out) { + mp_raise_OSError(ETIMEDOUT); + } + return received; +} + +mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t* self, const uint8_t* buf, mp_uint_t len) { + int sent = -1; + sent = esp_tls_conn_write(self->tls, buf, len); + + if (sent < 0) { + mp_raise_OSError(MP_ENOTCONN); + } + return sent; +} diff --git a/ports/esp32s2/common-hal/ssl/SSLSocket.h b/ports/esp32s2/common-hal/ssl/SSLSocket.h new file mode 100644 index 0000000000..e9e5bff062 --- /dev/null +++ b/ports/esp32s2/common-hal/ssl/SSLSocket.h @@ -0,0 +1,44 @@ +/* + * This file is part of the MicroPython project, http://micropython.org/ + * + * The MIT License (MIT) + * + * Copyright (c) 2021 Lucian Copeland for Adafruit Industries + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef MICROPY_INCLUDED_ESP32S2_COMMON_HAL_SSL_SSLSOCKET_H +#define MICROPY_INCLUDED_ESP32S2_COMMON_HAL_SSL_SSLSOCKET_H + +#include "py/obj.h" + +#include "common-hal/ssl/SSLContext.h" +#include "common-hal/socketpool/Socket.h" + +#include "components/esp-tls/esp_tls.h" + +typedef struct { + mp_obj_base_t base; + socketpool_socket_obj_t * sock; + esp_tls_t* tls; + ssl_sslcontext_obj_t* ssl_context; +} ssl_sslsocket_obj_t; + +#endif // MICROPY_INCLUDED_ESP32S2_COMMON_HAL_SSL_SSLSOCKET_H diff --git a/py/circuitpy_defns.mk b/py/circuitpy_defns.mk index 3ce7c01173..b208178f37 100644 --- a/py/circuitpy_defns.mk +++ b/py/circuitpy_defns.mk @@ -388,6 +388,7 @@ SRC_COMMON_HAL_ALL = \ socketpool/Socket.c \ ssl/__init__.c \ ssl/SSLContext.c \ + ssl/SSLSocket.c \ supervisor/Runtime.c \ supervisor/__init__.c \ watchdog/WatchDogMode.c \ diff --git a/shared-bindings/socketpool/Socket.c b/shared-bindings/socketpool/Socket.c index 0074173405..f169d6acac 100644 --- a/shared-bindings/socketpool/Socket.c +++ b/shared-bindings/socketpool/Socket.c @@ -3,8 +3,8 @@ * * The MIT License (MIT) * - * SPDX-FileCopyrightText: Copyright (c) 2014 Damien P. George - * 2018 Nick Moore for Adafruit Industries + * Copyright (c) 2020 Scott Shawcroft for Adafruit Industries + * Copyright (c) 2021 Lucian Copeland for Adafruit Industries * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -64,6 +64,25 @@ STATIC mp_obj_t socketpool_socket___exit__(size_t n_args, const mp_obj_t *args) } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(socketpool_socket___exit___obj, 4, 4, socketpool_socket___exit__); +//| def accept(self) -> Tuple[Socket, Tuple[str, int]]: +//| """Accept a connection on a listening socket of type SOCK_STREAM, +//| creating a new socket of type SOCK_STREAM. +//| Returns a tuple of (new_socket, remote_address)""" +//| +STATIC mp_obj_t socketpool_socket_accept(mp_obj_t self_in) { + socketpool_socket_obj_t *self = MP_OBJ_TO_PTR(self_in); + uint8_t ip[4]; + uint port; + + socketpool_socket_obj_t * sock = common_hal_socketpool_socket_accept(self, ip, &port); + + mp_obj_t tuple_contents[2]; + tuple_contents[0] = MP_OBJ_FROM_PTR(sock); + tuple_contents[1] = netutils_format_inet_addr(ip, port, NETUTILS_BIG); + return mp_obj_new_tuple(2, tuple_contents); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_1(socketpool_socket_accept_obj, socketpool_socket_accept); + //| def bind(self, address: Tuple[str, int]) -> None: //| """Bind a socket to an address //| @@ -89,41 +108,6 @@ STATIC mp_obj_t socketpool_socket_bind(mp_obj_t self_in, mp_obj_t addr_in) { } STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_bind_obj, socketpool_socket_bind); -//| def listen(self, backlog: int) -> None: -//| """Set socket to listen for incoming connections -//| -//| :param ~int backlog: length of backlog queue for waiting connetions""" -//| ... -//| -STATIC mp_obj_t socketpool_socket_listen(mp_obj_t self_in, mp_obj_t backlog_in) { - socketpool_socket_obj_t *self = MP_OBJ_TO_PTR(self_in); - - int backlog = mp_obj_get_int(backlog_in); - - common_hal_socketpool_socket_listen(self, backlog); - return mp_const_none; -} -STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_listen_obj, socketpool_socket_listen); - -//| def accept(self) -> Tuple[Socket, Tuple[str, int]]: -//| """Accept a connection on a listening socket of type SOCK_STREAM, -//| creating a new socket of type SOCK_STREAM. -//| Returns a tuple of (new_socket, remote_address)""" -//| -STATIC mp_obj_t socketpool_socket_accept(mp_obj_t self_in) { - socketpool_socket_obj_t *self = MP_OBJ_TO_PTR(self_in); - uint8_t ip[4]; - uint port; - - socketpool_socket_obj_t * sock = common_hal_socketpool_socket_accept(self, ip, &port); - - mp_obj_t tuple_contents[2]; - tuple_contents[0] = MP_OBJ_FROM_PTR(sock); - tuple_contents[1] = netutils_format_inet_addr(ip, port, NETUTILS_BIG); - return mp_obj_new_tuple(2, tuple_contents); -} -STATIC MP_DEFINE_CONST_FUN_OBJ_1(socketpool_socket_accept_obj, socketpool_socket_accept); - //| def close(self) -> None: //| """Closes this Socket and makes its resources available to its SocketPool.""" //| @@ -159,31 +143,47 @@ STATIC mp_obj_t socketpool_socket_connect(mp_obj_t self_in, mp_obj_t addr_in) { } STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_connect_obj, socketpool_socket_connect); -//| def send(self, bytes: ReadableBuffer) -> int: -//| """Send some bytes to the connected remote address. -//| Suits sockets of type SOCK_STREAM +//| def listen(self, backlog: int) -> None: +//| """Set socket to listen for incoming connections //| -//| :param ~bytes bytes: some bytes to send""" +//| :param ~int backlog: length of backlog queue for waiting connetions""" //| ... //| -STATIC mp_obj_t socketpool_socket_send(mp_obj_t self_in, mp_obj_t buf_in) { +STATIC mp_obj_t socketpool_socket_listen(mp_obj_t self_in, mp_obj_t backlog_in) { socketpool_socket_obj_t *self = MP_OBJ_TO_PTR(self_in); - if (common_hal_socketpool_socket_get_closed(self)) { - // Bad file number. - mp_raise_OSError(MP_EBADF); - } - if (!common_hal_socketpool_socket_get_connected(self)) { - mp_raise_BrokenPipeError(); - } - mp_buffer_info_t bufinfo; - mp_get_buffer_raise(buf_in, &bufinfo, MP_BUFFER_READ); - mp_int_t ret = common_hal_socketpool_socket_send(self, bufinfo.buf, bufinfo.len); - if (ret == -1) { - mp_raise_BrokenPipeError(); - } - return mp_obj_new_int_from_uint(ret); + + int backlog = mp_obj_get_int(backlog_in); + + common_hal_socketpool_socket_listen(self, backlog); + return mp_const_none; } -STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_send_obj, socketpool_socket_send); +STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_listen_obj, socketpool_socket_listen); + +//| def recvfrom_into(self, buffer: WriteableBuffer) -> Tuple[int, Tuple[str, int]]: +//| """Reads some bytes from a remote address. +//| +//| Returns a tuple containing +//| * the number of bytes received into the given buffer +//| * a remote_address, which is a tuple of ip address and port number +//| +//| :param object buffer: buffer to read into""" +//| ... +//| +STATIC mp_obj_t socketpool_socket_recvfrom_into(mp_obj_t self_in, mp_obj_t data_in) { + socketpool_socket_obj_t *self = MP_OBJ_TO_PTR(self_in); + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(data_in, &bufinfo, MP_BUFFER_WRITE); + + byte ip[4]; + mp_uint_t port; + mp_int_t ret = common_hal_socketpool_socket_recvfrom_into(self, + (byte*)bufinfo.buf, bufinfo.len, ip, &port); + mp_obj_t tuple_contents[2]; + tuple_contents[0] = mp_obj_new_int_from_uint(ret); + tuple_contents[1] = netutils_format_inet_addr(ip, port, NETUTILS_BIG); + return mp_obj_new_tuple(2, tuple_contents); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_recvfrom_into_obj, socketpool_socket_recvfrom_into); //| def recv_into(self, buffer: WriteableBuffer, bufsize: int) -> int: //| """Reads some bytes from the connected remote address, writing @@ -199,7 +199,6 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_send_obj, socketpool_socket_s //| :param int bufsize: optionally, a maximum number of bytes to read.""" //| ... //| - STATIC mp_obj_t socketpool_socket_recv_into(size_t n_args, const mp_obj_t *args) { socketpool_socket_obj_t *self = MP_OBJ_TO_PTR(args[0]); if (common_hal_socketpool_socket_get_closed(self)) { @@ -232,6 +231,32 @@ STATIC mp_obj_t socketpool_socket_recv_into(size_t n_args, const mp_obj_t *args) } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(socketpool_socket_recv_into_obj, 2, 3, socketpool_socket_recv_into); +//| def send(self, bytes: ReadableBuffer) -> int: +//| """Send some bytes to the connected remote address. +//| Suits sockets of type SOCK_STREAM +//| +//| :param ~bytes bytes: some bytes to send""" +//| ... +//| +STATIC mp_obj_t socketpool_socket_send(mp_obj_t self_in, mp_obj_t buf_in) { + socketpool_socket_obj_t *self = MP_OBJ_TO_PTR(self_in); + if (common_hal_socketpool_socket_get_closed(self)) { + // Bad file number. + mp_raise_OSError(MP_EBADF); + } + if (!common_hal_socketpool_socket_get_connected(self)) { + mp_raise_BrokenPipeError(); + } + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(buf_in, &bufinfo, MP_BUFFER_READ); + mp_int_t ret = common_hal_socketpool_socket_send(self, bufinfo.buf, bufinfo.len); + if (ret == -1) { + mp_raise_BrokenPipeError(); + } + return mp_obj_new_int_from_uint(ret); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_send_obj, socketpool_socket_send); + //| def sendto(self, bytes: ReadableBuffer, address: Tuple[str, int]) -> int: //| """Send some bytes to a specific address. //| Suits sockets of type SOCK_DGRAM @@ -240,7 +265,6 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(socketpool_socket_recv_into_obj, 2, 3 //| :param ~tuple address: tuple of (remote_address, remote_port)""" //| ... //| - STATIC mp_obj_t socketpool_socket_sendto(mp_obj_t self_in, mp_obj_t data_in, mp_obj_t addr_in) { socketpool_socket_obj_t *self = MP_OBJ_TO_PTR(self_in); @@ -264,37 +288,28 @@ STATIC mp_obj_t socketpool_socket_sendto(mp_obj_t self_in, mp_obj_t data_in, mp_ } STATIC MP_DEFINE_CONST_FUN_OBJ_3(socketpool_socket_sendto_obj, socketpool_socket_sendto); -//| def recvfrom_into(self, buffer: WriteableBuffer) -> Tuple[int, Tuple[str, int]]: -//| """Reads some bytes from a remote address. +//| def setblocking(self, flag: bool) -> Optional[int]: +//| """Set the blocking behaviour of this socket. //| -//| Returns a tuple containing -//| * the number of bytes received into the given buffer -//| * a remote_address, which is a tuple of ip address and port number -//| -//| :param object buffer: buffer to read into""" +//| :param ~bool flag: False means non-blocking, True means block indefinitely.""" //| ... //| -STATIC mp_obj_t socketpool_socket_recvfrom_into(mp_obj_t self_in, mp_obj_t data_in) { +// method socket.setblocking(flag) +STATIC mp_obj_t socketpool_socket_setblocking(mp_obj_t self_in, mp_obj_t blocking) { socketpool_socket_obj_t *self = MP_OBJ_TO_PTR(self_in); - mp_buffer_info_t bufinfo; - mp_get_buffer_raise(data_in, &bufinfo, MP_BUFFER_WRITE); - - byte ip[4]; - mp_uint_t port; - mp_int_t ret = common_hal_socketpool_socket_recvfrom_into(self, - (byte*)bufinfo.buf, bufinfo.len, ip, &port); - mp_obj_t tuple_contents[2]; - tuple_contents[0] = mp_obj_new_int_from_uint(ret); - tuple_contents[1] = netutils_format_inet_addr(ip, port, NETUTILS_BIG); - return mp_obj_new_tuple(2, tuple_contents); + if (mp_obj_is_true(blocking)) { + common_hal_socketpool_socket_settimeout(self, -1); + } else { + common_hal_socketpool_socket_settimeout(self, 0); + } + return mp_const_none; } -STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_recvfrom_into_obj, socketpool_socket_recvfrom_into); +STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_setblocking_obj, socketpool_socket_setblocking); // //| def setsockopt(self, level: int, optname: int, value: int) -> None: // //| """Sets socket options""" // //| ... // //| - // STATIC mp_obj_t socketpool_socket_setsockopt(size_t n_args, const mp_obj_t *args) { // // mod_network_socket_obj_t *self = MP_OBJ_TO_PTR(args[0]); @@ -324,13 +339,13 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_recvfrom_into_obj, socketpool // } // STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(socketpool_socket_setsockopt_obj, 4, 4, socketpool_socket_setsockopt); + //| def settimeout(self, value: int) -> None: //| """Set the timeout value for this socket. //| //| :param ~int value: timeout in seconds. 0 means non-blocking. None means block indefinitely.""" //| ... //| - STATIC mp_obj_t socketpool_socket_settimeout(mp_obj_t self_in, mp_obj_t timeout_in) { socketpool_socket_obj_t *self = MP_OBJ_TO_PTR(self_in); mp_uint_t timeout_ms; @@ -348,24 +363,6 @@ STATIC mp_obj_t socketpool_socket_settimeout(mp_obj_t self_in, mp_obj_t timeout_ } STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_settimeout_obj, socketpool_socket_settimeout); -// //| def setblocking(self, flag: bool) -> Optional[int]: -// //| """Set the blocking behaviour of this socket. -// //| -// //| :param ~bool flag: False means non-blocking, True means block indefinitely.""" -// //| ... -// //| - -// // method socket.setblocking(flag) -// STATIC mp_obj_t socketpool_socket_setblocking(mp_obj_t self_in, mp_obj_t blocking) { -// // if (mp_obj_is_true(blocking)) { -// // return socket_settimeout(self_in, mp_const_none); -// // } else { -// // return socket_settimeout(self_in, MP_OBJ_NEW_SMALL_INT(0)); -// // } -// return mp_const_none; -// } -// STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_setblocking_obj, socketpool_socket_setblocking); - //| def __hash__(self) -> int: //| """Returns a hash for the Socket.""" //| ... @@ -384,19 +381,19 @@ STATIC const mp_rom_map_elem_t socketpool_socket_locals_dict_table[] = { { MP_ROM_QSTR(MP_QSTR___enter__), MP_ROM_PTR(&default___enter___obj) }, { MP_ROM_QSTR(MP_QSTR___exit__), MP_ROM_PTR(&socketpool_socket___exit___obj) }, { MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&socketpool_socket_close_obj) }, - { MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&socketpool_socket_close_obj) }, - { MP_ROM_QSTR(MP_QSTR_bind), MP_ROM_PTR(&socketpool_socket_bind_obj) }, - { MP_ROM_QSTR(MP_QSTR_listen), MP_ROM_PTR(&socketpool_socket_listen_obj) }, { MP_ROM_QSTR(MP_QSTR_accept), MP_ROM_PTR(&socketpool_socket_accept_obj) }, + { MP_ROM_QSTR(MP_QSTR_bind), MP_ROM_PTR(&socketpool_socket_bind_obj) }, + { MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&socketpool_socket_close_obj) }, { MP_ROM_QSTR(MP_QSTR_connect), MP_ROM_PTR(&socketpool_socket_connect_obj) }, - { MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&socketpool_socket_send_obj) }, - { MP_ROM_QSTR(MP_QSTR_sendto), MP_ROM_PTR(&socketpool_socket_sendto_obj) }, + { MP_ROM_QSTR(MP_QSTR_listen), MP_ROM_PTR(&socketpool_socket_listen_obj) }, { MP_ROM_QSTR(MP_QSTR_recvfrom_into), MP_ROM_PTR(&socketpool_socket_recvfrom_into_obj) }, { MP_ROM_QSTR(MP_QSTR_recv_into), MP_ROM_PTR(&socketpool_socket_recv_into_obj) }, + { MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&socketpool_socket_send_obj) }, + { MP_ROM_QSTR(MP_QSTR_sendto), MP_ROM_PTR(&socketpool_socket_sendto_obj) }, + { MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&socketpool_socket_setblocking_obj) }, // { MP_ROM_QSTR(MP_QSTR_setsockopt), MP_ROM_PTR(&socketpool_socket_setsockopt_obj) }, { MP_ROM_QSTR(MP_QSTR_settimeout), MP_ROM_PTR(&socketpool_socket_settimeout_obj) }, - // { MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&socketpool_socket_setblocking_obj) }, }; STATIC MP_DEFINE_CONST_DICT(socketpool_socket_locals_dict, socketpool_socket_locals_dict_table); diff --git a/shared-bindings/socketpool/Socket.h b/shared-bindings/socketpool/Socket.h index b5dceb50f4..76af6e1e9b 100644 --- a/shared-bindings/socketpool/Socket.h +++ b/shared-bindings/socketpool/Socket.h @@ -31,22 +31,21 @@ extern const mp_obj_type_t socketpool_socket_type; -void common_hal_socketpool_socket_settimeout(socketpool_socket_obj_t* self, mp_uint_t timeout_ms); - -bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t* self, const char* host, size_t hostlen, uint8_t port); -bool common_hal_socketpool_socket_listen(socketpool_socket_obj_t* self, int backlog); socketpool_socket_obj_t * common_hal_socketpool_socket_accept(socketpool_socket_obj_t* self, uint8_t* ip, uint *port); - -bool common_hal_socketpool_socket_connect(socketpool_socket_obj_t* self, const char* host, size_t hostlen, mp_int_t port); -mp_uint_t common_hal_socketpool_socket_send(socketpool_socket_obj_t* self, const uint8_t* buf, mp_uint_t len); -mp_uint_t common_hal_socketpool_socket_recv_into(socketpool_socket_obj_t* self, const uint8_t* buf, mp_uint_t len); -mp_uint_t common_hal_socketpool_socket_sendto(socketpool_socket_obj_t* self, - const char* host, size_t hostlen, uint8_t port, const uint8_t* buf, mp_uint_t len); -mp_uint_t common_hal_socketpool_socket_recvfrom_into(socketpool_socket_obj_t* self, - uint8_t* buf, mp_uint_t len, uint8_t* ip, uint *port); +bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t* self, const char* host, size_t hostlen, uint8_t port); void common_hal_socketpool_socket_close(socketpool_socket_obj_t* self); +bool common_hal_socketpool_socket_connect(socketpool_socket_obj_t* self, const char* host, size_t hostlen, mp_int_t port); bool common_hal_socketpool_socket_get_closed(socketpool_socket_obj_t* self); bool common_hal_socketpool_socket_get_connected(socketpool_socket_obj_t* self); mp_uint_t common_hal_socketpool_socket_get_hash(socketpool_socket_obj_t* self); +mp_uint_t common_hal_socketpool_socket_get_timeout(socketpool_socket_obj_t* self); +bool common_hal_socketpool_socket_listen(socketpool_socket_obj_t* self, int backlog); +mp_uint_t common_hal_socketpool_socket_recvfrom_into(socketpool_socket_obj_t* self, + uint8_t* buf, mp_uint_t len, uint8_t* ip, uint *port); +mp_uint_t common_hal_socketpool_socket_recv_into(socketpool_socket_obj_t* self, const uint8_t* buf, mp_uint_t len); +mp_uint_t common_hal_socketpool_socket_send(socketpool_socket_obj_t* self, const uint8_t* buf, mp_uint_t len); +mp_uint_t common_hal_socketpool_socket_sendto(socketpool_socket_obj_t* self, + const char* host, size_t hostlen, uint8_t port, const uint8_t* buf, mp_uint_t len); +void common_hal_socketpool_socket_settimeout(socketpool_socket_obj_t* self, mp_uint_t timeout_ms); #endif // MICROPY_INCLUDED_SHARED_BINDINGS_SOCKETPOOL_SOCKET_H diff --git a/shared-bindings/socketpool/SocketPool.c b/shared-bindings/socketpool/SocketPool.c index 8f4069faad..6ff6d5f98d 100644 --- a/shared-bindings/socketpool/SocketPool.c +++ b/shared-bindings/socketpool/SocketPool.c @@ -3,8 +3,7 @@ * * The MIT License (MIT) * - * SPDX-FileCopyrightText: Copyright (c) 2014 Damien P. George - * 2018 Nick Moore for Adafruit Industries + * Copyright (c) 2020 Scott Shawcroft for Adafruit Industries * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/shared-bindings/ssl/SSLContext.h b/shared-bindings/ssl/SSLContext.h index f7f985af70..ad4d7e6a8c 100644 --- a/shared-bindings/ssl/SSLContext.h +++ b/shared-bindings/ssl/SSLContext.h @@ -30,12 +30,13 @@ #include "common-hal/ssl/SSLContext.h" #include "shared-bindings/socketpool/Socket.h" +#include "shared-bindings/ssl/SSLSocket.h" extern const mp_obj_type_t ssl_sslcontext_type; void common_hal_ssl_sslcontext_construct(ssl_sslcontext_obj_t* self); -socketpool_socket_obj_t* common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t* self, +ssl_sslsocket_obj_t* common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t* self, socketpool_socket_obj_t* sock, bool server_side, const char* server_hostname); #endif // MICROPY_INCLUDED_SHARED_BINDINGS_SSL_SSLCONTEXT_H diff --git a/shared-bindings/ssl/SSLSocket.c b/shared-bindings/ssl/SSLSocket.c new file mode 100644 index 0000000000..154d3d1d44 --- /dev/null +++ b/shared-bindings/ssl/SSLSocket.c @@ -0,0 +1,320 @@ +/* + * This file is part of the MicroPython project, http://micropython.org/ + * + * The MIT License (MIT) + * + * Copyright (c) 2021 Lucian Copeland for Adafruit Industries + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "shared-bindings/ssl/SSLSocket.h" + +#include +#include + +#include "lib/utils/context_manager_helpers.h" +#include "py/objtuple.h" +#include "py/objlist.h" +#include "py/runtime.h" +#include "py/mperrno.h" + +#include "lib/netutils/netutils.h" + +//| class SSLSocket: +//| """Implements TLS security on a subset of `socketpool.socket` functions. Cannot be created +//| directly. Instead, call `context.wrap_socket` on an existing socket object. +//| +//| Provides a subset of CPython's `ssl.SSLSocket` API. It only implements the versions of +//| recv that do not allocate bytes objects.""" +//| + +//| def __enter__(self) -> Socket: +//| """No-op used by Context Managers.""" +//| ... +//| +// Provided by context manager helper. + +//| def __exit__(self) -> None: +//| """Automatically closes the Socket when exiting a context. See +//| :ref:`lifetime-and-contextmanagers` for more info.""" +//| ... +//| +STATIC mp_obj_t ssl_sslsocket___exit__(size_t n_args, const mp_obj_t *args) { + (void)n_args; + common_hal_ssl_sslsocket_close(args[0]); + return mp_const_none; +} +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(ssl_sslsocket___exit___obj, 4, 4, ssl_sslsocket___exit__); + +//| def accept(self) -> Tuple[Socket, Tuple[str, int]]: +//| """Accept a connection on a listening socket of type SOCK_STREAM, +//| creating a new socket of type SOCK_STREAM. +//| Returns a tuple of (new_socket, remote_address)""" +//| +STATIC mp_obj_t ssl_sslsocket_accept(mp_obj_t self_in) { + ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in); + uint8_t ip[4]; + uint port; + + ssl_sslsocket_obj_t * sslsock = common_hal_ssl_sslsocket_accept(self, ip, &port); + + mp_obj_t tuple_contents[2]; + tuple_contents[0] = MP_OBJ_FROM_PTR(sslsock); + tuple_contents[1] = netutils_format_inet_addr(ip, port, NETUTILS_BIG); + return mp_obj_new_tuple(2, tuple_contents); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_1(ssl_sslsocket_accept_obj, ssl_sslsocket_accept); + +//| def bind(self, address: Tuple[str, int]) -> None: +//| """Bind a socket to an address +//| +//| :param ~tuple address: tuple of (remote_address, remote_port)""" +//| ... +//| +STATIC mp_obj_t ssl_sslsocket_bind(mp_obj_t self_in, mp_obj_t addr_in) { + ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in); + + mp_obj_t *addr_items; + mp_obj_get_array_fixed_n(addr_in, 2, &addr_items); + + size_t hostlen; + const char* host = mp_obj_str_get_data(addr_items[0], &hostlen); + mp_int_t port = mp_obj_get_int(addr_items[1]); + + bool ok = common_hal_ssl_sslsocket_bind(self, host, hostlen, port); + if (!ok) { + mp_raise_ValueError(translate("Error: Failure to bind")); + } + + return mp_const_none; +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(ssl_sslsocket_bind_obj, ssl_sslsocket_bind); + +//| def close(self) -> None: +//| """Closes this Socket""" +//| +STATIC mp_obj_t ssl_sslsocket_close(mp_obj_t self_in) { + ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in); + common_hal_ssl_sslsocket_close(self); + return mp_const_none; +} +STATIC MP_DEFINE_CONST_FUN_OBJ_1(ssl_sslsocket_close_obj, ssl_sslsocket_close); + +//| def connect(self, address: Tuple[str, int]) -> None: +//| """Connect a socket to a remote address +//| +//| :param ~tuple address: tuple of (remote_address, remote_port)""" +//| ... +//| +STATIC mp_obj_t ssl_sslsocket_connect(mp_obj_t self_in, mp_obj_t addr_in) { + ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in); + + mp_obj_t *addr_items; + mp_obj_get_array_fixed_n(addr_in, 2, &addr_items); + + size_t hostlen; + const char* host = mp_obj_str_get_data(addr_items[0], &hostlen); + mp_int_t port = mp_obj_get_int(addr_items[1]); + + bool ok = common_hal_ssl_sslsocket_connect(self, host, hostlen, port); + if (!ok) { + mp_raise_OSError(0); + } + + return mp_const_none; +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(ssl_sslsocket_connect_obj, ssl_sslsocket_connect); + +//| def listen(self, backlog: int) -> None: +//| """Set socket to listen for incoming connections +//| +//| :param ~int backlog: length of backlog queue for waiting connetions""" +//| ... +//| +STATIC mp_obj_t ssl_sslsocket_listen(mp_obj_t self_in, mp_obj_t backlog_in) { + ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in); + + int backlog = mp_obj_get_int(backlog_in); + + common_hal_ssl_sslsocket_listen(self, backlog); + return mp_const_none; +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(ssl_sslsocket_listen_obj, ssl_sslsocket_listen); + +//| def recv_into(self, buffer: WriteableBuffer, bufsize: int) -> int: +//| """Reads some bytes from the connected remote address, writing +//| into the provided buffer. If bufsize <= len(buffer) is given, +//| a maximum of bufsize bytes will be read into the buffer. If no +//| valid value is given for bufsize, the default is the length of +//| the given buffer. +//| +//| Suits sockets of type SOCK_STREAM +//| Returns an int of number of bytes read. +//| +//| :param bytearray buffer: buffer to receive into +//| :param int bufsize: optionally, a maximum number of bytes to read.""" +//| ... +//| +STATIC mp_obj_t ssl_sslsocket_recv_into(size_t n_args, const mp_obj_t *args) { + ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(args[0]); + if (common_hal_ssl_sslsocket_get_closed(self)) { + // Bad file number. + mp_raise_OSError(MP_EBADF); + } + // if (!common_hal_ssl_sslsocket_get_connected(self)) { + // // not connected + // mp_raise_OSError(MP_ENOTCONN); + // } + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(args[1], &bufinfo, MP_BUFFER_WRITE); + mp_int_t len = bufinfo.len; + if (n_args == 3) { + mp_int_t given_len = mp_obj_get_int(args[2]); + if (given_len > len) { + mp_raise_ValueError(translate("buffer too small for requested bytes")); + } + if (given_len > 0 && given_len < len) { + len = given_len; + } + } + + if (len == 0) { + return MP_OBJ_NEW_SMALL_INT(0); + } + + mp_int_t ret = common_hal_ssl_sslsocket_recv_into(self, (byte*)bufinfo.buf, len); + return mp_obj_new_int_from_uint(ret); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(ssl_sslsocket_recv_into_obj, 2, 3, ssl_sslsocket_recv_into); + +//| def send(self, bytes: ReadableBuffer) -> int: +//| """Send some bytes to the connected remote address. +//| Suits sockets of type SOCK_STREAM +//| +//| :param ~bytes bytes: some bytes to send""" +//| ... +//| +STATIC mp_obj_t ssl_sslsocket_send(mp_obj_t self_in, mp_obj_t buf_in) { + ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in); + if (common_hal_ssl_sslsocket_get_closed(self)) { + // Bad file number. + mp_raise_OSError(MP_EBADF); + } + if (!common_hal_ssl_sslsocket_get_connected(self)) { + mp_raise_BrokenPipeError(); + } + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(buf_in, &bufinfo, MP_BUFFER_READ); + mp_int_t ret = common_hal_ssl_sslsocket_send(self, bufinfo.buf, bufinfo.len); + if (ret == -1) { + mp_raise_BrokenPipeError(); + } + return mp_obj_new_int_from_uint(ret); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(ssl_sslsocket_send_obj, ssl_sslsocket_send); + +// //| def setsockopt(self, level: int, optname: int, value: int) -> None: +// //| """Sets socket options""" +// //| ... +// //| +// STATIC mp_obj_t ssl_sslsocket_setsockopt(size_t n_args, const mp_obj_t *args) { +// } +// STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(ssl_sslsocket_setsockopt_obj, 4, 4, ssl_sslsocket_setsockopt); + +//| def settimeout(self, value: int) -> None: +//| """Set the timeout value for this socket. +//| +//| :param ~int value: timeout in seconds. 0 means non-blocking. None means block indefinitely.""" +//| ... +//| +STATIC mp_obj_t ssl_sslsocket_settimeout(mp_obj_t self_in, mp_obj_t timeout_in) { + ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in); + mp_uint_t timeout_ms; + if (timeout_in == mp_const_none) { + timeout_ms = -1; + } else { + #if MICROPY_PY_BUILTINS_FLOAT + timeout_ms = 1000 * mp_obj_get_float(timeout_in); + #else + timeout_ms = 1000 * mp_obj_get_int(timeout_in); + #endif + } + common_hal_ssl_sslsocket_settimeout(self, timeout_ms); + return mp_const_none; +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(ssl_sslsocket_settimeout_obj, ssl_sslsocket_settimeout); + +//| def setblocking(self, flag: bool) -> Optional[int]: +//| """Set the blocking behaviour of this socket. +//| +//| :param ~bool flag: False means non-blocking, True means block indefinitely.""" +//| ... +//| +// method socket.setblocking(flag) +STATIC mp_obj_t ssl_sslsocket_setblocking(mp_obj_t self_in, mp_obj_t blocking) { + ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in); + if (mp_obj_is_true(blocking)) { + common_hal_ssl_sslsocket_settimeout(self, -1); + } else { + common_hal_ssl_sslsocket_settimeout(self, 0); + } + return mp_const_none; +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(ssl_sslsocket_setblocking_obj, ssl_sslsocket_setblocking); + +//| def __hash__(self) -> int: +//| """Returns a hash for the Socket.""" +//| ... +//| +STATIC mp_obj_t ssl_sslsocket_unary_op(mp_unary_op_t op, mp_obj_t self_in) { + switch (op) { + case MP_UNARY_OP_HASH: { + return MP_OBJ_NEW_SMALL_INT(common_hal_ssl_sslsocket_get_hash(MP_OBJ_TO_PTR(self_in))); + } + default: + return MP_OBJ_NULL; // op not supported + } +} + +STATIC const mp_rom_map_elem_t ssl_sslsocket_locals_dict_table[] = { + { MP_ROM_QSTR(MP_QSTR___enter__), MP_ROM_PTR(&default___enter___obj) }, + { MP_ROM_QSTR(MP_QSTR___exit__), MP_ROM_PTR(&ssl_sslsocket___exit___obj) }, + { MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&ssl_sslsocket_close_obj) }, + + { MP_ROM_QSTR(MP_QSTR_accept), MP_ROM_PTR(&ssl_sslsocket_accept_obj) }, + { MP_ROM_QSTR(MP_QSTR_bind), MP_ROM_PTR(&ssl_sslsocket_bind_obj) }, + { MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&ssl_sslsocket_close_obj) }, + { MP_ROM_QSTR(MP_QSTR_connect), MP_ROM_PTR(&ssl_sslsocket_connect_obj) }, + { MP_ROM_QSTR(MP_QSTR_listen), MP_ROM_PTR(&ssl_sslsocket_listen_obj) }, + { MP_ROM_QSTR(MP_QSTR_recv_into), MP_ROM_PTR(&ssl_sslsocket_recv_into_obj) }, + { MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&ssl_sslsocket_send_obj) }, + { MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&ssl_sslsocket_setblocking_obj) }, + // { MP_ROM_QSTR(MP_QSTR_setsockopt), MP_ROM_PTR(&ssl_sslsocket_setsockopt_obj) }, + { MP_ROM_QSTR(MP_QSTR_settimeout), MP_ROM_PTR(&ssl_sslsocket_settimeout_obj) }, +}; + +STATIC MP_DEFINE_CONST_DICT(ssl_sslsocket_locals_dict, ssl_sslsocket_locals_dict_table); + +const mp_obj_type_t ssl_sslsocket_type = { + { &mp_type_type }, + .name = MP_QSTR_SSLSocket, + .locals_dict = (mp_obj_dict_t*)&ssl_sslsocket_locals_dict, + .unary_op = ssl_sslsocket_unary_op, +}; diff --git a/shared-bindings/ssl/SSLSocket.h b/shared-bindings/ssl/SSLSocket.h new file mode 100644 index 0000000000..d8c589fd80 --- /dev/null +++ b/shared-bindings/ssl/SSLSocket.h @@ -0,0 +1,46 @@ +/* + * This file is part of the MicroPython project, http://micropython.org/ + * + * The MIT License (MIT) + * + * Copyright (c) 2020 Lucian Copeland for Adafruit Industries + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef MICROPY_INCLUDED_SHARED_BINDINGS_SSL_SSLSOCKET_H +#define MICROPY_INCLUDED_SHARED_BINDINGS_SSL_SSLSOCKET_H + +#include "common-hal/ssl/SSLSocket.h" + +extern const mp_obj_type_t ssl_sslsocket_type; + +ssl_sslsocket_obj_t * common_hal_ssl_sslsocket_accept(ssl_sslsocket_obj_t* self, uint8_t* ip, uint *port); +bool common_hal_ssl_sslsocket_bind(ssl_sslsocket_obj_t* self, const char* host, size_t hostlen, uint8_t port); +void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t* self); +bool common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t* self, const char* host, size_t hostlen, mp_int_t port); +bool common_hal_ssl_sslsocket_get_closed(ssl_sslsocket_obj_t* self); +bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t* self); +mp_uint_t common_hal_ssl_sslsocket_get_hash(ssl_sslsocket_obj_t* self); +bool common_hal_ssl_sslsocket_listen(ssl_sslsocket_obj_t* self, int backlog); +mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t* self, const uint8_t* buf, mp_uint_t len); +mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t* self, const uint8_t* buf, mp_uint_t len); +void common_hal_ssl_sslsocket_settimeout(ssl_sslsocket_obj_t* self, mp_uint_t timeout_ms); + +#endif // MICROPY_INCLUDED_SHARED_BINDINGS_SSL_SSLSOCKET_H From 815ab5277bf05b0ca85aa084f4e371c5532a7014 Mon Sep 17 00:00:00 2001 From: Lucian Copeland Date: Tue, 26 Jan 2021 14:13:12 -0500 Subject: [PATCH 2/4] Fix stubs error, out of sockets error, invalid TLS leak --- ports/esp32s2/common-hal/socketpool/Socket.c | 8 +++++++- ports/esp32s2/common-hal/ssl/SSLContext.c | 7 ++++--- shared-bindings/ssl/SSLContext.c | 2 +- shared-bindings/ssl/SSLSocket.c | 4 ++-- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/ports/esp32s2/common-hal/socketpool/Socket.c b/ports/esp32s2/common-hal/socketpool/Socket.c index 4cbf4cff26..0022c49c62 100644 --- a/ports/esp32s2/common-hal/socketpool/Socket.c +++ b/ports/esp32s2/common-hal/socketpool/Socket.c @@ -44,8 +44,8 @@ void socket_reset(void) { for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_handles); i++) { if (open_socket_handles[i]) { if (open_socket_handles[i]->num > 0) { + // Close automatically clears socket handle common_hal_socketpool_socket_close(open_socket_handles[i]); - open_socket_handles[i] = NULL; } else { open_socket_handles[i] = NULL; } @@ -136,6 +136,12 @@ void common_hal_socketpool_socket_close(socketpool_socket_obj_t* self) { lwip_close(self->num); self->num = -1; } + // Remove socket record + for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_handles); i++) { + if (open_socket_handles[i] == self) { + open_socket_handles[i] = NULL; + } + } } bool common_hal_socketpool_socket_connect(socketpool_socket_obj_t* self, diff --git a/ports/esp32s2/common-hal/ssl/SSLContext.c b/ports/esp32s2/common-hal/ssl/SSLContext.c index c0179399de..6b05905aa6 100644 --- a/ports/esp32s2/common-hal/ssl/SSLContext.c +++ b/ports/esp32s2/common-hal/ssl/SSLContext.c @@ -38,14 +38,15 @@ void common_hal_ssl_sslcontext_construct(ssl_sslcontext_obj_t* self) { ssl_sslsocket_obj_t* common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t* self, socketpool_socket_obj_t* socket, bool server_side, const char* server_hostname) { + if (socket->type != SOCK_STREAM || socket->num != -1) { + mp_raise_RuntimeError(translate("Invalid socket for TLS")); + } + ssl_sslsocket_obj_t *sock = m_new_obj_with_finaliser(ssl_sslsocket_obj_t); sock->base.type = &ssl_sslsocket_type; sock->ssl_context = self; sock->sock = socket; - if (socket->type != SOCK_STREAM || socket->num != -1) { - mp_raise_RuntimeError(translate("Invalid socket for TLS")); - } esp_tls_t* tls_handle = esp_tls_init(); if (tls_handle == NULL) { mp_raise_espidf_MemoryError(); diff --git a/shared-bindings/ssl/SSLContext.c b/shared-bindings/ssl/SSLContext.c index 9d4df72619..44e9e6bbf8 100644 --- a/shared-bindings/ssl/SSLContext.c +++ b/shared-bindings/ssl/SSLContext.c @@ -51,7 +51,7 @@ STATIC mp_obj_t ssl_sslcontext_make_new(const mp_obj_type_t *type, size_t n_args return MP_OBJ_FROM_PTR(s); } -//| def wrap_socket(sock: socketpool.Socket, *, server_side: bool = False, server_hostname: Optional[str] = None) -> socketpool.Socket: +//| def wrap_socket(sock: socketpool.Socket, *, server_side: bool = False, server_hostname: Optional[str] = None) -> ssl.SSLSocket: //| """Wraps the socket into a socket-compatible class that handles SSL negotiation. //| The socket must be of type SOCK_STREAM.""" //| ... diff --git a/shared-bindings/ssl/SSLSocket.c b/shared-bindings/ssl/SSLSocket.c index 154d3d1d44..cd2daeb3e3 100644 --- a/shared-bindings/ssl/SSLSocket.c +++ b/shared-bindings/ssl/SSLSocket.c @@ -45,7 +45,7 @@ //| recv that do not allocate bytes objects.""" //| -//| def __enter__(self) -> Socket: +//| def __enter__(self) -> SSLSocket: //| """No-op used by Context Managers.""" //| ... //| @@ -63,7 +63,7 @@ STATIC mp_obj_t ssl_sslsocket___exit__(size_t n_args, const mp_obj_t *args) { } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(ssl_sslsocket___exit___obj, 4, 4, ssl_sslsocket___exit__); -//| def accept(self) -> Tuple[Socket, Tuple[str, int]]: +//| def accept(self) -> Tuple[SSLSocket, Tuple[str, int]]: //| """Accept a connection on a listening socket of type SOCK_STREAM, //| creating a new socket of type SOCK_STREAM. //| Returns a tuple of (new_socket, remote_address)""" From a724f6f9545710531b415007f9aa93a7aa90060e Mon Sep 17 00:00:00 2001 From: Lucian Copeland Date: Fri, 29 Jan 2021 11:18:50 -0500 Subject: [PATCH 3/4] Fix documentation builds --- ports/esp32s2/common-hal/ssl/SSLContext.c | 2 +- shared-bindings/ssl/SSLSocket.c | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ports/esp32s2/common-hal/ssl/SSLContext.c b/ports/esp32s2/common-hal/ssl/SSLContext.c index 6b05905aa6..afc3ecce22 100644 --- a/ports/esp32s2/common-hal/ssl/SSLContext.c +++ b/ports/esp32s2/common-hal/ssl/SSLContext.c @@ -38,7 +38,7 @@ void common_hal_ssl_sslcontext_construct(ssl_sslcontext_obj_t* self) { ssl_sslsocket_obj_t* common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t* self, socketpool_socket_obj_t* socket, bool server_side, const char* server_hostname) { - if (socket->type != SOCK_STREAM || socket->num != -1) { + if (socket->type != SOCK_STREAM) { mp_raise_RuntimeError(translate("Invalid socket for TLS")); } diff --git a/shared-bindings/ssl/SSLSocket.c b/shared-bindings/ssl/SSLSocket.c index cd2daeb3e3..c184bceb3f 100644 --- a/shared-bindings/ssl/SSLSocket.c +++ b/shared-bindings/ssl/SSLSocket.c @@ -38,8 +38,8 @@ #include "lib/netutils/netutils.h" //| class SSLSocket: -//| """Implements TLS security on a subset of `socketpool.socket` functions. Cannot be created -//| directly. Instead, call `context.wrap_socket` on an existing socket object. +//| """Implements TLS security on a subset of `socketpool.Socket` functions. Cannot be created +//| directly. Instead, call `wrap_socket` on an existing socket object. //| //| Provides a subset of CPython's `ssl.SSLSocket` API. It only implements the versions of //| recv that do not allocate bytes objects.""" From 8277ffca861754b4df526c62126337c195cc0fee Mon Sep 17 00:00:00 2001 From: Lucian Copeland Date: Sat, 30 Jan 2021 16:13:28 -0500 Subject: [PATCH 4/4] Fix hash, close, error bugs --- ports/esp32s2/common-hal/socketpool/Socket.c | 7 +------ ports/esp32s2/common-hal/ssl/SSLSocket.c | 19 +++++++++++++------ shared-bindings/socketpool/Socket.c | 2 +- shared-bindings/socketpool/Socket.h | 1 - shared-bindings/ssl/SSLSocket.c | 2 +- shared-bindings/ssl/SSLSocket.h | 1 - 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/ports/esp32s2/common-hal/socketpool/Socket.c b/ports/esp32s2/common-hal/socketpool/Socket.c index 0022c49c62..cee940aafc 100644 --- a/ports/esp32s2/common-hal/socketpool/Socket.c +++ b/ports/esp32s2/common-hal/socketpool/Socket.c @@ -3,7 +3,6 @@ * * The MIT License (MIT) * - * Copyright (c) 2020 Scott Shawcroft for Adafruit Industries * Copyright (c) 2020 Lucian Copeland for Adafruit Industries * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -198,10 +197,6 @@ bool common_hal_socketpool_socket_get_connected(socketpool_socket_obj_t* self) { return self->connected; } -mp_uint_t common_hal_socketpool_socket_get_hash(socketpool_socket_obj_t* self) { - return self->num; -} - bool common_hal_socketpool_socket_listen(socketpool_socket_obj_t* self, int backlog) { return lwip_listen(self->num, backlog) == 0; } @@ -289,7 +284,7 @@ mp_uint_t common_hal_socketpool_socket_send(socketpool_socket_obj_t* self, const } if (sent < 0) { - mp_raise_OSError(MP_ENOTCONN); + mp_raise_OSError(errno); } return sent; } diff --git a/ports/esp32s2/common-hal/ssl/SSLSocket.c b/ports/esp32s2/common-hal/ssl/SSLSocket.c index d8e48c3d59..33507e0f4e 100644 --- a/ports/esp32s2/common-hal/ssl/SSLSocket.c +++ b/ports/esp32s2/common-hal/ssl/SSLSocket.c @@ -3,6 +3,7 @@ * * The MIT License (MIT) * + * Copyright (c) 2020 Scott Shawcroft for Adafruit Industries * Copyright (c) 2021 Lucian Copeland for Adafruit Industries * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -51,7 +52,7 @@ bool common_hal_ssl_sslsocket_bind(ssl_sslsocket_obj_t* self, } void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t* self) { - self->sock->connected = false; + common_hal_socketpool_socket_close(self->sock); esp_tls_conn_destroy(self->tls); self->tls = NULL; } @@ -99,10 +100,6 @@ bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t* self) { return self->sock->connected; } -mp_uint_t common_hal_ssl_sslsocket_get_hash(ssl_sslsocket_obj_t* self) { - return self->sock->num; -} - bool common_hal_ssl_sslsocket_listen(ssl_sslsocket_obj_t* self, int backlog) { return common_hal_socketpool_socket_listen(self->sock, backlog); } @@ -163,7 +160,17 @@ mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t* self, const uint8_t sent = esp_tls_conn_write(self->tls, buf, len); if (sent < 0) { - mp_raise_OSError(MP_ENOTCONN); + int esp_tls_code; + int flags; + esp_err_t err = esp_tls_get_and_clear_last_error(self->tls->error_handle, &esp_tls_code, &flags); + + if (err == ESP_ERR_MBEDTLS_SSL_SETUP_FAILED) { + mp_raise_espidf_MemoryError(); + } else if (ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED) { + mp_raise_OSError_msg_varg(translate("Failed SSL handshake")); + } else { + mp_raise_OSError_msg_varg(translate("Unhandled ESP TLS error %d %d %x %d"), esp_tls_code, flags, err, sent); + } } return sent; } diff --git a/shared-bindings/socketpool/Socket.c b/shared-bindings/socketpool/Socket.c index f169d6acac..27440487a7 100644 --- a/shared-bindings/socketpool/Socket.c +++ b/shared-bindings/socketpool/Socket.c @@ -370,7 +370,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(socketpool_socket_settimeout_obj, socketpool_so STATIC mp_obj_t socketpool_socket_unary_op(mp_unary_op_t op, mp_obj_t self_in) { switch (op) { case MP_UNARY_OP_HASH: { - return MP_OBJ_NEW_SMALL_INT(common_hal_socketpool_socket_get_hash(MP_OBJ_TO_PTR(self_in))); + return mp_obj_id(self_in); } default: return MP_OBJ_NULL; // op not supported diff --git a/shared-bindings/socketpool/Socket.h b/shared-bindings/socketpool/Socket.h index 76af6e1e9b..637a7a2146 100644 --- a/shared-bindings/socketpool/Socket.h +++ b/shared-bindings/socketpool/Socket.h @@ -37,7 +37,6 @@ void common_hal_socketpool_socket_close(socketpool_socket_obj_t* self); bool common_hal_socketpool_socket_connect(socketpool_socket_obj_t* self, const char* host, size_t hostlen, mp_int_t port); bool common_hal_socketpool_socket_get_closed(socketpool_socket_obj_t* self); bool common_hal_socketpool_socket_get_connected(socketpool_socket_obj_t* self); -mp_uint_t common_hal_socketpool_socket_get_hash(socketpool_socket_obj_t* self); mp_uint_t common_hal_socketpool_socket_get_timeout(socketpool_socket_obj_t* self); bool common_hal_socketpool_socket_listen(socketpool_socket_obj_t* self, int backlog); mp_uint_t common_hal_socketpool_socket_recvfrom_into(socketpool_socket_obj_t* self, diff --git a/shared-bindings/ssl/SSLSocket.c b/shared-bindings/ssl/SSLSocket.c index c184bceb3f..a937952a5d 100644 --- a/shared-bindings/ssl/SSLSocket.c +++ b/shared-bindings/ssl/SSLSocket.c @@ -286,7 +286,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(ssl_sslsocket_setblocking_obj, ssl_sslsocket_se STATIC mp_obj_t ssl_sslsocket_unary_op(mp_unary_op_t op, mp_obj_t self_in) { switch (op) { case MP_UNARY_OP_HASH: { - return MP_OBJ_NEW_SMALL_INT(common_hal_ssl_sslsocket_get_hash(MP_OBJ_TO_PTR(self_in))); + return mp_obj_id(self_in); } default: return MP_OBJ_NULL; // op not supported diff --git a/shared-bindings/ssl/SSLSocket.h b/shared-bindings/ssl/SSLSocket.h index d8c589fd80..b1f2c513d7 100644 --- a/shared-bindings/ssl/SSLSocket.h +++ b/shared-bindings/ssl/SSLSocket.h @@ -37,7 +37,6 @@ void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t* self); bool common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t* self, const char* host, size_t hostlen, mp_int_t port); bool common_hal_ssl_sslsocket_get_closed(ssl_sslsocket_obj_t* self); bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t* self); -mp_uint_t common_hal_ssl_sslsocket_get_hash(ssl_sslsocket_obj_t* self); bool common_hal_ssl_sslsocket_listen(ssl_sslsocket_obj_t* self, int backlog); mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t* self, const uint8_t* buf, mp_uint_t len); mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t* self, const uint8_t* buf, mp_uint_t len);