From 46cbe422285d8e1d44d0b39b7b53c31a9d37920d Mon Sep 17 00:00:00 2001 From: Fish Date: Fri, 6 Dec 2024 16:40:10 -0700 Subject: [PATCH] BackendZ3: Bypass integer string conversion limit. (#577) * BackendZ3: Bypass integer string conversion limit. CPython 3.11 introduced integer string conversion length limit (see https://docs.python.org/3/library/stdtypes.html#integer-string-conversion-length-limitation). We determined that this security protection does not apply to the threat model that angr would face, and in fact, causes issues when we deal with really long strings and integers. Therefore, we bypass the integer conversion limit in this PR. We also monkey-patch the Z3 Python binding so that it can accept long integers in constraints. * Lint code. * Use sys.get_int_max_str_digits() as the chunk size. * Lint code. --- claripy/backends/backend_z3.py | 81 +++++++++++++++++++++++++++++++++- tests/test_z3.py | 52 ++++++++++++++++++++++ 2 files changed, 131 insertions(+), 2 deletions(-) diff --git a/claripy/backends/backend_z3.py b/claripy/backends/backend_z3.py index 5e1d994bd..efe74b6ea 100644 --- a/claripy/backends/backend_z3.py +++ b/claripy/backends/backend_z3.py @@ -38,6 +38,7 @@ # pylint:disable=unidiomatic-typecheck ALL_Z3_CONTEXTS = weakref.WeakSet() +INT_STRING_CHUNK_SIZE: int | None = None # will be updated later if we are on CPython 3.11+ def handle_sigint(signals, frametype): @@ -97,6 +98,82 @@ def _add_memory_pressure(p): __pypy__.add_memory_pressure(p) +def int_to_str_unlimited(v: int) -> str: + """ + Convert an integer to a decimal string, without any size limit. + + :param v: The integer to convert. + :return: The string. + """ + + if INT_STRING_CHUNK_SIZE is None: + return str(v) + + if v == 0: + return "0" + + MOD = 10**INT_STRING_CHUNK_SIZE + v_str = "" + if v < 0: + is_negative = True + v = -v + else: + is_negative = False + while v > 0: + v_chunk = str(v % MOD) + v //= MOD + if v > 0: + v_chunk = v_chunk.zfill(INT_STRING_CHUNK_SIZE) + v_str = v_chunk + v_str + return v_str if not is_negative else "-" + v_str + + +def Z3_to_int_str(val): + # we will monkey-patch Z3 and replace Z3._to_int_str with this version, which is free of integer size limits. + + if isinstance(val, float): + return str(int(val)) + if isinstance(val, bool): + return "1" if val else "0" + return int_to_str_unlimited(val) + + +if hasattr(sys, "get_int_max_str_digits"): + # CPython 3.11+ + # monkey-patch Z3 so that it can accept long integers + z3.z3._to_int_str = Z3_to_int_str + # update INT_STRING_CHUNK_SIZE + INT_STRING_CHUNK_SIZE = sys.get_int_max_str_digits() + + +def str_to_int_unlimited(s: str) -> int: + """ + Convert a decimal string to an integer, without any size limit. + + :param s: The string to convert. + :return: The integer. + """ + if INT_STRING_CHUNK_SIZE is None: + return int(s) + + if not s: + return int(s) # an exception will be raised, which is intentional + + v = 0 + if s[0] == "-": + is_negative = True + s = s[1:] + else: + is_negative = False + + for i in range(0, len(s), INT_STRING_CHUNK_SIZE): + start = i + end = min(i + INT_STRING_CHUNK_SIZE, len(s)) + v *= 10 ** (end - start) + v += int(s[start:end], 10) + return v if not is_negative else -v + + # # Some global variables # @@ -457,7 +534,7 @@ def _abstract_internal(self, ctx, ast, split_on=None): bv_size = z3.Z3_get_bv_sort_size(ctx, z3_sort) if z3.Z3_get_numeral_uint64(ctx, ast, self._c_uint64_p): return claripy.BVV(self._c_uint64_p.contents.value, bv_size) - bv_num = int(z3.Z3_get_numeral_string(ctx, ast)) + bv_num = str_to_int_unlimited(z3.Z3_get_numeral_string(ctx, ast)) return claripy.BVV(bv_num, bv_size) if op_name in ("FPVal", "MinusZero", "MinusInf", "PlusZero", "PlusInf", "NaN"): ebits = z3.Z3_fpa_get_ebits(ctx, z3_sort) @@ -608,7 +685,7 @@ def _abstract_to_primitive(self, ctx, ast): def _abstract_bv_val(self, ctx, ast): if z3.Z3_get_numeral_uint64(ctx, ast, self._c_uint64_p): return self._c_uint64_p.contents.value - return int(z3.Z3_get_numeral_string(ctx, ast)) + return str_to_int_unlimited(z3.Z3_get_numeral_string(ctx, ast)) @staticmethod def _abstract_fp_val(ctx, ast, op_name): diff --git a/tests/test_z3.py b/tests/test_z3.py index 9cb295da7..5b6a8af2c 100644 --- a/tests/test_z3.py +++ b/tests/test_z3.py @@ -1,6 +1,7 @@ # pylint: disable=missing-class-docstring,no-self-use from __future__ import annotations +import sys import unittest import claripy @@ -35,6 +36,57 @@ def test_extrema(self): assert z.min(x, solver=s, extra_constraints=(x >= i,)) == i assert z.max(x, solver=s, extra_constraints=(x >= i,)) == range_[1] + def test_str2int(self): + """ + Test the str_to_int_unlimited function. + """ + + s2i = claripy.backends.backend_z3.str_to_int_unlimited + CHUNK_SIZE = sys.get_int_max_str_digits() if hasattr(sys, "get_int_max_str_digits") else 640 + + assert s2i("0") == 0 + assert s2i("1337") == 1337 + assert s2i("1337133713371337") == 1337133713371337 + assert s2i("1" + "0" * 639) == 10**639 + assert s2i("1" + "0" * 640) == 10**640 + assert s2i("1" + "0" * 641) == 10**641 + assert s2i("1" + "0" * 640 + "1") == 10**641 + 1 + assert s2i("1" + "0" * 8192) == 10**8192 + + assert s2i("1" + "0" * (CHUNK_SIZE - 1)) == 10 ** (CHUNK_SIZE - 1) + assert s2i("1" + "0" * CHUNK_SIZE) == 10**CHUNK_SIZE + assert s2i("1" + "0" * (CHUNK_SIZE + 1)) == 10 ** (CHUNK_SIZE + 1) + + assert s2i("-0") == 0 + assert s2i("-1") == -1 + assert s2i("-1" + "0" * CHUNK_SIZE) == -(10**CHUNK_SIZE) + + def test_int2str(self): + """ + Test the int_to_str_unlimited function. + """ + + i2s = claripy.backends.backend_z3.int_to_str_unlimited + CHUNK_SIZE = sys.get_int_max_str_digits() if hasattr(sys, "get_int_max_str_digits") else 640 + + assert i2s(0) == "0" + assert i2s(-1) == "-1" + assert i2s(1337) == "1337" + assert i2s(10**8192) == "1" + "0" * 8192 + assert i2s(10**CHUNK_SIZE) == "1" + "0" * CHUNK_SIZE + + def test_get_long_strings(self): + # Python 3.11 introduced limits in decimal-to-int conversion. we bypass it by using the str_to_int_unlimited + # method. + z3 = claripy.backends.backend_z3 + + s = claripy.backends.z3.solver() + backend = z3.BackendZ3() + x = claripy.BVS("x", 16385 * 8) + backend.add(s, [x == 10**16384]) + d = backend.eval(x, 1, solver=s) + assert d == [10**16384] + if __name__ == "__main__": unittest.main()