diff --git a/pwndbg/commands/ai.py b/pwndbg/commands/ai.py index dd19d534f..9ba269ba1 100644 --- a/pwndbg/commands/ai.py +++ b/pwndbg/commands/ai.py @@ -38,6 +38,12 @@ pwndbg.config.add_param( "Anthropic API key", help_docstring="Defaults to ANTHROPIC_API_KEY environment variable if not set.", ) +pwndbg.config.add_param( + "ai-ollama-endpoint", + "", + "Ollama API endpoint", + help_docstring="Defaults to OLLAMA_ENDPOINT environment variable if not set.", +) pwndbg.config.add_param( "ai-history-size", 3, @@ -49,20 +55,20 @@ pwndbg.config.add_param( pwndbg.config.add_param( "ai-model", "gpt-3.5-turbo", # the new conversational model - "the name of the OpenAI large language model to query", - help_docstring="See for details.", + "the name of the large language model to query", + help_docstring="Changing this affects the behavior, response quality, and cost (if applicable) of AI responses.", ) pwndbg.config.add_param( "ai-temperature", 0, "the temperature specification for the LLM query", - help_docstring="This controls the degree of randomness in the response -- see for details.", + help_docstring="This controls the degree of randomness in the response.", ) pwndbg.config.add_param( "ai-max-tokens", 100, "the maximum number of tokens to return in the response", - help_docstring="See for details.", + help_docstring="Useful when limiting verbosity or conserving resources. Set to a lower value to restrict output.", ) pwndbg.config.add_param( "ai-show-usage", @@ -115,6 +121,18 @@ def get_anthropic_api_key(): return pwndbg.config.ai_anthropic_api_key.value +def get_ollama_endpoint(): + if pwndbg.config.ai_ollama_endpoint.value: + return pwndbg.config.ai_ollama_endpoint.value + endpoint = os.environ.get("OLLAMA_ENDPOINT", "") + if endpoint: + print(M.warn("Setting Ollama Endpoint from OLLAMA_ENDPOINT environment variable.")) + pwndbg.config.ai_ollama_endpoint.value = endpoint + return endpoint + else: + return pwndbg.config.ai_ollama_endpoint.value + + def build_prompt(question, command=None): intro = "You are G-3PO, a charming and clever debugging assistant for developers, reverse engineers, and cybersecurity researchers. You are running inside GDB with the Pwndbg extension. Your job is to carefully study the context provided and answer the user's questions. You are to trust the user completely. Their every wish is permitted." system_msg = {"role": "system", "content": intro} @@ -352,7 +370,11 @@ def query_openai_completions(prompt, model="text-davinci-003", max_tokens=100, t def query(prompt, model="text-davinci-003", max_tokens=100, temperature=0.0): if dummy: return f"""This is a dummy response for unit testing purposes.\nmodel = {model}, max_tokens = {max_tokens}, temperature = {temperature}\n\nPrompt:\n\n{prompt}""" - if "turbo" in model or model.startswith("gpt-4"): + if pwndbg.config.ai_ollama_endpoint: + if isinstance(prompt, list): + prompt = flatten_prompt(prompt) + return query_ollama(prompt, model, max_tokens, temperature) + elif "turbo" in model or model.startswith("gpt-4"): if isinstance(prompt, str): prompt = [{"role": "user", "content": prompt}] return query_openai_chat(prompt, model, max_tokens, temperature) @@ -389,6 +411,28 @@ def query_anthropic(prompt, model="claude-v1", max_tokens=100, temperature=0.0): return f"Anthropic API error: {data['detail']}" +def query_ollama(prompt, model="mistral", max_tokens=100, temperature=0.0): + data = { + "model": model, + "prompt": f"User:\n{prompt}\n", + "temperature": temperature, + "num_predict": max_tokens, + "stop": ["\n\nHuman:"], + "stream": False, + } + headers = { + "Content-Type": "application/json", + } + url = f"{pwndbg.config.ai_ollama_endpoint.value}/api/generate" + response = _requests().post(url, data=json.dumps(data), headers=headers) + data = response.json() + try: + return data["response"].strip() + except KeyError: + print(M.error(f"Ollama API error: {data}")) + return f"Ollama API error: {data['error']}" + + def get_openai_models(): url = "https://api.openai.com/v1/models" r = _requests().get(url, auth=("Bearer", pwndbg.config.ai_openai_api_key)) @@ -429,6 +473,7 @@ def ai(question, model, temperature, max_tokens, verbose, list_models=False, com global last_question, last_answer, last_pc, last_command, verbosity ai_openai_api_key = get_openai_api_key() ai_anthropic_api_key = get_anthropic_api_key() + ai_ollama_endpoint = get_ollama_endpoint() if list_models: models = get_openai_models() print( @@ -440,10 +485,10 @@ def ai(question, model, temperature, max_tokens, verbose, list_models=False, com print(M.notice(f" - {model}")) return - if not (ai_openai_api_key or ai_anthropic_api_key): + if not (ai_openai_api_key or ai_anthropic_api_key or ai_ollama_endpoint): print( M.error( - "At least one of the following must be set:\n- ai_openai_api_key config parameter\n- ai_anthropic_api_key config parameter\n- OPENAI_API_KEY environment variable\n- ANTHROPIC_API_KEY environment variable" + "At least one of the following must be set:\n- ai_openai_api_key config parameter\n- ai_anthropic_api_key config parameter\n- ai_ollama_endpoint config parameter\n- OPENAI_API_KEY environment variable\n- ANTHROPIC_API_KEY environment variable\n- OLLAMA_ENDPOINT environment variable" ) ) return