362 lines
12 KiB
C
362 lines
12 KiB
C
/*
|
|
* This file is part of the MicroPython project, http://micropython.org/
|
|
*
|
|
* The MIT License (MIT)
|
|
*
|
|
* Copyright (c) 2016 Linaro Ltd.
|
|
* Copyright (c) 2019 Paul Sokolovsky
|
|
* Copyright (c) 2022 Jeff Epler for Adafruit Industries
|
|
*
|
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
* of this software and associated documentation files (the "Software"), to deal
|
|
* in the Software without restriction, including without limitation the rights
|
|
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
* copies of the Software, and to permit persons to whom the Software is
|
|
* furnished to do so, subject to the following conditions:
|
|
*
|
|
* The above copyright notice and this permission notice shall be included in
|
|
* all copies or substantial portions of the Software.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|
* THE SOFTWARE.
|
|
*/
|
|
|
|
#include "shared-bindings/ssl/SSLSocket.h"
|
|
#include "shared-bindings/socketpool/Socket.h"
|
|
#include "shared-bindings/ssl/SSLContext.h"
|
|
#include "shared-bindings/socketpool/SocketPool.h"
|
|
#include "shared-bindings/socketpool/Socket.h"
|
|
|
|
#include "shared/runtime/interrupt_char.h"
|
|
#include "py/mperrno.h"
|
|
#include "py/mphal.h"
|
|
#include "py/objstr.h"
|
|
#include "py/runtime.h"
|
|
#include "py/stream.h"
|
|
#include "supervisor/shared/tick.h"
|
|
|
|
#if defined(MBEDTLS_ERROR_C)
|
|
#include "../../lib/mbedtls_errors/mp_mbedtls_errors.c"
|
|
#endif
|
|
|
|
#ifdef MBEDTLS_DEBUG_C
|
|
#include "mbedtls/debug.h"
|
|
STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, const char *str) {
|
|
(void)ctx;
|
|
(void)level;
|
|
mp_printf(&mp_plat_print, "DBG:%s:%04d: %s\n", file, line, str);
|
|
}
|
|
#define DEBUG(fmt, ...) mp_printf(&mp_plat_print, "DBG:%s:%04d: " fmt "\n", __FILE__, __LINE__,##__VA_ARGS__)
|
|
#else
|
|
#define DEBUG(...) do {} while (0)
|
|
#endif
|
|
|
|
STATIC NORETURN void mbedtls_raise_error(int err) {
|
|
// _mbedtls_ssl_send and _mbedtls_ssl_recv (below) turn positive error codes from the
|
|
// underlying socket into negative codes to pass them through mbedtls. Here we turn them
|
|
// positive again so they get interpreted as the OSError they really are. The
|
|
// cut-off of -256 is a bit hacky, sigh.
|
|
if (err < 0 && err > -256) {
|
|
mp_raise_OSError(-err);
|
|
}
|
|
|
|
#if defined(MBEDTLS_ERROR_C)
|
|
// Including mbedtls_strerror takes about 1.5KB due to the error strings.
|
|
// MBEDTLS_ERROR_C is the define used by mbedtls to conditionally include mbedtls_strerror.
|
|
// It is set/unset in the MBEDTLS_CONFIG_FILE which is defined in the Makefile.
|
|
|
|
// Try to allocate memory for the message
|
|
#define ERR_STR_MAX 80 // mbedtls_strerror truncates if it doesn't fit
|
|
mp_obj_str_t *o_str = m_new_obj_maybe(mp_obj_str_t);
|
|
byte *o_str_buf = m_new_maybe(byte, ERR_STR_MAX);
|
|
if (o_str == NULL || o_str_buf == NULL) {
|
|
mp_raise_OSError(err);
|
|
}
|
|
|
|
// print the error message into the allocated buffer
|
|
mbedtls_strerror(err, (char *)o_str_buf, ERR_STR_MAX);
|
|
size_t len = strlen((char *)o_str_buf);
|
|
|
|
// Put the exception object together
|
|
o_str->base.type = &mp_type_str;
|
|
o_str->data = o_str_buf;
|
|
o_str->len = len;
|
|
o_str->hash = qstr_compute_hash(o_str->data, o_str->len);
|
|
// raise
|
|
mp_obj_t args[2] = { MP_OBJ_NEW_SMALL_INT(err), MP_OBJ_FROM_PTR(o_str)};
|
|
nlr_raise(mp_obj_exception_make_new(&mp_type_OSError, 2, 0, args));
|
|
#else
|
|
// mbedtls is compiled without error strings so we simply return the err number
|
|
mp_raise_OSError(err); // err is typically a large negative number
|
|
#endif
|
|
}
|
|
|
|
STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
|
|
mp_obj_t sock = *(mp_obj_t *)ctx;
|
|
|
|
// mp_uint_t out_sz = sock_stream->write(sock, buf, len, &err);
|
|
mp_int_t out_sz = socketpool_socket_send(sock, buf, len);
|
|
DEBUG("socket_send() -> %d", out_sz);
|
|
if (out_sz < 0) {
|
|
int err = -out_sz;
|
|
DEBUG("sock_stream->write() -> %d nonblocking? %d", out_sz, mp_is_nonblocking_error(err));
|
|
if (mp_is_nonblocking_error(err)) {
|
|
return MBEDTLS_ERR_SSL_WANT_WRITE;
|
|
}
|
|
return -err; // convert an MP_ERRNO to something mbedtls passes through as error
|
|
} else {
|
|
return out_sz;
|
|
}
|
|
}
|
|
|
|
// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket
|
|
STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
|
|
mp_obj_t sock = *(mp_obj_t *)ctx;
|
|
|
|
mp_int_t out_sz = socketpool_socket_recv_into(sock, buf, len);
|
|
DEBUG("socket_recv() -> %d", out_sz);
|
|
if (out_sz < 0) {
|
|
int err = -out_sz;
|
|
if (mp_is_nonblocking_error(err)) {
|
|
return MBEDTLS_ERR_SSL_WANT_READ;
|
|
}
|
|
return -err;
|
|
} else {
|
|
return out_sz;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t *self,
|
|
socketpool_socket_obj_t *socket, bool server_side, const char *server_hostname) {
|
|
|
|
if (socket->type != SOCKETPOOL_SOCK_STREAM) {
|
|
mp_raise_RuntimeError(translate("Invalid socket for TLS"));
|
|
}
|
|
|
|
ssl_sslsocket_obj_t *o = m_new_obj_with_finaliser(ssl_sslsocket_obj_t);
|
|
o->base.type = &ssl_sslsocket_type;
|
|
o->ssl_context = self;
|
|
o->sock = socket;
|
|
|
|
mbedtls_ssl_init(&o->ssl);
|
|
mbedtls_ssl_config_init(&o->conf);
|
|
mbedtls_x509_crt_init(&o->cacert);
|
|
mbedtls_x509_crt_init(&o->cert);
|
|
mbedtls_pk_init(&o->pkey);
|
|
mbedtls_ctr_drbg_init(&o->ctr_drbg);
|
|
#ifdef MBEDTLS_DEBUG_C
|
|
// Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose
|
|
mbedtls_debug_set_threshold(4);
|
|
#endif
|
|
|
|
mbedtls_entropy_init(&o->entropy);
|
|
const byte seed[] = "upy";
|
|
int ret = mbedtls_ctr_drbg_seed(&o->ctr_drbg, mbedtls_entropy_func, &o->entropy, seed, sizeof(seed));
|
|
if (ret != 0) {
|
|
goto cleanup;
|
|
}
|
|
|
|
ret = mbedtls_ssl_config_defaults(&o->conf,
|
|
server_side ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT,
|
|
MBEDTLS_SSL_TRANSPORT_STREAM,
|
|
MBEDTLS_SSL_PRESET_DEFAULT);
|
|
if (ret != 0) {
|
|
goto cleanup;
|
|
}
|
|
|
|
if (self->crt_bundle_attach != NULL) {
|
|
mbedtls_ssl_conf_authmode(&o->conf, MBEDTLS_SSL_VERIFY_REQUIRED);
|
|
self->crt_bundle_attach(&o->conf);
|
|
} else if (self->cacert_buf && self->cacert_bytes) {
|
|
ret = mbedtls_x509_crt_parse(&o->cacert, self->cacert_buf, self->cacert_bytes);
|
|
if (ret != 0) {
|
|
goto cleanup;
|
|
}
|
|
mbedtls_ssl_conf_authmode(&o->conf, MBEDTLS_SSL_VERIFY_REQUIRED);
|
|
mbedtls_ssl_conf_ca_chain(&o->conf, &o->cacert, NULL);
|
|
|
|
} else {
|
|
mbedtls_ssl_conf_authmode(&o->conf, MBEDTLS_SSL_VERIFY_NONE);
|
|
}
|
|
mbedtls_ssl_conf_rng(&o->conf, mbedtls_ctr_drbg_random, &o->ctr_drbg);
|
|
#ifdef MBEDTLS_DEBUG_C
|
|
mbedtls_ssl_conf_dbg(&o->conf, mbedtls_debug, NULL);
|
|
#endif
|
|
|
|
ret = mbedtls_ssl_setup(&o->ssl, &o->conf);
|
|
if (ret != 0) {
|
|
goto cleanup;
|
|
}
|
|
|
|
if (server_hostname != NULL) {
|
|
ret = mbedtls_ssl_set_hostname(&o->ssl, server_hostname);
|
|
if (ret != 0) {
|
|
goto cleanup;
|
|
}
|
|
}
|
|
|
|
mbedtls_ssl_set_bio(&o->ssl, &o->sock, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL);
|
|
|
|
if (self->cert_buf.buf != NULL) {
|
|
ret = mbedtls_pk_parse_key(&o->pkey, self->key_buf.buf, self->key_buf.len + 1, NULL, 0);
|
|
if (ret != 0) {
|
|
goto cleanup;
|
|
}
|
|
ret = mbedtls_x509_crt_parse(&o->cert, self->cert_buf.buf, self->cert_buf.len + 1);
|
|
if (ret != 0) {
|
|
goto cleanup;
|
|
}
|
|
|
|
ret = mbedtls_ssl_conf_own_cert(&o->conf, &o->cert, &o->pkey);
|
|
if (ret != 0) {
|
|
goto cleanup;
|
|
}
|
|
}
|
|
return o;
|
|
cleanup:
|
|
mbedtls_pk_free(&o->pkey);
|
|
mbedtls_x509_crt_free(&o->cert);
|
|
mbedtls_x509_crt_free(&o->cacert);
|
|
mbedtls_ssl_free(&o->ssl);
|
|
mbedtls_ssl_config_free(&o->conf);
|
|
mbedtls_ctr_drbg_free(&o->ctr_drbg);
|
|
mbedtls_entropy_free(&o->entropy);
|
|
|
|
if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED) {
|
|
mp_raise_type(&mp_type_MemoryError);
|
|
} else if (ret == MBEDTLS_ERR_PK_BAD_INPUT_DATA) {
|
|
mp_raise_ValueError(MP_ERROR_TEXT("invalid key"));
|
|
} else if (ret == MBEDTLS_ERR_X509_BAD_INPUT_DATA) {
|
|
mp_raise_ValueError(MP_ERROR_TEXT("invalid cert"));
|
|
} else {
|
|
mbedtls_raise_error(ret);
|
|
}
|
|
}
|
|
|
|
mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, uint32_t len) {
|
|
int ret = mbedtls_ssl_read(&self->ssl, buf, len);
|
|
DEBUG("recv_into mbedtls_ssl_read() -> %d\n", ret);
|
|
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
|
|
DEBUG("returning %d\n", 0);
|
|
// end of stream
|
|
return 0;
|
|
}
|
|
if (ret >= 0) {
|
|
DEBUG("returning %d\n", ret);
|
|
return ret;
|
|
}
|
|
if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
|
|
ret = MP_EWOULDBLOCK;
|
|
} else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
|
|
// If handshake is not finished, read attempt may end up in protocol
|
|
// wanting to write next handshake message. The same may happen with
|
|
// renegotation.
|
|
ret = MP_EWOULDBLOCK;
|
|
}
|
|
DEBUG("returning [error case] %d\n", -ret);
|
|
return -ret;
|
|
}
|
|
|
|
mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, uint32_t len) {
|
|
int ret = mbedtls_ssl_write(&self->ssl, buf, len);
|
|
DEBUG("send mbedtls_ssl_write() -> %d\n", ret);
|
|
if (ret >= 0) {
|
|
DEBUG("returning %d\n", ret);
|
|
return ret;
|
|
}
|
|
if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
|
|
ret = MP_EWOULDBLOCK;
|
|
} else if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
|
|
// If handshake is not finished, write attempt may end up in protocol
|
|
// wanting to read next handshake message. The same may happen with
|
|
// renegotation.
|
|
ret = MP_EWOULDBLOCK;
|
|
}
|
|
DEBUG("returning [error case] %d\n", -ret);
|
|
return -ret;
|
|
}
|
|
|
|
bool common_hal_ssl_sslsocket_bind(ssl_sslsocket_obj_t *self, const char *host, size_t hostlen, uint32_t port) {
|
|
mp_raise_NotImplementedError(NULL);
|
|
}
|
|
|
|
void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t *self) {
|
|
self->closed = true;
|
|
common_hal_socketpool_socket_close(self->sock);
|
|
mbedtls_pk_free(&self->pkey);
|
|
mbedtls_x509_crt_free(&self->cert);
|
|
mbedtls_x509_crt_free(&self->cacert);
|
|
mbedtls_ssl_free(&self->ssl);
|
|
mbedtls_ssl_config_free(&self->conf);
|
|
mbedtls_ctr_drbg_free(&self->ctr_drbg);
|
|
mbedtls_entropy_free(&self->entropy);
|
|
}
|
|
|
|
STATIC void do_handshake(ssl_sslsocket_obj_t *self) {
|
|
int ret;
|
|
while ((ret = mbedtls_ssl_handshake(&self->ssl)) != 0) {
|
|
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
|
|
goto cleanup;
|
|
}
|
|
RUN_BACKGROUND_TASKS;
|
|
if (MP_STATE_THREAD(mp_pending_exception) != MP_OBJ_NULL) {
|
|
mp_handle_pending(true);
|
|
}
|
|
mp_hal_delay_ms(1);
|
|
}
|
|
|
|
return;
|
|
|
|
cleanup:
|
|
self->closed = true;
|
|
mbedtls_pk_free(&self->pkey);
|
|
mbedtls_x509_crt_free(&self->cert);
|
|
mbedtls_x509_crt_free(&self->cacert);
|
|
mbedtls_ssl_free(&self->ssl);
|
|
mbedtls_ssl_config_free(&self->conf);
|
|
mbedtls_ctr_drbg_free(&self->ctr_drbg);
|
|
mbedtls_entropy_free(&self->entropy);
|
|
|
|
if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED) {
|
|
mp_raise_type(&mp_type_MemoryError);
|
|
} else if (ret == MBEDTLS_ERR_PK_BAD_INPUT_DATA) {
|
|
mp_raise_ValueError(MP_ERROR_TEXT("invalid key"));
|
|
} else if (ret == MBEDTLS_ERR_X509_BAD_INPUT_DATA) {
|
|
mp_raise_ValueError(MP_ERROR_TEXT("invalid cert"));
|
|
} else {
|
|
mbedtls_raise_error(ret);
|
|
}
|
|
}
|
|
|
|
void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self, const char *host, size_t hostlen, uint32_t port) {
|
|
common_hal_socketpool_socket_connect(self->sock, host, hostlen, port);
|
|
do_handshake(self);
|
|
}
|
|
|
|
bool common_hal_ssl_sslsocket_get_closed(ssl_sslsocket_obj_t *self) {
|
|
return self->closed;
|
|
}
|
|
|
|
bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t *self) {
|
|
return !self->closed;
|
|
}
|
|
|
|
bool common_hal_ssl_sslsocket_listen(ssl_sslsocket_obj_t *self, int backlog) {
|
|
mp_raise_NotImplementedError(NULL);
|
|
}
|
|
|
|
ssl_sslsocket_obj_t *common_hal_ssl_sslsocket_accept(ssl_sslsocket_obj_t *self, uint8_t *ip, uint32_t *port) {
|
|
mp_raise_NotImplementedError(NULL);
|
|
}
|
|
|
|
void common_hal_ssl_sslsocket_settimeout(ssl_sslsocket_obj_t *self, uint32_t timeout_ms) {
|
|
self->sock->timeout = timeout_ms;
|
|
}
|