extmod/modussl: Fix ussl read/recv/send/write errors when non-blocking.
Also fix related problems with socket on esp32, improve docs for wrap_socket, and add more tests.
This commit is contained in:
parent
2eed9780ba
commit
2c1299b007
|
@ -13,16 +13,23 @@ facilities for network sockets, both client-side and server-side.
|
|||
Functions
|
||||
---------
|
||||
|
||||
.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None)
|
||||
|
||||
.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None, do_handshake=True)
|
||||
Takes a `stream` *sock* (usually usocket.socket instance of ``SOCK_STREAM`` type),
|
||||
and returns an instance of ssl.SSLSocket, which wraps the underlying stream in
|
||||
an SSL context. Returned object has the usual `stream` interface methods like
|
||||
``read()``, ``write()``, etc. In MicroPython, the returned object does not expose
|
||||
socket interface and methods like ``recv()``, ``send()``. In particular, a
|
||||
server-side SSL socket should be created from a normal socket returned from
|
||||
``read()``, ``write()``, etc.
|
||||
A server-side SSL socket should be created from a normal socket returned from
|
||||
:meth:`~usocket.socket.accept()` on a non-SSL listening server socket.
|
||||
|
||||
- *do_handshake* determines whether the handshake is done as part of the ``wrap_socket``
|
||||
or whether it is deferred to be done as part of the initial reads or writes
|
||||
(there is no ``do_handshake`` method as in CPython).
|
||||
For blocking sockets doing the handshake immediately is standard. For non-blocking
|
||||
sockets (i.e. when the *sock* passed into ``wrap_socket`` is in non-blocking mode)
|
||||
the handshake should generally be deferred because otherwise ``wrap_socket`` blocks
|
||||
until it completes. Note that in AXTLS the handshake can be deferred until the first
|
||||
read or write but it then blocks until completion.
|
||||
|
||||
Depending on the underlying module implementation in a particular
|
||||
:term:`MicroPython port`, some or all keyword arguments above may be not supported.
|
||||
|
||||
|
@ -31,6 +38,11 @@ Functions
|
|||
Some implementations of ``ussl`` module do NOT validate server certificates,
|
||||
which makes an SSL connection established prone to man-in-the-middle attacks.
|
||||
|
||||
CPython's ``wrap_socket`` returns an ``SSLSocket`` object which has methods typical
|
||||
for sockets, such as ``send``, ``recv``, etc. MicroPython's ``wrap_socket``
|
||||
returns an object more similar to CPython's ``SSLObject`` which does not have
|
||||
these socket methods.
|
||||
|
||||
Exceptions
|
||||
----------
|
||||
|
||||
|
|
|
@ -167,10 +167,15 @@ STATIC mp_obj_ssl_socket_t *ussl_socket_new(mp_obj_t sock, struct ssl_args *args
|
|||
o->ssl_sock = ssl_client_new(o->ssl_ctx, (long)sock, NULL, 0, ext);
|
||||
|
||||
if (args->do_handshake.u_bool) {
|
||||
int res = ssl_handshake_status(o->ssl_sock);
|
||||
int r = ssl_handshake_status(o->ssl_sock);
|
||||
|
||||
if (res != SSL_OK) {
|
||||
ussl_raise_error(res);
|
||||
if (r != SSL_OK) {
|
||||
if (r == SSL_CLOSE_NOTIFY) { // EOF
|
||||
r = MP_ENOTCONN;
|
||||
} else if (r == SSL_EAGAIN) {
|
||||
r = MP_EAGAIN;
|
||||
}
|
||||
ussl_raise_error(r);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -242,8 +247,24 @@ STATIC mp_uint_t ussl_socket_write(mp_obj_t o_in, const void *buf, mp_uint_t siz
|
|||
return MP_STREAM_ERROR;
|
||||
}
|
||||
|
||||
mp_int_t r = ssl_write(o->ssl_sock, buf, size);
|
||||
mp_int_t r;
|
||||
eagain:
|
||||
r = ssl_write(o->ssl_sock, buf, size);
|
||||
if (r == 0) {
|
||||
// see comment in ussl_socket_read above
|
||||
if (o->blocking) {
|
||||
goto eagain;
|
||||
} else {
|
||||
r = SSL_EAGAIN;
|
||||
}
|
||||
}
|
||||
if (r < 0) {
|
||||
if (r == SSL_CLOSE_NOTIFY || r == SSL_ERROR_CONN_LOST) {
|
||||
return 0; // EOF
|
||||
}
|
||||
if (r == SSL_EAGAIN) {
|
||||
r = MP_EAGAIN;
|
||||
}
|
||||
*errcode = r;
|
||||
return MP_STREAM_ERROR;
|
||||
}
|
||||
|
|
|
@ -133,6 +133,7 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
|
|||
}
|
||||
}
|
||||
|
||||
// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket
|
||||
STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
|
||||
mp_obj_t sock = *(mp_obj_t *)ctx;
|
||||
|
||||
|
@ -171,7 +172,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
|
|||
mbedtls_pk_init(&o->pkey);
|
||||
mbedtls_ctr_drbg_init(&o->ctr_drbg);
|
||||
#ifdef MBEDTLS_DEBUG_C
|
||||
// Debug level (0-4)
|
||||
// Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose
|
||||
mbedtls_debug_set_threshold(0);
|
||||
#endif
|
||||
|
||||
|
|
|
@ -558,7 +558,8 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
|
|||
MP_THREAD_GIL_EXIT();
|
||||
int r = lwip_write(sock->fd, data + sentlen, datalen - sentlen);
|
||||
MP_THREAD_GIL_ENTER();
|
||||
if (r < 0 && errno != EWOULDBLOCK) {
|
||||
// lwip returns EINPROGRESS when trying to send right after a non-blocking connect
|
||||
if (r < 0 && errno != EWOULDBLOCK && errno != EINPROGRESS) {
|
||||
mp_raise_OSError(errno);
|
||||
}
|
||||
if (r > 0) {
|
||||
|
@ -567,7 +568,7 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
|
|||
check_for_exceptions();
|
||||
}
|
||||
if (sentlen == 0) {
|
||||
mp_raise_OSError(MP_ETIMEDOUT);
|
||||
mp_raise_OSError(sock->retries == 0 ? MP_EWOULDBLOCK : MP_ETIMEDOUT);
|
||||
}
|
||||
return sentlen;
|
||||
}
|
||||
|
@ -650,7 +651,8 @@ STATIC mp_uint_t socket_stream_write(mp_obj_t self_in, const void *buf, mp_uint_
|
|||
if (r > 0) {
|
||||
return r;
|
||||
}
|
||||
if (r < 0 && errno != EWOULDBLOCK) {
|
||||
// lwip returns MP_EINPROGRESS when trying to write right after a non-blocking connect
|
||||
if (r < 0 && errno != EWOULDBLOCK && errno != EINPROGRESS) {
|
||||
*errcode = errno;
|
||||
return MP_STREAM_ERROR;
|
||||
}
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# test that socket.accept() on a socket with timeout raises ETIMEDOUT
|
||||
|
||||
try:
|
||||
import usocket as socket
|
||||
import uerrno as errno, usocket as socket
|
||||
except:
|
||||
import socket
|
||||
import errno, socket
|
||||
|
||||
try:
|
||||
socket.socket.settimeout
|
||||
|
@ -18,5 +18,5 @@ s.listen(1)
|
|||
try:
|
||||
s.accept()
|
||||
except OSError as er:
|
||||
print(er.args[0] in (110, "timed out")) # 110 is ETIMEDOUT; CPython uses a string
|
||||
print(er.args[0] in (errno.ETIMEDOUT, "timed out")) # CPython uses a string instead of errno
|
||||
s.close()
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
# test that socket.connect() on a non-blocking socket raises EINPROGRESS
|
||||
# and that an immediate write/send/read/recv does the right thing
|
||||
|
||||
try:
|
||||
import sys, time
|
||||
import uerrno as errno, usocket as socket, ussl as ssl
|
||||
except:
|
||||
import socket, errno, ssl
|
||||
isMP = sys.implementation.name == "micropython"
|
||||
|
||||
|
||||
def dp(e):
|
||||
# uncomment next line for development and testing, to print the actual exceptions
|
||||
# print(repr(e))
|
||||
pass
|
||||
|
||||
|
||||
# do_connect establishes the socket and wraps it if tls is True.
|
||||
# If handshake is true, the initial connect (and TLS handshake) is
|
||||
# allowed to be performed before returning.
|
||||
def do_connect(peer_addr, tls, handshake):
|
||||
s = socket.socket()
|
||||
s.setblocking(False)
|
||||
try:
|
||||
# print("Connecting to", peer_addr)
|
||||
s.connect(peer_addr)
|
||||
except OSError as er:
|
||||
print("connect:", er.args[0] == errno.EINPROGRESS)
|
||||
if er.args[0] != errno.EINPROGRESS:
|
||||
print(" got", er.args[0])
|
||||
# wrap with ssl/tls if desired
|
||||
if tls:
|
||||
try:
|
||||
if sys.implementation.name == "micropython":
|
||||
s = ssl.wrap_socket(s, do_handshake=handshake)
|
||||
else:
|
||||
s = ssl.wrap_socket(s, do_handshake_on_connect=handshake)
|
||||
print("wrap: True")
|
||||
except Exception as e:
|
||||
dp(e)
|
||||
print("wrap:", e)
|
||||
elif handshake:
|
||||
# just sleep a little bit, this allows any connect() errors to happen
|
||||
time.sleep(0.2)
|
||||
return s
|
||||
|
||||
|
||||
# test runs the test against a specific peer address.
|
||||
def test(peer_addr, tls=False, handshake=False):
|
||||
# MicroPython plain sockets have read/write, but CPython's don't
|
||||
# MicroPython TLS sockets and CPython's have read/write
|
||||
# hasRW captures this wonderful state of affairs
|
||||
hasRW = isMP or tls
|
||||
|
||||
# MicroPython plain sockets and CPython's have send/recv
|
||||
# MicroPython TLS sockets don't have send/recv, but CPython's do
|
||||
# hasSR captures this wonderful state of affairs
|
||||
hasSR = not (isMP and tls)
|
||||
|
||||
# connect + send
|
||||
if hasSR:
|
||||
s = do_connect(peer_addr, tls, handshake)
|
||||
# send -> 4 or EAGAIN
|
||||
try:
|
||||
ret = s.send(b"1234")
|
||||
print("send:", handshake and ret == 4)
|
||||
except OSError as er:
|
||||
#
|
||||
dp(er)
|
||||
print("send:", er.args[0] in (errno.EAGAIN, errno.EINPROGRESS))
|
||||
s.close()
|
||||
else: # fake it...
|
||||
print("connect:", True)
|
||||
if tls:
|
||||
print("wrap:", True)
|
||||
print("send:", True)
|
||||
|
||||
# connect + write
|
||||
if hasRW:
|
||||
s = do_connect(peer_addr, tls, handshake)
|
||||
# write -> None
|
||||
try:
|
||||
ret = s.write(b"1234")
|
||||
print("write:", ret in (4, None)) # SSL may accept 4 into buffer
|
||||
except OSError as er:
|
||||
dp(er)
|
||||
print("write:", False) # should not raise
|
||||
except ValueError as er: # CPython
|
||||
dp(er)
|
||||
print("write:", er.args[0] == "Write on closed or unwrapped SSL socket.")
|
||||
s.close()
|
||||
else: # fake it...
|
||||
print("connect:", True)
|
||||
if tls:
|
||||
print("wrap:", True)
|
||||
print("write:", True)
|
||||
|
||||
if hasSR:
|
||||
# connect + recv
|
||||
s = do_connect(peer_addr, tls, handshake)
|
||||
# recv -> EAGAIN
|
||||
try:
|
||||
print("recv:", s.recv(10))
|
||||
except OSError as er:
|
||||
dp(er)
|
||||
print("recv:", er.args[0] == errno.EAGAIN)
|
||||
s.close()
|
||||
else: # fake it...
|
||||
print("connect:", True)
|
||||
if tls:
|
||||
print("wrap:", True)
|
||||
print("recv:", True)
|
||||
|
||||
# connect + read
|
||||
if hasRW:
|
||||
s = do_connect(peer_addr, tls, handshake)
|
||||
# read -> None
|
||||
try:
|
||||
ret = s.read(10)
|
||||
print("read:", ret is None)
|
||||
except OSError as er:
|
||||
dp(er)
|
||||
print("read:", False) # should not raise
|
||||
except ValueError as er: # CPython
|
||||
dp(er)
|
||||
print("read:", er.args[0] == "Read on closed or unwrapped SSL socket.")
|
||||
s.close()
|
||||
else: # fake it...
|
||||
print("connect:", True)
|
||||
if tls:
|
||||
print("wrap:", True)
|
||||
print("read:", True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# these tests use a non-existent test IP address, this way the connect takes forever and
|
||||
# we can see EAGAIN/None (https://tools.ietf.org/html/rfc5737)
|
||||
print("--- Plain sockets to nowhere ---")
|
||||
test(socket.getaddrinfo("192.0.2.1", 80)[0][-1], False, False)
|
||||
print("--- SSL sockets to nowhere ---")
|
||||
# this test fails with AXTLS because do_handshake=False blocks on first read/write and
|
||||
# there it times out until the connect is aborted
|
||||
test(socket.getaddrinfo("192.0.2.1", 443)[0][-1], True, False)
|
||||
print("--- Plain sockets ---")
|
||||
test(socket.getaddrinfo("micropython.org", 80)[0][-1], False, True)
|
||||
print("--- SSL sockets ---")
|
||||
test(socket.getaddrinfo("micropython.org", 443)[0][-1], True, True)
|
|
@ -0,0 +1,51 @@
|
|||
# test that socket.connect() on a non-blocking socket raises EINPROGRESS
|
||||
# and that an immediate write/send/read/recv does the right thing
|
||||
|
||||
import sys
|
||||
|
||||
try:
|
||||
import uerrno as errno, usocket as socket, ussl as ssl
|
||||
except:
|
||||
import errno, socket, ssl
|
||||
|
||||
|
||||
def test(addr, hostname, block=True):
|
||||
print("---", hostname or addr)
|
||||
s = socket.socket()
|
||||
s.setblocking(block)
|
||||
try:
|
||||
s.connect(addr)
|
||||
print("connected")
|
||||
except OSError as e:
|
||||
if e.args[0] != errno.EINPROGRESS:
|
||||
raise
|
||||
print("EINPROGRESS")
|
||||
|
||||
try:
|
||||
if sys.implementation.name == "micropython":
|
||||
s = ssl.wrap_socket(s, do_handshake=block)
|
||||
else:
|
||||
s = ssl.wrap_socket(s, do_handshake_on_connect=block)
|
||||
print("wrap: True")
|
||||
except OSError:
|
||||
print("wrap: error")
|
||||
|
||||
if not block:
|
||||
try:
|
||||
while s.write(b"0") is None:
|
||||
pass
|
||||
except (ValueError, OSError): # CPython raises ValueError, MicroPython raises OSError
|
||||
print("write: error")
|
||||
s.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# connect to plain HTTP port, oops!
|
||||
addr = socket.getaddrinfo("micropython.org", 80)[0][-1]
|
||||
test(addr, None)
|
||||
# connect to plain HTTP port, oops!
|
||||
addr = socket.getaddrinfo("micropython.org", 80)[0][-1]
|
||||
test(addr, None, False)
|
||||
# connect to server with self-signed cert, oops!
|
||||
addr = socket.getaddrinfo("test.mosquitto.org", 8883)[0][-1]
|
||||
test(addr, "test.mosquitto.org")
|
|
@ -0,0 +1,116 @@
|
|||
try:
|
||||
import usocket as socket, ussl as ssl, uerrno as errno, sys
|
||||
except:
|
||||
import socket, ssl, errno, sys, time, select
|
||||
|
||||
|
||||
def test_one(site, opts):
|
||||
ai = socket.getaddrinfo(site, 443)
|
||||
addr = ai[0][-1]
|
||||
print(addr)
|
||||
|
||||
# Connect the raw socket
|
||||
s = socket.socket()
|
||||
s.setblocking(False)
|
||||
try:
|
||||
s.connect(addr)
|
||||
raise OSError(-1, "connect blocks")
|
||||
except OSError as e:
|
||||
if e.args[0] != errno.EINPROGRESS:
|
||||
raise
|
||||
|
||||
if sys.implementation.name != "micropython":
|
||||
# in CPython we have to wait, otherwise wrap_socket is not happy
|
||||
select.select([], [s], [])
|
||||
|
||||
try:
|
||||
# Wrap with SSL
|
||||
try:
|
||||
if sys.implementation.name == "micropython":
|
||||
s = ssl.wrap_socket(s, do_handshake=False)
|
||||
else:
|
||||
s = ssl.wrap_socket(s, do_handshake_on_connect=False)
|
||||
except OSError as e:
|
||||
if e.args[0] != errno.EINPROGRESS:
|
||||
raise
|
||||
print("wrapped")
|
||||
|
||||
# CPython needs to be told to do the handshake
|
||||
if sys.implementation.name != "micropython":
|
||||
while True:
|
||||
try:
|
||||
s.do_handshake()
|
||||
break
|
||||
except ssl.SSLError as err:
|
||||
if err.args[0] == ssl.SSL_ERROR_WANT_READ:
|
||||
select.select([s], [], [])
|
||||
elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
|
||||
select.select([], [s], [])
|
||||
else:
|
||||
raise
|
||||
time.sleep(0.1)
|
||||
# print("shook hands")
|
||||
|
||||
# Write HTTP request
|
||||
out = b"GET / HTTP/1.0\r\nHost: %s\r\n\r\n" % bytes(site, "latin")
|
||||
while len(out) > 0:
|
||||
n = s.write(out)
|
||||
if n is None:
|
||||
continue
|
||||
if n > 0:
|
||||
out = out[n:]
|
||||
elif n == 0:
|
||||
raise OSError(-1, "unexpected EOF in write")
|
||||
print("wrote")
|
||||
|
||||
# Read response
|
||||
resp = b""
|
||||
while True:
|
||||
try:
|
||||
b = s.read(128)
|
||||
except OSError as err:
|
||||
if err.args[0] == 2: # 2=ssl.SSL_ERROR_WANT_READ:
|
||||
continue
|
||||
raise
|
||||
if b is None:
|
||||
continue
|
||||
if len(b) > 0:
|
||||
if len(resp) < 1024:
|
||||
resp += b
|
||||
elif len(b) == 0:
|
||||
break
|
||||
print("read")
|
||||
|
||||
if resp[:7] != b"HTTP/1.":
|
||||
raise ValueError("response doesn't start with HTTP/1.")
|
||||
# print(resp)
|
||||
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
|
||||
SITES = [
|
||||
"google.com",
|
||||
{"host": "www.google.com"},
|
||||
"micropython.org",
|
||||
"pypi.org",
|
||||
"api.telegram.org",
|
||||
{"host": "api.pushbullet.com", "sni": True},
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
for site in SITES:
|
||||
opts = {}
|
||||
if isinstance(site, dict):
|
||||
opts = site
|
||||
site = opts["host"]
|
||||
try:
|
||||
test_one(site, opts)
|
||||
print(site, "ok")
|
||||
except Exception as e:
|
||||
print(site, "error")
|
||||
print("DONE")
|
||||
|
||||
|
||||
main()
|
|
@ -27,6 +27,8 @@ def test_one(site, opts):
|
|||
|
||||
s.write(b"GET / HTTP/1.0\r\nHost: %s\r\n\r\n" % bytes(site, "latin"))
|
||||
resp = s.read(4096)
|
||||
if resp[:7] != b"HTTP/1.":
|
||||
raise ValueError("response doesn't start with HTTP/1.")
|
||||
# print(resp)
|
||||
|
||||
finally:
|
||||
|
@ -36,10 +38,10 @@ def test_one(site, opts):
|
|||
SITES = [
|
||||
"google.com",
|
||||
"www.google.com",
|
||||
"micropython.org",
|
||||
"pypi.org",
|
||||
"api.telegram.org",
|
||||
{"host": "api.pushbullet.com", "sni": True},
|
||||
# "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com",
|
||||
{"host": "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com", "sni": True},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
google.com ok
|
||||
www.google.com ok
|
||||
micropython.org ok
|
||||
pypi.org ok
|
||||
api.telegram.org ok
|
||||
api.pushbullet.com ok
|
||||
w9rybpfril.execute-api.ap-southeast-2.amazonaws.com ok
|
||||
|
|
Loading…
Reference in New Issue