compare all static ok hosts with port, add 127.0.0.1 and localhost to it

use strncmp rather than memcmp, one of the strings coul be smaller than the other
This commit is contained in:
Neradoc 2022-07-24 15:15:10 +02:00
parent 6575598ae6
commit 09915ab0b9
1 changed files with 16 additions and 17 deletions

View File

@ -350,41 +350,40 @@ static bool _endswith(const char *str, const char *suffix) {
return strcmp(str + (strlen(str) - strlen(suffix)), suffix) == 0;
}
const char *ok_hosts[] = {"code.circuitpython.org"};
const char *ok_hosts[] = {
"code.circuitpython.org",
"127.0.0.1",
"localhost",
};
static bool _origin_ok(const char *origin) {
const char *http = "http://";
const char *local = ".local";
// note: redirected requests send an Origin of "null" and will be caught by this
if (memcmp(origin, http, strlen(http)) != 0) {
if (strncmp(origin, http, strlen(http)) != 0) {
return false;
}
// These are prefix checks up to : so that any port works.
const char *hostname = common_hal_mdns_server_get_hostname(&mdns);
const char *end = origin + strlen(http) + strlen(hostname) + strlen(local);
if (memcmp(origin + strlen(http), hostname, strlen(hostname)) == 0 &&
memcmp(origin + strlen(http) + strlen(hostname), local, strlen(local)) == 0 &&
if (strncmp(origin + strlen(http), hostname, strlen(hostname)) == 0 &&
strncmp(origin + strlen(http) + strlen(hostname), local, strlen(local)) == 0 &&
(end[0] == '\0' || end[0] == ':')) {
return true;
}
end = origin + strlen(http) + strlen(_our_ip_encoded);
if (memcmp(origin + strlen(http), _our_ip_encoded, strlen(_our_ip_encoded)) == 0 &&
if (strncmp(origin + strlen(http), _our_ip_encoded, strlen(_our_ip_encoded)) == 0 &&
(end[0] == '\0' || end[0] == ':')) {
return true;
}
const char *localhost = "127.0.0.1";
end = origin + strlen(http) + strlen(localhost);
if (memcmp(origin + strlen(http), localhost, strlen(localhost)) == 0
&& (end[0] == '\0' || end[0] == ':')) {
return true;
}
for (size_t i = 0; i < MP_ARRAY_SIZE(ok_hosts); i++) {
// This checks exactly.
if (strcmp(origin + strlen(http), ok_hosts[i]) == 0) {
// Allows any port
end = origin + strlen(http) + strlen(ok_hosts[i]);
if (strncmp(origin + strlen(http), ok_hosts[i], strlen(ok_hosts[i])) == 0
&& (end[0] == '\0' || end[0] == ':')) {
return true;
}
}
@ -911,7 +910,7 @@ static bool _reply(socketpool_socket_obj_t *socket, _request *request) {
} else if (strlen(request->origin) > 0 && !_origin_ok(request->origin)) {
ESP_LOGE(TAG, "bad origin %s", request->origin);
_reply_forbidden(socket, request);
} else if (memcmp(request->path, "/fs/", 4) == 0) {
} else if (strncmp(request->path, "/fs/", 4) == 0) {
if (strcasecmp(request->method, "OPTIONS") == 0) {
// OPTIONS is sent for CORS preflight, unauthenticated
_reply_access_control(socket, request);
@ -1032,7 +1031,7 @@ static bool _reply(socketpool_socket_obj_t *socket, _request *request) {
}
}
}
} else if (memcmp(request->path, "/cp/", 4) == 0) {
} else if (strncmp(request->path, "/cp/", 4) == 0) {
const char *path = request->path + 3;
if (strcasecmp(request->method, "OPTIONS") == 0) {
// handle preflight requests to /cp/
@ -1177,7 +1176,7 @@ static void _process_request(socketpool_socket_obj_t *socket, _request *request)
request->state = STATE_HEADER_KEY;
if (strcasecmp(request->header_key, "Authorization") == 0) {
const char *prefix = "Basic ";
request->authenticated = memcmp(request->header_value, prefix, strlen(prefix)) == 0 &&
request->authenticated = strncmp(request->header_value, prefix, strlen(prefix)) == 0 &&
strcmp(_api_password, request->header_value + strlen(prefix)) == 0;
} else if (strcasecmp(request->header_key, "Host") == 0) {
request->redirect = strcmp(request->header_value, "circuitpython.local") == 0;