Rework cancellation in pwndbg-lldb (#3387)

pull/3408/head^2
Matt. 4 weeks ago committed by GitHub
parent b843b3f158
commit 5369d4149a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -7,6 +7,7 @@ import random
import re
import shlex
import sys
from asyncio import CancelledError
from contextlib import contextmanager
from typing import Any
from typing import Awaitable
@ -1866,6 +1867,9 @@ class LLDB(pwndbg.dbg_mod.Debugger):
# Relay used for exceptions originating from commands called through LLDB.
_exception_relay: BaseException | None
in_lldb_command_handler: bool
"Whether an LLDB command handler is currently running"
# temporarily suspend context output
should_suspend_ctx: bool
@ -1880,6 +1884,7 @@ class LLDB(pwndbg.dbg_mod.Debugger):
self.controllers = []
self._current_process_is_gdb_remote = False
self._exception_relay = None
self.in_lldb_command_handler = False
self.should_suspend_ctx = False
import pwndbg
@ -1917,6 +1922,10 @@ class LLDB(pwndbg.dbg_mod.Debugger):
e = self._exception_relay
self._exception_relay = None
if e is not None:
if isinstance(e, CancelledError):
# Cancellations are meaningful to the CLI, raise them unchanged.
raise e
raise pwndbg.dbg_mod.Error(e)
def _execute_lldb_command(self, command: str) -> str:
@ -1950,13 +1959,15 @@ class LLDB(pwndbg.dbg_mod.Debugger):
def __call__(self, _, command, exe_context, result):
try:
debugger.exec_states.append(exe_context)
debugger.in_lldb_command_handler = True
handler(debugger, command, True)
except BaseException as e:
debugger._exception_relay = e
finally:
debugger.in_lldb_command_handler = False
assert (
debugger.exec_states.pop() == exe_context
), "Execution state mismatch on command handler"
except BaseException as e:
debugger._exception_relay = e
raise
# LLDB is very particular with the object paths it will accept. It is at
# its happiest when its pulling objects straight off the module that was

@ -44,6 +44,7 @@ import shutil
import signal
import sys
import threading
from asyncio import CancelledError
from contextlib import contextmanager
from io import BytesIO
from io import TextIOBase
@ -85,6 +86,17 @@ else:
from pwndbg.dbg.lldb.repl.readline import wrap_with_history
class UserCancelledError(CancelledError):
"""
Internal cancellation exception used by the LLDB CLI.
Sometimes, it's necessary to cancel both commands and subroutines mid-execution. This is an internal exception type
that is used in these conditions.
"""
pass
def print_error(msg: str, *args):
"""
Print an error message in the style of the LLDB CLI.
@ -325,10 +337,7 @@ def run(
# Set ourselves up to respond to SIGINT by interrupting the process if it is
# running, and doing nothing otherwise.
def handle_sigint(_sig, _frame):
driver.cancel()
if driver.has_process():
driver.interrupt()
print()
driver.interrupt(in_lldb_command_handler=dbg.in_lldb_command_handler)
signal.signal(signal.SIGINT, handle_sigint)
@ -340,6 +349,7 @@ def run(
last_exc: Exception | None = None
while True:
try:
# Execute the prompt hook.
dbg._fire_prompt_hook()
@ -413,6 +423,8 @@ def run(
if not should_continue:
last_exc = asyncio.CancelledError()
continue
except UserCancelledError as e:
last_exc = e
def exec_repl_command(

@ -1,5 +1,6 @@
from __future__ import annotations
import contextlib
import sys
from asyncio import CancelledError
from typing import Any
@ -13,6 +14,7 @@ import lldb
import pwndbg
from pwndbg.dbg.lldb import YieldContinue
from pwndbg.dbg.lldb import YieldSingleStep
from pwndbg.dbg.lldb.repl import UserCancelledError
from pwndbg.dbg.lldb.repl import print_info
from pwndbg.dbg.lldb.repl.io import IODriver
from pwndbg.dbg.lldb.repl.io import IODriverPlainText
@ -147,6 +149,37 @@ class LaunchResultError(LaunchResult):
self.disconnected = disconnected
def _updates_scope_counter(target: str) -> Callable[[Callable[..., Any]], Any]:
"""
ProcessDriver makes use of scope counters as part of the decision process
for how and when cancellations should be raised. This decorator
automatically updates a given scope counter.
"""
def sub0(fn: Callable[..., Any]) -> Callable[..., Any]:
def sub1(self: ProcessDriver, *args, **kwargs):
setattr(self, target, getattr(self, target) + 1)
try:
if self.debug:
print(
f"[*] ProcessDriver: self.{target} += 1 ({getattr(self, target)})",
file=sys.__stdout__,
)
return fn(self, *args, **kwargs)
finally:
setattr(self, target, getattr(self, target) - 1)
if self.debug:
print(
f"[*] ProcessDriver: self.{target} -= 1 ({getattr(self, target)})",
file=sys.__stdout__,
)
return sub1
return sub0
class ProcessDriver:
"""
Drives the execution of a process, responding to its events and handling its
@ -158,7 +191,18 @@ class ProcessDriver:
listener: lldb.SBListener
debug: bool
eh: EventHandler
cancellation_requested: bool
_hold_cancellation: int
"Nested scope counter for cancellation suspension"
_pending_cancellation: bool
"Whether we should fire a cancellation request once we resume cancellations"
_in_run_until_next_stop: int
"Nested scope counter for _run_until_next_stop"
_in_run_coroutine: int
"Nested scope counter for run_coroutine"
def __init__(self, event_handler: EventHandler, debug=False):
self.io = None
@ -166,7 +210,10 @@ class ProcessDriver:
self.listener = None
self.debug = debug
self.eh = event_handler
self.cancellation_requested = False
self._hold_cancellation = 0
self._pending_cancellation = False
self._in_run_until_next_stop = 0
self._in_run_coroutine = 0
def has_process(self) -> bool:
"""
@ -183,34 +230,81 @@ class ProcessDriver:
"""
return self.process is not None
def cancel(self) -> None:
def interrupt(self, in_lldb_command_handler: bool = False) -> None:
"""
Request that a currently ongoing operation be cancelled.
Interrupts the currently running process or command.
"""
self.cancellation_requested = True
if not self.has_process():
return
def _should_cancel(self) -> bool:
"""
Checks whether a cancellation has been requested, and clears cancellation state.
"""
should = self.cancellation_requested
self._clear_cancel()
if self._in_run_until_next_stop > 0:
if self.debug:
print("[*] ProcessDriver: Sending Interrupt", file=sys.__stdout__)
return should
# If we're in a coroutine, we should tell it to stop as soon as it gets out of _run_until_next_stop()
self._pending_cancellation = self._in_run_coroutine > 0
def _clear_cancel(self) -> None:
"""
Clears cancellation state.
"""
self.cancellation_requested = False
self.process.SendAsyncInterrupt()
elif self._hold_cancellation > 0 and not in_lldb_command_handler:
if self.debug:
print("[*] ProcessDriver: Pushing pending cancellation", file=sys.__stdout__)
def interrupt(self) -> None:
self._pending_cancellation = True
else:
if self.debug:
print(
"[*] ProcessDriver: Requesting cancellation immediately",
file=sys.__stdout__,
end="",
)
if self._hold_cancellation > 0:
print(" (forced by being in a command handler)", file=sys.__stdout__)
else:
print(file=sys.__stdout__)
# This happens even if interrupts are suspended, if we're inside a
# command handler. We shouldn't interrupt LLDB until it starts
# executing our command handler, but we still want to be able to
# interrupt the handler if it's a particularly long-running command
# like `search`.
self._pending_cancellation = False
raise UserCancelledError("user-requested cancellation")
@contextlib.contextmanager
def suspend_interrupts(self, interrupt: Callable[[], None] | None = None):
"""
Interrupts the currently running process.
Sometimes it's necessary to guard against interruption by
self.interrupt, especially when being interrupted would lead to bad
process state.
"""
assert self.has_process(), "called interrupt() on a driver with no process"
self.process.SendAsyncInterrupt()
self._hold_cancellation += 1
try:
if self.debug:
print(
"[*] ProcessDriver: Temporarily suspending cancellations", file=sys.__stdout__
)
yield None
finally:
if self.debug:
print("[*] ProcessDriver: Resuming cancellations", file=sys.__stdout__)
self._hold_cancellation -= 1
if self._hold_cancellation == 0 and self._pending_cancellation:
if self.debug:
print(
"[*] ProcessDriver: Executing pending cancellation", file=sys.__stdout__
)
self._pending_cancellation = False
if interrupt is None:
# The default action is to raise an exception in place.
raise UserCancelledError("user-requested cancellation")
# Perform the called-provided action.
interrupt()
@_updates_scope_counter(target="_in_run_until_next_stop")
def _run_until_next_stop(
self,
with_io: bool = True,
@ -359,8 +453,11 @@ class ProcessDriver:
"""
assert self.has_process(), "called run_lldb_command() on a driver with no process"
with self.suspend_interrupts():
ret = lldb.SBCommandReturnObject()
self.process.GetTarget().GetDebugger().GetCommandInterpreter().HandleCommand(command, ret)
self.process.GetTarget().GetDebugger().GetCommandInterpreter().HandleCommand(
command, ret
)
if ret.IsValid():
# LLDB can give us strings that may fail to encode.
@ -414,22 +511,33 @@ class ProcessDriver:
process in this driver. Returns `True` if the coroutine ran to completion,
and `False` if it was cancelled.
"""
try:
return self._run_coroutine(coroutine)
except CancelledError:
# We got cancelled somewhere else.
return False
@_updates_scope_counter("_in_run_coroutine")
def _run_coroutine(self, coroutine: Coroutine[Any, Any, None]) -> bool:
"""
This loop may be spuriously cancelled. We handle that in run_coroutine().
"""
assert self.has_process(), "called run_coroutine() on a driver with no process"
exception: Exception | None = None
self._clear_cancel()
while True:
if self._should_cancel():
# We were requested to cancel the execution controller.
exception = CancelledError()
exceptions: List[BaseException] = []
def queue_cancel():
exceptions.append(UserCancelledError())
while True:
try:
if exception is None:
if len(exceptions) == 0:
step = coroutine.send(None)
else:
step = coroutine.throw(exception)
step = coroutine.throw(exceptions[-1])
# The coroutine has caught the exception. Continue running
# it as if nothing happened.
exception = None
exceptions.pop()
except StopIteration:
# We got to the end of the coroutine. We're done.
break
@ -438,6 +546,9 @@ class ProcessDriver:
# override our decision. We're done.
break
# Being interrupted here would be bad for keeping the state of the
# process consistent.
with self.suspend_interrupts(interrupt=queue_cancel):
if isinstance(step, YieldSingleStep):
# Pick the currently selected thread and step it forward by one
# instruction.
@ -454,15 +565,15 @@ class ProcessDriver:
# The step failed. Raise an error in the coroutine and give
# it a chance to recover gracefully before we propagate it
# up to the caller.
exception = pwndbg.dbg_mod.Error(
f"Could not perform single step: {e.description}"
exceptions.append(
pwndbg.dbg_mod.Error(f"Could not perform single step: {e.description}")
)
continue
status = self._run_until_next_stop()
if isinstance(status, _PollResultExited):
# The process exited. Cancel the execution controller.
exception = CancelledError()
exceptions.append(CancelledError())
continue
elif isinstance(step, YieldContinue):
@ -488,10 +599,10 @@ class ProcessDriver:
match status:
case _PollResultExited():
# The process exited, Cancel the execution controller.
exception = CancelledError()
exceptions.append(CancelledError())
continue
case _PollResultStopped(event):
event = event
pass
case _:
raise AssertionError(f"unexpected poll result {status}")
@ -515,8 +626,9 @@ class ProcessDriver:
stop, lldb.SBBreakpoint
):
bpwp_id = thread.GetStopReasonDataAtIndex(0)
elif thread.GetStopReason() == lldb.eStopReasonWatchpoint and isinstance(
stop, lldb.SBWatchpoint
elif (
thread.GetStopReason() == lldb.eStopReasonWatchpoint
and isinstance(stop, lldb.SBWatchpoint)
):
bpwp_id = thread.GetStopReasonDataAtIndex(0)
@ -532,16 +644,18 @@ class ProcessDriver:
# Something else that we weren't expecting caused the
# process to stop. Request that the coroutine be
# cancelled.
exception = CancelledError()
exceptions.append(CancelledError())
else:
# The process might've crashed, been terminated, exited, or
# we might've lost connection to it for some other reason.
# Regardless, we should cancel the coroutine.
exception = CancelledError()
exceptions.append(CancelledError())
# Let the caller distinguish between a coroutine that's been run to
# completion and one that got cancelled.
return not isinstance(exception, CancelledError)
if len(exceptions) > 0:
return not isinstance(exceptions[-1], CancelledError)
return True
def _prepare_listener_for(self, target: lldb.SBTarget):
"""

@ -189,6 +189,7 @@ def main() -> None:
def drive(startup: List[str] | None):
async def drive(c):
from pwndbg.dbg.lldb.repl import PwndbgController
from pwndbg.dbg.lldb.repl import UserCancelledError
assert isinstance(c, PwndbgController)
@ -197,7 +198,10 @@ def main() -> None:
await c.execute(line)
while True:
try:
await c.interactive()
except UserCancelledError:
print("^C")
return drive

Loading…
Cancel
Save