Add more type hints to register printing logic (#3490)

* Clarify typing in register printing

* Fix additional errors related to None

* lint

* Add comment
pull/3482/merge
OBarronCS 1 day ago committed by GitHub
parent dddb82b075
commit 5a29ea00b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -18,6 +18,7 @@ from typing import TypeVar
import unicorn as U
from typing_extensions import ParamSpec
from typing_extensions import override
import pwndbg
import pwndbg.aglib.arch
@ -42,6 +43,9 @@ from pwndbg.color import ColorParamSpec
from pwndbg.color import message
from pwndbg.color import theme
from pwndbg.commands import CommandCategory
from pwndbg.lib.regs import BitFlags
from pwndbg.lib.regs import RegisterContextProtocol
from pwndbg.lib.regs import VisitableRegister
if pwndbg.dbg.is_gdblib_available():
import gdb
@ -823,7 +827,7 @@ def calculate_padding_to_align(length, align):
return 0 if length % align == 0 else (align - (length % align))
def compact_regs(regs, width=None, target=sys.stdout):
def compact_regs(regs: List[str], width=None, target=sys.stdout) -> List[str]:
columns = max(0, int(pwndbg.config.show_compact_regs_columns))
min_width = max(1, int(pwndbg.config.show_compact_regs_min_width))
separation = max(1, int(pwndbg.config.show_compact_regs_separation))
@ -844,7 +848,7 @@ def compact_regs(regs, width=None, target=sys.stdout):
# => min_width = (window_width - (columns - 1) * separation) / columns
min_width = max(min_width, (width - (columns - 1) * separation) // columns)
result = []
result: List[str] = []
line = ""
line_length = 0
@ -931,13 +935,13 @@ pwndbg.config.add_param("show-flags", False, "whether to show flags registers")
pwndbg.config.add_param("show-retaddr-reg", True, "whether to show return address register")
class RegisterContext:
class RegisterContext(RegisterContextProtocol):
changed: List[str]
def __init__(self):
self.changed = pwndbg.aglib.regs.changed
def get_prefix(self, reg):
def get_prefix(self, reg: str) -> str:
# Make the register stand out and give a color if changed
regname = C.register(reg.ljust(4).upper())
if reg in self.changed:
@ -952,14 +956,15 @@ class RegisterContext:
)
return f"{m}{regname}"
def get_register_value(self, reg):
def get_register_value(self, reg: str) -> int | None:
val = pwndbg.aglib.regs.read_reg(reg)
if val is None:
print(message.warn(f"Unknown register: {reg!r}"))
return None
return val
def flag_register_context(self, reg, bit_flags):
@override
def flag_register_context(self, reg: str, bit_flags: BitFlags) -> str | None:
val = self.get_register_value(reg)
if val is None:
return None
@ -967,7 +972,8 @@ class RegisterContext:
prefix = self.get_prefix(reg)
return f"{prefix} {desc}"
def segment_registers_context(self, regs):
@override
def segment_registers_context(self, regs: list[str]) -> str | None:
result = ""
for reg in regs:
val = self.get_register_value(reg)
@ -977,7 +983,8 @@ class RegisterContext:
result += f"{prefix} {hex(val)} "
return result
def addressing_register_context(self, reg, is_virtual):
@override
def addressing_register_context(self, reg: str, is_virtual: bool) -> str | None:
if is_virtual:
return self.register_context_default(reg)
val = self.get_register_value(reg)
@ -993,7 +1000,7 @@ class RegisterContext:
pass
return f"{prefix} {desc}"
def register_context_default(self, reg):
def register_context_default(self, reg: str) -> str | None:
val = self.get_register_value(reg)
if val is None:
return None
@ -1003,14 +1010,15 @@ class RegisterContext:
return f"{prefix} {desc}"
def get_regs(regs: List[str] = None):
regs: List[Any] = regs
result = []
def get_regs(in_regs: List[str | VisitableRegister | None] | None = None):
# Python default parameters are instantiated once and shared across calls.
# Instead of a default value of [], we need to do this check so we get a fresh list each time
if in_regs is None:
in_regs = []
regs: List[str | VisitableRegister | None] = in_regs
result: List[str] = []
rc = RegisterContext()
if regs is None:
regs = []
if len(regs) == 0:
regs += pwndbg.aglib.regs.gpr
@ -1046,6 +1054,7 @@ def get_regs(regs: List[str] = None):
for reg in regs:
if reg is None:
continue
# If it's a VisitableRegister which has special logic to determine what to print
if not isinstance(reg, str):
desc = reg.context(rc)
if desc is not None:

@ -12,15 +12,39 @@ from typing import Dict
from typing import Iterator
from typing import List
from typing import OrderedDict
from typing import Protocol
from typing import Set
from typing import Tuple
from typing import Union
from typing_extensions import override
import pwndbg.lib.disasm.helpers as bit_math
from pwndbg.lib.arch import PWNDBG_SUPPORTED_ARCHITECTURES_TYPE
# The printing logic for registers uses the Visitor Pattern
# An implementation of RegisterContextProtocol is defined outside of this class
# (this is a lib/ file, so it shouldn't directly be able to access the process)
#
# Instances of VisitableRegister will call the methods of RegisterContextProtocol to do their logic.
class RegisterContextProtocol(Protocol):
def flag_register_context(self, reg: str, bit_flags: BitFlags) -> str | None:
...
def addressing_register_context(self, reg: str, is_virtual: bool) -> str | None:
...
class BitFlags:
def segment_registers_context(self,regs: list[str]) -> str | None:
...
# Represents a register or a set of registers that can be printed in the context register view
class VisitableRegister(Protocol):
def context(self, rc: RegisterContextProtocol) -> str | None:
...
class BitFlags(VisitableRegister):
# this is intentionally uninitialized -- arm uses the same self.flags structuture for different registers
# for example
# - aarch64_cpsr_flags is used for "cpsr", "spsr_el1", "spsr_el2", "spsr_el3"
@ -29,29 +53,27 @@ class BitFlags:
flags: OrderedDict[str, Union[int, Tuple[int, int]]]
value: int
def __init__(self, flags: List[Tuple[str, Union[int, Tuple[int, int]]]] = [], value=None):
def __init__(self, flags: List[Tuple[str, Union[int, Tuple[int, int]]]] = []):
self.regname = ""
self.flags = {}
self.flags = OrderedDict()
for name, bits in flags:
self.flags[name] = bits
self.value = value
self.value = 0
def __getattr__(self, name):
if name in {"regname"}:
return self.__dict__[name]
def __getattr__(self, name: str):
return getattr(self.flags, name)
def __getitem__(self, key):
def __getitem__(self, key: str) -> int:
r = self.flags[key]
if isinstance(r, int):
return (self.value >> r) & 1
s, e = r
return ((~((1 << s) - 1) & ((1 << (e + 1)) - 1)) & self.value) >> s
def __setitem__(self, key, value):
def __setitem__(self, key: str, value: int) -> None:
self.flags[key] = value
def __delitem__(self, key):
def __delitem__(self, key: str):
del self.flags[key]
def __iter__(self):
@ -63,14 +85,15 @@ class BitFlags:
def __repr__(self):
return f"BitFlags({self.flags})"
def update(self, regname: str):
def update(self, regname: str) -> None:
self.regname = regname
def context(self, rc):
@override
def context(self, rc: RegisterContextProtocol) -> str | None:
return rc.flag_register_context(self.regname, self)
class AddressingRegister:
class AddressingRegister(VisitableRegister):
"""
Represents a register that is used to store an address, e.g. cr3, gsbase, fsbase
"""
@ -84,14 +107,15 @@ class AddressingRegister:
self.value = 0
self.is_virtual = is_virtual
def update(self, regname: str):
def update(self, regname: str) -> None:
pass
def context(self, rc):
@override
def context(self, rc: RegisterContextProtocol) -> str | None:
return rc.addressing_register_context(self.reg, self.is_virtual)
class SegmentRegisters:
class SegmentRegisters(VisitableRegister):
"""
Represents the x86 segment register set
"""
@ -101,7 +125,8 @@ class SegmentRegisters:
def __init__(self, regs: List[str]):
self.regs = regs
def context(self, rc):
@override
def context(self, rc: RegisterContextProtocol) -> str | None:
return rc.segment_registers_context(self.regs)
@ -122,7 +147,7 @@ class KernelRegisterSet:
def __init__(
self,
segments: SegmentRegisters | None,
segments: SegmentRegisters,
controls: Dict[str, BitFlags | AddressingRegister] = {},
msrs: Dict[str, BitFlags | AddressingRegister] = {},
):

Loading…
Cancel
Save