diff --git a/py/emitnative.c b/py/emitnative.c index 6c8e9feeba..2a657b6964 100644 --- a/py/emitnative.c +++ b/py/emitnative.c @@ -2289,7 +2289,8 @@ STATIC void emit_native_binary_op(emit_t *emit, mp_binary_op_t op) { DEBUG_printf("binary_op(" UINT_FMT ")\n", op); vtype_kind_t vtype_lhs = peek_vtype(emit, 1); vtype_kind_t vtype_rhs = peek_vtype(emit, 0); - if (vtype_lhs == VTYPE_INT && vtype_rhs == VTYPE_INT) { + if ((vtype_lhs == VTYPE_INT || vtype_lhs == VTYPE_UINT) + && (vtype_rhs == VTYPE_INT || vtype_rhs == VTYPE_UINT)) { // for integers, inplace and normal ops are equivalent, so use just normal ops if (MP_BINARY_OP_INPLACE_OR <= op && op <= MP_BINARY_OP_INPLACE_POWER) { op += MP_BINARY_OP_OR - MP_BINARY_OP_INPLACE_OR; @@ -2306,9 +2307,13 @@ STATIC void emit_native_binary_op(emit_t *emit, mp_binary_op_t op) { if (op == MP_BINARY_OP_LSHIFT) { ASM_LSL_REG(emit->as, REG_RET); } else { - ASM_ASR_REG(emit->as, REG_RET); + if (vtype_lhs == VTYPE_UINT) { + ASM_LSR_REG(emit->as, REG_RET); + } else { + ASM_ASR_REG(emit->as, REG_RET); + } } - emit_post_push_reg(emit, VTYPE_INT, REG_RET); + emit_post_push_reg(emit, vtype_lhs, REG_RET); return; } #endif @@ -2316,6 +2321,10 @@ STATIC void emit_native_binary_op(emit_t *emit, mp_binary_op_t op) { // special cases for floor-divide and module because we dispatch to helper functions if (op == MP_BINARY_OP_FLOOR_DIVIDE || op == MP_BINARY_OP_MODULO) { emit_pre_pop_reg_reg(emit, &vtype_rhs, REG_ARG_2, &vtype_lhs, REG_ARG_1); + if (vtype_lhs != VTYPE_INT) { + EMIT_NATIVE_VIPER_TYPE_ERROR(emit, + MP_ERROR_TEXT("div/mod not implemented for uint"), mp_binary_op_method_name[op]); + } if (op == MP_BINARY_OP_FLOOR_DIVIDE) { emit_call(emit, MP_F_SMALL_INT_FLOOR_DIVIDE); } else { @@ -2334,31 +2343,35 @@ STATIC void emit_native_binary_op(emit_t *emit, mp_binary_op_t op) { if (op == MP_BINARY_OP_LSHIFT) { ASM_LSL_REG_REG(emit->as, REG_ARG_2, reg_rhs); } else { - ASM_ASR_REG_REG(emit->as, REG_ARG_2, reg_rhs); + if (vtype_lhs == VTYPE_UINT) { + ASM_LSR_REG_REG(emit->as, REG_ARG_2, reg_rhs); + } else { + ASM_ASR_REG_REG(emit->as, REG_ARG_2, reg_rhs); + } } - emit_post_push_reg(emit, VTYPE_INT, REG_ARG_2); + emit_post_push_reg(emit, vtype_lhs, REG_ARG_2); return; } #endif if (op == MP_BINARY_OP_OR) { ASM_OR_REG_REG(emit->as, REG_ARG_2, reg_rhs); - emit_post_push_reg(emit, VTYPE_INT, REG_ARG_2); + emit_post_push_reg(emit, vtype_lhs, REG_ARG_2); } else if (op == MP_BINARY_OP_XOR) { ASM_XOR_REG_REG(emit->as, REG_ARG_2, reg_rhs); - emit_post_push_reg(emit, VTYPE_INT, REG_ARG_2); + emit_post_push_reg(emit, vtype_lhs, REG_ARG_2); } else if (op == MP_BINARY_OP_AND) { ASM_AND_REG_REG(emit->as, REG_ARG_2, reg_rhs); - emit_post_push_reg(emit, VTYPE_INT, REG_ARG_2); + emit_post_push_reg(emit, vtype_lhs, REG_ARG_2); } else if (op == MP_BINARY_OP_ADD) { ASM_ADD_REG_REG(emit->as, REG_ARG_2, reg_rhs); - emit_post_push_reg(emit, VTYPE_INT, REG_ARG_2); + emit_post_push_reg(emit, vtype_lhs, REG_ARG_2); } else if (op == MP_BINARY_OP_SUBTRACT) { ASM_SUB_REG_REG(emit->as, REG_ARG_2, reg_rhs); - emit_post_push_reg(emit, VTYPE_INT, REG_ARG_2); + emit_post_push_reg(emit, vtype_lhs, REG_ARG_2); } else if (op == MP_BINARY_OP_MULTIPLY) { ASM_MUL_REG_REG(emit->as, REG_ARG_2, reg_rhs); - emit_post_push_reg(emit, VTYPE_INT, REG_ARG_2); + emit_post_push_reg(emit, vtype_lhs, REG_ARG_2); } else if (MP_BINARY_OP_LESS <= op && op <= MP_BINARY_OP_NOT_EQUAL) { // comparison ops are (in enum order): // MP_BINARY_OP_LESS @@ -2367,11 +2380,26 @@ STATIC void emit_native_binary_op(emit_t *emit, mp_binary_op_t op) { // MP_BINARY_OP_LESS_EQUAL // MP_BINARY_OP_MORE_EQUAL // MP_BINARY_OP_NOT_EQUAL + + if (vtype_lhs != vtype_rhs) { + EMIT_NATIVE_VIPER_TYPE_ERROR(emit, MP_ERROR_TEXT("comparison of int and uint")); + } + + size_t op_idx = op - MP_BINARY_OP_LESS + (vtype_lhs == VTYPE_UINT ? 0 : 6); + need_reg_single(emit, REG_RET, 0); #if N_X64 asm_x64_xor_r64_r64(emit->as, REG_RET, REG_RET); asm_x64_cmp_r64_with_r64(emit->as, reg_rhs, REG_ARG_2); - static byte ops[6] = { + static byte ops[6 + 6] = { + // unsigned + ASM_X64_CC_JB, + ASM_X64_CC_JA, + ASM_X64_CC_JE, + ASM_X64_CC_JBE, + ASM_X64_CC_JAE, + ASM_X64_CC_JNE, + // signed ASM_X64_CC_JL, ASM_X64_CC_JG, ASM_X64_CC_JE, @@ -2379,11 +2407,19 @@ STATIC void emit_native_binary_op(emit_t *emit, mp_binary_op_t op) { ASM_X64_CC_JGE, ASM_X64_CC_JNE, }; - asm_x64_setcc_r8(emit->as, ops[op - MP_BINARY_OP_LESS], REG_RET); + asm_x64_setcc_r8(emit->as, ops[op_idx], REG_RET); #elif N_X86 asm_x86_xor_r32_r32(emit->as, REG_RET, REG_RET); asm_x86_cmp_r32_with_r32(emit->as, reg_rhs, REG_ARG_2); - static byte ops[6] = { + static byte ops[6 + 6] = { + // unsigned + ASM_X86_CC_JB, + ASM_X86_CC_JA, + ASM_X86_CC_JE, + ASM_X86_CC_JBE, + ASM_X86_CC_JAE, + ASM_X86_CC_JNE, + // signed ASM_X86_CC_JL, ASM_X86_CC_JG, ASM_X86_CC_JE, @@ -2391,24 +2427,39 @@ STATIC void emit_native_binary_op(emit_t *emit, mp_binary_op_t op) { ASM_X86_CC_JGE, ASM_X86_CC_JNE, }; - asm_x86_setcc_r8(emit->as, ops[op - MP_BINARY_OP_LESS], REG_RET); + asm_x86_setcc_r8(emit->as, ops[op_idx], REG_RET); #elif N_THUMB asm_thumb_cmp_rlo_rlo(emit->as, REG_ARG_2, reg_rhs); - static uint16_t ops[6] = { - ASM_THUMB_OP_ITE_GE, + static uint16_t ops[6 + 6] = { + // unsigned + ASM_THUMB_OP_ITE_CC, + ASM_THUMB_OP_ITE_HI, + ASM_THUMB_OP_ITE_EQ, + ASM_THUMB_OP_ITE_LS, + ASM_THUMB_OP_ITE_CS, + ASM_THUMB_OP_ITE_NE, + // signed + ASM_THUMB_OP_ITE_LT, ASM_THUMB_OP_ITE_GT, ASM_THUMB_OP_ITE_EQ, - ASM_THUMB_OP_ITE_GT, + ASM_THUMB_OP_ITE_LE, ASM_THUMB_OP_ITE_GE, - ASM_THUMB_OP_ITE_EQ, + ASM_THUMB_OP_ITE_NE, }; - static byte ret[6] = { 0, 1, 1, 0, 1, 0, }; - asm_thumb_op16(emit->as, ops[op - MP_BINARY_OP_LESS]); - asm_thumb_mov_rlo_i8(emit->as, REG_RET, ret[op - MP_BINARY_OP_LESS]); - asm_thumb_mov_rlo_i8(emit->as, REG_RET, ret[op - MP_BINARY_OP_LESS] ^ 1); + asm_thumb_op16(emit->as, ops[op_idx]); + asm_thumb_mov_rlo_i8(emit->as, REG_RET, 1); + asm_thumb_mov_rlo_i8(emit->as, REG_RET, 0); #elif N_ARM asm_arm_cmp_reg_reg(emit->as, REG_ARG_2, reg_rhs); - static uint ccs[6] = { + static uint ccs[6 + 6] = { + // unsigned + ASM_ARM_CC_CC, + ASM_ARM_CC_HI, + ASM_ARM_CC_EQ, + ASM_ARM_CC_LS, + ASM_ARM_CC_CS, + ASM_ARM_CC_NE, + // signed ASM_ARM_CC_LT, ASM_ARM_CC_GT, ASM_ARM_CC_EQ, @@ -2416,9 +2467,17 @@ STATIC void emit_native_binary_op(emit_t *emit, mp_binary_op_t op) { ASM_ARM_CC_GE, ASM_ARM_CC_NE, }; - asm_arm_setcc_reg(emit->as, REG_RET, ccs[op - MP_BINARY_OP_LESS]); + asm_arm_setcc_reg(emit->as, REG_RET, ccs[op_idx]); #elif N_XTENSA || N_XTENSAWIN - static uint8_t ccs[6] = { + static uint8_t ccs[6 + 6] = { + // unsigned + ASM_XTENSA_CC_LTU, + 0x80 | ASM_XTENSA_CC_LTU, // for GTU we'll swap args + ASM_XTENSA_CC_EQ, + 0x80 | ASM_XTENSA_CC_GEU, // for LEU we'll swap args + ASM_XTENSA_CC_GEU, + ASM_XTENSA_CC_NE, + // signed ASM_XTENSA_CC_LT, 0x80 | ASM_XTENSA_CC_LT, // for GT we'll swap args ASM_XTENSA_CC_EQ, @@ -2426,7 +2485,7 @@ STATIC void emit_native_binary_op(emit_t *emit, mp_binary_op_t op) { ASM_XTENSA_CC_GE, ASM_XTENSA_CC_NE, }; - uint8_t cc = ccs[op - MP_BINARY_OP_LESS]; + uint8_t cc = ccs[op_idx]; if ((cc & 0x80) == 0) { asm_xtensa_setcc_reg_reg_reg(emit->as, cc, REG_RET, REG_ARG_2, reg_rhs); } else { diff --git a/tests/micropython/viper_binop_arith_uint.py b/tests/micropython/viper_binop_arith_uint.py new file mode 100644 index 0000000000..e4270a10a7 --- /dev/null +++ b/tests/micropython/viper_binop_arith_uint.py @@ -0,0 +1,32 @@ +# test arithmetic operators with uint type + + +@micropython.viper +def add(x: uint, y: uint): + return x + y, y + x + + +print("add") +print(*add(1, 2)) +print(*(x & 0xFFFFFFFF for x in add(-1, -2))) + + +@micropython.viper +def sub(x: uint, y: uint): + return x - y, y - x + + +print("sub") +print(*(x & 0xFFFFFFFF for x in sub(1, 2))) +print(*(x & 0xFFFFFFFF for x in sub(-1, -2))) + + +@micropython.viper +def mul(x: uint, y: uint): + return x * y, y * x + + +print("mul") +print(*mul(2, 3)) +print(*(x & 0xFFFFFFFF for x in mul(2, -3))) +print(*mul(-2, -3)) diff --git a/tests/micropython/viper_binop_arith_uint.py.exp b/tests/micropython/viper_binop_arith_uint.py.exp new file mode 100644 index 0000000000..72f84b716a --- /dev/null +++ b/tests/micropython/viper_binop_arith_uint.py.exp @@ -0,0 +1,10 @@ +add +3 3 +4294967293 4294967293 +sub +4294967295 1 +1 4294967295 +mul +6 6 +4294967290 4294967290 +6 6 diff --git a/tests/micropython/viper_binop_bitwise_uint.py b/tests/micropython/viper_binop_bitwise_uint.py new file mode 100644 index 0000000000..3bc7ba8d11 --- /dev/null +++ b/tests/micropython/viper_binop_bitwise_uint.py @@ -0,0 +1,58 @@ +# test bitwise operators on uint type + + +@micropython.viper +def shl(x: uint, y: uint) -> uint: + return x << y + + +print("shl") +print(shl(1, 0)) +print(shl(1, 30)) +print(shl(-1, 10) & 0xFFFFFFFF) + + +@micropython.viper +def shr(x: uint, y: uint) -> uint: + return x >> y + + +print("shr") +print(shr(1, 0)) +print(shr(16, 3)) +print(shr(-1, 1) in (0x7FFFFFFF, 0x7FFFFFFF_FFFFFFFF)) + + +@micropython.viper +def and_(x: uint, y: uint): + return x & y, y & x + + +print("and") +print(*and_(1, 0)) +print(*and_(1, 3)) +print(*and_(-1, 2)) +print(*(x & 0xFFFFFFFF for x in and_(-1, -2))) + + +@micropython.viper +def or_(x: uint, y: uint): + return x | y, y | x + + +print("or") +print(*or_(1, 0)) +print(*or_(1, 2)) +print(*(x & 0xFFFFFFFF for x in or_(-1, 2))) + + +@micropython.viper +def xor(x: uint, y: uint): + return x ^ y, y ^ x + + +print("xor") +print(*xor(1, 0)) +print(*xor(1, 3)) +print(*(x & 0xFFFFFFFF for x in xor(-1, 3))) +print(*xor(-1, -3)) diff --git a/tests/micropython/viper_binop_bitwise_uint.py.exp b/tests/micropython/viper_binop_bitwise_uint.py.exp new file mode 100644 index 0000000000..3ad6469a2f --- /dev/null +++ b/tests/micropython/viper_binop_bitwise_uint.py.exp @@ -0,0 +1,22 @@ +shl +1 +1073741824 +4294966272 +shr +1 +2 +True +and +0 0 +1 1 +2 2 +4294967294 4294967294 +or +1 1 +3 3 +4294967295 4294967295 +xor +1 1 +2 2 +4294967292 4294967292 +2 2 diff --git a/tests/micropython/viper_binop_comp_uint.py b/tests/micropython/viper_binop_comp_uint.py new file mode 100644 index 0000000000..85aa32c78c --- /dev/null +++ b/tests/micropython/viper_binop_comp_uint.py @@ -0,0 +1,31 @@ +# test comparison operators with uint type + + +@micropython.viper +def f(x: uint, y: uint): + if x < y: + print(" <", end="") + if x > y: + print(" >", end="") + if x == y: + print(" ==", end="") + if x <= y: + print(" <=", end="") + if x >= y: + print(" >=", end="") + if x != y: + print(" !=", end="") + + +def test(a, b): + print(a, b, end="") + f(a, b) + print() + + +test(1, 1) +test(2, 1) +test(1, 2) +test(2, -1) +test(-2, 1) +test(-2, -1) diff --git a/tests/micropython/viper_binop_comp_uint.py.exp b/tests/micropython/viper_binop_comp_uint.py.exp new file mode 100644 index 0000000000..cacce62b6e --- /dev/null +++ b/tests/micropython/viper_binop_comp_uint.py.exp @@ -0,0 +1,6 @@ +1 1 == <= >= +2 1 > >= != +1 2 < <= != +2 -1 < <= != +-2 1 > >= != +-2 -1 < <= != diff --git a/tests/micropython/viper_error.py b/tests/micropython/viper_error.py index 790f3d75c4..80617af0c1 100644 --- a/tests/micropython/viper_error.py +++ b/tests/micropython/viper_error.py @@ -52,6 +52,7 @@ test("@micropython.viper\ndef f() -> int: return []") # can't do binary op between incompatible types test("@micropython.viper\ndef f(): 1 + []") +test("@micropython.viper\ndef f(x:int, y:uint): x < y") # can't load test("@micropython.viper\ndef f(): 1[0]") @@ -73,6 +74,8 @@ test("@micropython.viper\ndef f(x:int): -x") test("@micropython.viper\ndef f(x:int): ~x") # binary op not implemented +test("@micropython.viper\ndef f(x:uint, y:uint): res = x // y") +test("@micropython.viper\ndef f(x:uint, y:uint): res = x % y") test("@micropython.viper\ndef f(x:int): res = x in x") # yield (from) not implemented diff --git a/tests/micropython/viper_error.py.exp b/tests/micropython/viper_error.py.exp index da9a0ca93e..31c85b1d87 100644 --- a/tests/micropython/viper_error.py.exp +++ b/tests/micropython/viper_error.py.exp @@ -6,6 +6,7 @@ ViperTypeError("local 'x' has type 'int' but source is 'object'",) ViperTypeError("can't implicitly convert 'ptr' to 'bool'",) ViperTypeError("return expected 'int' but got 'object'",) ViperTypeError("can't do binary op between 'int' and 'object'",) +ViperTypeError('comparison of int and uint',) ViperTypeError("can't load from 'int'",) ViperTypeError("can't load from 'int'",) ViperTypeError("can't store to 'int'",) @@ -17,6 +18,8 @@ ViperTypeError('must raise an object',) ViperTypeError('unary op __pos__ not implemented',) ViperTypeError('unary op __neg__ not implemented',) ViperTypeError('unary op __invert__ not implemented',) +ViperTypeError('div/mod not implemented for uint',) +ViperTypeError('div/mod not implemented for uint',) ViperTypeError('binary op not implemented',) NotImplementedError('native yield',) NotImplementedError('native yield',)