diff --git a/extmod/modwebsocket.c b/extmod/modwebsocket.c index 7d2363a4c1..34ed46a6ef 100644 --- a/extmod/modwebsocket.c +++ b/extmod/modwebsocket.c @@ -26,6 +26,8 @@ #include #include +#include +#include #include "py/nlr.h" #include "py/obj.h" @@ -38,11 +40,13 @@ enum { FRAME_HEADER, FRAME_OPT, PAYLOAD }; typedef struct _mp_obj_websocket_t { mp_obj_base_t base; mp_obj_t sock; - uint32_t mask; + uint32_t msg_sz; + byte mask[4]; byte state; byte to_recv; byte mask_pos; - byte buf[4]; + byte buf_pos; + byte buf[6]; } mp_obj_websocket_t; STATIC mp_obj_t websocket_make_new(const mp_obj_type_t *type, size_t n_args, size_t n_kw, const mp_obj_t *args) { @@ -53,9 +57,95 @@ STATIC mp_obj_t websocket_make_new(const mp_obj_type_t *type, size_t n_args, siz o->state = FRAME_HEADER; o->to_recv = 2; o->mask_pos = 0; + o->buf_pos = 0; return o; } +STATIC mp_uint_t websocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int *errcode) { + mp_obj_websocket_t *self = self_in; + const mp_stream_p_t *stream_p = mp_get_stream_raise(self->sock, MP_STREAM_OP_READ); + while (1) { + if (self->to_recv != 0) { + mp_uint_t out_sz = stream_p->read(self->sock, self->buf + self->buf_pos, self->to_recv, errcode); + if (out_sz == MP_STREAM_ERROR) { + return out_sz; + } + self->buf_pos += out_sz; + self->to_recv -= out_sz; + if (self->to_recv != 0) { + *errcode = EAGAIN; + return MP_STREAM_ERROR; + } + } + + switch (self->state) { + case FRAME_HEADER: { + assert(self->buf[0] & 0x80); + int to_recv = 0; + size_t sz = self->buf[1] & 0x7f; + if (sz == 126) { + // Msg size is next 2 bytes + to_recv += 2; + } else if (sz == 127) { + // Msg size is next 2 bytes + assert(0); + } + if (self->buf[1] & 0x80) { + // Next 4 bytes is mask + to_recv += 4; + } + + self->buf_pos = 0; + self->to_recv = to_recv; + self->msg_sz = sz; // May be overriden by FRAME_OPT + if (to_recv != 0) { + self->state = FRAME_OPT; + } else { + self->state = PAYLOAD; + } + continue; + } + + case FRAME_OPT: { + if ((self->buf_pos & 3) == 2) { + // First two bytes are message length + self->msg_sz = (self->buf[0] << 8) | self->buf[1]; + } + if (self->buf_pos >= 4) { + // Last 4 bytes is mask + memcpy(self->mask, self->buf + self->buf_pos - 4, 4); + } + self->buf_pos = 0; + self->state = PAYLOAD; + continue; + } + + case PAYLOAD: { + size_t sz = MIN(size, self->msg_sz); + mp_uint_t out_sz = stream_p->read(self->sock, buf, sz, errcode); + if (out_sz == MP_STREAM_ERROR) { + return out_sz; + } + + sz = out_sz; + for (byte *p = buf; sz--; p++) { + *p ^= self->mask[self->mask_pos++ & 3]; + } + + self->msg_sz -= out_sz; + if (self->msg_sz == 0) { + self->state = FRAME_HEADER; + self->to_recv = 2; + self->mask_pos = 0; + self->buf_pos = 0; + } + return out_sz; + } + + } + } +} + STATIC mp_uint_t websocket_write(mp_obj_t self_in, const void *buf, mp_uint_t size, int *errcode) { mp_obj_websocket_t *self = self_in; assert(size < 126); @@ -69,12 +159,13 @@ STATIC mp_uint_t websocket_write(mp_obj_t self_in, const void *buf, mp_uint_t si } STATIC const mp_map_elem_t websocket_locals_dict_table[] = { + { MP_OBJ_NEW_QSTR(MP_QSTR_read), (mp_obj_t)&mp_stream_read_obj }, { MP_OBJ_NEW_QSTR(MP_QSTR_write), (mp_obj_t)&mp_stream_write_obj }, }; STATIC MP_DEFINE_CONST_DICT(websocket_locals_dict, websocket_locals_dict_table); STATIC const mp_stream_p_t websocket_stream_p = { -// .read = websocket_read, + .read = websocket_read, .write = websocket_write, };