Handle MDNS deinited better.

This commit is contained in:
Scott Shawcroft 2023-01-12 11:24:20 -08:00
parent ca80f30348
commit 5c517b7e5a
No known key found for this signature in database
GPG Key ID: 0DFD512649C052DA
5 changed files with 63 additions and 31 deletions

View File

@ -37,6 +37,7 @@ STATIC bool inited = false;
void mdns_server_construct(mdns_server_obj_t *self, bool workflow) {
if (inited) {
self->inited = false;
return;
}
mdns_init();
@ -46,6 +47,8 @@ void mdns_server_construct(mdns_server_obj_t *self, bool workflow) {
snprintf(self->default_hostname, sizeof(self->default_hostname), "cpy-%02x%02x%02x", mac[3], mac[4], mac[5]);
common_hal_mdns_server_set_hostname(self, self->default_hostname);
self->inited = true;
if (workflow) {
// Set a delegated entry to ourselves. This allows us to respond to "circuitpython.local"
// queries as well.
@ -67,21 +70,23 @@ void common_hal_mdns_server_construct(mdns_server_obj_t *self, mp_obj_t network_
mp_raise_ValueError(translate("mDNS only works with built-in WiFi"));
return;
}
if (inited) {
mdns_server_construct(self, false);
if (common_hal_mdns_server_deinited(self)) {
mp_raise_RuntimeError(translate("mDNS already initialized"));
}
mdns_server_construct(self, false);
}
void common_hal_mdns_server_deinit(mdns_server_obj_t *self) {
if (common_hal_mdns_server_deinited(self)) {
return;
}
self->inited = false;
inited = false;
mdns_free();
}
bool common_hal_mdns_server_deinited(mdns_server_obj_t *self) {
// This returns INVALID_STATE when not initialized and INVALID_PARAM when it
// is.
return mdns_instance_name_set(NULL) == ESP_ERR_INVALID_STATE;
return !self->inited;
}
const char *common_hal_mdns_server_get_hostname(mdns_server_obj_t *self) {

View File

@ -34,4 +34,5 @@ typedef struct {
const char *instance_name;
// "cpy-" "XXXXXX" "\0"
char default_hostname[4 + 6 + 1];
bool inited;
} mdns_server_obj_t;

View File

@ -36,6 +36,8 @@
#include "lwip/apps/mdns.h"
#include "lwip/prot/dns.h"
// Track if we are globally inited. This essentially forces one inited MDNS
// object at a time. (But ignores MDNS objects that are deinited.)
STATIC bool inited = false;
#define NETIF_STA (&cyw43_state.netif[CYW43_ITF_STA])
@ -43,11 +45,13 @@ STATIC bool inited = false;
void mdns_server_construct(mdns_server_obj_t *self, bool workflow) {
if (inited) {
self->inited = false;
return;
}
mdns_resp_init();
inited = true;
self->inited = true;
uint8_t mac[6];
wifi_radio_get_mac_address(&common_hal_wifi_radio_obj, mac);
@ -75,12 +79,16 @@ void common_hal_mdns_server_construct(mdns_server_obj_t *self, mp_obj_t network_
}
void common_hal_mdns_server_deinit(mdns_server_obj_t *self) {
if (common_hal_mdns_server_deinited(self)) {
return;
}
self->inited = false;
inited = false;
mdns_resp_remove_netif(NETIF_STA);
}
bool common_hal_mdns_server_deinited(mdns_server_obj_t *self) {
return !mdns_resp_netif_active(NETIF_STA);
return !self->inited;
}
const char *common_hal_mdns_server_get_hostname(mdns_server_obj_t *self) {
@ -215,7 +223,6 @@ STATIC void alloc_search_result_cb(struct mdns_answer *answer, const char *varpa
if ((flags & MDNS_SEARCH_RESULT_FIRST) != 0) {
// first
mdns_remoteservice_obj_t *service = gc_alloc(sizeof(mdns_remoteservice_obj_t), 0, false);
mp_printf(&mp_plat_print, "found service %p\n", service);
if (service == NULL) {
// alloc fails
mdns_search_stop(state->request_id);

View File

@ -37,4 +37,5 @@ typedef struct {
// "cpy-" "XXXXXX" "\0"
char default_hostname[4 + 6 + 1];
const char *service_type[MDNS_MAX_SERVICES];
bool inited;
} mdns_server_obj_t;

View File

@ -202,7 +202,8 @@ STATIC void _update_encoded_ip(void) {
mdns_server_obj_t *supervisor_web_workflow_mdns(mp_obj_t network_interface) {
#if CIRCUITPY_MDNS
if (network_interface == &common_hal_wifi_radio_obj) {
if (network_interface == &common_hal_wifi_radio_obj &&
mdns.base.type == &mdns_server_type) {
return &mdns;
}
#endif
@ -309,11 +310,6 @@ void supervisor_start_web_workflow(void) {
if (first_start) {
port_changed = false;
#if CIRCUITPY_MDNS
mdns_server_construct(&mdns, true);
mdns.base.type = &mdns_server_type;
common_hal_mdns_server_set_instance_name(&mdns, MICROPY_HW_BOARD_NAME);
#endif
pool.base.type = &socketpool_socketpool_type;
common_hal_socketpool_socketpool_construct(&pool, &common_hal_wifi_radio_obj);
@ -322,13 +318,26 @@ void supervisor_start_web_workflow(void) {
websocket_init();
}
#if CIRCUITPY_MDNS
// Try to start MDNS if the user deinited it.
if (mdns.base.type != &mdns_server_type ||
common_hal_mdns_server_deinited(&mdns)) {
mdns_server_construct(&mdns, true);
mdns.base.type = &mdns_server_type;
if (!common_hal_mdns_server_deinited(&mdns)) {
common_hal_mdns_server_set_instance_name(&mdns, MICROPY_HW_BOARD_NAME);
}
}
#endif
if (port_changed) {
common_hal_socketpool_socket_close(&listening);
}
if (first_start || port_changed) {
web_api_port = new_port;
#if CIRCUITPY_MDNS
common_hal_mdns_server_advertise_service(&mdns, "_circuitpython", "_tcp", web_api_port);
if (!common_hal_mdns_server_deinited(&mdns)) {
common_hal_mdns_server_advertise_service(&mdns, "_circuitpython", "_tcp", web_api_port);
}
#endif
socketpool_socket(&pool, SOCKETPOOL_AF_INET, SOCKETPOOL_SOCK_STREAM, &listening);
common_hal_socketpool_socket_settimeout(&listening, 0);
@ -453,17 +462,18 @@ static bool _origin_ok(const char *origin) {
}
// These are prefix checks up to : so that any port works.
// TODO: Support DHCP hostname in addition to MDNS.
#if CIRCUITPY_MDNS
const char *local = ".local";
const char *hostname = common_hal_mdns_server_get_hostname(&mdns);
const char *end = origin + strlen(http) + strlen(hostname) + strlen(local);
if (strncmp(origin + strlen(http), hostname, strlen(hostname)) == 0 &&
strncmp(origin + strlen(http) + strlen(hostname), local, strlen(local)) == 0 &&
(end[0] == '\0' || end[0] == ':')) {
return true;
}
#else
const char *end;
#if CIRCUITPY_MDNS
if (!common_hal_mdns_server_deinited(&mdns)) {
const char *local = ".local";
const char *hostname = common_hal_mdns_server_get_hostname(&mdns);
end = origin + strlen(http) + strlen(hostname) + strlen(local);
if (strncmp(origin + strlen(http), hostname, strlen(hostname)) == 0 &&
strncmp(origin + strlen(http) + strlen(hostname), local, strlen(local)) == 0 &&
(end[0] == '\0' || end[0] == ':')) {
return true;
}
}
#endif
_update_encoded_ip();
@ -742,12 +752,13 @@ static void _reply_with_file(socketpool_socket_obj_t *socket, _request *request,
}
static void _reply_with_devices_json(socketpool_socket_obj_t *socket, _request *request) {
size_t total_results = 0;
#if CIRCUITPY_MDNS
mdns_remoteservice_obj_t found_devices[32];
size_t total_results = mdns_server_find(&mdns, "_circuitpython", "_tcp", 1, found_devices, MP_ARRAY_SIZE(found_devices));
if (!common_hal_mdns_server_deinited(&mdns)) {
total_results = mdns_server_find(&mdns, "_circuitpython", "_tcp", 1, found_devices, MP_ARRAY_SIZE(found_devices));
}
size_t count = MIN(total_results, MP_ARRAY_SIZE(found_devices));
#else
size_t total_results = 0;
#endif
socketpool_socket_send(socket, (const uint8_t *)OK_JSON, strlen(OK_JSON));
_cors_header(socket, request);
@ -784,10 +795,11 @@ static void _reply_with_version_json(socketpool_socket_obj_t *socket, _request *
_send_str(socket, "\r\n");
mp_print_t _socket_print = {socket, _print_chunk};
#if CIRCUITPY_MDNS
const char *hostname = common_hal_mdns_server_get_hostname(&mdns);
#else
const char *hostname = "";
#if CIRCUITPY_MDNS
if (!common_hal_mdns_server_deinited(&mdns)) {
hostname = common_hal_mdns_server_get_hostname(&mdns);
}
#endif
_update_encoded_ip();
// Note: this leverages the fact that C concats consecutive string literals together.
@ -1032,7 +1044,13 @@ static void _decode_percents(char *str) {
static bool _reply(socketpool_socket_obj_t *socket, _request *request) {
if (request->redirect) {
#if CIRCUITPY_MDNS
_reply_redirect(socket, request, request->path);
if (!common_hal_mdns_server_deinited(&mdns)) {
_reply_redirect(socket, request, request->path);
} else {
_reply_missing(socket, request);
}
#else
_reply_missing(socket, request);
#endif
} else if (strlen(request->origin) > 0 && !_origin_ok(request->origin)) {
_reply_forbidden(socket, request);