From 5a29ea00b542cacb90fc4fb5f67f499b30b5b85f Mon Sep 17 00:00:00 2001 From: OBarronCS <55004530+OBarronCS@users.noreply.github.com> Date: Fri, 12 Dec 2025 21:57:32 -0500 Subject: [PATCH] Add more type hints to register printing logic (#3490) * Clarify typing in register printing * Fix additional errors related to None * lint * Add comment --- pwndbg/commands/context.py | 41 +++++++++++++++---------- pwndbg/lib/regs.py | 63 ++++++++++++++++++++++++++------------ 2 files changed, 69 insertions(+), 35 deletions(-) diff --git a/pwndbg/commands/context.py b/pwndbg/commands/context.py index 9de6386a3..a2c66531c 100644 --- a/pwndbg/commands/context.py +++ b/pwndbg/commands/context.py @@ -18,6 +18,7 @@ from typing import TypeVar import unicorn as U from typing_extensions import ParamSpec +from typing_extensions import override import pwndbg import pwndbg.aglib.arch @@ -42,6 +43,9 @@ from pwndbg.color import ColorParamSpec from pwndbg.color import message from pwndbg.color import theme from pwndbg.commands import CommandCategory +from pwndbg.lib.regs import BitFlags +from pwndbg.lib.regs import RegisterContextProtocol +from pwndbg.lib.regs import VisitableRegister if pwndbg.dbg.is_gdblib_available(): import gdb @@ -823,7 +827,7 @@ def calculate_padding_to_align(length, align): return 0 if length % align == 0 else (align - (length % align)) -def compact_regs(regs, width=None, target=sys.stdout): +def compact_regs(regs: List[str], width=None, target=sys.stdout) -> List[str]: columns = max(0, int(pwndbg.config.show_compact_regs_columns)) min_width = max(1, int(pwndbg.config.show_compact_regs_min_width)) separation = max(1, int(pwndbg.config.show_compact_regs_separation)) @@ -844,7 +848,7 @@ def compact_regs(regs, width=None, target=sys.stdout): # => min_width = (window_width - (columns - 1) * separation) / columns min_width = max(min_width, (width - (columns - 1) * separation) // columns) - result = [] + result: List[str] = [] line = "" line_length = 0 @@ -931,13 +935,13 @@ pwndbg.config.add_param("show-flags", False, "whether to show flags registers") pwndbg.config.add_param("show-retaddr-reg", True, "whether to show return address register") -class RegisterContext: +class RegisterContext(RegisterContextProtocol): changed: List[str] def __init__(self): self.changed = pwndbg.aglib.regs.changed - def get_prefix(self, reg): + def get_prefix(self, reg: str) -> str: # Make the register stand out and give a color if changed regname = C.register(reg.ljust(4).upper()) if reg in self.changed: @@ -952,14 +956,15 @@ class RegisterContext: ) return f"{m}{regname}" - def get_register_value(self, reg): + def get_register_value(self, reg: str) -> int | None: val = pwndbg.aglib.regs.read_reg(reg) if val is None: print(message.warn(f"Unknown register: {reg!r}")) return None return val - def flag_register_context(self, reg, bit_flags): + @override + def flag_register_context(self, reg: str, bit_flags: BitFlags) -> str | None: val = self.get_register_value(reg) if val is None: return None @@ -967,7 +972,8 @@ class RegisterContext: prefix = self.get_prefix(reg) return f"{prefix} {desc}" - def segment_registers_context(self, regs): + @override + def segment_registers_context(self, regs: list[str]) -> str | None: result = "" for reg in regs: val = self.get_register_value(reg) @@ -977,7 +983,8 @@ class RegisterContext: result += f"{prefix} {hex(val)} " return result - def addressing_register_context(self, reg, is_virtual): + @override + def addressing_register_context(self, reg: str, is_virtual: bool) -> str | None: if is_virtual: return self.register_context_default(reg) val = self.get_register_value(reg) @@ -993,7 +1000,7 @@ class RegisterContext: pass return f"{prefix} {desc}" - def register_context_default(self, reg): + def register_context_default(self, reg: str) -> str | None: val = self.get_register_value(reg) if val is None: return None @@ -1003,14 +1010,15 @@ class RegisterContext: return f"{prefix} {desc}" -def get_regs(regs: List[str] = None): - regs: List[Any] = regs - result = [] +def get_regs(in_regs: List[str | VisitableRegister | None] | None = None): + # Python default parameters are instantiated once and shared across calls. + # Instead of a default value of [], we need to do this check so we get a fresh list each time + if in_regs is None: + in_regs = [] + regs: List[str | VisitableRegister | None] = in_regs + result: List[str] = [] rc = RegisterContext() - if regs is None: - regs = [] - if len(regs) == 0: regs += pwndbg.aglib.regs.gpr @@ -1046,11 +1054,12 @@ def get_regs(regs: List[str] = None): for reg in regs: if reg is None: continue + # If it's a VisitableRegister which has special logic to determine what to print if not isinstance(reg, str): desc = reg.context(rc) if desc is not None: result.append(desc) - continue + continue desc = rc.register_context_default(reg) if desc is not None: result.append(desc) diff --git a/pwndbg/lib/regs.py b/pwndbg/lib/regs.py index 88c68458b..fcb09731d 100644 --- a/pwndbg/lib/regs.py +++ b/pwndbg/lib/regs.py @@ -12,15 +12,39 @@ from typing import Dict from typing import Iterator from typing import List from typing import OrderedDict +from typing import Protocol from typing import Set from typing import Tuple from typing import Union +from typing_extensions import override + import pwndbg.lib.disasm.helpers as bit_math from pwndbg.lib.arch import PWNDBG_SUPPORTED_ARCHITECTURES_TYPE +# The printing logic for registers uses the Visitor Pattern +# An implementation of RegisterContextProtocol is defined outside of this class +# (this is a lib/ file, so it shouldn't directly be able to access the process) +# +# Instances of VisitableRegister will call the methods of RegisterContextProtocol to do their logic. + +class RegisterContextProtocol(Protocol): + def flag_register_context(self, reg: str, bit_flags: BitFlags) -> str | None: + ... + + def addressing_register_context(self, reg: str, is_virtual: bool) -> str | None: + ... + + def segment_registers_context(self,regs: list[str]) -> str | None: + ... -class BitFlags: + +# Represents a register or a set of registers that can be printed in the context register view +class VisitableRegister(Protocol): + def context(self, rc: RegisterContextProtocol) -> str | None: + ... + +class BitFlags(VisitableRegister): # this is intentionally uninitialized -- arm uses the same self.flags structuture for different registers # for example # - aarch64_cpsr_flags is used for "cpsr", "spsr_el1", "spsr_el2", "spsr_el3" @@ -29,29 +53,27 @@ class BitFlags: flags: OrderedDict[str, Union[int, Tuple[int, int]]] value: int - def __init__(self, flags: List[Tuple[str, Union[int, Tuple[int, int]]]] = [], value=None): + def __init__(self, flags: List[Tuple[str, Union[int, Tuple[int, int]]]] = []): self.regname = "" - self.flags = {} + self.flags = OrderedDict() for name, bits in flags: self.flags[name] = bits - self.value = value + self.value = 0 - def __getattr__(self, name): - if name in {"regname"}: - return self.__dict__[name] + def __getattr__(self, name: str): return getattr(self.flags, name) - def __getitem__(self, key): + def __getitem__(self, key: str) -> int: r = self.flags[key] if isinstance(r, int): return (self.value >> r) & 1 s, e = r return ((~((1 << s) - 1) & ((1 << (e + 1)) - 1)) & self.value) >> s - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: int) -> None: self.flags[key] = value - def __delitem__(self, key): + def __delitem__(self, key: str): del self.flags[key] def __iter__(self): @@ -63,14 +85,15 @@ class BitFlags: def __repr__(self): return f"BitFlags({self.flags})" - def update(self, regname: str): + def update(self, regname: str) -> None: self.regname = regname - def context(self, rc): + @override + def context(self, rc: RegisterContextProtocol) -> str | None: return rc.flag_register_context(self.regname, self) -class AddressingRegister: +class AddressingRegister(VisitableRegister): """ Represents a register that is used to store an address, e.g. cr3, gsbase, fsbase """ @@ -84,14 +107,15 @@ class AddressingRegister: self.value = 0 self.is_virtual = is_virtual - def update(self, regname: str): + def update(self, regname: str) -> None: pass - def context(self, rc): + @override + def context(self, rc: RegisterContextProtocol) -> str | None: return rc.addressing_register_context(self.reg, self.is_virtual) -class SegmentRegisters: +class SegmentRegisters(VisitableRegister): """ Represents the x86 segment register set """ @@ -101,7 +125,8 @@ class SegmentRegisters: def __init__(self, regs: List[str]): self.regs = regs - def context(self, rc): + @override + def context(self, rc: RegisterContextProtocol) -> str | None: return rc.segment_registers_context(self.regs) @@ -122,7 +147,7 @@ class KernelRegisterSet: def __init__( self, - segments: SegmentRegisters | None, + segments: SegmentRegisters, controls: Dict[str, BitFlags | AddressingRegister] = {}, msrs: Dict[str, BitFlags | AddressingRegister] = {}, ): @@ -262,7 +287,7 @@ class RegisterSet: # Otherwise, the values will be clobbered # https://github.com/pwndbg/pwndbg/pull/2337 self.emulated_regs_order: List[UnicornRegisterWrite] = [] - + # Avoid duplicates seen_emulated_register: set[str] = set()