py/objset: Check that RHS of a binary op is a set/frozenset.
CPython docs explicitly state that the RHS of a set/frozenset binary op must be a set to prevent user errors. It also preserves commutativity of the ops, eg: "abc" & set() is a TypeError, and so should be set() & "abc". This change actually decreases unix (x64) code by 160 bytes; it increases stm32 by 4 bytes and esp8266 by 28 bytes (but previous patch already introduced a much large saving).
This commit is contained in:
parent
01978648fd
commit
2ac1364688
@ -463,6 +463,10 @@ STATIC mp_obj_t set_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
|
|||||||
#else
|
#else
|
||||||
bool update = true;
|
bool update = true;
|
||||||
#endif
|
#endif
|
||||||
|
if (op != MP_BINARY_OP_IN && !is_set_or_frozenset(rhs)) {
|
||||||
|
// For all ops except containment the RHS must be a set/frozenset
|
||||||
|
return MP_OBJ_NULL;
|
||||||
|
}
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case MP_BINARY_OP_OR:
|
case MP_BINARY_OP_OR:
|
||||||
return set_union(lhs, rhs);
|
return set_union(lhs, rhs);
|
||||||
|
@ -47,6 +47,18 @@ s1 = s2 = set('abc')
|
|||||||
s1 -= set('ad')
|
s1 -= set('ad')
|
||||||
print(s1 is s2, len(s1))
|
print(s1 is s2, len(s1))
|
||||||
|
|
||||||
|
# RHS must be a set
|
||||||
|
try:
|
||||||
|
print(set('12') >= '1')
|
||||||
|
except TypeError:
|
||||||
|
print('TypeError')
|
||||||
|
|
||||||
|
# RHS must be a set
|
||||||
|
try:
|
||||||
|
print(set('12') <= '123')
|
||||||
|
except TypeError:
|
||||||
|
print('TypeError')
|
||||||
|
|
||||||
# unsupported operator
|
# unsupported operator
|
||||||
try:
|
try:
|
||||||
set('abc') * 2
|
set('abc') * 2
|
||||||
|
@ -39,18 +39,6 @@ try:
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
print('NotImplementedError')
|
print('NotImplementedError')
|
||||||
|
|
||||||
# should raise type error
|
|
||||||
try:
|
|
||||||
print(set('12') >= '1')
|
|
||||||
except TypeError:
|
|
||||||
print('TypeError')
|
|
||||||
|
|
||||||
# should raise type error
|
|
||||||
try:
|
|
||||||
print(set('12') <= '123')
|
|
||||||
except TypeError:
|
|
||||||
print('TypeError')
|
|
||||||
|
|
||||||
# uPy raises TypeError, shold be ValueError
|
# uPy raises TypeError, shold be ValueError
|
||||||
try:
|
try:
|
||||||
'%c' % b'\x01\x02'
|
'%c' % b'\x01\x02'
|
||||||
|
@ -3,8 +3,6 @@ AttributeError
|
|||||||
TypeError
|
TypeError
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
True
|
|
||||||
True
|
|
||||||
TypeError, ValueError
|
TypeError, ValueError
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
|
Loading…
x
Reference in New Issue
Block a user