Add more type hints to register printing logic (#3490)

* Clarify typing in register printing

* Fix additional errors related to None

* lint

* Add comment
pull/3482/merge
OBarronCS 1 day ago committed by GitHub
parent dddb82b075
commit 5a29ea00b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -18,6 +18,7 @@ from typing import TypeVar
import unicorn as U import unicorn as U
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from typing_extensions import override
import pwndbg import pwndbg
import pwndbg.aglib.arch import pwndbg.aglib.arch
@ -42,6 +43,9 @@ from pwndbg.color import ColorParamSpec
from pwndbg.color import message from pwndbg.color import message
from pwndbg.color import theme from pwndbg.color import theme
from pwndbg.commands import CommandCategory from pwndbg.commands import CommandCategory
from pwndbg.lib.regs import BitFlags
from pwndbg.lib.regs import RegisterContextProtocol
from pwndbg.lib.regs import VisitableRegister
if pwndbg.dbg.is_gdblib_available(): if pwndbg.dbg.is_gdblib_available():
import gdb import gdb
@ -823,7 +827,7 @@ def calculate_padding_to_align(length, align):
return 0 if length % align == 0 else (align - (length % align)) return 0 if length % align == 0 else (align - (length % align))
def compact_regs(regs, width=None, target=sys.stdout): def compact_regs(regs: List[str], width=None, target=sys.stdout) -> List[str]:
columns = max(0, int(pwndbg.config.show_compact_regs_columns)) columns = max(0, int(pwndbg.config.show_compact_regs_columns))
min_width = max(1, int(pwndbg.config.show_compact_regs_min_width)) min_width = max(1, int(pwndbg.config.show_compact_regs_min_width))
separation = max(1, int(pwndbg.config.show_compact_regs_separation)) separation = max(1, int(pwndbg.config.show_compact_regs_separation))
@ -844,7 +848,7 @@ def compact_regs(regs, width=None, target=sys.stdout):
# => min_width = (window_width - (columns - 1) * separation) / columns # => min_width = (window_width - (columns - 1) * separation) / columns
min_width = max(min_width, (width - (columns - 1) * separation) // columns) min_width = max(min_width, (width - (columns - 1) * separation) // columns)
result = [] result: List[str] = []
line = "" line = ""
line_length = 0 line_length = 0
@ -931,13 +935,13 @@ pwndbg.config.add_param("show-flags", False, "whether to show flags registers")
pwndbg.config.add_param("show-retaddr-reg", True, "whether to show return address register") pwndbg.config.add_param("show-retaddr-reg", True, "whether to show return address register")
class RegisterContext: class RegisterContext(RegisterContextProtocol):
changed: List[str] changed: List[str]
def __init__(self): def __init__(self):
self.changed = pwndbg.aglib.regs.changed self.changed = pwndbg.aglib.regs.changed
def get_prefix(self, reg): def get_prefix(self, reg: str) -> str:
# Make the register stand out and give a color if changed # Make the register stand out and give a color if changed
regname = C.register(reg.ljust(4).upper()) regname = C.register(reg.ljust(4).upper())
if reg in self.changed: if reg in self.changed:
@ -952,14 +956,15 @@ class RegisterContext:
) )
return f"{m}{regname}" return f"{m}{regname}"
def get_register_value(self, reg): def get_register_value(self, reg: str) -> int | None:
val = pwndbg.aglib.regs.read_reg(reg) val = pwndbg.aglib.regs.read_reg(reg)
if val is None: if val is None:
print(message.warn(f"Unknown register: {reg!r}")) print(message.warn(f"Unknown register: {reg!r}"))
return None return None
return val return val
def flag_register_context(self, reg, bit_flags): @override
def flag_register_context(self, reg: str, bit_flags: BitFlags) -> str | None:
val = self.get_register_value(reg) val = self.get_register_value(reg)
if val is None: if val is None:
return None return None
@ -967,7 +972,8 @@ class RegisterContext:
prefix = self.get_prefix(reg) prefix = self.get_prefix(reg)
return f"{prefix} {desc}" return f"{prefix} {desc}"
def segment_registers_context(self, regs): @override
def segment_registers_context(self, regs: list[str]) -> str | None:
result = "" result = ""
for reg in regs: for reg in regs:
val = self.get_register_value(reg) val = self.get_register_value(reg)
@ -977,7 +983,8 @@ class RegisterContext:
result += f"{prefix} {hex(val)} " result += f"{prefix} {hex(val)} "
return result return result
def addressing_register_context(self, reg, is_virtual): @override
def addressing_register_context(self, reg: str, is_virtual: bool) -> str | None:
if is_virtual: if is_virtual:
return self.register_context_default(reg) return self.register_context_default(reg)
val = self.get_register_value(reg) val = self.get_register_value(reg)
@ -993,7 +1000,7 @@ class RegisterContext:
pass pass
return f"{prefix} {desc}" return f"{prefix} {desc}"
def register_context_default(self, reg): def register_context_default(self, reg: str) -> str | None:
val = self.get_register_value(reg) val = self.get_register_value(reg)
if val is None: if val is None:
return None return None
@ -1003,14 +1010,15 @@ class RegisterContext:
return f"{prefix} {desc}" return f"{prefix} {desc}"
def get_regs(regs: List[str] = None): def get_regs(in_regs: List[str | VisitableRegister | None] | None = None):
regs: List[Any] = regs # Python default parameters are instantiated once and shared across calls.
result = [] # Instead of a default value of [], we need to do this check so we get a fresh list each time
if in_regs is None:
in_regs = []
regs: List[str | VisitableRegister | None] = in_regs
result: List[str] = []
rc = RegisterContext() rc = RegisterContext()
if regs is None:
regs = []
if len(regs) == 0: if len(regs) == 0:
regs += pwndbg.aglib.regs.gpr regs += pwndbg.aglib.regs.gpr
@ -1046,11 +1054,12 @@ def get_regs(regs: List[str] = None):
for reg in regs: for reg in regs:
if reg is None: if reg is None:
continue continue
# If it's a VisitableRegister which has special logic to determine what to print
if not isinstance(reg, str): if not isinstance(reg, str):
desc = reg.context(rc) desc = reg.context(rc)
if desc is not None: if desc is not None:
result.append(desc) result.append(desc)
continue continue
desc = rc.register_context_default(reg) desc = rc.register_context_default(reg)
if desc is not None: if desc is not None:
result.append(desc) result.append(desc)

@ -12,15 +12,39 @@ from typing import Dict
from typing import Iterator from typing import Iterator
from typing import List from typing import List
from typing import OrderedDict from typing import OrderedDict
from typing import Protocol
from typing import Set from typing import Set
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
from typing_extensions import override
import pwndbg.lib.disasm.helpers as bit_math import pwndbg.lib.disasm.helpers as bit_math
from pwndbg.lib.arch import PWNDBG_SUPPORTED_ARCHITECTURES_TYPE from pwndbg.lib.arch import PWNDBG_SUPPORTED_ARCHITECTURES_TYPE
# The printing logic for registers uses the Visitor Pattern
# An implementation of RegisterContextProtocol is defined outside of this class
# (this is a lib/ file, so it shouldn't directly be able to access the process)
#
# Instances of VisitableRegister will call the methods of RegisterContextProtocol to do their logic.
class RegisterContextProtocol(Protocol):
def flag_register_context(self, reg: str, bit_flags: BitFlags) -> str | None:
...
def addressing_register_context(self, reg: str, is_virtual: bool) -> str | None:
...
def segment_registers_context(self,regs: list[str]) -> str | None:
...
class BitFlags:
# Represents a register or a set of registers that can be printed in the context register view
class VisitableRegister(Protocol):
def context(self, rc: RegisterContextProtocol) -> str | None:
...
class BitFlags(VisitableRegister):
# this is intentionally uninitialized -- arm uses the same self.flags structuture for different registers # this is intentionally uninitialized -- arm uses the same self.flags structuture for different registers
# for example # for example
# - aarch64_cpsr_flags is used for "cpsr", "spsr_el1", "spsr_el2", "spsr_el3" # - aarch64_cpsr_flags is used for "cpsr", "spsr_el1", "spsr_el2", "spsr_el3"
@ -29,29 +53,27 @@ class BitFlags:
flags: OrderedDict[str, Union[int, Tuple[int, int]]] flags: OrderedDict[str, Union[int, Tuple[int, int]]]
value: int value: int
def __init__(self, flags: List[Tuple[str, Union[int, Tuple[int, int]]]] = [], value=None): def __init__(self, flags: List[Tuple[str, Union[int, Tuple[int, int]]]] = []):
self.regname = "" self.regname = ""
self.flags = {} self.flags = OrderedDict()
for name, bits in flags: for name, bits in flags:
self.flags[name] = bits self.flags[name] = bits
self.value = value self.value = 0
def __getattr__(self, name): def __getattr__(self, name: str):
if name in {"regname"}:
return self.__dict__[name]
return getattr(self.flags, name) return getattr(self.flags, name)
def __getitem__(self, key): def __getitem__(self, key: str) -> int:
r = self.flags[key] r = self.flags[key]
if isinstance(r, int): if isinstance(r, int):
return (self.value >> r) & 1 return (self.value >> r) & 1
s, e = r s, e = r
return ((~((1 << s) - 1) & ((1 << (e + 1)) - 1)) & self.value) >> s return ((~((1 << s) - 1) & ((1 << (e + 1)) - 1)) & self.value) >> s
def __setitem__(self, key, value): def __setitem__(self, key: str, value: int) -> None:
self.flags[key] = value self.flags[key] = value
def __delitem__(self, key): def __delitem__(self, key: str):
del self.flags[key] del self.flags[key]
def __iter__(self): def __iter__(self):
@ -63,14 +85,15 @@ class BitFlags:
def __repr__(self): def __repr__(self):
return f"BitFlags({self.flags})" return f"BitFlags({self.flags})"
def update(self, regname: str): def update(self, regname: str) -> None:
self.regname = regname self.regname = regname
def context(self, rc): @override
def context(self, rc: RegisterContextProtocol) -> str | None:
return rc.flag_register_context(self.regname, self) return rc.flag_register_context(self.regname, self)
class AddressingRegister: class AddressingRegister(VisitableRegister):
""" """
Represents a register that is used to store an address, e.g. cr3, gsbase, fsbase Represents a register that is used to store an address, e.g. cr3, gsbase, fsbase
""" """
@ -84,14 +107,15 @@ class AddressingRegister:
self.value = 0 self.value = 0
self.is_virtual = is_virtual self.is_virtual = is_virtual
def update(self, regname: str): def update(self, regname: str) -> None:
pass pass
def context(self, rc): @override
def context(self, rc: RegisterContextProtocol) -> str | None:
return rc.addressing_register_context(self.reg, self.is_virtual) return rc.addressing_register_context(self.reg, self.is_virtual)
class SegmentRegisters: class SegmentRegisters(VisitableRegister):
""" """
Represents the x86 segment register set Represents the x86 segment register set
""" """
@ -101,7 +125,8 @@ class SegmentRegisters:
def __init__(self, regs: List[str]): def __init__(self, regs: List[str]):
self.regs = regs self.regs = regs
def context(self, rc): @override
def context(self, rc: RegisterContextProtocol) -> str | None:
return rc.segment_registers_context(self.regs) return rc.segment_registers_context(self.regs)
@ -122,7 +147,7 @@ class KernelRegisterSet:
def __init__( def __init__(
self, self,
segments: SegmentRegisters | None, segments: SegmentRegisters,
controls: Dict[str, BitFlags | AddressingRegister] = {}, controls: Dict[str, BitFlags | AddressingRegister] = {},
msrs: Dict[str, BitFlags | AddressingRegister] = {}, msrs: Dict[str, BitFlags | AddressingRegister] = {},
): ):
@ -262,7 +287,7 @@ class RegisterSet:
# Otherwise, the values will be clobbered # Otherwise, the values will be clobbered
# https://github.com/pwndbg/pwndbg/pull/2337 # https://github.com/pwndbg/pwndbg/pull/2337
self.emulated_regs_order: List[UnicornRegisterWrite] = [] self.emulated_regs_order: List[UnicornRegisterWrite] = []
# Avoid duplicates # Avoid duplicates
seen_emulated_register: set[str] = set() seen_emulated_register: set[str] = set()

Loading…
Cancel
Save