Source code for mythril.laser.smt.bitvec_helper

from typing import Union, overload, List, Set, cast, Any, Callable
import z3

from mythril.laser.smt.bool import Bool, Or
from mythril.laser.smt.bitvec import BitVec
from mythril.laser.smt.array import BaseArray, Array

Annotations = Set[Any]


def _z3_array_converter(array: Union[z3.Array, z3.K]) -> Array:
    new_array = Array(
        "name_to_be_overwritten", array.domain().size(), array.range().size()
    )
    new_array.raw = array
    return new_array


def _comparison_helper(a: BitVec, b: BitVec, operation: Callable) -> Bool:
    annotations = a.annotations.union(b.annotations)
    return Bool(operation(a.raw, b.raw), annotations)


def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec:
    raw = operation(a.raw, b.raw)
    union = a.annotations.union(b.annotations)
    return BitVec(raw, annotations=union)


[docs]def LShR(a: BitVec, b: BitVec): return _arithmetic_helper(a, b, z3.LShR)
@overload def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec: ... @overload def If(a: Union[Bool, bool], b: BaseArray, c: BaseArray) -> BaseArray: ...
[docs]def If( a: Union[Bool, bool], b: Union[BaseArray, BitVec, int], c: Union[BaseArray, BitVec, int], ) -> Union[BitVec, BaseArray]: """Create an if-then-else expression. :param a: :param b: :param c: :return: """ if not isinstance(a, Bool): a = Bool(z3.BoolVal(a)) if isinstance(b, BaseArray) and isinstance(c, BaseArray): array = z3.If(a.raw, b.raw, c.raw) return _z3_array_converter(array) default_sort_size = 256 if isinstance(b, BitVec): default_sort_size = b.size() if isinstance(c, BitVec): default_sort_size = c.size() if not isinstance(b, BitVec): b = BitVec(z3.BitVecVal(b, default_sort_size)) if not isinstance(c, BitVec): c = BitVec(z3.BitVecVal(c, default_sort_size)) union = a.annotations.union(b.annotations).union(c.annotations) return BitVec(z3.If(a.raw, b.raw, c.raw), union)
[docs]def UGT(a: BitVec, b: BitVec) -> Bool: """Create an unsigned greater than expression. :param a: :param b: :return: """ return _comparison_helper(a, b, z3.UGT)
[docs]def UGE(a: BitVec, b: BitVec) -> Bool: """Create an unsigned greater than or equal to expression. :param a: :param b: :return: """ return Or(UGT(a, b), a == b)
[docs]def ULT(a: BitVec, b: BitVec) -> Bool: """Create an unsigned less than expression. :param a: :param b: :return: """ return _comparison_helper(a, b, z3.ULT)
[docs]def ULE(a: BitVec, b: BitVec) -> Bool: """Create an unsigned less than or equal to expression. :param a: :param b: :return: """ return Or(ULT(a, b), a == b)
@overload def Concat(*args: List[BitVec]) -> BitVec: ... @overload def Concat(*args: BitVec) -> BitVec: ...
[docs]def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: """Create a concatenation expression. :param args: :return: """ # The following statement is used if a list is provided as an argument to concat if len(args) == 1 and isinstance(args[0], list): bvs: List[BitVec] = args[0] else: bvs = cast(List[BitVec], args) nraw = z3.Concat([a.raw for a in bvs]) annotations: Annotations = set() for bv in bvs: annotations = annotations.union(bv.annotations) return BitVec(nraw, annotations)
[docs]def Extract(high: int, low: int, bv: BitVec) -> BitVec: """Create an extract expression. :param high: :param low: :param bv: :return: """ raw = z3.Extract(high, low, bv.raw) return BitVec(raw, annotations=bv.annotations)
[docs]def URem(a: BitVec, b: BitVec) -> BitVec: """Create an unsigned remainder expression. :param a: :param b: :return: """ return _arithmetic_helper(a, b, z3.URem)
[docs]def SRem(a: BitVec, b: BitVec) -> BitVec: """Create a signed remainder expression. :param a: :param b: :return: """ return _arithmetic_helper(a, b, z3.SRem)
[docs]def UDiv(a: BitVec, b: BitVec) -> BitVec: """Create an unsigned division expression. :param a: :param b: :return: """ return _arithmetic_helper(a, b, z3.UDiv)
[docs]def Sum(*args: BitVec) -> BitVec: """Create sum expression. :return: """ raw = z3.Sum([a.raw for a in args]) annotations = set() # type: Annotations for bv in args: annotations = annotations.union(bv.annotations) return BitVec(raw, annotations)
[docs]def BVAddNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool: """Creates predicate that verifies that the addition doesn't overflow. :param a: :param b: :param signed: :return: """ if not isinstance(a, BitVec): a = BitVec(z3.BitVecVal(a, 256)) if not isinstance(b, BitVec): b = BitVec(z3.BitVecVal(b, 256)) return Bool(z3.BVAddNoOverflow(a.raw, b.raw, signed))
[docs]def BVMulNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool: """Creates predicate that verifies that the multiplication doesn't overflow. :param a: :param b: :param signed: :return: """ if not isinstance(a, BitVec): a = BitVec(z3.BitVecVal(a, 256)) if not isinstance(b, BitVec): b = BitVec(z3.BitVecVal(b, 256)) return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed))
[docs]def BVSubNoUnderflow( a: Union[BitVec, int], b: Union[BitVec, int], signed: bool ) -> Bool: """Creates predicate that verifies that the subtraction doesn't overflow. :param a: :param b: :param signed: :return: """ if not isinstance(a, BitVec): a = BitVec(z3.BitVecVal(a, 256)) if not isinstance(b, BitVec): b = BitVec(z3.BitVecVal(b, 256)) return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed))