Implemented set.difference and set.difference_update

This commit is contained in:
John R. Lenton 2014-01-12 17:07:17 +00:00
parent 2a24172cdc
commit 032129f3b5
2 changed files with 62 additions and 0 deletions

View File

@ -137,6 +137,45 @@ static mp_obj_t set_discard(mp_obj_t self_in, mp_obj_t item) {
} }
static MP_DEFINE_CONST_FUN_OBJ_2(set_discard_obj, set_discard); static MP_DEFINE_CONST_FUN_OBJ_2(set_discard_obj, set_discard);
static mp_obj_t set_diff_int(int n_args, const mp_obj_t *args, bool update) {
assert(n_args > 0);
assert(MP_OBJ_IS_TYPE(args[0], &set_type));
mp_obj_set_t *self;
if (update) {
self = args[0];
} else {
self = set_copy(args[0]);
}
for (int i = 1; i < n_args; i++) {
mp_obj_t other = args[i];
if (self == other) {
set_clear(self);
} else {
mp_obj_t iter = rt_getiter(other);
mp_obj_t next;
while ((next = rt_iternext(iter)) != mp_const_stop_iteration) {
set_discard(self, next);
}
}
}
return self;
}
static mp_obj_t set_diff(int n_args, const mp_obj_t *args) {
return set_diff_int(n_args, args, false);
}
static MP_DEFINE_CONST_FUN_OBJ_VAR(set_diff_obj, 1, set_diff);
static mp_obj_t set_diff_update(int n_args, const mp_obj_t *args) {
set_diff_int(n_args, args, true);
return mp_const_none;
}
static MP_DEFINE_CONST_FUN_OBJ_VAR(set_diff_update_obj, 1, set_diff_update);
/******************************************************************************/ /******************************************************************************/
/* set constructors & public C API */ /* set constructors & public C API */
@ -146,6 +185,8 @@ static const mp_method_t set_type_methods[] = {
{ "clear", &set_clear_obj }, { "clear", &set_clear_obj },
{ "copy", &set_copy_obj }, { "copy", &set_copy_obj },
{ "discard", &set_discard_obj }, { "discard", &set_discard_obj },
{ "difference", &set_diff_obj },
{ "difference_update", &set_diff_update_obj },
{ NULL, NULL }, // end-of-list sentinel { NULL, NULL }, // end-of-list sentinel
}; };

View File

@ -0,0 +1,21 @@
def report(s):
l = list(s)
l.sort()
print(l)
l = [1, 2, 3, 4]
s = set(l)
outs = [s.difference(),
s.difference({1}),
s.difference({1}, [1, 2]),
s.difference({1}, {1, 2}, {2, 3})]
for out in outs:
report(out)
s = set(l)
print(s.difference_update())
report(s)
print(s.difference_update({1}))
report(s)
print(s.difference_update({1}, [2]))
report(s)