diff --git a/ports/raspberrypi/common-hal/socketpool/Socket.c b/ports/raspberrypi/common-hal/socketpool/Socket.c index 95f79c5880..fb777a3f04 100644 --- a/ports/raspberrypi/common-hal/socketpool/Socket.c +++ b/ports/raspberrypi/common-hal/socketpool/Socket.c @@ -37,6 +37,7 @@ #include "py/stream.h" #include "shared-bindings/socketpool/SocketPool.h" #include "shared/runtime/interrupt_char.h" +#include "shared/netutils/netutils.h" #include "supervisor/port.h" #include "supervisor/shared/tick.h" #include "supervisor/workflow.h" @@ -115,6 +116,10 @@ static inline void poll_sockets(void) { #ifdef MICROPY_EVENT_POLL_HOOK MICROPY_EVENT_POLL_HOOK; #else + RUN_BACKGROUND_TASKS; + if (MP_STATE_THREAD(mp_pending_exception) != MP_OBJ_NULL) { + mp_handle_pending(true); + } mp_hal_delay_ms(1); #endif } @@ -739,14 +744,117 @@ int socketpool_socket_accept(socketpool_socket_obj_t *self, uint8_t *ip, uint32_ return -MP_EBADF; } -socketpool_socket_obj_t *common_hal_socketpool_socket_accept(socketpool_socket_obj_t *self, +socketpool_socket_obj_t *common_hal_socketpool_socket_accept(socketpool_socket_obj_t *socket, uint8_t *ip, uint32_t *port) { - mp_raise_NotImplementedError(NULL); + if (socket->type != MOD_NETWORK_SOCK_STREAM) { + mp_raise_OSError(MP_EOPNOTSUPP); + } + + // Create new socket object, do it here because we must not raise an out-of-memory + // exception when the LWIP concurrency lock is held + socketpool_socket_obj_t *socket2 = m_new_ll_obj_with_finaliser(socketpool_socket_obj_t); + socket2->base.type = &socketpool_socket_type; + + MICROPY_PY_LWIP_ENTER + + if (socket->pcb.tcp == NULL) { + MICROPY_PY_LWIP_EXIT + m_del_obj(socketpool_socket_obj_t, socket2); + mp_raise_OSError(MP_EBADF); + } + + // I need to do this because "tcp_accepted", later, is a macro. + struct tcp_pcb *listener = socket->pcb.tcp; + if (listener->state != LISTEN) { + MICROPY_PY_LWIP_EXIT + m_del_obj(socketpool_socket_obj_t, socket2); + mp_raise_OSError(MP_EINVAL); + } + + // accept incoming connection + struct tcp_pcb *volatile *incoming_connection = &lwip_socket_incoming_array(socket)[socket->incoming.connection.iget]; + if (*incoming_connection == NULL) { + if (socket->timeout == 0) { + MICROPY_PY_LWIP_EXIT + m_del_obj(socketpool_socket_obj_t, socket2); + mp_raise_OSError(MP_EAGAIN); + } else if (socket->timeout != -1) { + mp_uint_t retries = socket->timeout / 100; + while (*incoming_connection == NULL) { + MICROPY_PY_LWIP_EXIT + if (retries-- == 0) { + m_del_obj(socketpool_socket_obj_t, socket2); + mp_raise_OSError(MP_ETIMEDOUT); + } + mp_hal_delay_ms(100); + MICROPY_PY_LWIP_REENTER + } + } else { + while (*incoming_connection == NULL) { + MICROPY_PY_LWIP_EXIT + poll_sockets(); + MICROPY_PY_LWIP_REENTER + } + } + } + + // We get a new pcb handle... + socket2->pcb.tcp = *incoming_connection; + if (++socket->incoming.connection.iget >= socket->incoming.connection.alloc) { + socket->incoming.connection.iget = 0; + } + *incoming_connection = NULL; + + // ...and set up the new socket for it. + socket2->domain = MOD_NETWORK_AF_INET; + socket2->type = MOD_NETWORK_SOCK_STREAM; + socket2->incoming.pbuf = NULL; + socket2->timeout = socket->timeout; + socket2->state = STATE_CONNECTED; + socket2->recv_offset = 0; + socket2->callback = MP_OBJ_NULL; + tcp_arg(socket2->pcb.tcp, (void *)socket2); + tcp_err(socket2->pcb.tcp, _lwip_tcp_error); + tcp_recv(socket2->pcb.tcp, _lwip_tcp_recv); + + tcp_accepted(listener); + + MICROPY_PY_LWIP_EXIT + + // output values + memcpy(ip, &(socket2->pcb.tcp->remote_ip), NETUTILS_IPV4ADDR_BUFSIZE); + *port = (mp_uint_t)socket2->pcb.tcp->remote_port; + return MP_OBJ_FROM_PTR(socket2); } -bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self, +bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *socket, const char *host, size_t hostlen, uint32_t port) { - mp_raise_NotImplementedError(NULL); + uint8_t ip[NETUTILS_IPV4ADDR_BUFSIZE]; + + // get address + ip_addr_t bind_addr; + int error = socketpool_resolve_host(socket->pool, host, &bind_addr); + if (error != 0) { + mp_raise_OSError(EHOSTUNREACH); + } + + err_t err = ERR_ARG; + switch (socket->type) { + case MOD_NETWORK_SOCK_STREAM: { + err = tcp_bind(socket->pcb.tcp, &bind_addr, port); + break; + } + case MOD_NETWORK_SOCK_DGRAM: { + err = udp_bind(socket->pcb.udp, &bind_addr, port); + break; + } + } + + if (err != ERR_OK) { + mp_raise_OSError(error_lookup_table[-err]); + } + + return mp_const_none; } STATIC err_t _lwip_tcp_close_poll(void *arg, struct tcp_pcb *pcb) { @@ -891,8 +999,34 @@ bool common_hal_socketpool_socket_get_connected(socketpool_socket_obj_t *socket) return socket->state == STATE_CONNECTED; } -bool common_hal_socketpool_socket_listen(socketpool_socket_obj_t *self, int backlog) { - mp_raise_NotImplementedError(NULL); +bool common_hal_socketpool_socket_listen(socketpool_socket_obj_t *socket, int backlog) { + if (socket->type != MOD_NETWORK_SOCK_STREAM) { + mp_raise_OSError(MP_EOPNOTSUPP); + } + + struct tcp_pcb *new_pcb = tcp_listen_with_backlog(socket->pcb.tcp, (u8_t)backlog); + if (new_pcb == NULL) { + mp_raise_OSError(MP_ENOMEM); + } + socket->pcb.tcp = new_pcb; + + // Allocate memory for the backlog of connections + if (backlog <= 1) { + socket->incoming.connection.alloc = 0; + socket->incoming.connection.tcp.item = NULL; + } else { + socket->incoming.connection.alloc = backlog; + socket->incoming.connection.tcp.array = m_new0(struct tcp_pcb *, backlog); + } + socket->incoming.connection.iget = 0; + socket->incoming.connection.iput = 0; + + tcp_accept(new_pcb, _lwip_tcp_accept); + + // Socket is no longer considered "new" for purposes of polling + socket->state = STATE_LISTENING; + + return mp_const_none; } mp_uint_t common_hal_socketpool_socket_recvfrom_into(socketpool_socket_obj_t *socket,