From 8980ebfa169059ca993e27f164e92aed51185b1f Mon Sep 17 00:00:00 2001 From: Ted Hess Date: Sun, 16 Apr 2023 12:06:15 -0400 Subject: [PATCH] Simplify CORS checks and don't restrict host names. Minor socket cleanup. --- .../espressif/common-hal/socketpool/Socket.c | 1 - supervisor/shared/web_workflow/web_workflow.c | 82 ++++++++----------- 2 files changed, 35 insertions(+), 48 deletions(-) diff --git a/ports/espressif/common-hal/socketpool/Socket.c b/ports/espressif/common-hal/socketpool/Socket.c index 5af3cdbbc3..ef56480672 100644 --- a/ports/espressif/common-hal/socketpool/Socket.c +++ b/ports/espressif/common-hal/socketpool/Socket.c @@ -88,7 +88,6 @@ STATIC void socket_select_task(void *arg) { } assert(num_triggered > 0); - assert(!FD_ISSET(socket_change_fd, &excptfds)); // Notice event trigger if (FD_ISSET(socket_change_fd, &readfds)) { diff --git a/supervisor/shared/web_workflow/web_workflow.c b/supervisor/shared/web_workflow/web_workflow.c index d24547dd09..faadd52bda 100644 --- a/supervisor/shared/web_workflow/web_workflow.c +++ b/supervisor/shared/web_workflow/web_workflow.c @@ -24,6 +24,9 @@ * THE SOFTWARE. */ +// Include strchrnul() +#define _GNU_SOURCE + #include #include @@ -85,8 +88,8 @@ typedef struct { char destination[256]; char header_key[64]; char header_value[256]; - // We store the origin so we can reply back with it. - char origin[64]; + char origin[64]; // We store the origin so we can reply back with it. + char host[64]; // We store the host to check against origin. size_t content_length; size_t offset; uint64_t timestamp_ms; @@ -454,49 +457,33 @@ static bool _endswith(const char *str, const char *suffix) { return strcmp(str + (strlen(str) - strlen(suffix)), suffix) == 0; } -const char *ok_hosts[] = { - "127.0.0.1", - "localhost", -}; +const char http_scheme[] = "http://"; +#define PREFIX_HTTP_LEN (sizeof(http_scheme) - 1) -static bool _origin_ok(const char *origin) { - const char *http = "http://"; - - // note: redirected requests send an Origin of "null" and will be caught by this - if (strncmp(origin, http, strlen(http)) != 0) { - return false; - } - // These are prefix checks up to : so that any port works. - // TODO: Support DHCP hostname in addition to MDNS. - const char *end; - #if CIRCUITPY_MDNS - if (!common_hal_mdns_server_deinited(&mdns)) { - const char *local = ".local"; - const char *hostname = common_hal_mdns_server_get_hostname(&mdns); - end = origin + strlen(http) + strlen(hostname) + strlen(local); - if (strncmp(origin + strlen(http), hostname, strlen(hostname)) == 0 && - strncmp(origin + strlen(http) + strlen(hostname), local, strlen(local)) == 0 && - (end[0] == '\0' || end[0] == ':')) { - return true; - } - } - #endif - - _update_encoded_ip(); - end = origin + strlen(http) + strlen(_our_ip_encoded); - if (strncmp(origin + strlen(http), _our_ip_encoded, strlen(_our_ip_encoded)) == 0 && - (end[0] == '\0' || end[0] == ':')) { +static bool _origin_ok(_request *request) { + // Origin may be 'null' + if (request->origin[0] == '\0') { return true; } - - for (size_t i = 0; i < MP_ARRAY_SIZE(ok_hosts); i++) { - // Allows any port - end = origin + strlen(http) + strlen(ok_hosts[i]); - if (strncmp(origin + strlen(http), ok_hosts[i], strlen(ok_hosts[i])) == 0 - && (end[0] == '\0' || end[0] == ':')) { + // Origin has http prefix? + if (strncmp(request->origin, http_scheme, PREFIX_HTTP_LEN) != 0) { + // Not HTTP scheme request - ok + request->origin[0] = '\0'; + return true; + } + // Host given? + if (request->host[0] != '\0') { + // OK if host and origin match (fqdn + port #) + if (strcmp(request->host, &request->origin[PREFIX_HTTP_LEN]) == 0) { + return true; + } + // DEBUG: OK if origin is 'localhost' (ignoring port #) + *strchrnul(&request->origin[PREFIX_HTTP_LEN], ':') = '\0'; + if (strcmp(&request->origin[PREFIX_HTTP_LEN], "localhost") == 0) { return true; } } + // Otherwise deny request return false; } @@ -517,8 +504,8 @@ static void _cors_header(socketpool_socket_obj_t *socket, _request *request) { _send_strs(socket, "Access-Control-Allow-Credentials: true\r\n", "Vary: Origin, Accept, Upgrade\r\n", - "Access-Control-Allow-Origin: *\r\n", - NULL); + "Access-Control-Allow-Origin: ", + (request->origin[0] == '\0') ? "*" : request->origin, "\r\n", NULL); } static void _reply_continue(socketpool_socket_obj_t *socket, _request *request) { @@ -1086,11 +1073,7 @@ static bool _reply(socketpool_socket_obj_t *socket, _request *request) { #else _reply_missing(socket, request); #endif - -// For now until CORS is sorted, allow always the origin requester. -// Note: caller knows who we are better than us. CORS is not security -// unless browser cooperates. Do not rely on mDNS or IP. - } else if (strlen(request->origin) > 0 && !_origin_ok(request->origin)) { + } else if (!_origin_ok(request)) { _reply_forbidden(socket, request); } else if (strncmp(request->path, "/fs/", 4) == 0) { if (strcasecmp(request->method, "OPTIONS") == 0) { @@ -1314,6 +1297,7 @@ static bool _reply(socketpool_socket_obj_t *socket, _request *request) { static void _reset_request(_request *request) { request->state = STATE_METHOD; request->origin[0] = '\0'; + request->host[0] = '\0'; request->content_length = 0; request->offset = 0; request->timestamp_ms = 0; @@ -1340,6 +1324,7 @@ static void _process_request(socketpool_socket_obj_t *socket, _request *request) if (len == 0 || len == -MP_ENOTCONN) { // Disconnect - clear 'in-progress' _reset_request(request); + common_hal_socketpool_socket_close(socket); } break; } @@ -1421,6 +1406,8 @@ static void _process_request(socketpool_socket_obj_t *socket, _request *request) request->redirect = strncmp(request->header_value, cp_local, strlen(cp_local)) == 0 && (strlen(request->header_value) == strlen(cp_local) || request->header_value[strlen(cp_local)] == ':'); + strncpy(request->host, request->header_value, sizeof(request->host) - 1); + request->host[sizeof(request->host) - 1] = '\0'; } else if (strcasecmp(request->header_key, "Content-Length") == 0) { request->content_length = strtoul(request->header_value, NULL, 10); } else if (strcasecmp(request->header_key, "Expect") == 0) { @@ -1428,7 +1415,8 @@ static void _process_request(socketpool_socket_obj_t *socket, _request *request) } else if (strcasecmp(request->header_key, "Accept") == 0) { request->json = strcasecmp(request->header_value, "application/json") == 0; } else if (strcasecmp(request->header_key, "Origin") == 0) { - strcpy(request->origin, request->header_value); + strncpy(request->origin, request->header_value, sizeof(request->origin) - 1); + request->origin[sizeof(request->origin) - 1] = '\0'; } else if (strcasecmp(request->header_key, "X-Timestamp") == 0) { request->timestamp_ms = strtoull(request->header_value, NULL, 10); } else if (strcasecmp(request->header_key, "Upgrade") == 0) {