extmod/modussl: Fix ussl read/recv/send/write errors when non-blocking.

Also fix related problems with socket on esp32, improve docs for
wrap_socket, and add more tests.
This commit is contained in:
Thorsten von Eicken 2020-04-02 10:01:16 -07:00 committed by Damien George
parent 2eed9780ba
commit 2c1299b007
10 changed files with 372 additions and 19 deletions

View File

@ -13,16 +13,23 @@ facilities for network sockets, both client-side and server-side.
Functions
---------
.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None)
.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None, do_handshake=True)
Takes a `stream` *sock* (usually usocket.socket instance of ``SOCK_STREAM`` type),
and returns an instance of ssl.SSLSocket, which wraps the underlying stream in
an SSL context. Returned object has the usual `stream` interface methods like
``read()``, ``write()``, etc. In MicroPython, the returned object does not expose
socket interface and methods like ``recv()``, ``send()``. In particular, a
server-side SSL socket should be created from a normal socket returned from
``read()``, ``write()``, etc.
A server-side SSL socket should be created from a normal socket returned from
:meth:`~usocket.socket.accept()` on a non-SSL listening server socket.
- *do_handshake* determines whether the handshake is done as part of the ``wrap_socket``
or whether it is deferred to be done as part of the initial reads or writes
(there is no ``do_handshake`` method as in CPython).
For blocking sockets doing the handshake immediately is standard. For non-blocking
sockets (i.e. when the *sock* passed into ``wrap_socket`` is in non-blocking mode)
the handshake should generally be deferred because otherwise ``wrap_socket`` blocks
until it completes. Note that in AXTLS the handshake can be deferred until the first
read or write but it then blocks until completion.
Depending on the underlying module implementation in a particular
:term:`MicroPython port`, some or all keyword arguments above may be not supported.
@ -31,6 +38,11 @@ Functions
Some implementations of ``ussl`` module do NOT validate server certificates,
which makes an SSL connection established prone to man-in-the-middle attacks.
CPython's ``wrap_socket`` returns an ``SSLSocket`` object which has methods typical
for sockets, such as ``send``, ``recv``, etc. MicroPython's ``wrap_socket``
returns an object more similar to CPython's ``SSLObject`` which does not have
these socket methods.
Exceptions
----------

View File

@ -167,10 +167,15 @@ STATIC mp_obj_ssl_socket_t *ussl_socket_new(mp_obj_t sock, struct ssl_args *args
o->ssl_sock = ssl_client_new(o->ssl_ctx, (long)sock, NULL, 0, ext);
if (args->do_handshake.u_bool) {
int res = ssl_handshake_status(o->ssl_sock);
int r = ssl_handshake_status(o->ssl_sock);
if (res != SSL_OK) {
ussl_raise_error(res);
if (r != SSL_OK) {
if (r == SSL_CLOSE_NOTIFY) { // EOF
r = MP_ENOTCONN;
} else if (r == SSL_EAGAIN) {
r = MP_EAGAIN;
}
ussl_raise_error(r);
}
}
@ -242,8 +247,24 @@ STATIC mp_uint_t ussl_socket_write(mp_obj_t o_in, const void *buf, mp_uint_t siz
return MP_STREAM_ERROR;
}
mp_int_t r = ssl_write(o->ssl_sock, buf, size);
mp_int_t r;
eagain:
r = ssl_write(o->ssl_sock, buf, size);
if (r == 0) {
// see comment in ussl_socket_read above
if (o->blocking) {
goto eagain;
} else {
r = SSL_EAGAIN;
}
}
if (r < 0) {
if (r == SSL_CLOSE_NOTIFY || r == SSL_ERROR_CONN_LOST) {
return 0; // EOF
}
if (r == SSL_EAGAIN) {
r = MP_EAGAIN;
}
*errcode = r;
return MP_STREAM_ERROR;
}

View File

@ -133,6 +133,7 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
}
}
// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket
STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
mp_obj_t sock = *(mp_obj_t *)ctx;
@ -171,7 +172,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
mbedtls_pk_init(&o->pkey);
mbedtls_ctr_drbg_init(&o->ctr_drbg);
#ifdef MBEDTLS_DEBUG_C
// Debug level (0-4)
// Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose
mbedtls_debug_set_threshold(0);
#endif

View File

@ -558,7 +558,8 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
MP_THREAD_GIL_EXIT();
int r = lwip_write(sock->fd, data + sentlen, datalen - sentlen);
MP_THREAD_GIL_ENTER();
if (r < 0 && errno != EWOULDBLOCK) {
// lwip returns EINPROGRESS when trying to send right after a non-blocking connect
if (r < 0 && errno != EWOULDBLOCK && errno != EINPROGRESS) {
mp_raise_OSError(errno);
}
if (r > 0) {
@ -567,7 +568,7 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
check_for_exceptions();
}
if (sentlen == 0) {
mp_raise_OSError(MP_ETIMEDOUT);
mp_raise_OSError(sock->retries == 0 ? MP_EWOULDBLOCK : MP_ETIMEDOUT);
}
return sentlen;
}
@ -650,7 +651,8 @@ STATIC mp_uint_t socket_stream_write(mp_obj_t self_in, const void *buf, mp_uint_
if (r > 0) {
return r;
}
if (r < 0 && errno != EWOULDBLOCK) {
// lwip returns MP_EINPROGRESS when trying to write right after a non-blocking connect
if (r < 0 && errno != EWOULDBLOCK && errno != EINPROGRESS) {
*errcode = errno;
return MP_STREAM_ERROR;
}

View File

@ -1,9 +1,9 @@
# test that socket.accept() on a socket with timeout raises ETIMEDOUT
try:
import usocket as socket
import uerrno as errno, usocket as socket
except:
import socket
import errno, socket
try:
socket.socket.settimeout
@ -18,5 +18,5 @@ s.listen(1)
try:
s.accept()
except OSError as er:
print(er.args[0] in (110, "timed out")) # 110 is ETIMEDOUT; CPython uses a string
print(er.args[0] in (errno.ETIMEDOUT, "timed out")) # CPython uses a string instead of errno
s.close()

View File

@ -0,0 +1,147 @@
# test that socket.connect() on a non-blocking socket raises EINPROGRESS
# and that an immediate write/send/read/recv does the right thing
try:
import sys, time
import uerrno as errno, usocket as socket, ussl as ssl
except:
import socket, errno, ssl
isMP = sys.implementation.name == "micropython"
def dp(e):
# uncomment next line for development and testing, to print the actual exceptions
# print(repr(e))
pass
# do_connect establishes the socket and wraps it if tls is True.
# If handshake is true, the initial connect (and TLS handshake) is
# allowed to be performed before returning.
def do_connect(peer_addr, tls, handshake):
s = socket.socket()
s.setblocking(False)
try:
# print("Connecting to", peer_addr)
s.connect(peer_addr)
except OSError as er:
print("connect:", er.args[0] == errno.EINPROGRESS)
if er.args[0] != errno.EINPROGRESS:
print(" got", er.args[0])
# wrap with ssl/tls if desired
if tls:
try:
if sys.implementation.name == "micropython":
s = ssl.wrap_socket(s, do_handshake=handshake)
else:
s = ssl.wrap_socket(s, do_handshake_on_connect=handshake)
print("wrap: True")
except Exception as e:
dp(e)
print("wrap:", e)
elif handshake:
# just sleep a little bit, this allows any connect() errors to happen
time.sleep(0.2)
return s
# test runs the test against a specific peer address.
def test(peer_addr, tls=False, handshake=False):
# MicroPython plain sockets have read/write, but CPython's don't
# MicroPython TLS sockets and CPython's have read/write
# hasRW captures this wonderful state of affairs
hasRW = isMP or tls
# MicroPython plain sockets and CPython's have send/recv
# MicroPython TLS sockets don't have send/recv, but CPython's do
# hasSR captures this wonderful state of affairs
hasSR = not (isMP and tls)
# connect + send
if hasSR:
s = do_connect(peer_addr, tls, handshake)
# send -> 4 or EAGAIN
try:
ret = s.send(b"1234")
print("send:", handshake and ret == 4)
except OSError as er:
#
dp(er)
print("send:", er.args[0] in (errno.EAGAIN, errno.EINPROGRESS))
s.close()
else: # fake it...
print("connect:", True)
if tls:
print("wrap:", True)
print("send:", True)
# connect + write
if hasRW:
s = do_connect(peer_addr, tls, handshake)
# write -> None
try:
ret = s.write(b"1234")
print("write:", ret in (4, None)) # SSL may accept 4 into buffer
except OSError as er:
dp(er)
print("write:", False) # should not raise
except ValueError as er: # CPython
dp(er)
print("write:", er.args[0] == "Write on closed or unwrapped SSL socket.")
s.close()
else: # fake it...
print("connect:", True)
if tls:
print("wrap:", True)
print("write:", True)
if hasSR:
# connect + recv
s = do_connect(peer_addr, tls, handshake)
# recv -> EAGAIN
try:
print("recv:", s.recv(10))
except OSError as er:
dp(er)
print("recv:", er.args[0] == errno.EAGAIN)
s.close()
else: # fake it...
print("connect:", True)
if tls:
print("wrap:", True)
print("recv:", True)
# connect + read
if hasRW:
s = do_connect(peer_addr, tls, handshake)
# read -> None
try:
ret = s.read(10)
print("read:", ret is None)
except OSError as er:
dp(er)
print("read:", False) # should not raise
except ValueError as er: # CPython
dp(er)
print("read:", er.args[0] == "Read on closed or unwrapped SSL socket.")
s.close()
else: # fake it...
print("connect:", True)
if tls:
print("wrap:", True)
print("read:", True)
if __name__ == "__main__":
# these tests use a non-existent test IP address, this way the connect takes forever and
# we can see EAGAIN/None (https://tools.ietf.org/html/rfc5737)
print("--- Plain sockets to nowhere ---")
test(socket.getaddrinfo("192.0.2.1", 80)[0][-1], False, False)
print("--- SSL sockets to nowhere ---")
# this test fails with AXTLS because do_handshake=False blocks on first read/write and
# there it times out until the connect is aborted
test(socket.getaddrinfo("192.0.2.1", 443)[0][-1], True, False)
print("--- Plain sockets ---")
test(socket.getaddrinfo("micropython.org", 80)[0][-1], False, True)
print("--- SSL sockets ---")
test(socket.getaddrinfo("micropython.org", 443)[0][-1], True, True)

View File

@ -0,0 +1,51 @@
# test that socket.connect() on a non-blocking socket raises EINPROGRESS
# and that an immediate write/send/read/recv does the right thing
import sys
try:
import uerrno as errno, usocket as socket, ussl as ssl
except:
import errno, socket, ssl
def test(addr, hostname, block=True):
print("---", hostname or addr)
s = socket.socket()
s.setblocking(block)
try:
s.connect(addr)
print("connected")
except OSError as e:
if e.args[0] != errno.EINPROGRESS:
raise
print("EINPROGRESS")
try:
if sys.implementation.name == "micropython":
s = ssl.wrap_socket(s, do_handshake=block)
else:
s = ssl.wrap_socket(s, do_handshake_on_connect=block)
print("wrap: True")
except OSError:
print("wrap: error")
if not block:
try:
while s.write(b"0") is None:
pass
except (ValueError, OSError): # CPython raises ValueError, MicroPython raises OSError
print("write: error")
s.close()
if __name__ == "__main__":
# connect to plain HTTP port, oops!
addr = socket.getaddrinfo("micropython.org", 80)[0][-1]
test(addr, None)
# connect to plain HTTP port, oops!
addr = socket.getaddrinfo("micropython.org", 80)[0][-1]
test(addr, None, False)
# connect to server with self-signed cert, oops!
addr = socket.getaddrinfo("test.mosquitto.org", 8883)[0][-1]
test(addr, "test.mosquitto.org")

View File

@ -0,0 +1,116 @@
try:
import usocket as socket, ussl as ssl, uerrno as errno, sys
except:
import socket, ssl, errno, sys, time, select
def test_one(site, opts):
ai = socket.getaddrinfo(site, 443)
addr = ai[0][-1]
print(addr)
# Connect the raw socket
s = socket.socket()
s.setblocking(False)
try:
s.connect(addr)
raise OSError(-1, "connect blocks")
except OSError as e:
if e.args[0] != errno.EINPROGRESS:
raise
if sys.implementation.name != "micropython":
# in CPython we have to wait, otherwise wrap_socket is not happy
select.select([], [s], [])
try:
# Wrap with SSL
try:
if sys.implementation.name == "micropython":
s = ssl.wrap_socket(s, do_handshake=False)
else:
s = ssl.wrap_socket(s, do_handshake_on_connect=False)
except OSError as e:
if e.args[0] != errno.EINPROGRESS:
raise
print("wrapped")
# CPython needs to be told to do the handshake
if sys.implementation.name != "micropython":
while True:
try:
s.do_handshake()
break
except ssl.SSLError as err:
if err.args[0] == ssl.SSL_ERROR_WANT_READ:
select.select([s], [], [])
elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
select.select([], [s], [])
else:
raise
time.sleep(0.1)
# print("shook hands")
# Write HTTP request
out = b"GET / HTTP/1.0\r\nHost: %s\r\n\r\n" % bytes(site, "latin")
while len(out) > 0:
n = s.write(out)
if n is None:
continue
if n > 0:
out = out[n:]
elif n == 0:
raise OSError(-1, "unexpected EOF in write")
print("wrote")
# Read response
resp = b""
while True:
try:
b = s.read(128)
except OSError as err:
if err.args[0] == 2: # 2=ssl.SSL_ERROR_WANT_READ:
continue
raise
if b is None:
continue
if len(b) > 0:
if len(resp) < 1024:
resp += b
elif len(b) == 0:
break
print("read")
if resp[:7] != b"HTTP/1.":
raise ValueError("response doesn't start with HTTP/1.")
# print(resp)
finally:
s.close()
SITES = [
"google.com",
{"host": "www.google.com"},
"micropython.org",
"pypi.org",
"api.telegram.org",
{"host": "api.pushbullet.com", "sni": True},
]
def main():
for site in SITES:
opts = {}
if isinstance(site, dict):
opts = site
site = opts["host"]
try:
test_one(site, opts)
print(site, "ok")
except Exception as e:
print(site, "error")
print("DONE")
main()

View File

@ -27,6 +27,8 @@ def test_one(site, opts):
s.write(b"GET / HTTP/1.0\r\nHost: %s\r\n\r\n" % bytes(site, "latin"))
resp = s.read(4096)
if resp[:7] != b"HTTP/1.":
raise ValueError("response doesn't start with HTTP/1.")
# print(resp)
finally:
@ -36,10 +38,10 @@ def test_one(site, opts):
SITES = [
"google.com",
"www.google.com",
"micropython.org",
"pypi.org",
"api.telegram.org",
{"host": "api.pushbullet.com", "sni": True},
# "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com",
{"host": "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com", "sni": True},
]

View File

@ -1,5 +1,6 @@
google.com ok
www.google.com ok
micropython.org ok
pypi.org ok
api.telegram.org ok
api.pushbullet.com ok
w9rybpfril.execute-api.ap-southeast-2.amazonaws.com ok