Source code for mythril.laser.plugin.plugins.instruction_profiler

from collections import namedtuple
from datetime import datetime
from typing import Dict, List, Tuple
from mythril.laser.plugin.builder import PluginBuilder
from mythril.laser.plugin.interface import LaserPlugin
from mythril.laser.ethereum.svm import LaserEVM
from mythril.laser.ethereum.state.global_state import GlobalState
from datetime import datetime
import logging

# Type annotations:
#   start_time: datetime
#   end_time: datetime
_InstrExecRecord = namedtuple("_InstrExecRecord", ["start_time", "end_time"])

# Type annotations:
#   total_time: float
#   total_nr: float
#   min_time: float
#   max_time: float
_InstrExecStatistic = namedtuple(
    "_InstrExecStatistic", ["total_time", "total_nr", "min_time", "max_time"]
)

# Map the instruction opcode to its records if all execution times
_InstrExecRecords = Dict[str, List[_InstrExecRecord]]

# Map the instruction opcode to the statistic of its execution times
_InstrExecStatistics = Dict[str, _InstrExecStatistic]

log = logging.getLogger(__name__)


[docs]class InstructionProfilerBuilder(PluginBuilder): name = "instruction-profiler" def __call__(self, *args, **kwargs): return InstructionProfiler()
[docs]class InstructionProfiler(LaserPlugin): """Performance profile for the execution of each instruction.""" def __init__(self): self._reset() def _reset(self): self.records = dict() self.start_time = None
[docs] def initialize(self, symbolic_vm: LaserEVM) -> None: @symbolic_vm.instr_hook("pre", None) def get_start_time(op_code: str): def start_time_wrapper(global_state: GlobalState): self.start_time = datetime.now() return start_time_wrapper @symbolic_vm.instr_hook("post", None) def record(op_code: str): def record_opcode(global_state: GlobalState): end_time = datetime.now() try: self.records[op_code].append( _InstrExecRecord(self.start_time, end_time) ) except KeyError: self.records[op_code] = [ _InstrExecRecord(self.start_time, end_time) ] return record_opcode @symbolic_vm.laser_hook("stop_sym_exec") def print_stats(): total, stats = self._make_stats() s = "Total: {} s\n".format(total) for op in sorted(stats): stat = stats[op] s += "[{:12s}] {:>8.4f} %, nr {:>6}, total {:>8.4f} s, avg {:>8.4f} s, min {:>8.4f} s, max {:>8.4f} s\n".format( op, stat.total_time * 100 / total, stat.total_nr, stat.total_time, stat.total_time / stat.total_nr, stat.min_time, stat.max_time, ) log.info(s)
def _make_stats(self) -> Tuple[float, _InstrExecStatistics]: periods = { op: list( map(lambda r: r.end_time.timestamp() - r.start_time.timestamp(), rs) ) for op, rs in self.records.items() } stats = dict() total_time = 0 for _, (op, times) in enumerate(periods.items()): stat = _InstrExecStatistic( total_time=sum(times), total_nr=len(times), min_time=min(times), max_time=max(times), ) total_time += stat.total_time stats[op] = stat return total_time, stats