You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pwndbg/pwndbg/commands/kmem_trace.py

258 lines
9.7 KiB
Python

from __future__ import annotations
import argparse
import threading
import pwndbg.aglib.regs
import pwndbg.aglib.symbol
import pwndbg.arguments
import pwndbg.color as C
import pwndbg.color.message as M
import pwndbg.commands
from pwndbg.dbg import BreakpointLocation
parser = argparse.ArgumentParser(
description="""
Tracing kernel memory (SLUB and buddy) allocations and frees.
Unless --all is specified, only the allocations triggered by the until the function returns will be printed.
This option may be helpful if you also want to trace frees scheduled with rcu or if the traced command steps out of the current function.
"""
)
parser.add_argument("-s", "--trace-slab", action="store_true", help="enable slab allocator tracing")
parser.add_argument(
"-b", "--trace-buddy", action="store_true", help="enable buddy allocator tracing"
)
parser.add_argument("-v", "--verbose", action="store_true", help="print backtraces")
parser.add_argument(
"-c", "--command", type=str, default="n", help="command to be traced (e.g. `n`, `nextret`)"
)
parser.add_argument(
"--all",
action="store_true",
help="display ALL memory allocations/frees regardless if they are triggered by the current function.",
)
class KmemTracepointsData:
def __init__(self, verbose, trace_all):
self.results = []
self.order = None
self.mutex = threading.RLock()
self.verbose = verbose
self.curr = None # None means tracing all
if not trace_all:
# current frame only accounting for jumps
pc = pwndbg.dbg.selected_frame().parent().pc()
self.curr = pwndbg.aglib.symbol.resolve_addr(pc).split("+")[0]
def add_result(self, result: str):
if not result:
return
with self.mutex:
bt = pwndbg.commands.context.context_backtrace(False)
if not self.curr or any(self.curr in line for line in bt):
self.results.append(result)
if self.verbose:
self.results += bt
def _format_kmem_tracepoint_output(self, prefix, name, type, addr):
prefix = prefix.ljust(12, " ")
if "FREE" in prefix:
prefix = C.red(prefix)
else:
prefix = C.green(prefix)
name = C.blue(name.ljust(16, " "))
type = type.ljust(4, " ")
return f"{prefix} {name} {type} @ {C.blue(hex(addr))}"
def format_slab_kmem_tracepoint_output(self, is_free: bool, objaddr: int):
if objaddr == 0:
return
if is_free:
prefix = "[SLAB FREE]"
else:
prefix = "[SLAB ALLOC]"
try:
cache = pwndbg.aglib.kernel.slab.find_containing_slab_cache(objaddr)
name = cache.name
except Exception:
self.results.append(M.warn(f"{prefix} invalid SLUB object @ {objaddr}"))
return
result = self._format_kmem_tracepoint_output(prefix, name, "obj", objaddr)
self.add_result(result)
def format_page_kmem_tracepoint_output(self, is_free: bool, page: int, order: int):
if is_free:
prefix = "[PAGE FREE]"
else:
prefix = "[PAGE ALLOC]"
name = f"order-{order}"
physmap = pwndbg.aglib.kernel.page_to_virt(page)
result = self._format_kmem_tracepoint_output(prefix, name, "page", page)
result += f" (physmap: {C.red(hex(physmap))})"
self.add_result(result)
class KmemTracepoints:
def __init__(self):
# try to capture the lowest possible level of exported functions in the (de)alloc chain
# for example __alloc_pages_bulk calls __alloc_pages and only __alloc_pages is included
# lists might not be complete
# try to resolve all names, if does not exist, means it is not exported for that version
kmalloc_names = ( # (trys to) include all slab alloc functions for all v5.x and v6.x
"__kmalloc",
"__kmalloc_node",
"__kmalloc_node_track_caller",
"__kmalloc_track_caller",
"__krealloc",
"kmalloc_order",
"kmalloc_order_trace",
"kmem_cache_alloc",
"kmem_cache_alloc_node",
"kmem_cache_alloc_node_trace",
"kmem_cache_alloc_trace",
"kmem_cache_alloc_lru",
"krealloc",
"kmalloc_node_trace",
"kmalloc_trace",
"__kmalloc_node_noprof",
"__kmalloc_noprof",
"kmalloc_node_trace_noprof",
"kmalloc_node_track_caller_noprof",
"kmalloc_trace_noprof",
"kmem_cache_alloc_lru_noprof",
"kmem_cache_alloc_node_noprof",
"kmem_cache_alloc_noprof",
"krealloc_noprof",
"__kmalloc_node_track_caller_noprof",
"__kmalloc_cache_node_noprof",
"__kmalloc_cache_noprof",
)
self.kallocs = self.resolve_names(kmalloc_names)
kfree_names = ("kfree",)
self.kfrees = self.resolve_names(kfree_names)
palloc_names = ( # all of those functions have the 2nd arg == order
"__alloc_frozen_pages_noprof",
"__alloc_pages",
"__alloc_pages_nodemask",
"alloc_pages_noprof",
)
self.pallocs = self.resolve_names(palloc_names)
pfree_names = ( # page *, order
"__free_pages",
)
self.pfrees = self.resolve_names(pfree_names)
self.sps = []
self.data = None
self.slab_tracepoints_enabled = True
self.buddy_tracepoints_enabled = True
def resolve_names(self, names):
result = []
for name in names:
addr = pwndbg.aglib.symbol.lookup_symbol_addr(name)
if addr is None:
continue
result.append(addr)
return result
@staticmethod
def _kalloc_handler() -> bool:
self = get_kmem_tracepoints()
objaddr = pwndbg.aglib.regs.read_reg_uncached(pwndbg.aglib.regs.retval)
self.data.format_slab_kmem_tracepoint_output(False, objaddr)
return False
@staticmethod
def kalloc_handler(sp: pwndbg.dbg_mod.StopPoint) -> bool:
pwndbg.dbg.selected_inferior().trace_ret(KmemTracepoints._kalloc_handler, True)
return False
@staticmethod
def kfree_handler(sp: pwndbg.dbg_mod.StopPoint) -> bool:
self = get_kmem_tracepoints()
objaddr = pwndbg.arguments.argument(0)
self.data.format_slab_kmem_tracepoint_output(True, objaddr)
return False
@staticmethod
def _palloc_handler() -> bool:
self = get_kmem_tracepoints()
page = pwndbg.aglib.regs.read_reg_uncached(pwndbg.aglib.regs.retval)
order = self.data.order
self.data.format_page_kmem_tracepoint_output(False, page, order)
return False
@staticmethod
def palloc_handler(sp: pwndbg.dbg_mod.StopPoint) -> bool:
self = get_kmem_tracepoints()
order = pwndbg.arguments.argument(1)
pwndbg.dbg.selected_inferior().trace_ret(KmemTracepoints._palloc_handler, True)
self.data.order = order
return False
@staticmethod
def pfree_handler(sp: pwndbg.dbg_mod.StopPoint) -> bool:
self = get_kmem_tracepoints()
page = pwndbg.arguments.argument(0)
order = pwndbg.arguments.argument(1)
self.data.format_page_kmem_tracepoint_output(self.results, True, page, order)
return False
def register_breakpoints(self, verbose, trace_all):
self.results = []
inf = pwndbg.dbg.selected_inferior()
self.data = KmemTracepointsData(verbose, trace_all)
if self.slab_tracepoints_enabled:
for kalloc in self.kallocs:
bp = BreakpointLocation(kalloc)
sp = inf.break_at(bp, KmemTracepoints.kalloc_handler, internal=True)
self.sps.append(sp)
for kfree in self.kfrees:
bp = BreakpointLocation(kfree)
sp = inf.break_at(bp, KmemTracepoints.kfree_handler, internal=True)
self.sps.append(sp)
if self.buddy_tracepoints_enabled:
for palloc in self.pallocs:
bp = BreakpointLocation(palloc)
sp = inf.break_at(bp, KmemTracepoints.palloc_handler, internal=True)
self.sps.append(sp)
for pfree in self.pfrees:
bp = BreakpointLocation(pfree)
sp = inf.break_at(bp, KmemTracepoints.pfree_handler, internal=True)
self.sps.append(sp)
def remove_breakpoints(self):
for sp in self.sps:
sp.remove()
self.sps = []
self.slab_tracepoints_enabled = True
self.buddy_tracepoints_enabled = True
@pwndbg.lib.cache.cache_until("objfile")
def get_kmem_tracepoints():
return KmemTracepoints()
@pwndbg.commands.Command(parser, category=pwndbg.commands.CommandCategory.KERNEL)
@pwndbg.commands.OnlyWhenQemuKernel
@pwndbg.commands.OnlyWithKernelDebugSymbols
@pwndbg.commands.OnlyWhenPagingEnabled
def kmem_trace(trace_slab: bool, trace_buddy: bool, verbose: bool, command: str, all: bool):
tps = get_kmem_tracepoints()
if not trace_slab and not trace_buddy:
trace_slab = trace_buddy = True
tps.slab_tracepoints_enabled = trace_slab
tps.buddy_tracepoints_enabled = trace_buddy
tps.register_breakpoints(verbose, all)
print(M.success("Finished registering tracepoints."))
old_val = pwndbg.config.context_backtrace_lines.value
pwndbg.config.context_backtrace_lines.value = 1000 # enable full backtrace
pwndbg.dbg.selected_inferior().runcmd(command)
pwndbg.config.context_backtrace_lines.value = old_val # restore
pwndbg.commands.context.context()
tps.remove_breakpoints()
print("\n".join(tps.data.results))
pwndbg.dbg.ctx_suspend_once()