diff --git a/.github/helper/py_presubmit.py b/.github/helper/py_presubmit.py index 9665388424..ba99f23e98 100755 --- a/.github/helper/py_presubmit.py +++ b/.github/helper/py_presubmit.py @@ -22,78 +22,80 @@ def do_checks(changed_files): - """Runs all presubmit checks. Returns False if any fails.""" - checks = [ - check_license, - ] - return all([check(changed_files) for check in checks]) + """Runs all presubmit checks. Returns False if any fails.""" + checks = [ + check_license, + ] + return all([check(changed_files) for check in checks]) -_CHECK_LICENSE_FILENAMES = ['Dockerfile'] +_CHECK_LICENSE_FILENAMES = ["Dockerfile"] _CHECK_LICENSE_EXTENSIONS = [ - '.bash', - '.Dockerfile', - '.go', - '.h', - '.htm', - '.html', - '.proto', - '.py', - '.rs', - '.sh', - '.ts', + ".bash", + ".Dockerfile", + ".go", + ".h", + ".htm", + ".html", + ".proto", + ".py", + ".rs", + ".sh", + ".ts", ] -THIRD_PARTY_DIR_NAME = 'third_party' +THIRD_PARTY_DIR_NAME = "third_party" -_LICENSE_STRING = 'http://www.apache.org/licenses/LICENSE-2.0' +_LICENSE_STRING = "http://www.apache.org/licenses/LICENSE-2.0" def check_license(paths): - """Validates license header.""" - if not paths: - return True - - success = True - for path in paths: - path_parts = str(path).split(os.sep) - if any(path_part == THIRD_PARTY_DIR_NAME for path_part in path_parts): - continue - filename = os.path.basename(path) - extension = os.path.splitext(path)[1] - if (filename not in _CHECK_LICENSE_FILENAMES and - extension not in _CHECK_LICENSE_EXTENSIONS): - continue - - with open(path) as file_handle: - if _LICENSE_STRING not in file_handle.read(): - print('Missing license header in file %s.' % str(path)) - success = False - - return success + """Validates license header.""" + if not paths: + return True + + success = True + for path in paths: + path_parts = str(path).split(os.sep) + if any(path_part == THIRD_PARTY_DIR_NAME for path_part in path_parts): + continue + filename = os.path.basename(path) + extension = os.path.splitext(path)[1] + if ( + filename not in _CHECK_LICENSE_FILENAMES + and extension not in _CHECK_LICENSE_EXTENSIONS + ): + continue + + with open(path) as file_handle: + if _LICENSE_STRING not in file_handle.read(): + print("Missing license header in file %s." % str(path)) + success = False + + return success def bool_to_returncode(success): - """Returns 0 if |success|. Otherwise returns 1.""" - if success: - print('Success.') - return 0 + """Returns 0 if |success|. Otherwise returns 1.""" + if success: + print("Success.") + return 0 - print('Failed.') - return 1 + print("Failed.") + return 1 def get_all_files(): - """Returns a list of absolute paths of files in this repo.""" - get_all_files_command = ['git', 'ls-files'] - output = subprocess.check_output(get_all_files_command).decode().splitlines() - return [os.path.abspath(path) for path in output if os.path.isfile(path)] + """Returns a list of absolute paths of files in this repo.""" + get_all_files_command = ["git", "ls-files"] + output = subprocess.check_output(get_all_files_command).decode().splitlines() + return [os.path.abspath(path) for path in output if os.path.isfile(path)] def main(): - relevant_files = get_all_files() - success = do_checks(relevant_files) - return bool_to_returncode(success) + relevant_files = get_all_files() + success = do_checks(relevant_files) + return bool_to_returncode(success) -if __name__ == '__main__': - sys.exit(main()) +if __name__ == "__main__": + sys.exit(main()) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..910fdcde95 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,44 @@ +repos: + - repo: https://github.com/psf/black.git + rev: "24.10.0" + hooks: + - id: black + + - repo: https://github.com/nbQA-dev/nbQA + rev: 1.9.0 + hooks: + - id: nbqa-black + name: nbqa-black + description: Run 'black' on a Jupyter Notebook + entry: nbqa black + language: python + require_serial: true + types_or: [jupyter, markdown] + additional_dependencies: [black] + + - repo: https://github.com/pycqa/isort + rev: "5.13.2" + hooks: + - id: isort + entry: isort + args: + - --profile=black + - --float-to-top + + - repo: https://github.com/nbQA-dev/nbQA + rev: 1.9.0 + hooks: + - id: nbqa-flake8 + args: + - --ignore=E501,E712,W291,F632,E203,F821,F403,W391,F401 + - --exclude=.*,__init__.py + name: nbqa-flake8 + description: Run 'flake8' on a Jupyter Notebook + entry: nbqa flake8 + language: python + require_serial: true + types_or: [jupyter, markdown] + additional_dependencies: + - flake8-variables-names + - pep8-naming + - flake8-functions-names diff --git a/agent/analyzer.py b/agent/analyzer.py index 8f5e32ce36..6765872c47 100644 --- a/agent/analyzer.py +++ b/agent/analyzer.py @@ -16,4 +16,4 @@ class Analyzer(BaseAgent): - pass + pass diff --git a/agent/base_agent.py b/agent/base_agent.py index a5b7475e9d..159da314eb 100644 --- a/agent/base_agent.py +++ b/agent/base_agent.py @@ -34,267 +34,301 @@ class BaseAgent(ABC): - """The abstract base class for LLM agents in stages.""" - - def __init__(self, - trial: int, - llm: LLM, - args: argparse.Namespace, - tools: Optional[list[BaseTool]] = None, - name: str = ''): - self.trial: int = trial - self.llm: LLM = llm - self.tools: list[BaseTool] = tools or [] - self.args = args - self.name: str = name or self.__class__.__name__ - self.chat_history: str = '' # Communication history between LLM and tool. - self.max_round = self.args.max_round - - def __repr__(self) -> str: - return self.__class__.__name__ - - def get_tool(self, tool_name: str) -> Optional[BaseTool]: - """Gets a tool of the agent by name.""" - for tool in self.tools: - if tool.name == tool_name: - return tool - return None - - def chat_llm_with_tools(self, client: Any, prompt: Optional[Prompt], tools, - trial) -> Any: - """Chat with LLM with tools.""" - logger.info( - '%s', - trial, - prompt.gettext() if prompt else '', - trial, - trial=trial) - response = self.llm.chat_llm_with_tools(client=client, - prompt=prompt, - tools=tools) - logger.info( - '%s', - trial, - response, - trial, - trial=trial) - return response - - def chat_llm(self, cur_round: int, client: Any, prompt: Prompt, - trial: int) -> str: - """Chat with LLM.""" - logger.info('%s', - cur_round, - prompt.gettext(), - cur_round, - trial=trial) - response = self.llm.chat_llm(client=client, prompt=prompt) - logger.info('%s', - cur_round, - response, - cur_round, - trial=trial) - return response - - def ask_llm(self, cur_round: int, prompt: Prompt, trial: int) -> str: - """Ask LLM.""" - logger.info('%s', - cur_round, - prompt.gettext(), - cur_round, - trial=trial) - response = self.llm.ask_llm(prompt=prompt) - logger.info('%s', - cur_round, - response, - cur_round, - trial=trial) - return response - - def _parse_tag(self, response: str, tag: str) -> str: - """Parses the XML-style tags from LLM response.""" - match = re.search(rf'<{tag}>(.*?)', response, re.DOTALL) - return match.group(1).strip() if match else '' - - def _parse_tags(self, response: str, tag: str) -> list[str]: - """Parses the XML-style tags from LLM response.""" - matches = re.findall(rf'<{tag}>(.*?)', response, re.DOTALL) - return [content.strip() for content in matches] - - def _filter_code(self, raw_code_block: str) -> str: - """Filters out irrelevant lines from |raw_code_block|.""" - # TODO(dongge): Move this function to a separate module. - # Remove markdown-style code block symbols. - filtered_lines = [ - line for line in raw_code_block.splitlines() - if not line.strip().startswith('```') - ] - # Sometimes LLM returns a build script containing only comments. - if all(line.strip().startswith('#') for line in filtered_lines): - return '' - filtered_code_block = '\n'.join(filtered_lines) - return filtered_code_block - - def _format_bash_execution_result( - self, - process: sp.CompletedProcess, - previous_prompt: Optional[Prompt] = None) -> str: - """Formats a prompt based on bash execution result.""" - if previous_prompt: - previous_prompt_text = previous_prompt.gettext() - else: - previous_prompt_text = '' - stdout = self.llm.truncate_prompt(process.stdout, - previous_prompt_text).strip() - stderr = self.llm.truncate_prompt(process.stderr, - stdout + previous_prompt_text).strip() - return (f'\n{process.args}\n\n' - f'\n{process.returncode}\n\n' - f'\n{stdout}\n\n' - f'\n{stderr}\n\n') - - def _container_handle_bash_command(self, response: str, tool: BaseTool, - prompt: Prompt) -> Prompt: - """Handles the command from LLM with container |tool|.""" - prompt_text = '' - for command in self._parse_tags(response, 'bash'): - prompt_text += self._format_bash_execution_result( - tool.execute(command), previous_prompt=prompt) + '\n' - prompt.append(prompt_text) - return prompt - - def _container_handle_invalid_tool_usage(self, tool: BaseTool, cur_round: int, - response: str, - prompt: Prompt) -> Prompt: - """Formats a prompt to re-teach LLM how to use the |tool|.""" - logger.warning('ROUND %02d Invalid response from LLM: %s', - cur_round, - response, - trial=self.trial) - prompt_text = (f'No valid instruction received, Please follow the ' - f'interaction protocols:\n{tool.tutorial()}') - prompt.append(prompt_text) - return prompt - - def _container_handle_bash_commands(self, response: str, tool: BaseTool, - prompt: Prompt) -> Prompt: - """Handles the command from LLM with container |tool|.""" - prompt_text = '' - for command in self._parse_tags(response, 'bash'): - prompt_text += self._format_bash_execution_result( - tool.execute(command), previous_prompt=prompt) + '\n' - prompt.append(prompt_text) - return prompt - - def _sleep_random_duration( - self, - trial: int, - min_sec: int = 1, - max_sec: int = 60, - ) -> None: - """Sleeps for a random duration between min_sec and max_sec. Agents uses - this to avoid exceeding quota limit (e.g., LLM query frequency).""" - duration = random.randint(min_sec, max_sec) - logger.debug('Sleeping for %d before the next query', duration, trial=trial) - time.sleep(duration) - - @classmethod - def _parse_args(cls) -> argparse.Namespace: - """Parses command line args.""" - parser = argparse.ArgumentParser( - description='Execute agent in cloud with dill files.') - parser.add_argument('-a', - '--agent', - help='The dill file path for the agent to execute.') - parser.add_argument( - '-rh', - '--result-history', - help='The dill file path for the agent input result history.') - parser.add_argument( - '-rn', - '--result-new', - help='The dill file path to store the agent output new result.') - return parser.parse_args() - - @classmethod - def _preprocess_fi_setup(cls) -> None: - """Logic for starting a custom Fuzz Introspector used on cloud builds""" - logger.info('Checkign if we should use local FI', trial=0) - if not os.path.isdir('/workspace/data-dir'): - logger.info('This does not require a local FI.', trial=0) - return - logger.info('We should use local FI.', trial=0) - - # Clone Fuzz Introspector - introspector_repo = 'https://github.com/ossf/fuzz-introspector' - introspector_dst = '/workspace/fuzz-introspector' - sp.check_call(f'git clone {introspector_repo} {introspector_dst}', - shell=True) - fi_web_dir = '/workspace/fuzz-introspector/tools/web-fuzzing-introspection' - # Install reqs - sp.check_call( - 'python3.11 -m pip install --ignore-installed -r requirements.txt', - cwd=fi_web_dir, - shell=True) - - # Copy over the DB - shutil.rmtree(os.path.join(fi_web_dir, 'app/static/assets/db/')) - shutil.copytree('/workspace/data-dir/fuzz_introspector_db', - os.path.join(fi_web_dir, 'app/static/assets/db/')) - - # Launch webapp - fi_environ = os.environ - fi_environ['FUZZ_INTROSPECTOR_SHUTDOWN'] = '1' - fi_environ[ - 'FUZZ_INTROSPECTOR_LOCAL_OSS_FUZZ'] = '/workspace/data-dir/oss-fuzz2' - sp.check_call('python3.11 main.py >> /dev/null &', - shell=True, - env=fi_environ, - cwd=os.path.join(fi_web_dir, 'app')) - - logger.info('Waiting for the webapp to start', trial=0) - - sec_to_wait = 10 - max_wait_iterations = 10 - for idx in range(max_wait_iterations): - time.sleep(sec_to_wait) - - resp = requests.get('http://127.0.0.1:8080', timeout=10) - if 'Fuzzing' in resp.text: - break - if idx == max_wait_iterations - 1: - # Launching FI failed. We can still continue, although context - # will be missing from runs. - logger.info('Failed to start webapp', trial=10) - - introspector.set_introspector_endpoints('http://127.0.0.1:8080/api') - - @classmethod - def cloud_main(cls) -> None: - """Executes agent using dill files. This is for cloud experiments launched - by cloud_builder.py. It runs `new_result = agent.execute(result_history)` in - the same way as local experiments, except `agent` and `result_history` are - deserialized from dill files and new_result will be serialized to share data - with the cloud experiment requester.""" - args = cls._parse_args() - - cls._preprocess_fi_setup() - - agent = utils.deserialize_from_dill(args.agent) - agent.llm.cloud_setup() - result_history = utils.deserialize_from_dill(args.result_history) - result = agent.execute(result_history) - utils.serialize_to_dill(result, args.result_new) - - @abstractmethod - def _initial_prompt(self, results: list[Result]) -> Prompt: - """The initial prompt of the agent.""" - - @abstractmethod - def execute(self, result_history: list[Result]) -> Result: - """Executes the agent based on previous result.""" + """The abstract base class for LLM agents in stages.""" + + def __init__( + self, + trial: int, + llm: LLM, + args: argparse.Namespace, + tools: Optional[list[BaseTool]] = None, + name: str = "", + ): + self.trial: int = trial + self.llm: LLM = llm + self.tools: list[BaseTool] = tools or [] + self.args = args + self.name: str = name or self.__class__.__name__ + self.chat_history: str = "" # Communication history between LLM and tool. + self.max_round = self.args.max_round + + def __repr__(self) -> str: + return self.__class__.__name__ + + def get_tool(self, tool_name: str) -> Optional[BaseTool]: + """Gets a tool of the agent by name.""" + for tool in self.tools: + if tool.name == tool_name: + return tool + return None + + def chat_llm_with_tools( + self, client: Any, prompt: Optional[Prompt], tools, trial + ) -> Any: + """Chat with LLM with tools.""" + logger.info( + "%s", + trial, + prompt.gettext() if prompt else "", + trial, + trial=trial, + ) + response = self.llm.chat_llm_with_tools( + client=client, prompt=prompt, tools=tools + ) + logger.info( + "%s", + trial, + response, + trial, + trial=trial, + ) + return response + + def chat_llm(self, cur_round: int, client: Any, prompt: Prompt, trial: int) -> str: + """Chat with LLM.""" + logger.info( + "%s", + cur_round, + prompt.gettext(), + cur_round, + trial=trial, + ) + response = self.llm.chat_llm(client=client, prompt=prompt) + logger.info( + "%s", + cur_round, + response, + cur_round, + trial=trial, + ) + return response + + def ask_llm(self, cur_round: int, prompt: Prompt, trial: int) -> str: + """Ask LLM.""" + logger.info( + "%s", + cur_round, + prompt.gettext(), + cur_round, + trial=trial, + ) + response = self.llm.ask_llm(prompt=prompt) + logger.info( + "%s", + cur_round, + response, + cur_round, + trial=trial, + ) + return response + + def _parse_tag(self, response: str, tag: str) -> str: + """Parses the XML-style tags from LLM response.""" + match = re.search(rf"<{tag}>(.*?)", response, re.DOTALL) + return match.group(1).strip() if match else "" + + def _parse_tags(self, response: str, tag: str) -> list[str]: + """Parses the XML-style tags from LLM response.""" + matches = re.findall(rf"<{tag}>(.*?)", response, re.DOTALL) + return [content.strip() for content in matches] + + def _filter_code(self, raw_code_block: str) -> str: + """Filters out irrelevant lines from |raw_code_block|.""" + # TODO(dongge): Move this function to a separate module. + # Remove markdown-style code block symbols. + filtered_lines = [ + line + for line in raw_code_block.splitlines() + if not line.strip().startswith("```") + ] + # Sometimes LLM returns a build script containing only comments. + if all(line.strip().startswith("#") for line in filtered_lines): + return "" + filtered_code_block = "\n".join(filtered_lines) + return filtered_code_block + + def _format_bash_execution_result( + self, process: sp.CompletedProcess, previous_prompt: Optional[Prompt] = None + ) -> str: + """Formats a prompt based on bash execution result.""" + if previous_prompt: + previous_prompt_text = previous_prompt.gettext() + else: + previous_prompt_text = "" + stdout = self.llm.truncate_prompt(process.stdout, previous_prompt_text).strip() + stderr = self.llm.truncate_prompt( + process.stderr, stdout + previous_prompt_text + ).strip() + return ( + f"\n{process.args}\n\n" + f"\n{process.returncode}\n\n" + f"\n{stdout}\n\n" + f"\n{stderr}\n\n" + ) + + def _container_handle_bash_command( + self, response: str, tool: BaseTool, prompt: Prompt + ) -> Prompt: + """Handles the command from LLM with container |tool|.""" + prompt_text = "" + for command in self._parse_tags(response, "bash"): + prompt_text += ( + self._format_bash_execution_result( + tool.execute(command), previous_prompt=prompt + ) + + "\n" + ) + prompt.append(prompt_text) + return prompt + + def _container_handle_invalid_tool_usage( + self, tool: BaseTool, cur_round: int, response: str, prompt: Prompt + ) -> Prompt: + """Formats a prompt to re-teach LLM how to use the |tool|.""" + logger.warning( + "ROUND %02d Invalid response from LLM: %s", + cur_round, + response, + trial=self.trial, + ) + prompt_text = ( + f"No valid instruction received, Please follow the " + f"interaction protocols:\n{tool.tutorial()}" + ) + prompt.append(prompt_text) + return prompt + + def _container_handle_bash_commands( + self, response: str, tool: BaseTool, prompt: Prompt + ) -> Prompt: + """Handles the command from LLM with container |tool|.""" + prompt_text = "" + for command in self._parse_tags(response, "bash"): + prompt_text += ( + self._format_bash_execution_result( + tool.execute(command), previous_prompt=prompt + ) + + "\n" + ) + prompt.append(prompt_text) + return prompt + + def _sleep_random_duration( + self, + trial: int, + min_sec: int = 1, + max_sec: int = 60, + ) -> None: + """Sleeps for a random duration between min_sec and max_sec. Agents uses + this to avoid exceeding quota limit (e.g., LLM query frequency).""" + duration = random.randint(min_sec, max_sec) + logger.debug("Sleeping for %d before the next query", duration, trial=trial) + time.sleep(duration) + + @classmethod + def _parse_args(cls) -> argparse.Namespace: + """Parses command line args.""" + parser = argparse.ArgumentParser( + description="Execute agent in cloud with dill files." + ) + parser.add_argument( + "-a", "--agent", help="The dill file path for the agent to execute." + ) + parser.add_argument( + "-rh", + "--result-history", + help="The dill file path for the agent input result history.", + ) + parser.add_argument( + "-rn", + "--result-new", + help="The dill file path to store the agent output new result.", + ) + return parser.parse_args() + + @classmethod + def _preprocess_fi_setup(cls) -> None: + """Logic for starting a custom Fuzz Introspector used on cloud builds""" + logger.info("Checkign if we should use local FI", trial=0) + if not os.path.isdir("/workspace/data-dir"): + logger.info("This does not require a local FI.", trial=0) + return + logger.info("We should use local FI.", trial=0) + + # Clone Fuzz Introspector + introspector_repo = "https://github.com/ossf/fuzz-introspector" + introspector_dst = "/workspace/fuzz-introspector" + sp.check_call(f"git clone {introspector_repo} {introspector_dst}", shell=True) + fi_web_dir = "/workspace/fuzz-introspector/tools/web-fuzzing-introspection" + # Install reqs + sp.check_call( + "python3.11 -m pip install --ignore-installed -r requirements.txt", + cwd=fi_web_dir, + shell=True, + ) + + # Copy over the DB + shutil.rmtree(os.path.join(fi_web_dir, "app/static/assets/db/")) + shutil.copytree( + "/workspace/data-dir/fuzz_introspector_db", + os.path.join(fi_web_dir, "app/static/assets/db/"), + ) + + # Launch webapp + fi_environ = os.environ + fi_environ["FUZZ_INTROSPECTOR_SHUTDOWN"] = "1" + fi_environ["FUZZ_INTROSPECTOR_LOCAL_OSS_FUZZ"] = "/workspace/data-dir/oss-fuzz2" + sp.check_call( + "python3.11 main.py >> /dev/null &", + shell=True, + env=fi_environ, + cwd=os.path.join(fi_web_dir, "app"), + ) + + logger.info("Waiting for the webapp to start", trial=0) + + sec_to_wait = 10 + max_wait_iterations = 10 + for idx in range(max_wait_iterations): + time.sleep(sec_to_wait) + + resp = requests.get("http://127.0.0.1:8080", timeout=10) + if "Fuzzing" in resp.text: + break + if idx == max_wait_iterations - 1: + # Launching FI failed. We can still continue, although context + # will be missing from runs. + logger.info("Failed to start webapp", trial=10) + + introspector.set_introspector_endpoints("http://127.0.0.1:8080/api") + + @classmethod + def cloud_main(cls) -> None: + """Executes agent using dill files. This is for cloud experiments launched + by cloud_builder.py. It runs `new_result = agent.execute(result_history)` in + the same way as local experiments, except `agent` and `result_history` are + deserialized from dill files and new_result will be serialized to share data + with the cloud experiment requester.""" + args = cls._parse_args() + + cls._preprocess_fi_setup() + + agent = utils.deserialize_from_dill(args.agent) + agent.llm.cloud_setup() + result_history = utils.deserialize_from_dill(args.result_history) + result = agent.execute(result_history) + utils.serialize_to_dill(result, args.result_new) + + @abstractmethod + def _initial_prompt(self, results: list[Result]) -> Prompt: + """The initial prompt of the agent.""" + + @abstractmethod + def execute(self, result_history: list[Result]) -> Result: + """Executes the agent based on previous result.""" if __name__ == "__main__": - # For cloud experiments. - BaseAgent.cloud_main() + # For cloud experiments. + BaseAgent.cloud_main() diff --git a/agent/coverage_analyzer.py b/agent/coverage_analyzer.py index 13ee991244..fd9da89173 100644 --- a/agent/coverage_analyzer.py +++ b/agent/coverage_analyzer.py @@ -26,120 +26,141 @@ from results import AnalysisResult, CoverageResult, Result, RunResult from tool.container_tool import ProjectContainerTool -INVALID_PRMOT_PATH = os.path.join('prompts', 'agent', - 'coverage-analyzer-invalid-response.txt') +INVALID_PRMOT_PATH = os.path.join( + "prompts", "agent", "coverage-analyzer-invalid-response.txt" +) class CoverageAnalyzer(BaseAgent): - """The Agent to refine a compilable fuzz target for higher coverage.""" - - def _initial_prompt(self, results: list[Result]) -> Prompt: - """Constructs initial prompt of the agent.""" - last_result = results[-1] - benchmark = last_result.benchmark - - if not isinstance(last_result, RunResult): - logger.error('The last result in %s is not RunResult: %s', - self.name, - results, - trial=self.trial) - return Prompt() - - builder = CoverageAnalyzerTemplateBuilder(self.llm, benchmark, last_result) - prompt = builder.build(example_pair=[], - tool_guides=self.inspect_tool.tutorial(), - project_dir=self.inspect_tool.project_dir) - # TODO: A different file name/dir. - prompt.save(self.args.work_dirs.prompt) - - return prompt - - def _container_handle_conclusion(self, cur_round: int, response: str, - coverage_result: CoverageResult, - prompt: Prompt) -> Optional[Prompt]: - """Runs a compilation tool to validate the new fuzz target and build script - from LLM.""" - conclusion = self._parse_tag(response, 'conclusion') - if not conclusion: - return prompt - logger.info('----- ROUND %02d Received conclusion -----', - cur_round, - trial=self.trial) - - coverage_result.improve_required = conclusion.strip().lower() == 'true' - coverage_result.insight = self._parse_tag(response, 'insights') - coverage_result.suggestions = self._parse_tag(response, 'suggestions') - - return None - - def _container_tool_reaction( - self, cur_round: int, response: str, run_result: RunResult, - coverage_result: CoverageResult) -> Optional[Prompt]: - """Validates LLM conclusion or executes its command.""" - del run_result - prompt = prompt_builder.DefaultTemplateBuilder(self.llm, None).build([]) - - prompt = self._container_handle_bash_commands(response, self.inspect_tool, - prompt) - # Only report conclusion when no more bash investigation is required. - if not prompt.gettext(): - # Then build fuzz target. - prompt = self._container_handle_conclusion(cur_round, response, - coverage_result, prompt) - if prompt is None: - # Succeeded. + """The Agent to refine a compilable fuzz target for higher coverage.""" + + def _initial_prompt(self, results: list[Result]) -> Prompt: + """Constructs initial prompt of the agent.""" + last_result = results[-1] + benchmark = last_result.benchmark + + if not isinstance(last_result, RunResult): + logger.error( + "The last result in %s is not RunResult: %s", + self.name, + results, + trial=self.trial, + ) + return Prompt() + + builder = CoverageAnalyzerTemplateBuilder(self.llm, benchmark, last_result) + prompt = builder.build( + example_pair=[], + tool_guides=self.inspect_tool.tutorial(), + project_dir=self.inspect_tool.project_dir, + ) + # TODO: A different file name/dir. + prompt.save(self.args.work_dirs.prompt) + + return prompt + + def _container_handle_conclusion( + self, + cur_round: int, + response: str, + coverage_result: CoverageResult, + prompt: Prompt, + ) -> Optional[Prompt]: + """Runs a compilation tool to validate the new fuzz target and build script + from LLM.""" + conclusion = self._parse_tag(response, "conclusion") + if not conclusion: + return prompt + logger.info( + "----- ROUND %02d Received conclusion -----", cur_round, trial=self.trial + ) + + coverage_result.improve_required = conclusion.strip().lower() == "true" + coverage_result.insight = self._parse_tag(response, "insights") + coverage_result.suggestions = self._parse_tag(response, "suggestions") + return None - # Finally check invalid responses. - if not response or not prompt.get(): - prompt = self._container_handle_invalid_tool_usage( - self.inspect_tool, cur_round, response, prompt) - with open(INVALID_PRMOT_PATH, 'r') as prompt_file: - prompt.append(prompt_file.read()) - - return prompt - - def execute(self, result_history: list[Result]) -> AnalysisResult: - """Executes the agent to analyze the root cause to the low coverage.""" - WorkDirs(self.args.work_dirs.base, keep=True) - last_result = result_history[-1] - assert isinstance(last_result, RunResult) - - logger.info('Executing %s', self.name, trial=last_result.trial) - benchmark = last_result.benchmark - # TODO(dongge): Use the generated fuzz target and build script here. - self.inspect_tool = ProjectContainerTool(benchmark, name='inspect') - self.inspect_tool.write_to_file(content=last_result.fuzz_target_source, - file_path=benchmark.target_path) - if last_result.build_script_source: - self.inspect_tool.write_to_file( - content=last_result.build_script_source, - file_path=self.inspect_tool.build_script_path) - self.inspect_tool.compile(extra_commands=' && rm -rf /out/* > /dev/null') - cur_round = 1 - coverage_result = CoverageResult() - prompt = self._initial_prompt(result_history) - - try: - client = self.llm.get_chat_client(model=self.llm.get_model()) - while prompt and cur_round < self.max_round: - response = self.chat_llm(cur_round, - client=client, - prompt=prompt, - trial=last_result.trial) - prompt = self._container_tool_reaction(cur_round, response, last_result, - coverage_result) - cur_round += 1 - finally: - # Cleanup: stop and remove the container - logger.debug('Stopping and removing the inspect container %s', - self.inspect_tool.container_id, - trial=last_result.trial) - self.inspect_tool.terminate() - - analysis_result = AnalysisResult( - author=self, - run_result=last_result, - coverage_result=coverage_result, - chat_history={self.name: coverage_result.to_dict()}) - return analysis_result + def _container_tool_reaction( + self, + cur_round: int, + response: str, + run_result: RunResult, + coverage_result: CoverageResult, + ) -> Optional[Prompt]: + """Validates LLM conclusion or executes its command.""" + del run_result + prompt = prompt_builder.DefaultTemplateBuilder(self.llm, None).build([]) + + prompt = self._container_handle_bash_commands( + response, self.inspect_tool, prompt + ) + # Only report conclusion when no more bash investigation is required. + if not prompt.gettext(): + # Then build fuzz target. + prompt = self._container_handle_conclusion( + cur_round, response, coverage_result, prompt + ) + if prompt is None: + # Succeeded. + return None + + # Finally check invalid responses. + if not response or not prompt.get(): + prompt = self._container_handle_invalid_tool_usage( + self.inspect_tool, cur_round, response, prompt + ) + with open(INVALID_PRMOT_PATH, "r") as prompt_file: + prompt.append(prompt_file.read()) + + return prompt + + def execute(self, result_history: list[Result]) -> AnalysisResult: + """Executes the agent to analyze the root cause to the low coverage.""" + WorkDirs(self.args.work_dirs.base, keep=True) + last_result = result_history[-1] + assert isinstance(last_result, RunResult) + + logger.info("Executing %s", self.name, trial=last_result.trial) + benchmark = last_result.benchmark + # TODO(dongge): Use the generated fuzz target and build script here. + self.inspect_tool = ProjectContainerTool(benchmark, name="inspect") + self.inspect_tool.write_to_file( + content=last_result.fuzz_target_source, file_path=benchmark.target_path + ) + if last_result.build_script_source: + self.inspect_tool.write_to_file( + content=last_result.build_script_source, + file_path=self.inspect_tool.build_script_path, + ) + self.inspect_tool.compile(extra_commands=" && rm -rf /out/* > /dev/null") + cur_round = 1 + coverage_result = CoverageResult() + prompt = self._initial_prompt(result_history) + + try: + client = self.llm.get_chat_client(model=self.llm.get_model()) + while prompt and cur_round < self.max_round: + response = self.chat_llm( + cur_round, client=client, prompt=prompt, trial=last_result.trial + ) + prompt = self._container_tool_reaction( + cur_round, response, last_result, coverage_result + ) + cur_round += 1 + finally: + # Cleanup: stop and remove the container + logger.debug( + "Stopping and removing the inspect container %s", + self.inspect_tool.container_id, + trial=last_result.trial, + ) + self.inspect_tool.terminate() + + analysis_result = AnalysisResult( + author=self, + run_result=last_result, + coverage_result=coverage_result, + chat_history={self.name: coverage_result.to_dict()}, + ) + return analysis_result diff --git a/agent/crash_analyzer.py b/agent/crash_analyzer.py index 8437bbac7a..37888ec230 100644 --- a/agent/crash_analyzer.py +++ b/agent/crash_analyzer.py @@ -36,185 +36,221 @@ class CrashAnalyzer(BaseAgent): - """The Agent to analyze a runtime crash and provide insight to fuzz target.""" - - def __init__(self, - trial: int, - llm: LLM, - args: argparse.Namespace, - tools: Optional[list[BaseTool]] = None, - name: str = '', - artifact_path: str = '') -> None: - super().__init__(trial, llm, args, tools, name) - self.artifact_path = artifact_path - - def _initial_prompt(self, results: list[Result]) -> Prompt: - """Constructs initial prompt of the agent.""" - last_result = results[-1] - - if isinstance(last_result, RunResult): - crash_analyzer_prompt_builder = \ - prompt_builder.CrashAnalyzerTemplateBuilder( - model=self.llm, - benchmark=last_result.benchmark) - prompt = crash_analyzer_prompt_builder.build_crash_analyzer_prompt( - last_result.benchmark, last_result.fuzz_target_source, - last_result.run_error, last_result.crash_func) - return prompt - - logger.error("Expected a RunResult object in results list", - trial=self.trial) - return prompt_builder.CrashAnalyzerTemplateBuilder(self.llm).build([]) - - def _format_lldb_execution_result( - self, - lldb_command: str, - process: sp.CompletedProcess, - previous_prompt: Optional[Prompt] = None) -> str: - """Formats a prompt based on lldb execution result.""" - if previous_prompt: - previous_prompt_text = previous_prompt.get() - else: - previous_prompt_text = '' - stdout = self.llm.truncate_prompt(process.stdout, - previous_prompt_text).strip() - stderr = self.llm.truncate_prompt(process.stderr, - stdout + previous_prompt_text).strip() - return (f'\n{lldb_command.strip()}\n\n' - f'\n{stdout}\n\n' - f'\n{stderr}\n\n') - - def _container_handle_lldb_command(self, response: str, tool: LLDBTool, - prompt: Prompt) -> Prompt: - """Handles the command from LLM with lldb tool.""" - prompt_text = '' - for command in self._parse_tags(response, 'lldb'): - process = tool.execute_in_screen(command) - prompt_text += self._format_lldb_execution_result( - command, process, previous_prompt=prompt) + '\n' - prompt.append(prompt_text) - return prompt - - def _container_handle_conclusion(self, cur_round: int, response: str, - crash_result: CrashResult) -> None: - """Parses LLM conclusion, analysis and suggestion.""" - logger.info('----- ROUND %02d Received conclusion -----', + """The Agent to analyze a runtime crash and provide insight to fuzz target.""" + + def __init__( + self, + trial: int, + llm: LLM, + args: argparse.Namespace, + tools: Optional[list[BaseTool]] = None, + name: str = "", + artifact_path: str = "", + ) -> None: + super().__init__(trial, llm, args, tools, name) + self.artifact_path = artifact_path + + def _initial_prompt(self, results: list[Result]) -> Prompt: + """Constructs initial prompt of the agent.""" + last_result = results[-1] + + if isinstance(last_result, RunResult): + crash_analyzer_prompt_builder = prompt_builder.CrashAnalyzerTemplateBuilder( + model=self.llm, benchmark=last_result.benchmark + ) + prompt = crash_analyzer_prompt_builder.build_crash_analyzer_prompt( + last_result.benchmark, + last_result.fuzz_target_source, + last_result.run_error, + last_result.crash_func, + ) + return prompt + + logger.error("Expected a RunResult object in results list", trial=self.trial) + return prompt_builder.CrashAnalyzerTemplateBuilder(self.llm).build([]) + + def _format_lldb_execution_result( + self, + lldb_command: str, + process: sp.CompletedProcess, + previous_prompt: Optional[Prompt] = None, + ) -> str: + """Formats a prompt based on lldb execution result.""" + if previous_prompt: + previous_prompt_text = previous_prompt.get() + else: + previous_prompt_text = "" + stdout = self.llm.truncate_prompt(process.stdout, previous_prompt_text).strip() + stderr = self.llm.truncate_prompt( + process.stderr, stdout + previous_prompt_text + ).strip() + return ( + f"\n{lldb_command.strip()}\n\n" + f"\n{stdout}\n\n" + f"\n{stderr}\n\n" + ) + + def _container_handle_lldb_command( + self, response: str, tool: LLDBTool, prompt: Prompt + ) -> Prompt: + """Handles the command from LLM with lldb tool.""" + prompt_text = "" + for command in self._parse_tags(response, "lldb"): + process = tool.execute_in_screen(command) + prompt_text += ( + self._format_lldb_execution_result( + command, process, previous_prompt=prompt + ) + + "\n" + ) + prompt.append(prompt_text) + return prompt + + def _container_handle_conclusion( + self, cur_round: int, response: str, crash_result: CrashResult + ) -> None: + """Parses LLM conclusion, analysis and suggestion.""" + logger.info( + "----- ROUND %02d Received conclusion -----", cur_round, trial=self.trial + ) + + conclusion = self._parse_tag(response, "conclusion") + if conclusion == "Crash is caused by bug in fuzz driver.": + crash_result.true_bug = False + elif conclusion == "Crash is caused by bug in project.": + crash_result.true_bug = True + else: + logger.error( + "***** Failed to match conclusion in %02d rounds *****", + cur_round, + trial=self.trial, + ) + + crash_result.insight = self._parse_tag(response, "analysis and suggestion") + if not crash_result.insight: + logger.error( + "Round %02d No analysis and suggestion in conclusion: %s", cur_round, - trial=self.trial) - - conclusion = self._parse_tag(response, 'conclusion') - if conclusion == 'Crash is caused by bug in fuzz driver.': - crash_result.true_bug = False - elif conclusion == 'Crash is caused by bug in project.': - crash_result.true_bug = True - else: - logger.error('***** Failed to match conclusion in %02d rounds *****', - cur_round, - trial=self.trial) - - crash_result.insight = self._parse_tag(response, 'analysis and suggestion') - if not crash_result.insight: - logger.error('Round %02d No analysis and suggestion in conclusion: %s', - cur_round, - response, - trial=self.trial) - - def _container_tool_reaction(self, cur_round: int, response: str, - crash_result: CrashResult) -> Optional[Prompt]: - """Validates LLM conclusion or executes its command.""" - if self._parse_tag(response, 'conclusion'): - return self._container_handle_conclusion(cur_round, response, - crash_result) - prompt = prompt_builder.CrashAnalyzerTemplateBuilder(self.llm, - None).build([]) - if self._parse_tag(response, 'lldb'): - return self._container_handle_lldb_command(response, self.analyze_tool, - prompt) - if self._parse_tag(response, 'bash'): - return self._container_handle_bash_command(response, self.check_tool, - prompt) - return None - - def execute(self, result_history: list[Result]) -> AnalysisResult: - """Executes the agent based on previous run result.""" - WorkDirs(self.args.work_dirs.base, keep=True) - last_result = result_history[-1] - benchmark = last_result.benchmark - logger.info('Executing Crash Analyzer', trial=self.trial) - assert isinstance(last_result, RunResult) - - if not os.path.exists(last_result.artifact_path): - logger.error('Artifact path %s does not exist', - last_result.artifact_path, - trial=self.trial) - - # TODO(dongge): Move these to oss_fuzz_checkout. - generated_target_name = os.path.basename(benchmark.target_path) - sample_id = os.path.splitext(generated_target_name)[0] - generated_oss_fuzz_project = ( - f'{benchmark.id}-{sample_id}-lldb-{self.trial:02d}') - generated_oss_fuzz_project = oss_fuzz_checkout.rectify_docker_tag( - generated_oss_fuzz_project) - - # TODO(dongge): Write to OSS-Fuzz project dir files directly. - fuzz_target_path = os.path.join(last_result.work_dirs.fuzz_targets, - f'{self.trial:02d}.fuzz_target') - with open(fuzz_target_path, 'w') as ft_file: - ft_file.write(last_result.fuzz_target_source) - if last_result.build_script_source: - build_script_path = os.path.join(last_result.work_dirs.fuzz_targets, - f'{self.trial:02d}.build_script') - with open(build_script_path, 'w') as ft_file: - ft_file.write(last_result.build_script_source) - else: - build_script_path = '' - - evaluator_lib.Evaluator.create_ossfuzz_project_with_lldb( - benchmark, generated_oss_fuzz_project, fuzz_target_path, last_result, - build_script_path, last_result.artifact_path) - - self.analyze_tool = LLDBTool(benchmark, - result=last_result, - name='lldb', - project_name=generated_oss_fuzz_project) - self.analyze_tool.execute('compile > /dev/null') - # Launch LLDB and load fuzz target binary - self.analyze_tool.execute(f'screen -dmS lldb_session -L ' - f'-Logfile /tmp/lldb_log.txt ' - f'lldb /out/{last_result.benchmark.target_name}') - self.check_tool = ProjectContainerTool( - benchmark, name='check', project_name=generated_oss_fuzz_project) - self.check_tool.compile(extra_commands=' && rm -rf /out/* > /dev/null') - prompt = self._initial_prompt(result_history) - prompt.add_problem(self.analyze_tool.tutorial()) - prompt.add_problem(self.check_tool.tutorial()) - crash_result = CrashResult(benchmark=benchmark, - trial=last_result.trial, - work_dirs=last_result.work_dirs, - author=self, - chat_history={self.name: ''}) - cur_round = 1 - try: - client = self.llm.get_chat_client(model=self.llm.get_model()) - while prompt and cur_round < MAX_ROUND: - response = self.chat_llm(cur_round=cur_round, - client=client, - prompt=prompt, - trial=self.trial) - prompt = self._container_tool_reaction(cur_round, response, - crash_result) - cur_round += 1 - self._sleep_random_duration(trial=self.trial) - finally: - # Cleanup: stop the container - logger.debug('Stopping the crash analyze container %s', - self.analyze_tool.container_id, - trial=self.trial) - self.analyze_tool.terminate() - - analysis_result = AnalysisResult( - author=self, - run_result=last_result, - crash_result=crash_result, - chat_history={self.name: crash_result.to_dict()}) - return analysis_result + response, + trial=self.trial, + ) + + def _container_tool_reaction( + self, cur_round: int, response: str, crash_result: CrashResult + ) -> Optional[Prompt]: + """Validates LLM conclusion or executes its command.""" + if self._parse_tag(response, "conclusion"): + return self._container_handle_conclusion(cur_round, response, crash_result) + prompt = prompt_builder.CrashAnalyzerTemplateBuilder(self.llm, None).build([]) + if self._parse_tag(response, "lldb"): + return self._container_handle_lldb_command( + response, self.analyze_tool, prompt + ) + if self._parse_tag(response, "bash"): + return self._container_handle_bash_command( + response, self.check_tool, prompt + ) + return None + + def execute(self, result_history: list[Result]) -> AnalysisResult: + """Executes the agent based on previous run result.""" + WorkDirs(self.args.work_dirs.base, keep=True) + last_result = result_history[-1] + benchmark = last_result.benchmark + logger.info("Executing Crash Analyzer", trial=self.trial) + assert isinstance(last_result, RunResult) + + if not os.path.exists(last_result.artifact_path): + logger.error( + "Artifact path %s does not exist", + last_result.artifact_path, + trial=self.trial, + ) + + # TODO(dongge): Move these to oss_fuzz_checkout. + generated_target_name = os.path.basename(benchmark.target_path) + sample_id = os.path.splitext(generated_target_name)[0] + generated_oss_fuzz_project = f"{benchmark.id}-{sample_id}-lldb-{self.trial:02d}" + generated_oss_fuzz_project = oss_fuzz_checkout.rectify_docker_tag( + generated_oss_fuzz_project + ) + + # TODO(dongge): Write to OSS-Fuzz project dir files directly. + fuzz_target_path = os.path.join( + last_result.work_dirs.fuzz_targets, f"{self.trial:02d}.fuzz_target" + ) + with open(fuzz_target_path, "w") as ft_file: + ft_file.write(last_result.fuzz_target_source) + if last_result.build_script_source: + build_script_path = os.path.join( + last_result.work_dirs.fuzz_targets, f"{self.trial:02d}.build_script" + ) + with open(build_script_path, "w") as ft_file: + ft_file.write(last_result.build_script_source) + else: + build_script_path = "" + + evaluator_lib.Evaluator.create_ossfuzz_project_with_lldb( + benchmark, + generated_oss_fuzz_project, + fuzz_target_path, + last_result, + build_script_path, + last_result.artifact_path, + ) + + self.analyze_tool = LLDBTool( + benchmark, + result=last_result, + name="lldb", + project_name=generated_oss_fuzz_project, + ) + self.analyze_tool.execute("compile > /dev/null") + # Launch LLDB and load fuzz target binary + self.analyze_tool.execute( + f"screen -dmS lldb_session -L " + f"-Logfile /tmp/lldb_log.txt " + f"lldb /out/{last_result.benchmark.target_name}" + ) + self.check_tool = ProjectContainerTool( + benchmark, name="check", project_name=generated_oss_fuzz_project + ) + self.check_tool.compile(extra_commands=" && rm -rf /out/* > /dev/null") + prompt = self._initial_prompt(result_history) + prompt.add_problem(self.analyze_tool.tutorial()) + prompt.add_problem(self.check_tool.tutorial()) + crash_result = CrashResult( + benchmark=benchmark, + trial=last_result.trial, + work_dirs=last_result.work_dirs, + author=self, + chat_history={self.name: ""}, + ) + cur_round = 1 + try: + client = self.llm.get_chat_client(model=self.llm.get_model()) + while prompt and cur_round < MAX_ROUND: + response = self.chat_llm( + cur_round=cur_round, client=client, prompt=prompt, trial=self.trial + ) + prompt = self._container_tool_reaction( + cur_round, response, crash_result + ) + cur_round += 1 + self._sleep_random_duration(trial=self.trial) + finally: + # Cleanup: stop the container + logger.debug( + "Stopping the crash analyze container %s", + self.analyze_tool.container_id, + trial=self.trial, + ) + self.analyze_tool.terminate() + + analysis_result = AnalysisResult( + author=self, + run_result=last_result, + crash_result=crash_result, + chat_history={self.name: crash_result.to_dict()}, + ) + return analysis_result diff --git a/agent/enhancer.py b/agent/enhancer.py index 8abb12d721..97386926fb 100644 --- a/agent/enhancer.py +++ b/agent/enhancer.py @@ -16,69 +16,82 @@ """ import logger from agent.prototyper import Prototyper -from llm_toolkit.prompt_builder import (CoverageEnhancerTemplateBuilder, - EnhancerTemplateBuilder, - JvmFixingBuilder) +from llm_toolkit.prompt_builder import ( + CoverageEnhancerTemplateBuilder, + EnhancerTemplateBuilder, + JvmFixingBuilder, +) from llm_toolkit.prompts import Prompt, TextPrompt from results import AnalysisResult, BuildResult, Result class Enhancer(Prototyper): - """The Agent to refine a compilable fuzz target for higher coverage.""" + """The Agent to refine a compilable fuzz target for higher coverage.""" - def _initial_prompt(self, results: list[Result]) -> Prompt: - """Constructs initial prompt of the agent.""" - last_result = results[-1] - benchmark = last_result.benchmark + def _initial_prompt(self, results: list[Result]) -> Prompt: + """Constructs initial prompt of the agent.""" + last_result = results[-1] + benchmark = last_result.benchmark - if not isinstance(last_result, AnalysisResult): - logger.error('The last result in Enhancer is not AnalysisResult: %s', - results, - trial=self.trial) - return Prompt() + if not isinstance(last_result, AnalysisResult): + logger.error( + "The last result in Enhancer is not AnalysisResult: %s", + results, + trial=self.trial, + ) + return Prompt() - last_build_result = None - for result in results[::-1]: - if isinstance(result, BuildResult): - last_build_result = result - break - if not last_build_result: - logger.error('Unable to find the last build result in Enhancer : %s', - results, - trial=self.trial) - return Prompt() + last_build_result = None + for result in results[::-1]: + if isinstance(result, BuildResult): + last_build_result = result + break + if not last_build_result: + logger.error( + "Unable to find the last build result in Enhancer : %s", + results, + trial=self.trial, + ) + return Prompt() - if benchmark.language == 'jvm': - # TODO: Do this in a separate agent for JVM coverage. - builder = JvmFixingBuilder(self.llm, benchmark, - last_result.run_result.fuzz_target_source, []) - prompt = builder.build([], None, None) - else: - # TODO(dongge): Refine this logic. - if last_result.semantic_result: - error_desc, errors = last_result.semantic_result.get_error_info() - builder = EnhancerTemplateBuilder(self.llm, benchmark, - last_build_result, error_desc, errors) - elif last_result.coverage_result: - builder = CoverageEnhancerTemplateBuilder( - self.llm, - benchmark, - last_build_result, - coverage_result=last_result.coverage_result) - else: - logger.error( - 'Last result does not contain either semantic result or ' - 'coverage result', - trial=self.trial) - # TODO(dongge): Give some default initial prompt. - prompt = TextPrompt( - 'Last result does not contain either semantic result or ' - 'coverage result') - return prompt - prompt = builder.build(example_pair=[], - tool_guides=self.inspect_tool.tutorial(), - project_dir=self.inspect_tool.project_dir) - # TODO: A different file name/dir. - prompt.save(self.args.work_dirs.prompt) + if benchmark.language == "jvm": + # TODO: Do this in a separate agent for JVM coverage. + builder = JvmFixingBuilder( + self.llm, benchmark, last_result.run_result.fuzz_target_source, [] + ) + prompt = builder.build([], None, None) + else: + # TODO(dongge): Refine this logic. + if last_result.semantic_result: + error_desc, errors = last_result.semantic_result.get_error_info() + builder = EnhancerTemplateBuilder( + self.llm, benchmark, last_build_result, error_desc, errors + ) + elif last_result.coverage_result: + builder = CoverageEnhancerTemplateBuilder( + self.llm, + benchmark, + last_build_result, + coverage_result=last_result.coverage_result, + ) + else: + logger.error( + "Last result does not contain either semantic result or " + "coverage result", + trial=self.trial, + ) + # TODO(dongge): Give some default initial prompt. + prompt = TextPrompt( + "Last result does not contain either semantic result or " + "coverage result" + ) + return prompt + prompt = builder.build( + example_pair=[], + tool_guides=self.inspect_tool.tutorial(), + project_dir=self.inspect_tool.project_dir, + ) + # TODO: A different file name/dir. + prompt.save(self.args.work_dirs.prompt) - return prompt + return prompt diff --git a/agent/function_analyzer.py b/agent/function_analyzer.py index b87ad80af2..8535c063c8 100644 --- a/agent/function_analyzer.py +++ b/agent/function_analyzer.py @@ -35,150 +35,159 @@ def get_function_source_tool(project_name: str, function_signature: str): - """ - Retrieves a function's source using the project name and function signature. + """ + Retrieves a function's source using the project name and function signature. - Args: - project_name (str): The name of the project. - function_signature (str): The signature of the function. + Args: + project_name (str): The name of the project. + function_signature (str): The signature of the function. - Returns: - str: The source code of the function if found, otherwise an empty string. - """ + Returns: + str: The source code of the function if found, otherwise an empty string. + """ - function_code = introspector.query_introspector_function_source( - project_name, function_signature) - - if function_code: - logger.info("Function with signature '%s' found and extracted.", - function_signature) - else: - logger.info( - "Error: Function with signature '%s' not found in project '%s'.", - function_signature, project_name) - - return function_code - - -class FunctionAnalyzer(BaseAgent): - """An LLM agent to analyze a function and identify its implicit requirements. - The results of this analysis will be used by the writer agents to - generate correct fuzz target for the function. - """ - - def initialize(self, benchmark: benchmarklib.Benchmark): - """Initialize the function analyzer agent with the given benchmark.""" - - self.benchmark = benchmark - - # Initialize the prompt builder - self.prompt_builder = prompt_builder.FunctionAnalyzerTemplateBuilder( - self.llm, self.benchmark) - - # Get the agent's instructions - analyzer_instruction = self.prompt_builder.build_instruction() - - # Create the agent using the ADK library - function_analyzer = Agent( - name=self.name, - # TODO: Get the model name from args. - # Currently, the default names are incompatible with the ADK library. - model='gemini-2.0-flash', - description=( - "Agent to analyze a function and identify its requirements."), - instruction=analyzer_instruction.get(), - tools=[get_function_source_tool]) - - # Get user id and session id - # TODO: Figure out how to get this data - user_id = "user" - session_id = "session" - - # Create the session service - session_service = InMemorySessionService() - session_service.create_session( - app_name=self.name, - user_id=user_id, - session_id=session_id, + function_code = introspector.query_introspector_function_source( + project_name, function_signature ) - # Create the runner - self.runner = Runner( - agent=function_analyzer, - app_name=self.name, - session_service=session_service, - ) - - logger.info( - "Function Analyzer Agent created, with name: %s, and session id: %s", - self.name, session_id) - - def call_agent(self, query: str, runner: Runner, user_id: str, - session_id: str) -> PreWritingResult: - """Call the agent asynchronously with the given query.""" - - logger.info(">>> User query: %s", query) - - content = types.Content(role='user', parts=[types.Part(text=query)]) - - final_response_text = "Agent did not produce a final response." - - result_available = False - - for event in runner.run( - user_id=user_id, - session_id=session_id, - new_message=content, - ): - if event.is_final_response(): - if event.content and event.content.parts: - final_response_text = event.content.parts[0].text - result_available = True - elif event.actions and event.actions.escalate: - error_message = event.error_message or 'No specific message.' - final_response_text = f"Agent escalated: {error_message}" - break - - logger.info("<<< Agent response: %s", final_response_text) - - if result_available and final_response_text: - # Get the requirements from the response - requirements = self._parse_tags(final_response_text, 'requirement') + if function_code: + logger.info( + "Function with signature '%s' found and extracted.", function_signature + ) else: - requirements = [] - - # Prepare the result - result = PreWritingResult( - benchmark=self.benchmark, - trial=self.trial, - work_dirs=self.args.work_dir, - result_available=result_available, - requirements=requirements, - ) - - return result + logger.info( + "Error: Function with signature '%s' not found in project '%s'.", + function_signature, + project_name, + ) - def execute(self, result_history: list[Result]) -> PreWritingResult: - """Execute the agent with the given results.""" + return function_code - # Call the agent asynchronously and return the result - prompt = self._initial_prompt(result_history) - query = prompt.gettext() - user_id = "user" - session_id = "session" - result = self.call_agent(query, self.runner, user_id, session_id) - if result.result_available: - # Save the result to the history - result_history.append(result) - - return result - - def _initial_prompt(self, results: list[Result]) -> Prompt: - """Create the initial prompt for the agent.""" - - prompt = self.prompt_builder.build( - project_name=self.benchmark.project, - function_signature=self.benchmark.function_signature) - - return prompt +class FunctionAnalyzer(BaseAgent): + """An LLM agent to analyze a function and identify its implicit requirements. + The results of this analysis will be used by the writer agents to + generate correct fuzz target for the function. + """ + + def initialize(self, benchmark: benchmarklib.Benchmark): + """Initialize the function analyzer agent with the given benchmark.""" + + self.benchmark = benchmark + + # Initialize the prompt builder + self.prompt_builder = prompt_builder.FunctionAnalyzerTemplateBuilder( + self.llm, self.benchmark + ) + + # Get the agent's instructions + analyzer_instruction = self.prompt_builder.build_instruction() + + # Create the agent using the ADK library + function_analyzer = Agent( + name=self.name, + # TODO: Get the model name from args. + # Currently, the default names are incompatible with the ADK library. + model="gemini-2.0-flash", + description=("Agent to analyze a function and identify its requirements."), + instruction=analyzer_instruction.get(), + tools=[get_function_source_tool], + ) + + # Get user id and session id + # TODO: Figure out how to get this data + user_id = "user" + session_id = "session" + + # Create the session service + session_service = InMemorySessionService() + session_service.create_session( + app_name=self.name, + user_id=user_id, + session_id=session_id, + ) + + # Create the runner + self.runner = Runner( + agent=function_analyzer, + app_name=self.name, + session_service=session_service, + ) + + logger.info( + "Function Analyzer Agent created, with name: %s, and session id: %s", + self.name, + session_id, + ) + + def call_agent( + self, query: str, runner: Runner, user_id: str, session_id: str + ) -> PreWritingResult: + """Call the agent asynchronously with the given query.""" + + logger.info(">>> User query: %s", query) + + content = types.Content(role="user", parts=[types.Part(text=query)]) + + final_response_text = "Agent did not produce a final response." + + result_available = False + + for event in runner.run( + user_id=user_id, + session_id=session_id, + new_message=content, + ): + if event.is_final_response(): + if event.content and event.content.parts: + final_response_text = event.content.parts[0].text + result_available = True + elif event.actions and event.actions.escalate: + error_message = event.error_message or "No specific message." + final_response_text = f"Agent escalated: {error_message}" + break + + logger.info("<<< Agent response: %s", final_response_text) + + if result_available and final_response_text: + # Get the requirements from the response + requirements = self._parse_tags(final_response_text, "requirement") + else: + requirements = [] + + # Prepare the result + result = PreWritingResult( + benchmark=self.benchmark, + trial=self.trial, + work_dirs=self.args.work_dir, + result_available=result_available, + requirements=requirements, + ) + + return result + + def execute(self, result_history: list[Result]) -> PreWritingResult: + """Execute the agent with the given results.""" + + # Call the agent asynchronously and return the result + prompt = self._initial_prompt(result_history) + query = prompt.gettext() + user_id = "user" + session_id = "session" + result = self.call_agent(query, self.runner, user_id, session_id) + + if result.result_available: + # Save the result to the history + result_history.append(result) + + return result + + def _initial_prompt(self, results: list[Result]) -> Prompt: + """Create the initial prompt for the agent.""" + + prompt = self.prompt_builder.build( + project_name=self.benchmark.project, + function_signature=self.benchmark.function_signature, + ) + + return prompt diff --git a/agent/one_prompt_enhancer.py b/agent/one_prompt_enhancer.py index 7840c9e767..30a3516693 100644 --- a/agent/one_prompt_enhancer.py +++ b/agent/one_prompt_enhancer.py @@ -23,70 +23,77 @@ class OnePromptEnhancer(OnePromptPrototyper): - """The Agent to generate a simple but valid fuzz target from scratch.""" + """The Agent to generate a simple but valid fuzz target from scratch.""" - def _initial_prompt(self, results: list[Result]) -> Prompt: - """Constructs initial prompt of the agent.""" - last_result = results[-1] - benchmark = last_result.benchmark + def _initial_prompt(self, results: list[Result]) -> Prompt: + """Constructs initial prompt of the agent.""" + last_result = results[-1] + benchmark = last_result.benchmark - if not isinstance(last_result, AnalysisResult): - logger.error('The last result in Enhancer is not AnalysisResult: %s', - results, - trial=self.trial) - return Prompt() + if not isinstance(last_result, AnalysisResult): + logger.error( + "The last result in Enhancer is not AnalysisResult: %s", + results, + trial=self.trial, + ) + return Prompt() - if benchmark.language == 'jvm': - # TODO: Do this in a separate agent for JVM coverage. - builder = JvmFixingBuilder(self.llm, benchmark, - last_result.run_result.fuzz_target_source, []) - prompt = builder.build([], None, None) - else: - # TODO(dongge): Refine this logic. - builder = DefaultTemplateBuilder(self.llm) - if last_result.semantic_result: - error_desc, errors = last_result.semantic_result.get_error_info() - prompt = builder.build_fixer_prompt(benchmark, - last_result.fuzz_target_source, - error_desc, - errors, - context='', - instruction='') - else: - prompt = builder.build_fixer_prompt( - benchmark=benchmark, - raw_code=last_result.fuzz_target_source, - error_desc='', - errors=[], - coverage_result=last_result.coverage_result, - context='', - instruction='') - # TODO: A different file name/dir. - prompt.save(self.args.work_dirs.prompt) + if benchmark.language == "jvm": + # TODO: Do this in a separate agent for JVM coverage. + builder = JvmFixingBuilder( + self.llm, benchmark, last_result.run_result.fuzz_target_source, [] + ) + prompt = builder.build([], None, None) + else: + # TODO(dongge): Refine this logic. + builder = DefaultTemplateBuilder(self.llm) + if last_result.semantic_result: + error_desc, errors = last_result.semantic_result.get_error_info() + prompt = builder.build_fixer_prompt( + benchmark, + last_result.fuzz_target_source, + error_desc, + errors, + context="", + instruction="", + ) + else: + prompt = builder.build_fixer_prompt( + benchmark=benchmark, + raw_code=last_result.fuzz_target_source, + error_desc="", + errors=[], + coverage_result=last_result.coverage_result, + context="", + instruction="", + ) + # TODO: A different file name/dir. + prompt.save(self.args.work_dirs.prompt) - return prompt + return prompt - def execute(self, result_history: list[Result]) -> BuildResult: - """Executes the agent based on previous result.""" - last_result = result_history[-1] - logger.info('Executing One Prompt Enhancer', trial=last_result.trial) - # Use keep to avoid deleting files, such as benchmark.yaml - WorkDirs(self.args.work_dirs.base, keep=True) + def execute(self, result_history: list[Result]) -> BuildResult: + """Executes the agent based on previous result.""" + last_result = result_history[-1] + logger.info("Executing One Prompt Enhancer", trial=last_result.trial) + # Use keep to avoid deleting files, such as benchmark.yaml + WorkDirs(self.args.work_dirs.base, keep=True) - prompt = self._initial_prompt(result_history) - cur_round = 1 - build_result = BuildResult(benchmark=last_result.benchmark, - trial=last_result.trial, - work_dirs=last_result.work_dirs, - author=self, - chat_history={self.name: prompt.gettext()}) + prompt = self._initial_prompt(result_history) + cur_round = 1 + build_result = BuildResult( + benchmark=last_result.benchmark, + trial=last_result.trial, + work_dirs=last_result.work_dirs, + author=self, + chat_history={self.name: prompt.gettext()}, + ) - while prompt and cur_round <= self.max_round: - self._generate_fuzz_target(prompt, result_history, build_result, - cur_round) + while prompt and cur_round <= self.max_round: + self._generate_fuzz_target(prompt, result_history, build_result, cur_round) - self._validate_fuzz_target(cur_round, build_result) - prompt = self._advice_fuzz_target(build_result, cur_round) - cur_round += 1 + self._validate_fuzz_target(cur_round, build_result) + prompt = self._advice_fuzz_target(build_result, cur_round) + cur_round += 1 - return build_result + return build_result diff --git a/agent/one_prompt_prototyper.py b/agent/one_prompt_prototyper.py index 86c178b88c..c59b54e789 100644 --- a/agent/one_prompt_prototyper.py +++ b/agent/one_prompt_prototyper.py @@ -26,246 +26,309 @@ from data_prep.project_context.context_introspector import ContextRetriever from experiment.benchmark import Benchmark from experiment.workdir import WorkDirs -from llm_toolkit import (code_fixer, models, output_parser, prompt_builder, - prompts) +from llm_toolkit import ( + code_fixer, + models, + output_parser, + prompt_builder, + prompts, +) from llm_toolkit.prompts import Prompt from results import BuildResult, Result from tool.container_tool import ProjectContainerTool class OnePromptPrototyper(BaseAgent): - """The Agent to generate a simple but valid fuzz target from scratch.""" + """The Agent to generate a simple but valid fuzz target from scratch.""" - def _prompt_builder(self, - results: list[Result]) -> prompt_builder.PromptBuilder: - """Returns the prompt builder based on language and customization.""" - last_result = results[-1] - benchmark = last_result.benchmark - # If this is a test benchmark then we will use a test prompt builder. - if benchmark.test_file_path: - logger.info('Generating a target for test case: %s', - benchmark.test_file_path, - trial=last_result.trial) - return prompt_builder.TestToHarnessConverter(self.llm, benchmark, - self.args.template_directory) - # TODO: Do these in separate agents. - if benchmark.language == 'jvm': - # For Java projects - return prompt_builder.DefaultJvmTemplateBuilder( - self.llm, benchmark, self.args.template_directory) - if benchmark.language == 'python': - # For Python projects - return prompt_builder.DefaultPythonTemplateBuilder( - self.llm, benchmark, self.args.template_directory) - if benchmark.language == 'rust': - # For Rust projects - return prompt_builder.DefaultRustTemplateBuilder( - self.llm, benchmark, self.args.template_directory) + def _prompt_builder(self, results: list[Result]) -> prompt_builder.PromptBuilder: + """Returns the prompt builder based on language and customization.""" + last_result = results[-1] + benchmark = last_result.benchmark + # If this is a test benchmark then we will use a test prompt builder. + if benchmark.test_file_path: + logger.info( + "Generating a target for test case: %s", + benchmark.test_file_path, + trial=last_result.trial, + ) + return prompt_builder.TestToHarnessConverter( + self.llm, benchmark, self.args.template_directory + ) + # TODO: Do these in separate agents. + if benchmark.language == "jvm": + # For Java projects + return prompt_builder.DefaultJvmTemplateBuilder( + self.llm, benchmark, self.args.template_directory + ) + if benchmark.language == "python": + # For Python projects + return prompt_builder.DefaultPythonTemplateBuilder( + self.llm, benchmark, self.args.template_directory + ) + if benchmark.language == "rust": + # For Rust projects + return prompt_builder.DefaultRustTemplateBuilder( + self.llm, benchmark, self.args.template_directory + ) - if self.args.prompt_builder == 'CSpecific': - return prompt_builder.CSpecificBuilder(self.llm, benchmark, - self.args.template_directory) - # Use default - return prompt_builder.DefaultTemplateBuilder(self.llm, benchmark, - self.args.template_directory) + if self.args.prompt_builder == "CSpecific": + return prompt_builder.CSpecificBuilder( + self.llm, benchmark, self.args.template_directory + ) + # Use default + return prompt_builder.DefaultTemplateBuilder( + self.llm, benchmark, self.args.template_directory + ) - def _initial_prompt(self, results: list[Result]) -> Prompt: - """Constructs initial prompt of the agent.""" - last_result = results[-1] - benchmark = last_result.benchmark + def _initial_prompt(self, results: list[Result]) -> Prompt: + """Constructs initial prompt of the agent.""" + last_result = results[-1] + benchmark = last_result.benchmark - if benchmark.use_project_examples: - project_examples = project_targets.generate_data( - benchmark.project, - benchmark.language, - cloud_experiment_bucket=self.args.cloud_experiment_bucket) - else: - project_examples = [] + if benchmark.use_project_examples: + project_examples = project_targets.generate_data( + benchmark.project, + benchmark.language, + cloud_experiment_bucket=self.args.cloud_experiment_bucket, + ) + else: + project_examples = [] - if self.args.context: - retriever = ContextRetriever(benchmark) - context_info = retriever.get_context_info() - else: - context_info = {} + if self.args.context: + retriever = ContextRetriever(benchmark) + context_info = retriever.get_context_info() + else: + context_info = {} - builder = self._prompt_builder(results) - prompt = builder.build(prompt_builder.EXAMPLES.get(benchmark.language, []), - project_example_content=project_examples, - project_context_content=context_info) - prompt.save(self.args.work_dirs.prompt) - return prompt + builder = self._prompt_builder(results) + prompt = builder.build( + prompt_builder.EXAMPLES.get(benchmark.language, []), + project_example_content=project_examples, + project_context_content=context_info, + ) + prompt.save(self.args.work_dirs.prompt) + return prompt - def execute(self, result_history: list[Result]) -> BuildResult: - """Executes the agent based on previous result.""" - last_result = result_history[-1] - logger.info('Executing %s', self.name, trial=last_result.trial) - # Use keep to avoid deleting files, such as benchmark.yaml - WorkDirs(self.args.work_dirs.base, keep=True) + def execute(self, result_history: list[Result]) -> BuildResult: + """Executes the agent based on previous result.""" + last_result = result_history[-1] + logger.info("Executing %s", self.name, trial=last_result.trial) + # Use keep to avoid deleting files, such as benchmark.yaml + WorkDirs(self.args.work_dirs.base, keep=True) - prompt = self._initial_prompt(result_history) - cur_round = 1 - build_result = BuildResult(benchmark=last_result.benchmark, - trial=last_result.trial, - work_dirs=last_result.work_dirs, - author=self, - chat_history={self.name: prompt.gettext()}) + prompt = self._initial_prompt(result_history) + cur_round = 1 + build_result = BuildResult( + benchmark=last_result.benchmark, + trial=last_result.trial, + work_dirs=last_result.work_dirs, + author=self, + chat_history={self.name: prompt.gettext()}, + ) - while prompt and cur_round <= self.max_round: - self._generate_fuzz_target(prompt, result_history, build_result, - cur_round) - self._validate_fuzz_target(cur_round, build_result) - prompt = self._advice_fuzz_target(build_result, cur_round) - cur_round += 1 + while prompt and cur_round <= self.max_round: + self._generate_fuzz_target(prompt, result_history, build_result, cur_round) + self._validate_fuzz_target(cur_round, build_result) + prompt = self._advice_fuzz_target(build_result, cur_round) + cur_round += 1 - return build_result + return build_result - def _advice_fuzz_target(self, build_result: BuildResult, - cur_round: int) -> Optional[Prompt]: - """Returns a prompt to fix fuzz target based on its build result errors.""" - if build_result.success: - logger.info('***** %s succeded in %02d rounds *****', - self.name, - cur_round, - trial=build_result.trial) - return None - fixer_model = models.LLM.setup(ai_binary=self.args.ai_binary, - name=self.llm.name, - num_samples=1, - temperature=self.args.temperature) + def _advice_fuzz_target( + self, build_result: BuildResult, cur_round: int + ) -> Optional[Prompt]: + """Returns a prompt to fix fuzz target based on its build result errors.""" + if build_result.success: + logger.info( + "***** %s succeded in %02d rounds *****", + self.name, + cur_round, + trial=build_result.trial, + ) + return None + fixer_model = models.LLM.setup( + ai_binary=self.args.ai_binary, + name=self.llm.name, + num_samples=1, + temperature=self.args.temperature, + ) - errors = code_fixer.extract_error_from_lines( - build_result.compile_log.split('\n'), - os.path.basename(build_result.benchmark.target_path), - build_result.benchmark.language) - build_result.compile_error = '\n'.join(errors) - if build_result.benchmark.language == 'jvm': - builder = prompt_builder.JvmFixingBuilder( - fixer_model, build_result.benchmark, build_result.fuzz_target_source, - build_result.compile_error.split('\n')) - prompt = builder.build([], None, None) - else: - builder = prompt_builder.DefaultTemplateBuilder(fixer_model) + errors = code_fixer.extract_error_from_lines( + build_result.compile_log.split("\n"), + os.path.basename(build_result.benchmark.target_path), + build_result.benchmark.language, + ) + build_result.compile_error = "\n".join(errors) + if build_result.benchmark.language == "jvm": + builder = prompt_builder.JvmFixingBuilder( + fixer_model, + build_result.benchmark, + build_result.fuzz_target_source, + build_result.compile_error.split("\n"), + ) + prompt = builder.build([], None, None) + else: + builder = prompt_builder.DefaultTemplateBuilder(fixer_model) - context = code_fixer.collect_context(build_result.benchmark, errors) - instruction = code_fixer.collect_instructions( - build_result.benchmark, errors, build_result.fuzz_target_source) - prompt = builder.build_fixer_prompt(build_result.benchmark, - build_result.fuzz_target_source, - '', - errors, - context=context, - instruction=instruction) + context = code_fixer.collect_context(build_result.benchmark, errors) + instruction = code_fixer.collect_instructions( + build_result.benchmark, errors, build_result.fuzz_target_source + ) + prompt = builder.build_fixer_prompt( + build_result.benchmark, + build_result.fuzz_target_source, + "", + errors, + context=context, + instruction=instruction, + ) - return prompt + return prompt - def _generate_fuzz_target(self, prompt: prompts.Prompt, - result_history: list[Result], - build_result: BuildResult, cur_round: int) -> None: - """Generates and iterates fuzz target with LLM.""" - benchmark = build_result.benchmark + def _generate_fuzz_target( + self, + prompt: prompts.Prompt, + result_history: list[Result], + build_result: BuildResult, + cur_round: int, + ) -> None: + """Generates and iterates fuzz target with LLM.""" + benchmark = build_result.benchmark - logger.info('Generating targets for %s %s using %s..', - benchmark.project, - benchmark.function_signature, - self.llm.name, - trial=build_result.trial) + logger.info( + "Generating targets for %s %s using %s..", + benchmark.project, + benchmark.function_signature, + self.llm.name, + trial=build_result.trial, + ) - target_code = self.ask_llm(cur_round, prompt, self.trial) - target_code = output_parser.filter_code(target_code) - target_code = self._prompt_builder( - result_history).post_process_generated_code(target_code) - build_result.fuzz_target_source = target_code + target_code = self.ask_llm(cur_round, prompt, self.trial) + target_code = output_parser.filter_code(target_code) + target_code = self._prompt_builder(result_history).post_process_generated_code( + target_code + ) + build_result.fuzz_target_source = target_code - def _validate_fuzz_target(self, cur_round: int, - build_result: BuildResult) -> None: - """Validates the new fuzz target by recompiling it.""" - benchmark = build_result.benchmark - compilation_tool = ProjectContainerTool(benchmark=benchmark) + def _validate_fuzz_target(self, cur_round: int, build_result: BuildResult) -> None: + """Validates the new fuzz target by recompiling it.""" + benchmark = build_result.benchmark + compilation_tool = ProjectContainerTool(benchmark=benchmark) - # Replace fuzz target and build script in the container. - replace_file_content_command = ( - 'cat << "OFG_EOF" > {file_path}\n{file_content}\nOFG_EOF') - compilation_tool.execute( - replace_file_content_command.format( - file_path=benchmark.target_path, - file_content=build_result.fuzz_target_source)) + # Replace fuzz target and build script in the container. + replace_file_content_command = ( + 'cat << "OFG_EOF" > {file_path}\n{file_content}\nOFG_EOF' + ) + compilation_tool.execute( + replace_file_content_command.format( + file_path=benchmark.target_path, + file_content=build_result.fuzz_target_source, + ) + ) - if build_result.build_script_source: - compilation_tool.execute( - replace_file_content_command.format( - file_path='/src/build.sh', - file_content=build_result.build_script_source)) + if build_result.build_script_source: + compilation_tool.execute( + replace_file_content_command.format( + file_path="/src/build.sh", + file_content=build_result.build_script_source, + ) + ) - # Recompile. - logger.info('===== ROUND %02d Recompile =====', - cur_round, - trial=build_result.trial) - start_time = time.time() - compile_process = compilation_tool.compile() - end_time = time.time() - logger.debug('ROUND %02d compilation time: %s', - cur_round, - timedelta(seconds=end_time - start_time), - trial=build_result.trial) - compile_succeed = compile_process.returncode == 0 - logger.debug('ROUND %02d Fuzz target compiles: %s', - cur_round, - compile_succeed, - trial=build_result.trial) + # Recompile. + logger.info( + "===== ROUND %02d Recompile =====", cur_round, trial=build_result.trial + ) + start_time = time.time() + compile_process = compilation_tool.compile() + end_time = time.time() + logger.debug( + "ROUND %02d compilation time: %s", + cur_round, + timedelta(seconds=end_time - start_time), + trial=build_result.trial, + ) + compile_succeed = compile_process.returncode == 0 + logger.debug( + "ROUND %02d Fuzz target compiles: %s", + cur_round, + compile_succeed, + trial=build_result.trial, + ) - # Double-check binary. - ls_result = compilation_tool.execute(f'ls /out/{benchmark.target_name}') - binary_exists = ls_result.returncode == 0 - logger.debug('ROUND %02d Final fuzz target binary exists: %s', - cur_round, - binary_exists, - trial=build_result.trial) + # Double-check binary. + ls_result = compilation_tool.execute(f"ls /out/{benchmark.target_name}") + binary_exists = ls_result.returncode == 0 + logger.debug( + "ROUND %02d Final fuzz target binary exists: %s", + cur_round, + binary_exists, + trial=build_result.trial, + ) - # Validate if function-under-test is referenced by the fuzz target. - function_referenced = self._validate_fuzz_target_references_function( - compilation_tool, benchmark, cur_round, build_result.trial) + # Validate if function-under-test is referenced by the fuzz target. + function_referenced = self._validate_fuzz_target_references_function( + compilation_tool, benchmark, cur_round, build_result.trial + ) - compilation_tool.terminate() - self._update_build_result(build_result, - compile_process=compile_process, - compiles=compile_succeed, - binary_exists=binary_exists, - referenced=function_referenced) + compilation_tool.terminate() + self._update_build_result( + build_result, + compile_process=compile_process, + compiles=compile_succeed, + binary_exists=binary_exists, + referenced=function_referenced, + ) - def _validate_fuzz_target_references_function( - self, compilation_tool: ProjectContainerTool, benchmark: Benchmark, - cur_round: int, trial: int) -> bool: - """Validates if the LLM generated fuzz target assembly code references - function-under-test.""" + def _validate_fuzz_target_references_function( + self, + compilation_tool: ProjectContainerTool, + benchmark: Benchmark, + cur_round: int, + trial: int, + ) -> bool: + """Validates if the LLM generated fuzz target assembly code references + function-under-test.""" - # LLVMFuzzerTestOneInput and binary dumps are only valid - # for C/C++ projects. - # Temporary skipping this check for other language. - if benchmark.language in ['jvm', 'python', 'rust']: - return True + # LLVMFuzzerTestOneInput and binary dumps are only valid + # for C/C++ projects. + # Temporary skipping this check for other language. + if benchmark.language in ["jvm", "python", "rust"]: + return True - disassemble_result = compilation_tool.execute( - 'objdump --disassemble=LLVMFuzzerTestOneInput -d ' - f'/out/{benchmark.target_name}') - function_referenced = (disassemble_result.returncode == 0 and - benchmark.function_name in disassemble_result.stdout) - logger.debug('ROUND %02d Final fuzz target function referenced: %s', - cur_round, - function_referenced, - trial=trial) - if not function_referenced: - logger.debug('ROUND %02d Final fuzz target function not referenced', - cur_round, - trial=trial) - return function_referenced + disassemble_result = compilation_tool.execute( + "objdump --disassemble=LLVMFuzzerTestOneInput -d " + f"/out/{benchmark.target_name}" + ) + function_referenced = ( + disassemble_result.returncode == 0 + and benchmark.function_name in disassemble_result.stdout + ) + logger.debug( + "ROUND %02d Final fuzz target function referenced: %s", + cur_round, + function_referenced, + trial=trial, + ) + if not function_referenced: + logger.debug( + "ROUND %02d Final fuzz target function not referenced", + cur_round, + trial=trial, + ) + return function_referenced - def _update_build_result(self, build_result: BuildResult, - compile_process: sp.CompletedProcess, compiles: bool, - binary_exists: bool, referenced: bool) -> None: - """Updates the build result with the latest info.""" - build_result.compiles = compiles - build_result.binary_exists = binary_exists - build_result.compile_error = compile_process.stderr - build_result.compile_log = self._format_bash_execution_result( - compile_process) - build_result.is_function_referenced = referenced + def _update_build_result( + self, + build_result: BuildResult, + compile_process: sp.CompletedProcess, + compiles: bool, + binary_exists: bool, + referenced: bool, + ) -> None: + """Updates the build result with the latest info.""" + build_result.compiles = compiles + build_result.binary_exists = binary_exists + build_result.compile_error = compile_process.stderr + build_result.compile_log = self._format_bash_execution_result(compile_process) + build_result.is_function_referenced = referenced diff --git a/agent/prototyper.py b/agent/prototyper.py index 8bfbac9245..61eab7a548 100644 --- a/agent/prototyper.py +++ b/agent/prototyper.py @@ -34,421 +34,504 @@ class Prototyper(BaseAgent): - """The Agent to generate a simple but valid fuzz target from scratch.""" - - def _initial_prompt(self, results: list[Result]) -> Prompt: - """Constructs initial prompt of the agent.""" - benchmark = results[-1].benchmark - - if benchmark.use_project_examples: - project_examples = project_targets.generate_data( - benchmark.project, - benchmark.language, - cloud_experiment_bucket=self.args.cloud_experiment_bucket) - else: - project_examples = [] - - if self.args.context: - retriever = ContextRetriever(benchmark) - context_info = retriever.get_context_info() - else: - context_info = {} - - builder = prompt_builder.PrototyperTemplateBuilder( - model=self.llm, - benchmark=benchmark, - ) - prompt = builder.build(example_pair=prompt_builder.EXAMPLES.get( - benchmark.file_type.value.lower(), []), - project_example_content=project_examples, - project_context_content=context_info, - tool_guides=self.inspect_tool.tutorial(), - project_dir=self.inspect_tool.project_dir) - return prompt - - def _update_fuzz_target_and_build_script(self, response: str, - build_result: BuildResult) -> None: - """Updates fuzz target and build script in build_result with LLM response. - """ - fuzz_target_source = self._filter_code( - self._parse_tag(response, 'fuzz target')) - build_result.fuzz_target_source = fuzz_target_source - - build_script_source = self._filter_code( - self._parse_tag(response, 'build script')) - # Sometimes LLM adds chronos, which makes no sense for new build scripts. - build_result.build_script_source = build_script_source.replace( - 'source /src/chronos.sh', '') - - def _update_build_result(self, build_result: BuildResult, - compile_process: sp.CompletedProcess, compiles: bool, - binary_exists: bool, referenced: bool) -> None: - """Updates the build result with the latest info.""" - build_result.compiles = compiles - build_result.compile_error = compile_process.stderr - build_result.compile_log = self._format_bash_execution_result( - compile_process) - build_result.binary_exists = binary_exists - build_result.is_function_referenced = referenced - - def _validate_fuzz_target_and_build_script( - self, cur_round: int, build_result: BuildResult - ) -> tuple[Optional[BuildResult], Optional[BuildResult]]: - """Validates the new fuzz target and build script.""" - # Steps: - # 1. Recompile without modifying the build script, in case LLM is wrong. - # 2. Recompile with the modified build script, if any. - build_result_alt = None - if build_result.build_script_source: - build_result_alt = copy.deepcopy(build_result) - logger.info('First compile fuzz target without modifying build script.', - trial=build_result_alt.trial) - build_result_alt.build_script_source = '' - self._validate_fuzz_target_and_build_script_via_compile( - cur_round, build_result_alt) - - # No need to run expensive build_result, when *_alt is perfect. - if build_result_alt and build_result_alt.success: - return build_result_alt, None - - # New fuzz target + has new build.sh. - logger.info('Compile fuzz target with modified build script.', - trial=build_result.trial) - self._validate_fuzz_target_and_build_script_via_compile( - cur_round, build_result) - - # Although build_result_alt is not perfect, LLM may still learn from it. - return build_result_alt, build_result - - def _validate_fuzz_target_references_function( - self, compilation_tool: ProjectContainerTool, benchmark: Benchmark, - cur_round: int, trial: int) -> bool: - """Validates if the LLM generated fuzz target assembly code references - function-under-test.""" - - # LLVMFuzzerTestOneInput and binary dumps are only valid - # for C/C++ projects. - # Temporary skipping this check for other language. - if benchmark.language in ['jvm', 'python', 'rust']: - return True - - disassemble_result = compilation_tool.execute( - 'objdump --disassemble=LLVMFuzzerTestOneInput -d ' - f'/out/{benchmark.target_name}') - function_referenced = (disassemble_result.returncode == 0 and - benchmark.function_name in disassemble_result.stdout) - logger.debug('ROUND %02d Final fuzz target function referenced: %s', - cur_round, - function_referenced, - trial=trial) - if not function_referenced: - logger.debug('ROUND %02d Final fuzz target function not referenced', - cur_round, - trial=trial) - return function_referenced - - def _validate_fuzz_target_and_build_script_via_compile( - self, cur_round: int, build_result: BuildResult) -> None: - """Validates the new fuzz target and build script by recompiling them.""" - benchmark = build_result.benchmark - compilation_tool = ProjectContainerTool(benchmark=benchmark) - - # Replace fuzz target and build script in the container. - compilation_tool.write_to_file(content=build_result.fuzz_target_source, - file_path=benchmark.target_path) - if build_result.build_script_source: - compilation_tool.write_to_file( - content=build_result.build_script_source, - file_path=compilation_tool.build_script_path) - - # Recompile. - logger.info('===== ROUND %02d Recompile =====', + """The Agent to generate a simple but valid fuzz target from scratch.""" + + def _initial_prompt(self, results: list[Result]) -> Prompt: + """Constructs initial prompt of the agent.""" + benchmark = results[-1].benchmark + + if benchmark.use_project_examples: + project_examples = project_targets.generate_data( + benchmark.project, + benchmark.language, + cloud_experiment_bucket=self.args.cloud_experiment_bucket, + ) + else: + project_examples = [] + + if self.args.context: + retriever = ContextRetriever(benchmark) + context_info = retriever.get_context_info() + else: + context_info = {} + + builder = prompt_builder.PrototyperTemplateBuilder( + model=self.llm, + benchmark=benchmark, + ) + prompt = builder.build( + example_pair=prompt_builder.EXAMPLES.get( + benchmark.file_type.value.lower(), [] + ), + project_example_content=project_examples, + project_context_content=context_info, + tool_guides=self.inspect_tool.tutorial(), + project_dir=self.inspect_tool.project_dir, + ) + return prompt + + def _update_fuzz_target_and_build_script( + self, response: str, build_result: BuildResult + ) -> None: + """Updates fuzz target and build script in build_result with LLM response.""" + fuzz_target_source = self._filter_code(self._parse_tag(response, "fuzz target")) + build_result.fuzz_target_source = fuzz_target_source + + build_script_source = self._filter_code( + self._parse_tag(response, "build script") + ) + # Sometimes LLM adds chronos, which makes no sense for new build scripts. + build_result.build_script_source = build_script_source.replace( + "source /src/chronos.sh", "" + ) + + def _update_build_result( + self, + build_result: BuildResult, + compile_process: sp.CompletedProcess, + compiles: bool, + binary_exists: bool, + referenced: bool, + ) -> None: + """Updates the build result with the latest info.""" + build_result.compiles = compiles + build_result.compile_error = compile_process.stderr + build_result.compile_log = self._format_bash_execution_result(compile_process) + build_result.binary_exists = binary_exists + build_result.is_function_referenced = referenced + + def _validate_fuzz_target_and_build_script( + self, cur_round: int, build_result: BuildResult + ) -> tuple[Optional[BuildResult], Optional[BuildResult]]: + """Validates the new fuzz target and build script.""" + # Steps: + # 1. Recompile without modifying the build script, in case LLM is wrong. + # 2. Recompile with the modified build script, if any. + build_result_alt = None + if build_result.build_script_source: + build_result_alt = copy.deepcopy(build_result) + logger.info( + "First compile fuzz target without modifying build script.", + trial=build_result_alt.trial, + ) + build_result_alt.build_script_source = "" + self._validate_fuzz_target_and_build_script_via_compile( + cur_round, build_result_alt + ) + + # No need to run expensive build_result, when *_alt is perfect. + if build_result_alt and build_result_alt.success: + return build_result_alt, None + + # New fuzz target + has new build.sh. + logger.info( + "Compile fuzz target with modified build script.", trial=build_result.trial + ) + self._validate_fuzz_target_and_build_script_via_compile(cur_round, build_result) + + # Although build_result_alt is not perfect, LLM may still learn from it. + return build_result_alt, build_result + + def _validate_fuzz_target_references_function( + self, + compilation_tool: ProjectContainerTool, + benchmark: Benchmark, + cur_round: int, + trial: int, + ) -> bool: + """Validates if the LLM generated fuzz target assembly code references + function-under-test.""" + + # LLVMFuzzerTestOneInput and binary dumps are only valid + # for C/C++ projects. + # Temporary skipping this check for other language. + if benchmark.language in ["jvm", "python", "rust"]: + return True + + disassemble_result = compilation_tool.execute( + "objdump --disassemble=LLVMFuzzerTestOneInput -d " + f"/out/{benchmark.target_name}" + ) + function_referenced = ( + disassemble_result.returncode == 0 + and benchmark.function_name in disassemble_result.stdout + ) + logger.debug( + "ROUND %02d Final fuzz target function referenced: %s", + cur_round, + function_referenced, + trial=trial, + ) + if not function_referenced: + logger.debug( + "ROUND %02d Final fuzz target function not referenced", cur_round, - trial=build_result.trial) - start_time = time.time() - compile_process = compilation_tool.compile() - end_time = time.time() - logger.debug('ROUND %02d compilation time: %s', - cur_round, - timedelta(seconds=end_time - start_time), - trial=build_result.trial) - compile_succeed = compile_process.returncode == 0 - logger.debug('ROUND %02d Fuzz target compiles: %s', - cur_round, - compile_succeed, - trial=build_result.trial) - - # Double-check binary. - ls_result = compilation_tool.execute(f'ls /out/{benchmark.target_name}') - binary_exists = ls_result.returncode == 0 - logger.debug('ROUND %02d Final fuzz target binary exists: %s', - cur_round, - binary_exists, - trial=build_result.trial) - - # Validate if function-under-test is referenced by the fuzz target. - function_referenced = self._validate_fuzz_target_references_function( - compilation_tool, benchmark, cur_round, build_result.trial) - - compilation_tool.terminate() - self._update_build_result(build_result, - compile_process=compile_process, - compiles=compile_succeed, - binary_exists=binary_exists, - referenced=function_referenced) - - def _generate_prompt_from_build_result( - self, build_result_alt: Optional[BuildResult], - build_result_ori: Optional[BuildResult], build_result: BuildResult, - prompt: Prompt, cur_round: int) -> tuple[BuildResult, Optional[Prompt]]: - """Selects which build result to use and generates a prompt accordingly.""" - - # Case 1: Successful. - if build_result_alt and build_result_alt.success: - # Preference 1: New fuzz target + default build.sh can compile, save - # binary to expected path, and reference function-under-test. - logger.info( - 'Default /src/build.sh works perfectly, no need for a new ' - 'buid script', - trial=build_result.trial) - logger.info('***** %s succeeded in %02d rounds *****', - self.name, - cur_round, - trial=build_result.trial) - return build_result_alt, None - - if build_result_ori and build_result_ori.success: - # Preference 2: New fuzz target + new build.sh can compile, save - # binary to expected path, and reference function-under-test. - logger.info('***** %s succeeded in %02d rounds *****', - self.name, - cur_round, - trial=build_result.trial) - return build_result_ori, None - - # Case 2: Binary exits, meaning not referencing function-under-test. - function_signature = build_result.benchmark.function_signature - fuzz_target_source = build_result.fuzz_target_source - build_script_source = build_result.build_script_source - compile_log = self.llm.truncate_prompt(build_result.compile_log, - extra_text=prompt.get()).strip() - prompt_text = ( - "The fuzz target's `LLVMFuzzerTestOneInput` did not invoke the " - f'function-under-test `{function_signature}`:\n' - f'\n{fuzz_target_source}\n\n' - '{BUILD_TEXT}\n' - f'\n{compile_log}\n\n' - 'That is NOT enough. YOU MUST MODIFY THE FUZZ TARGET to CALL ' - f'FUNCTION `{function_signature}` **EXPLICITLY OR IMPLICITLY** in ' - '`LLVMFuzzerTestOneInput` to generate a valid fuzz target.\nStudy the ' - 'source code for function usages to know how.\n') - if build_result_alt and build_result_alt.binary_exists: - # Preference 3: New fuzz target + default build.sh can compile and save - # binary to expected path, but does not reference function-under-test. - prompt_text = prompt_text.replace( - '{BUILD_TEXT}', - 'Althoug `/src/build.bk.sh` compiles and saves the binary to the ' - 'correct path:') - # NOTE: Unsafe to say the following, because /src/build.sh may miss a - # library required by the function-under-test, and the fuzz target did not - # invoke the function-under-test either. - # prompt_text += ( - # 'In addition, given the default /src/build.sh works perfectly, you ' - # 'do not have to generate a new build script and can leave ' - # ' empty.') - prompt_text += ( - 'When you have a solution later, make sure you output the FULL fuzz ' - 'target. YOU MUST NOT OMIT ANY CODE even if it is the same as before.' - '\n') - prompt.append(prompt_text) - return build_result_alt, prompt - if (build_result_ori and build_result_ori.binary_exists and - not build_result_ori.build_script_source): - # Preference 4.1: New fuzz target + default build.sh can compile and save - # binary to expected path, but does not reference function-under-test. - prompt_text = prompt_text.replace( - '{BUILD_TEXT}', - 'Althoug `/src/build.bk.sh` compiles and saves the binary to the ' - 'correct path:') - prompt_text += ( - 'When you have a solution later, make sure you output the FULL fuzz ' - 'target. YOU MUST NOT OMIT ANY CODE even if it is the same as before.' - '\n') - prompt.append(prompt_text) - return build_result_ori, prompt - if build_result_ori and build_result_ori.binary_exists: - # Preference 4.2: New fuzz target + New build.sh can compile and save - # binary to expected path, but does not reference function-under-test. - prompt_text = prompt_text.replace( - '{BUILD_TEXT}', - 'Althoug your build script compiles and saves the binary to the ' - 'correct path:\n' - f'\n{build_script_source}\n\n') - prompt_text += ( - 'When you have a solution later, make sure you output the FULL fuzz ' - 'target (and the FULL build script, if any). YOU MUST NOT OMIT ANY ' - 'CODE even if it is the same as before.\n') - prompt.append(prompt_text) - return build_result_ori, prompt - - # Case 3: Compiles, meaning the binary is not saved. - binary_path = os.path.join('/out', build_result.benchmark.target_name) - if (build_result_ori and build_result_ori.compiles and - build_result_ori.build_script_source): - # Preference 5.1: New fuzz target + new build.sh can compile, but does - # not save binary to expected path. - prompt_text = ( - 'The fuzz target and build script compiles successfully, but the ' - 'final fuzz target binary was not saved to the expected path at ' - f'`{binary_path}`.\n' - f'\n{fuzz_target_source}\n\n' - f'\n{build_script_source}\n\n' - f'\n{compile_log}\n\n' - 'YOU MUST MODIFY THE BUILD SCRIPT to ensure the binary is saved to ' - f'{binary_path}.\n') - prompt_text += ( - 'When you have a solution later, make sure you output the FULL fuzz ' - 'target (and the FULL build script, if any). YOU MUST NOT OMIT ANY ' - 'CODE even if it is the same as before.\n') - prompt.append(prompt_text) - return build_result_ori, prompt - if (build_result_ori and build_result_ori.compiles and - not build_result_ori.build_script_source): - # Preference 5.2: New fuzz target + default build.sh can compile, but does - # not save binary to expected path, indicating benchmark data error. - logger.error( - 'The human-written build.sh does not save the fuzz target binary to ' - 'expected path /out/%s, indicating incorrect info in benchmark YAML.', - build_result.benchmark.target_name, - trial=build_result.trial) - prompt_text = ( - 'The fuzz target compiles successfully with /src/build.bk.sh, but the' - ' final fuzz target binary was not saved to the expected path at ' - f'`{binary_path}`.\n' - f'\n{fuzz_target_source}\n\n' - f'\n{compile_log}\n\n' - 'YOU MUST MODIFY THE BUILD SCRIPT to ensure the binary is saved to ' - f'{binary_path}.\n') - prompt_text += ( - 'When you have a solution later, make sure you output the FULL fuzz ' - 'target (and the FULL build script, if any). YOU MUST NOT OMIT ANY ' - 'CODE even if it is the same as before.\n') - prompt.append(prompt_text) - return build_result_ori, prompt - if build_result_alt and build_result_alt.compiles: - # Preference 6: New fuzz target + default build.sh can compile, but does - # not save binary to expected path, indicating benchmark data error. - logger.error( - 'The human-written build.sh does not save the fuzz target binary to ' - 'expected path /out/%s, indicating incorrect info in benchmark YAML.', - build_result.benchmark.target_name, - trial=build_result.trial) - prompt_text = ( - 'The fuzz target compiles successfully with /src/build.bk.sh, but the' - ' final fuzz target binary was not saved to the expected path at ' - f'`{binary_path}`.\n' - f'\n{fuzz_target_source}\n\n' - f'\n{compile_log}\n\n' - 'YOU MUST MODIFY THE BUILD SCRIPT to ensure the binary is saved to ' - f'{binary_path}.\n') - prompt_text += ( - 'When you have a solution later, make sure you output the FULL fuzz ' - 'target (and the FULL build script, if any). YOU MUST NOT OMIT ANY ' - 'CODE even if it is the same as before.\n') - prompt.append(prompt_text) - return build_result_alt, prompt - - # Preference 7: New fuzz target + both `build.sh`s cannot compile. No need - # to mention the default build.sh. - # return build_result - builder = prompt_builder.PrototyperFixerTemplateBuilder( - model=self.llm, - benchmark=build_result.benchmark, - build_result=build_result, - compile_log=compile_log, - initial=prompt.get()) - prompt = builder.build(example_pair=[], - project_dir=self.inspect_tool.project_dir) - return build_result, prompt - - def _container_handle_conclusion(self, cur_round: int, response: str, - build_result: BuildResult, - prompt: Prompt) -> Optional[Prompt]: - """Runs a compilation tool to validate the new fuzz target and build script - from LLM.""" - if not self._parse_tag(response, 'fuzz target'): - return prompt - logger.info('----- ROUND %02d Received conclusion -----', + trial=trial, + ) + return function_referenced + + def _validate_fuzz_target_and_build_script_via_compile( + self, cur_round: int, build_result: BuildResult + ) -> None: + """Validates the new fuzz target and build script by recompiling them.""" + benchmark = build_result.benchmark + compilation_tool = ProjectContainerTool(benchmark=benchmark) + + # Replace fuzz target and build script in the container. + compilation_tool.write_to_file( + content=build_result.fuzz_target_source, file_path=benchmark.target_path + ) + if build_result.build_script_source: + compilation_tool.write_to_file( + content=build_result.build_script_source, + file_path=compilation_tool.build_script_path, + ) + + # Recompile. + logger.info( + "===== ROUND %02d Recompile =====", cur_round, trial=build_result.trial + ) + start_time = time.time() + compile_process = compilation_tool.compile() + end_time = time.time() + logger.debug( + "ROUND %02d compilation time: %s", + cur_round, + timedelta(seconds=end_time - start_time), + trial=build_result.trial, + ) + compile_succeed = compile_process.returncode == 0 + logger.debug( + "ROUND %02d Fuzz target compiles: %s", + cur_round, + compile_succeed, + trial=build_result.trial, + ) + + # Double-check binary. + ls_result = compilation_tool.execute(f"ls /out/{benchmark.target_name}") + binary_exists = ls_result.returncode == 0 + logger.debug( + "ROUND %02d Final fuzz target binary exists: %s", + cur_round, + binary_exists, + trial=build_result.trial, + ) + + # Validate if function-under-test is referenced by the fuzz target. + function_referenced = self._validate_fuzz_target_references_function( + compilation_tool, benchmark, cur_round, build_result.trial + ) + + compilation_tool.terminate() + self._update_build_result( + build_result, + compile_process=compile_process, + compiles=compile_succeed, + binary_exists=binary_exists, + referenced=function_referenced, + ) + + def _generate_prompt_from_build_result( + self, + build_result_alt: Optional[BuildResult], + build_result_ori: Optional[BuildResult], + build_result: BuildResult, + prompt: Prompt, + cur_round: int, + ) -> tuple[BuildResult, Optional[Prompt]]: + """Selects which build result to use and generates a prompt accordingly.""" + + # Case 1: Successful. + if build_result_alt and build_result_alt.success: + # Preference 1: New fuzz target + default build.sh can compile, save + # binary to expected path, and reference function-under-test. + logger.info( + "Default /src/build.sh works perfectly, no need for a new " + "buid script", + trial=build_result.trial, + ) + logger.info( + "***** %s succeeded in %02d rounds *****", + self.name, cur_round, - trial=build_result.trial) - - self._update_fuzz_target_and_build_script(response, build_result) - - build_result_alt, build_result_ori = ( - self._validate_fuzz_target_and_build_script(cur_round, build_result)) - - # Updates build_result with _alt or _ori, depending on their status. - final_build_result, prompt_final = self._generate_prompt_from_build_result( - build_result_alt, build_result_ori, build_result, prompt, cur_round) - # Ensure build_result is consistent with the one selected. - if final_build_result is not None: - build_result.__dict__.update(final_build_result.__dict__) - - return prompt_final - - def _container_tool_reaction(self, cur_round: int, response: str, - build_result: BuildResult) -> Optional[Prompt]: - """Validates LLM conclusion or executes its command.""" - prompt = prompt_builder.DefaultTemplateBuilder(self.llm, None).build([]) - - if response: - prompt = self._container_handle_bash_commands(response, self.inspect_tool, - prompt) - - # Then build fuzz target. - prompt = self._container_handle_conclusion(cur_round, response, - build_result, prompt) - if prompt is None: - # Succeeded. - return None - - # Finally check invalid responses. - if not response or not prompt.get(): - prompt = self._container_handle_invalid_tool_usage( - self.inspect_tool, cur_round, response, prompt) - - return prompt - - def execute(self, result_history: list[Result]) -> BuildResult: - """Executes the agent based on previous result.""" - # Use keep to avoid deleting files, such as benchmark.yaml - WorkDirs(self.args.work_dirs.base, keep=True) - last_result = result_history[-1] - logger.info('Executing %s', self.name, trial=last_result.trial) - benchmark = last_result.benchmark - self.inspect_tool = ProjectContainerTool(benchmark, name='inspect') - self.inspect_tool.compile(extra_commands=' && rm -rf /out/* > /dev/null') - cur_round = 1 - build_result = BuildResult(benchmark=benchmark, - trial=last_result.trial, - work_dirs=last_result.work_dirs, - author=self, - chat_history={self.name: ''}) - prompt = self._initial_prompt(result_history) - try: - client = self.llm.get_chat_client(model=self.llm.get_model()) - while prompt and cur_round < self.max_round: - response = self.chat_llm(cur_round, - client=client, - prompt=prompt, - trial=last_result.trial) - prompt = self._container_tool_reaction(cur_round, response, - build_result) - cur_round += 1 - finally: - # Cleanup: stop and remove the container - logger.debug('Stopping and removing the inspect container %s', - self.inspect_tool.container_id, - trial=last_result.trial) - self.inspect_tool.terminate() - return build_result + trial=build_result.trial, + ) + return build_result_alt, None + + if build_result_ori and build_result_ori.success: + # Preference 2: New fuzz target + new build.sh can compile, save + # binary to expected path, and reference function-under-test. + logger.info( + "***** %s succeeded in %02d rounds *****", + self.name, + cur_round, + trial=build_result.trial, + ) + return build_result_ori, None + + # Case 2: Binary exits, meaning not referencing function-under-test. + function_signature = build_result.benchmark.function_signature + fuzz_target_source = build_result.fuzz_target_source + build_script_source = build_result.build_script_source + compile_log = self.llm.truncate_prompt( + build_result.compile_log, extra_text=prompt.get() + ).strip() + prompt_text = ( + "The fuzz target's `LLVMFuzzerTestOneInput` did not invoke the " + f"function-under-test `{function_signature}`:\n" + f"\n{fuzz_target_source}\n\n" + "{BUILD_TEXT}\n" + f"\n{compile_log}\n\n" + "That is NOT enough. YOU MUST MODIFY THE FUZZ TARGET to CALL " + f"FUNCTION `{function_signature}` **EXPLICITLY OR IMPLICITLY** in " + "`LLVMFuzzerTestOneInput` to generate a valid fuzz target.\nStudy the " + "source code for function usages to know how.\n" + ) + if build_result_alt and build_result_alt.binary_exists: + # Preference 3: New fuzz target + default build.sh can compile and save + # binary to expected path, but does not reference function-under-test. + prompt_text = prompt_text.replace( + "{BUILD_TEXT}", + "Althoug `/src/build.bk.sh` compiles and saves the binary to the " + "correct path:", + ) + # NOTE: Unsafe to say the following, because /src/build.sh may miss a + # library required by the function-under-test, and the fuzz target did not + # invoke the function-under-test either. + # prompt_text += ( + # 'In addition, given the default /src/build.sh works perfectly, you ' + # 'do not have to generate a new build script and can leave ' + # ' empty.') + prompt_text += ( + "When you have a solution later, make sure you output the FULL fuzz " + "target. YOU MUST NOT OMIT ANY CODE even if it is the same as before." + "\n" + ) + prompt.append(prompt_text) + return build_result_alt, prompt + if ( + build_result_ori + and build_result_ori.binary_exists + and not build_result_ori.build_script_source + ): + # Preference 4.1: New fuzz target + default build.sh can compile and save + # binary to expected path, but does not reference function-under-test. + prompt_text = prompt_text.replace( + "{BUILD_TEXT}", + "Althoug `/src/build.bk.sh` compiles and saves the binary to the " + "correct path:", + ) + prompt_text += ( + "When you have a solution later, make sure you output the FULL fuzz " + "target. YOU MUST NOT OMIT ANY CODE even if it is the same as before." + "\n" + ) + prompt.append(prompt_text) + return build_result_ori, prompt + if build_result_ori and build_result_ori.binary_exists: + # Preference 4.2: New fuzz target + New build.sh can compile and save + # binary to expected path, but does not reference function-under-test. + prompt_text = prompt_text.replace( + "{BUILD_TEXT}", + "Althoug your build script compiles and saves the binary to the " + "correct path:\n" + f"\n{build_script_source}\n\n", + ) + prompt_text += ( + "When you have a solution later, make sure you output the FULL fuzz " + "target (and the FULL build script, if any). YOU MUST NOT OMIT ANY " + "CODE even if it is the same as before.\n" + ) + prompt.append(prompt_text) + return build_result_ori, prompt + + # Case 3: Compiles, meaning the binary is not saved. + binary_path = os.path.join("/out", build_result.benchmark.target_name) + if ( + build_result_ori + and build_result_ori.compiles + and build_result_ori.build_script_source + ): + # Preference 5.1: New fuzz target + new build.sh can compile, but does + # not save binary to expected path. + prompt_text = ( + "The fuzz target and build script compiles successfully, but the " + "final fuzz target binary was not saved to the expected path at " + f"`{binary_path}`.\n" + f"\n{fuzz_target_source}\n\n" + f"\n{build_script_source}\n\n" + f"\n{compile_log}\n\n" + "YOU MUST MODIFY THE BUILD SCRIPT to ensure the binary is saved to " + f"{binary_path}.\n" + ) + prompt_text += ( + "When you have a solution later, make sure you output the FULL fuzz " + "target (and the FULL build script, if any). YOU MUST NOT OMIT ANY " + "CODE even if it is the same as before.\n" + ) + prompt.append(prompt_text) + return build_result_ori, prompt + if ( + build_result_ori + and build_result_ori.compiles + and not build_result_ori.build_script_source + ): + # Preference 5.2: New fuzz target + default build.sh can compile, but does + # not save binary to expected path, indicating benchmark data error. + logger.error( + "The human-written build.sh does not save the fuzz target binary to " + "expected path /out/%s, indicating incorrect info in benchmark YAML.", + build_result.benchmark.target_name, + trial=build_result.trial, + ) + prompt_text = ( + "The fuzz target compiles successfully with /src/build.bk.sh, but the" + " final fuzz target binary was not saved to the expected path at " + f"`{binary_path}`.\n" + f"\n{fuzz_target_source}\n\n" + f"\n{compile_log}\n\n" + "YOU MUST MODIFY THE BUILD SCRIPT to ensure the binary is saved to " + f"{binary_path}.\n" + ) + prompt_text += ( + "When you have a solution later, make sure you output the FULL fuzz " + "target (and the FULL build script, if any). YOU MUST NOT OMIT ANY " + "CODE even if it is the same as before.\n" + ) + prompt.append(prompt_text) + return build_result_ori, prompt + if build_result_alt and build_result_alt.compiles: + # Preference 6: New fuzz target + default build.sh can compile, but does + # not save binary to expected path, indicating benchmark data error. + logger.error( + "The human-written build.sh does not save the fuzz target binary to " + "expected path /out/%s, indicating incorrect info in benchmark YAML.", + build_result.benchmark.target_name, + trial=build_result.trial, + ) + prompt_text = ( + "The fuzz target compiles successfully with /src/build.bk.sh, but the" + " final fuzz target binary was not saved to the expected path at " + f"`{binary_path}`.\n" + f"\n{fuzz_target_source}\n\n" + f"\n{compile_log}\n\n" + "YOU MUST MODIFY THE BUILD SCRIPT to ensure the binary is saved to " + f"{binary_path}.\n" + ) + prompt_text += ( + "When you have a solution later, make sure you output the FULL fuzz " + "target (and the FULL build script, if any). YOU MUST NOT OMIT ANY " + "CODE even if it is the same as before.\n" + ) + prompt.append(prompt_text) + return build_result_alt, prompt + + # Preference 7: New fuzz target + both `build.sh`s cannot compile. No need + # to mention the default build.sh. + # return build_result + builder = prompt_builder.PrototyperFixerTemplateBuilder( + model=self.llm, + benchmark=build_result.benchmark, + build_result=build_result, + compile_log=compile_log, + initial=prompt.get(), + ) + prompt = builder.build( + example_pair=[], project_dir=self.inspect_tool.project_dir + ) + return build_result, prompt + + def _container_handle_conclusion( + self, cur_round: int, response: str, build_result: BuildResult, prompt: Prompt + ) -> Optional[Prompt]: + """Runs a compilation tool to validate the new fuzz target and build script + from LLM.""" + if not self._parse_tag(response, "fuzz target"): + return prompt + logger.info( + "----- ROUND %02d Received conclusion -----", + cur_round, + trial=build_result.trial, + ) + + self._update_fuzz_target_and_build_script(response, build_result) + + build_result_alt, build_result_ori = ( + self._validate_fuzz_target_and_build_script(cur_round, build_result) + ) + + # Updates build_result with _alt or _ori, depending on their status. + final_build_result, prompt_final = self._generate_prompt_from_build_result( + build_result_alt, build_result_ori, build_result, prompt, cur_round + ) + # Ensure build_result is consistent with the one selected. + if final_build_result is not None: + build_result.__dict__.update(final_build_result.__dict__) + + return prompt_final + + def _container_tool_reaction( + self, cur_round: int, response: str, build_result: BuildResult + ) -> Optional[Prompt]: + """Validates LLM conclusion or executes its command.""" + prompt = prompt_builder.DefaultTemplateBuilder(self.llm, None).build([]) + + if response: + prompt = self._container_handle_bash_commands( + response, self.inspect_tool, prompt + ) + + # Then build fuzz target. + prompt = self._container_handle_conclusion( + cur_round, response, build_result, prompt + ) + if prompt is None: + # Succeeded. + return None + + # Finally check invalid responses. + if not response or not prompt.get(): + prompt = self._container_handle_invalid_tool_usage( + self.inspect_tool, cur_round, response, prompt + ) + + return prompt + + def execute(self, result_history: list[Result]) -> BuildResult: + """Executes the agent based on previous result.""" + # Use keep to avoid deleting files, such as benchmark.yaml + WorkDirs(self.args.work_dirs.base, keep=True) + last_result = result_history[-1] + logger.info("Executing %s", self.name, trial=last_result.trial) + benchmark = last_result.benchmark + self.inspect_tool = ProjectContainerTool(benchmark, name="inspect") + self.inspect_tool.compile(extra_commands=" && rm -rf /out/* > /dev/null") + cur_round = 1 + build_result = BuildResult( + benchmark=benchmark, + trial=last_result.trial, + work_dirs=last_result.work_dirs, + author=self, + chat_history={self.name: ""}, + ) + prompt = self._initial_prompt(result_history) + try: + client = self.llm.get_chat_client(model=self.llm.get_model()) + while prompt and cur_round < self.max_round: + response = self.chat_llm( + cur_round, client=client, prompt=prompt, trial=last_result.trial + ) + prompt = self._container_tool_reaction( + cur_round, response, build_result + ) + cur_round += 1 + finally: + # Cleanup: stop and remove the container + logger.debug( + "Stopping and removing the inspect container %s", + self.inspect_tool.container_id, + trial=last_result.trial, + ) + self.inspect_tool.terminate() + return build_result diff --git a/agent/semantic_analyzer.py b/agent/semantic_analyzer.py index f92e0637f3..ff052dc780 100644 --- a/agent/semantic_analyzer.py +++ b/agent/semantic_analyzer.py @@ -25,270 +25,344 @@ from results import AnalysisResult, Result, RunResult # Regex for extract function name. -FUNC_NAME = re.compile(r'(?:^|\s|\b)([\w:]+::)*(\w+)(?:<[^>]*>)?(?=\(|$)') +FUNC_NAME = re.compile(r"(?:^|\s|\b)([\w:]+::)*(\w+)(?:<[^>]*>)?(?=\(|$)") # Regex for extract line number, -LINE_NUMBER = re.compile(r':(\d+):') +LINE_NUMBER = re.compile(r":(\d+):") LIBFUZZER_MODULES_LOADED_REGEX = re.compile( - r'^INFO:\s+Loaded\s+\d+\s+(modules|PC tables)\s+\((\d+)\s+.*\).*') -LIBFUZZER_COV_REGEX = re.compile(r'.*cov: (\d+) ft:') -LIBFUZZER_CRASH_TYPE_REGEX = re.compile(r'.*Test unit written to.*') -LIBFUZZER_COV_LINE_PREFIX = re.compile(r'^#(\d+)') -LIBFUZZER_STACK_FRAME_LINE_PREFIX = re.compile(r'^\s+#\d+') -CRASH_EXCLUSIONS = re.compile(r'.*(slow-unit-|timeout-|leak-|oom-).*') -CRASH_STACK_WITH_SOURCE_INFO = re.compile(r'in.*:\d+:\d+$') - -LIBFUZZER_LOG_STACK_FRAME_LLVM = '/src/llvm-project/compiler-rt' -LIBFUZZER_LOG_STACK_FRAME_LLVM2 = '/work/llvm-stage2/projects/compiler-rt' -LIBFUZZER_LOG_STACK_FRAME_CPP = '/usr/local/bin/../include/c++' + r"^INFO:\s+Loaded\s+\d+\s+(modules|PC tables)\s+\((\d+)\s+.*\).*" +) +LIBFUZZER_COV_REGEX = re.compile(r".*cov: (\d+) ft:") +LIBFUZZER_CRASH_TYPE_REGEX = re.compile(r".*Test unit written to.*") +LIBFUZZER_COV_LINE_PREFIX = re.compile(r"^#(\d+)") +LIBFUZZER_STACK_FRAME_LINE_PREFIX = re.compile(r"^\s+#\d+") +CRASH_EXCLUSIONS = re.compile(r".*(slow-unit-|timeout-|leak-|oom-).*") +CRASH_STACK_WITH_SOURCE_INFO = re.compile(r"in.*:\d+:\d+$") + +LIBFUZZER_LOG_STACK_FRAME_LLVM = "/src/llvm-project/compiler-rt" +LIBFUZZER_LOG_STACK_FRAME_LLVM2 = "/work/llvm-stage2/projects/compiler-rt" +LIBFUZZER_LOG_STACK_FRAME_CPP = "/usr/local/bin/../include/c++" EARLY_FUZZING_ROUND_THRESHOLD = 3 ParseResult = namedtuple( - 'ParseResult', - ['cov_pcs', 'total_pcs', 'crashes', 'crash_info', 'semantic_check_result']) + "ParseResult", + ["cov_pcs", "total_pcs", "crashes", "crash_info", "semantic_check_result"], +) class SemanticAnalyzer(BaseAgent): - """The Agent to generate a simple but valid fuzz target from scratch.""" - - def _initial_prompt(self, results: list[Result]) -> Prompt: - """Constructs initial prompt of the agent.""" - del results - return Prompt() - - def execute(self, result_history: list[Result]) -> AnalysisResult: - """Executes the agent based on previous result.""" - last_result = result_history[-1] - assert isinstance(last_result, RunResult) - - _, _, _, _, semantic_result = self._parse_libfuzzer_logs( - last_result.run_log, last_result.benchmark.project) - - analysis_result = AnalysisResult( - author=self, - run_result=last_result, - semantic_result=semantic_result, - chat_history={self.name: semantic_result.to_dict()}) - return analysis_result - - def _parse_libfuzzer_logs(self, - fuzzlog, - project_name: str, - check_cov_increase: bool = True) -> ParseResult: - """Parses libFuzzer logs.""" - lines = None - # Some crashes can mess up the libfuzzer output and raise decode error. - lines = fuzzlog.split('\n') - - cov_pcs, total_pcs, crashes = 0, 0, False - - for line in lines: - m = LIBFUZZER_MODULES_LOADED_REGEX.match(line) - if m: - total_pcs = int(m.group(2)) - continue - - m = LIBFUZZER_COV_REGEX.match(line) - if m: - cov_pcs = int(m.group(1)) - continue - - m = LIBFUZZER_CRASH_TYPE_REGEX.match(line) - if m and not CRASH_EXCLUSIONS.match(line): - # TODO(@happy-qop): Handling oom, slow cases in semantic checks & fix. - crashes = True - continue - - initcov, donecov, lastround = self._parse_fuzz_cov_info_from_libfuzzer_logs( - lines) - - # NOTE: Crashes from incorrect fuzz targets will not be counted finally. - - if crashes: - symptom = SemanticCheckResult.extract_symptom(fuzzlog) - crash_stacks = self._parse_stacks_from_libfuzzer_logs(lines) - crash_func = self._parse_func_from_stacks(project_name, crash_stacks) - crash_info = SemanticCheckResult.extract_crash_info(fuzzlog) - - # FP case 1: Common fuzz target errors. - # Null-deref, normally indicating inadequate parameter initialization or - # wrong function usage. - if symptom == 'null-deref': - return ParseResult( - cov_pcs, total_pcs, True, crash_info, - SemanticCheckResult(SemanticCheckResult.NULL_DEREF, symptom, - crash_stacks, crash_func)) - - # Signal, normally indicating assertion failure due to inadequate - # parameter initialization or wrong function usage. - if symptom == 'signal': - return ParseResult( - cov_pcs, total_pcs, True, crash_info, - SemanticCheckResult(SemanticCheckResult.SIGNAL, symptom, - crash_stacks, crash_func)) + """The Agent to generate a simple but valid fuzz target from scratch.""" + + def _initial_prompt(self, results: list[Result]) -> Prompt: + """Constructs initial prompt of the agent.""" + del results + return Prompt() + + def execute(self, result_history: list[Result]) -> AnalysisResult: + """Executes the agent based on previous result.""" + last_result = result_history[-1] + assert isinstance(last_result, RunResult) + + _, _, _, _, semantic_result = self._parse_libfuzzer_logs( + last_result.run_log, last_result.benchmark.project + ) + + analysis_result = AnalysisResult( + author=self, + run_result=last_result, + semantic_result=semantic_result, + chat_history={self.name: semantic_result.to_dict()}, + ) + return analysis_result + + def _parse_libfuzzer_logs( + self, fuzzlog, project_name: str, check_cov_increase: bool = True + ) -> ParseResult: + """Parses libFuzzer logs.""" + lines = None + # Some crashes can mess up the libfuzzer output and raise decode error. + lines = fuzzlog.split("\n") + + cov_pcs, total_pcs, crashes = 0, 0, False + + for line in lines: + m = LIBFUZZER_MODULES_LOADED_REGEX.match(line) + if m: + total_pcs = int(m.group(2)) + continue + + m = LIBFUZZER_COV_REGEX.match(line) + if m: + cov_pcs = int(m.group(1)) + continue + + m = LIBFUZZER_CRASH_TYPE_REGEX.match(line) + if m and not CRASH_EXCLUSIONS.match(line): + # TODO(@happy-qop): Handling oom, slow cases in semantic checks & fix. + crashes = True + continue + + initcov, donecov, lastround = self._parse_fuzz_cov_info_from_libfuzzer_logs( + lines + ) + + # NOTE: Crashes from incorrect fuzz targets will not be counted finally. + + if crashes: + symptom = SemanticCheckResult.extract_symptom(fuzzlog) + crash_stacks = self._parse_stacks_from_libfuzzer_logs(lines) + crash_func = self._parse_func_from_stacks(project_name, crash_stacks) + crash_info = SemanticCheckResult.extract_crash_info(fuzzlog) + + # FP case 1: Common fuzz target errors. + # Null-deref, normally indicating inadequate parameter initialization or + # wrong function usage. + if symptom == "null-deref": + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + SemanticCheckResult( + SemanticCheckResult.NULL_DEREF, + symptom, + crash_stacks, + crash_func, + ), + ) + + # Signal, normally indicating assertion failure due to inadequate + # parameter initialization or wrong function usage. + if symptom == "signal": + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + SemanticCheckResult( + SemanticCheckResult.SIGNAL, symptom, crash_stacks, crash_func + ), + ) + + # Exit, normally indicating the fuzz target exited in a controlled manner, + # blocking its bug discovery. + if symptom.endswith("fuzz target exited"): + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + SemanticCheckResult( + SemanticCheckResult.EXIT, symptom, crash_stacks, crash_func + ), + ) + + # Fuzz target modified constants. + if symptom.endswith("fuzz target overwrites its const input"): + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + SemanticCheckResult( + SemanticCheckResult.OVERWRITE_CONST, + symptom, + crash_stacks, + crash_func, + ), + ) + + # OOM, normally indicating malloc's parameter is too large, e.g., because + # of using parameter `size`. + # TODO(dongge): Refine this, 1) Merge this with the other oom case found + # from reproducer name; 2) Capture the actual number in (malloc(\d+)). + if "out-of-memory" in symptom or "out of memory" in symptom: + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + SemanticCheckResult( + SemanticCheckResult.FP_OOM, symptom, crash_stacks, crash_func + ), + ) + + # FP case 2: fuzz target crashes at init or first few rounds. + if lastround is None or lastround <= EARLY_FUZZING_ROUND_THRESHOLD: + # No cov line has been identified or only INITED round has been passed. + # This is very likely the false positive cases. + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + SemanticCheckResult( + SemanticCheckResult.FP_NEAR_INIT_CRASH, + symptom, + crash_stacks, + crash_func, + ), + ) + + # FP case 3: no func in 1st thread stack belongs to testing proj. + if len(crash_stacks) > 0: + first_stack = crash_stacks[0] + for stack_frame in first_stack: + if self._stack_func_is_of_testing_project(stack_frame): + if "LLVMFuzzerTestOneInput" in stack_frame: + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + SemanticCheckResult( + SemanticCheckResult.FP_TARGET_CRASH, + symptom, + crash_stacks, + crash_func, + ), + ) + break + + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + SemanticCheckResult( + SemanticCheckResult.NO_SEMANTIC_ERR, + symptom, + crash_stacks, + crash_func, + ), + ) + + if check_cov_increase and initcov == donecov and lastround is not None: + # Another error fuzz target case: no cov increase. + # A special case is initcov == donecov == None, which indicates no + # interesting inputs were found. This may happen if the target rejected + # all inputs we tried. + return ParseResult( + cov_pcs, + total_pcs, + False, + "", + SemanticCheckResult(SemanticCheckResult.NO_COV_INCREASE), + ) - # Exit, normally indicating the fuzz target exited in a controlled manner, - # blocking its bug discovery. - if symptom.endswith('fuzz target exited'): - return ParseResult( - cov_pcs, total_pcs, True, crash_info, - SemanticCheckResult(SemanticCheckResult.EXIT, symptom, crash_stacks, - crash_func)) - - # Fuzz target modified constants. - if symptom.endswith('fuzz target overwrites its const input'): - return ParseResult( - cov_pcs, total_pcs, True, crash_info, - SemanticCheckResult(SemanticCheckResult.OVERWRITE_CONST, symptom, - crash_stacks, crash_func)) - - # OOM, normally indicating malloc's parameter is too large, e.g., because - # of using parameter `size`. - # TODO(dongge): Refine this, 1) Merge this with the other oom case found - # from reproducer name; 2) Capture the actual number in (malloc(\d+)). - if 'out-of-memory' in symptom or 'out of memory' in symptom: - return ParseResult( - cov_pcs, total_pcs, True, crash_info, - SemanticCheckResult(SemanticCheckResult.FP_OOM, symptom, - crash_stacks, crash_func)) - - # FP case 2: fuzz target crashes at init or first few rounds. - if lastround is None or lastround <= EARLY_FUZZING_ROUND_THRESHOLD: - # No cov line has been identified or only INITED round has been passed. - # This is very likely the false positive cases. return ParseResult( - cov_pcs, total_pcs, True, crash_info, - SemanticCheckResult(SemanticCheckResult.FP_NEAR_INIT_CRASH, symptom, - crash_stacks, crash_func)) - - # FP case 3: no func in 1st thread stack belongs to testing proj. - if len(crash_stacks) > 0: - first_stack = crash_stacks[0] - for stack_frame in first_stack: - if self._stack_func_is_of_testing_project(stack_frame): - if 'LLVMFuzzerTestOneInput' in stack_frame: - return ParseResult( - cov_pcs, total_pcs, True, crash_info, - SemanticCheckResult(SemanticCheckResult.FP_TARGET_CRASH, - symptom, crash_stacks, crash_func)) - break - - return ParseResult( - cov_pcs, total_pcs, True, crash_info, - SemanticCheckResult(SemanticCheckResult.NO_SEMANTIC_ERR, symptom, - crash_stacks, crash_func)) - - if check_cov_increase and initcov == donecov and lastround is not None: - # Another error fuzz target case: no cov increase. - # A special case is initcov == donecov == None, which indicates no - # interesting inputs were found. This may happen if the target rejected - # all inputs we tried. - return ParseResult( - cov_pcs, total_pcs, False, '', - SemanticCheckResult(SemanticCheckResult.NO_COV_INCREASE)) - - return ParseResult(cov_pcs, total_pcs, crashes, '', - SemanticCheckResult(SemanticCheckResult.NO_SEMANTIC_ERR)) - - def _parse_fuzz_cov_info_from_libfuzzer_logs( - self, - lines: list[str]) -> tuple[Optional[int], Optional[int], Optional[int]]: - """Parses cov of INITED & DONE, and round number from libFuzzer logs.""" - initcov, donecov, lastround = None, None, None - - for line in lines: - if line.startswith('#'): - # Parses cov line to get the round number. - match = LIBFUZZER_COV_LINE_PREFIX.match(line) - roundno = int(match.group(1)) if match else None - - if roundno is not None: - lastround = roundno - if 'INITED' in line and 'cov: ' in line: - initcov = int(line.split('cov: ')[1].split(' ft:')[0]) - elif 'DONE' in line and 'cov: ' in line: - donecov = int(line.split('cov: ')[1].split(' ft:')[0]) - - return initcov, donecov, lastround - - def _stack_func_is_of_testing_project(self, stack_frame: str) -> bool: - return (bool(CRASH_STACK_WITH_SOURCE_INFO.match(stack_frame)) and - LIBFUZZER_LOG_STACK_FRAME_LLVM not in stack_frame and - LIBFUZZER_LOG_STACK_FRAME_LLVM2 not in stack_frame and - LIBFUZZER_LOG_STACK_FRAME_CPP not in stack_frame) - - def _parse_stacks_from_libfuzzer_logs(self, - lines: list[str]) -> list[list[str]]: - """Parses stack traces from libFuzzer logs.""" - # TODO (dongge): Use stack parsing from ClusterFuzz. - # There can have over one thread stack in a log. - stacks = [] - - # A stack -> a sequence of stack frame lines. - stack, stack_parsing = [], False - for line in lines: - is_stack_frame_line = LIBFUZZER_STACK_FRAME_LINE_PREFIX.match( - line) is not None - if (not stack_parsing) and is_stack_frame_line: - # First line. - stack_parsing = True - stack = [line.strip()] - elif stack_parsing and is_stack_frame_line: - # Middle line(s). - stack.append(line.strip()) - elif stack_parsing and (not is_stack_frame_line): - # Last line. - stack_parsing = False - stacks.append(stack) - - # Last stack. - if stack_parsing: - stacks.append(stack) - - return stacks - - def _parse_func_from_stacks(self, project_name: str, - stacks: list[list[str]]) -> dict: - """Parses project functions from stack traces.""" - func_info = defaultdict(set) - - for stack in stacks: - for line in stack: - # Use 3 spaces to divide each line of crash info into four parts. - # Only parse the fourth part, which includes the function name, - # file path, and line number. - parts = line.split(' ', 3) - if len(parts) < 4: - continue - func_and_file_path = parts[3] - if project_name not in func_and_file_path: - continue - func_name, _, file_path = func_and_file_path.partition(' /') - if func_name == 'LLVMFuzzerTestOneInput': - line_match = LINE_NUMBER.search(file_path) - if line_match: - line_number = int(line_match.group(1)) - func_info[func_name].add(line_number) - else: - logger.warning('Failed to parse line number from %s in project %s', - func_name, - project_name, - trial=self.trial) - break - if project_name in file_path: - func_match = FUNC_NAME.search(func_name) - line_match = LINE_NUMBER.search(file_path) - if func_match and line_match: - func_name = func_match.group(2) - line_number = int(line_match.group(1)) - func_info[func_name].add(line_number) - else: - logger.warning( - 'Failed to parse function name from %s in project %s', - func_name, - project_name, - trial=self.trial) - - return { - func_name: list(line_numbers) - for func_name, line_numbers in func_info.items() - } + cov_pcs, + total_pcs, + crashes, + "", + SemanticCheckResult(SemanticCheckResult.NO_SEMANTIC_ERR), + ) + + def _parse_fuzz_cov_info_from_libfuzzer_logs( + self, lines: list[str] + ) -> tuple[Optional[int], Optional[int], Optional[int]]: + """Parses cov of INITED & DONE, and round number from libFuzzer logs.""" + initcov, donecov, lastround = None, None, None + + for line in lines: + if line.startswith("#"): + # Parses cov line to get the round number. + match = LIBFUZZER_COV_LINE_PREFIX.match(line) + roundno = int(match.group(1)) if match else None + + if roundno is not None: + lastround = roundno + if "INITED" in line and "cov: " in line: + initcov = int(line.split("cov: ")[1].split(" ft:")[0]) + elif "DONE" in line and "cov: " in line: + donecov = int(line.split("cov: ")[1].split(" ft:")[0]) + + return initcov, donecov, lastround + + def _stack_func_is_of_testing_project(self, stack_frame: str) -> bool: + return ( + bool(CRASH_STACK_WITH_SOURCE_INFO.match(stack_frame)) + and LIBFUZZER_LOG_STACK_FRAME_LLVM not in stack_frame + and LIBFUZZER_LOG_STACK_FRAME_LLVM2 not in stack_frame + and LIBFUZZER_LOG_STACK_FRAME_CPP not in stack_frame + ) + + def _parse_stacks_from_libfuzzer_logs(self, lines: list[str]) -> list[list[str]]: + """Parses stack traces from libFuzzer logs.""" + # TODO (dongge): Use stack parsing from ClusterFuzz. + # There can have over one thread stack in a log. + stacks = [] + + # A stack -> a sequence of stack frame lines. + stack, stack_parsing = [], False + for line in lines: + is_stack_frame_line = ( + LIBFUZZER_STACK_FRAME_LINE_PREFIX.match(line) is not None + ) + if (not stack_parsing) and is_stack_frame_line: + # First line. + stack_parsing = True + stack = [line.strip()] + elif stack_parsing and is_stack_frame_line: + # Middle line(s). + stack.append(line.strip()) + elif stack_parsing and (not is_stack_frame_line): + # Last line. + stack_parsing = False + stacks.append(stack) + + # Last stack. + if stack_parsing: + stacks.append(stack) + + return stacks + + def _parse_func_from_stacks( + self, project_name: str, stacks: list[list[str]] + ) -> dict: + """Parses project functions from stack traces.""" + func_info = defaultdict(set) + + for stack in stacks: + for line in stack: + # Use 3 spaces to divide each line of crash info into four parts. + # Only parse the fourth part, which includes the function name, + # file path, and line number. + parts = line.split(" ", 3) + if len(parts) < 4: + continue + func_and_file_path = parts[3] + if project_name not in func_and_file_path: + continue + func_name, _, file_path = func_and_file_path.partition(" /") + if func_name == "LLVMFuzzerTestOneInput": + line_match = LINE_NUMBER.search(file_path) + if line_match: + line_number = int(line_match.group(1)) + func_info[func_name].add(line_number) + else: + logger.warning( + "Failed to parse line number from %s in project %s", + func_name, + project_name, + trial=self.trial, + ) + break + if project_name in file_path: + func_match = FUNC_NAME.search(func_name) + line_match = LINE_NUMBER.search(file_path) + if func_match and line_match: + func_name = func_match.group(2) + line_number = int(line_match.group(1)) + func_info[func_name].add(line_number) + else: + logger.warning( + "Failed to parse function name from %s in project %s", + func_name, + project_name, + trial=self.trial, + ) + + return { + func_name: list(line_numbers) + for func_name, line_numbers in func_info.items() + } diff --git a/agent_tests/function_analyzer_test.py b/agent_tests/function_analyzer_test.py index 8b2c2af73f..fe04d63007 100644 --- a/agent_tests/function_analyzer_test.py +++ b/agent_tests/function_analyzer_test.py @@ -25,58 +25,54 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -RESULTS_DIR = './results' +RESULTS_DIR = "./results" def parse_args() -> argparse.Namespace: - """Parses command line arguments.""" - parser = argparse.ArgumentParser( - description='Evaluate the function analyzer agent.') + """Parses command line arguments.""" + parser = argparse.ArgumentParser( + description="Evaluate the function analyzer agent." + ) - parser.add_argument('-y', - '--benchmark-yaml', - type=str, - required=True, - help='A benchmark YAML file.') + parser.add_argument( + "-y", "--benchmark-yaml", type=str, required=True, help="A benchmark YAML file." + ) - parser.add_argument('-w', '--work-dir', default=RESULTS_DIR) + parser.add_argument("-w", "--work-dir", default=RESULTS_DIR) - parser.add_argument('-mr', - '--max-round', - type=int, - default=100, - help='Max trial round for agents.') + parser.add_argument( + "-mr", "--max-round", type=int, default=100, help="Max trial round for agents." + ) - parsed_args = parser.parse_args() + parsed_args = parser.parse_args() - return parsed_args + return parsed_args if __name__ == "__main__": - model = models.LLM.setup(ai_binary='', name='vertex_ai_gemini-1-5-chat') + model = models.LLM.setup(ai_binary="", name="vertex_ai_gemini-1-5-chat") - args = parse_args() + args = parse_args() - function_analyzer = FunctionAnalyzer(trial=1, llm=model, args=args) + function_analyzer = FunctionAnalyzer(trial=1, llm=model, args=args) - benchmarks: List[Benchmark] = benchmarklib.Benchmark.from_yaml( - args.benchmark_yaml) + benchmarks: List[Benchmark] = benchmarklib.Benchmark.from_yaml(args.benchmark_yaml) - if len(benchmarks) == 0: - raise ValueError("No benchmarks found in the YAML file.") + if len(benchmarks) == 0: + raise ValueError("No benchmarks found in the YAML file.") - test_benchmark = benchmarks[0] - logger.info("Loaded benchmark for function: %s", test_benchmark.function_name) + test_benchmark = benchmarks[0] + logger.info("Loaded benchmark for function: %s", test_benchmark.function_name) - # Initialize the function analyzer with the first benchmark - function_analyzer.initialize(test_benchmark) + # Initialize the function analyzer with the first benchmark + function_analyzer.initialize(test_benchmark) - # Run the function analyzer - result = function_analyzer.execute([]) + # Run the function analyzer + result = function_analyzer.execute([]) - # Print the result - logger.info("Function Analyzer Result:") - logger.info("Result available: %s", result.result_available) - if result.result_available: - logger.info("Requirements: %s", result.requirements) + # Print the result + logger.info("Function Analyzer Result:") + logger.info("Result available: %s", result.result_available) + if result.result_available: + logger.info("Requirements: %s", result.requirements) diff --git a/ci/ci_trial_build.py b/ci/ci_trial_build.py index a840539d2d..4d099b8c9f 100644 --- a/ci/ci_trial_build.py +++ b/ci/ci_trial_build.py @@ -23,66 +23,66 @@ import github # type: ignore import request_pr_exp -TRIGGER_COMMAND = '/gcbrun' -TRIAL_BUILD_COMMAND_STR = f'{TRIGGER_COMMAND} exp ' -SKIP_COMMAND_STR = f'{TRIGGER_COMMAND} skip' +TRIGGER_COMMAND = "/gcbrun" +TRIAL_BUILD_COMMAND_STR = f"{TRIGGER_COMMAND} exp " +SKIP_COMMAND_STR = f"{TRIGGER_COMMAND} skip" def get_comments(pull_request_number): - """Returns comments on the GitHub Pull request referenced by - |pull_request_number|.""" - github_obj = github.Github() - repo = github_obj.get_repo('google/oss-fuzz-gen') - pull = repo.get_pull(pull_request_number) - pull_comments = list(pull.get_comments()) - issue = repo.get_issue(pull_request_number) - issue_comments = list(issue.get_comments()) - # Github only returns comments if from the pull object when a pull request is - # open. If it is a draft, it will only return comments from the issue object. - return pull_comments + issue_comments + """Returns comments on the GitHub Pull request referenced by + |pull_request_number|.""" + github_obj = github.Github() + repo = github_obj.get_repo("google/oss-fuzz-gen") + pull = repo.get_pull(pull_request_number) + pull_comments = list(pull.get_comments()) + issue = repo.get_issue(pull_request_number) + issue_comments = list(issue.get_comments()) + # Github only returns comments if from the pull object when a pull request is + # open. If it is a draft, it will only return comments from the issue object. + return pull_comments + issue_comments def get_latest_gcbrun_command(comments): - """Gets the last /gcbrun comment from comments.""" - for comment in reversed(comments): - # This seems to get comments on code too. - body = comment.body - if body.startswith(SKIP_COMMAND_STR): - return None - if not body.startswith(TRIAL_BUILD_COMMAND_STR): - continue - if len(body) == len(TRIAL_BUILD_COMMAND_STR): - return None - return body[len(TRIAL_BUILD_COMMAND_STR):].strip().split(' ') - return None + """Gets the last /gcbrun comment from comments.""" + for comment in reversed(comments): + # This seems to get comments on code too. + body = comment.body + if body.startswith(SKIP_COMMAND_STR): + return None + if not body.startswith(TRIAL_BUILD_COMMAND_STR): + continue + if len(body) == len(TRIAL_BUILD_COMMAND_STR): + return None + return body[len(TRIAL_BUILD_COMMAND_STR) :].strip().split(" ") + return None def exec_command_from_github(pull_request_number): - """Executes the gcbrun command for trial_build.py in the most recent command - on |pull_request_number|.""" - comments = get_comments(pull_request_number) - command = get_latest_gcbrun_command(comments) - if command is None: - logging.info('Trial build not requested.') - return None + """Executes the gcbrun command for trial_build.py in the most recent command + on |pull_request_number|.""" + comments = get_comments(pull_request_number) + command = get_latest_gcbrun_command(comments) + if command is None: + logging.info("Trial build not requested.") + return None - # Set the branch so that the trial_build builds the projects from the PR - # branch. - command = ['-p', str(pull_request_number)] + command - command = [c for c in command if c] - logging.info('Command: %s.', command) - return request_pr_exp.main(command) + # Set the branch so that the trial_build builds the projects from the PR + # branch. + command = ["-p", str(pull_request_number)] + command + command = [c for c in command if c] + logging.info("Command: %s.", command) + return request_pr_exp.main(command) def main(): - """Entrypoint for GitHub CI into trial_build.py""" - logging.basicConfig(level=logging.INFO) - pull_request_number = int(os.environ['PULL_REQUEST_NUMBER']) - result = exec_command_from_github(pull_request_number) - if result or result is None: - return 0 - return 1 + """Entrypoint for GitHub CI into trial_build.py""" + logging.basicConfig(level=logging.INFO) + pull_request_number = int(os.environ["PULL_REQUEST_NUMBER"]) + result = exec_command_from_github(pull_request_number) + if result or result is None: + return 0 + return 1 -if __name__ == '__main__': - sys.exit(main()) +if __name__ == "__main__": + sys.exit(main()) diff --git a/ci/request_pr_exp.py b/ci/request_pr_exp.py index a24b27d42f..dcc3eb692a 100644 --- a/ci/request_pr_exp.py +++ b/ci/request_pr_exp.py @@ -35,16 +35,17 @@ # Configure logging to display all messages at or above INFO level logging.basicConfig(level=logging.INFO) -DEFAULT_CLUSTER = 'llm-experiment' -LARGE_CLUSTER = 'llm-experiment-large' -DEFAULT_LOCATION = 'us-central1-c' -LARGE_LOCATION = 'us-central1' -TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), 'k8s', 'pr-exp.yaml') -LARGE_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), 'k8s', - 'large-pr-exp.yaml') -BENCHMARK_SET = 'comparison' -LLM_NAME = 'vertex_ai_gemini-1-5' -LLM_CHAT_NAME = 'vertex_ai_gemini-1-5-chat' +DEFAULT_CLUSTER = "llm-experiment" +LARGE_CLUSTER = "llm-experiment-large" +DEFAULT_LOCATION = "us-central1-c" +LARGE_LOCATION = "us-central1" +TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "k8s", "pr-exp.yaml") +LARGE_TEMPLATE_PATH = os.path.join( + os.path.dirname(__file__), "k8s", "large-pr-exp.yaml" +) +BENCHMARK_SET = "comparison" +LLM_NAME = "vertex_ai_gemini-1-5" +LLM_CHAT_NAME = "vertex_ai_gemini-1-5-chat" EXP_DELAY = 0 FUZZING_TIMEOUT = 300 REQUEST_CPU = 6 @@ -56,339 +57,379 @@ VARY_TEMPERATURE = True MAX_ROUND = 100 -PR_LINK_PREFIX = 'https://github.com/google/oss-fuzz-gen/pull' -JOB_LINK_PREFIX = ('https://console.cloud.google.com/kubernetes/job/' - '//default') -REPORT_LINK_PREFIX = 'https://llm-exp.oss-fuzz.com/Result-reports/ofg-pr' +PR_LINK_PREFIX = "https://github.com/google/oss-fuzz-gen/pull" +JOB_LINK_PREFIX = ( + "https://console.cloud.google.com/kubernetes/job/" "//default" +) +REPORT_LINK_PREFIX = "https://llm-exp.oss-fuzz.com/Result-reports/ofg-pr" # Use storage.cloud.google.com if we want to give external access. -BUCKET_LINK_PREFIX = ('https://console.cloud.google.com/storage/browser/' - 'oss-fuzz-gcb-experiment-run-logs/Result-reports/ofg-pr') -BUCKET_GS_LINK_PREFIX = ( - 'gs://oss-fuzz-gcb-experiment-run-logs/Result-reports/ofg-pr') +BUCKET_LINK_PREFIX = ( + "https://console.cloud.google.com/storage/browser/" + "oss-fuzz-gcb-experiment-run-logs/Result-reports/ofg-pr" +) +BUCKET_GS_LINK_PREFIX = "gs://oss-fuzz-gcb-experiment-run-logs/Result-reports/ofg-pr" -DEFAULT_VERTEX_AI_LOCATION = 'us-central1' +DEFAULT_VERTEX_AI_LOCATION = "us-central1" VERTEX_AI_LOCATIONS = { - 'vertex_ai_gemini-pro': - 'asia-east1,asia-east2,asia-northeast1,asia-northeast3,asia-south1,asia-southeast1,australia-southeast1,europe-central2,europe-north1,europe-southwest1,europe-west1,europe-west2,europe-west3,europe-west4,europe-west6,europe-west8,europe-west9,southamerica-east1,us-central1,us-east1,us-east4,us-east5,us-south1,us-west1,us-west4', - 'vertex_ai_gemini-ultra': - 'asia-east1,asia-east2,asia-northeast1,asia-northeast3,asia-south1,asia-southeast1,australia-southeast1,europe-central2,europe-north1,europe-southwest1,europe-west1,europe-west2,europe-west3,europe-west4,europe-west6,europe-west8,europe-west9,southamerica-east1,us-central1,us-east1,us-east4,us-east5,us-south1,us-west1,us-west4', - 'vertex_ai_gemini-1-5': - 'asia-east1,asia-east2,asia-northeast1,asia-northeast3,asia-south1,asia-southeast1,australia-southeast1,europe-central2,europe-north1,europe-southwest1,europe-west1,europe-west2,europe-west3,europe-west4,europe-west6,europe-west8,europe-west9,southamerica-east1,us-central1,us-east1,us-east4,us-east5,us-south1,us-west1,us-west4', - 'vertex_ai_gemini-1-5-chat': - 'asia-east1,asia-east2,asia-northeast1,asia-northeast3,asia-south1,asia-southeast1,australia-southeast1,europe-central2,europe-north1,europe-southwest1,europe-west1,europe-west2,europe-west3,europe-west4,europe-west6,europe-west8,europe-west9,southamerica-east1,us-central1,us-east1,us-east4,us-east5,us-south1,us-west1,us-west4', - 'vertex_ai_gemini-2-flash': - 'europe-central2,europe-north1,europe-southwest1,europe-west1,europe-west4,europe-west8,europe-west9,us-central1,us-east1,us-east4,us-east5,us-south1,us-west1,us-west4', - 'vertex_ai_gemini-2-flash-chat': - 'europe-central2,europe-north1,europe-southwest1,europe-west1,europe-west4,europe-west8,europe-west9,us-central1,us-east1,us-east4,us-east5,us-south1,us-west1,us-west4' + "vertex_ai_gemini-pro": "asia-east1,asia-east2,asia-northeast1,asia-northeast3,asia-south1,asia-southeast1,australia-southeast1,europe-central2,europe-north1,europe-southwest1,europe-west1,europe-west2,europe-west3,europe-west4,europe-west6,europe-west8,europe-west9,southamerica-east1,us-central1,us-east1,us-east4,us-east5,us-south1,us-west1,us-west4", + "vertex_ai_gemini-ultra": "asia-east1,asia-east2,asia-northeast1,asia-northeast3,asia-south1,asia-southeast1,australia-southeast1,europe-central2,europe-north1,europe-southwest1,europe-west1,europe-west2,europe-west3,europe-west4,europe-west6,europe-west8,europe-west9,southamerica-east1,us-central1,us-east1,us-east4,us-east5,us-south1,us-west1,us-west4", + "vertex_ai_gemini-1-5": "asia-east1,asia-east2,asia-northeast1,asia-northeast3,asia-south1,asia-southeast1,australia-southeast1,europe-central2,europe-north1,europe-southwest1,europe-west1,europe-west2,europe-west3,europe-west4,europe-west6,europe-west8,europe-west9,southamerica-east1,us-central1,us-east1,us-east4,us-east5,us-south1,us-west1,us-west4", + "vertex_ai_gemini-1-5-chat": "asia-east1,asia-east2,asia-northeast1,asia-northeast3,asia-south1,asia-southeast1,australia-southeast1,europe-central2,europe-north1,europe-southwest1,europe-west1,europe-west2,europe-west3,europe-west4,europe-west6,europe-west8,europe-west9,southamerica-east1,us-central1,us-east1,us-east4,us-east5,us-south1,us-west1,us-west4", + "vertex_ai_gemini-2-flash": "europe-central2,europe-north1,europe-southwest1,europe-west1,europe-west4,europe-west8,europe-west9,us-central1,us-east1,us-east4,us-east5,us-south1,us-west1,us-west4", + "vertex_ai_gemini-2-flash-chat": "europe-central2,europe-north1,europe-southwest1,europe-west1,europe-west4,europe-west8,europe-west9,us-central1,us-east1,us-east4,us-east5,us-south1,us-west1,us-west4", } def _parse_args(cmd) -> argparse.Namespace: - """Parses the command line arguments.""" - parser = argparse.ArgumentParser( - description= - 'Requests a GKE experiment with the given PR ID from OSS-Fuzz-Gen.') - parser.add_argument( - '-c', - '--cluster', - type=str, - default=DEFAULT_CLUSTER, - help=f'The cluster name to run GKE jobs, default: {DEFAULT_CLUSTER}') - parser.add_argument( - '-l', - '--location', - type=str, - default=DEFAULT_LOCATION, - help=f'The cluster location to run GKE jobs, default: {DEFAULT_LOCATION}') - parser.add_argument( - '-t', - '--gke-template', - type=str, - default=TEMPLATE_PATH, - help=f'The template to request GKE job, default: {TEMPLATE_PATH}') - parser.add_argument( - '-p', - '--pr-id', - type=int, - required=True, - help='The PR ID from OSS-Fuzz-Gen. Wait until the CI finishes building.') - parser.add_argument( - '-n', - '--name-suffix', - required=True, - type=str, - help=('Experiment name suffix (e.g., your name), this will be used in ' - 'GKE job and result report.')) - parser.add_argument( - '-b', - '--benchmark-set', - type=str, - default=BENCHMARK_SET, - help=f'Experiment benchmark set, default: {BENCHMARK_SET}.') - parser.add_argument('-m', - '--llm', - type=str, - default=LLM_NAME, - help=f'Large Language Model name, default: {LLM_NAME}.') - parser.add_argument( - '-ll', - '--llm-locations', - type=str, - help=( - 'Comma-separated list of locations where the LLM is available. ' - 'If not provided, default locations will be used based on the LLM.')) - parser.add_argument( - '-d', - '--delay', - type=int, - default=EXP_DELAY, - help=('Delay each benchmark experiment by N seconds, default: ' - f'{EXP_DELAY}.')) - parser.add_argument( - '-to', - '--fuzzing-timeout', - type=int, - default=FUZZING_TIMEOUT, - help=f'Fuzzing timeout in seconds, default: {FUZZING_TIMEOUT} seconds.') - parser.add_argument( - '-rc', - '--request-cpus', - type=int, - default=REQUEST_CPU, - help=f'CPU requested for experiment, default: {REQUEST_CPU}.') - parser.add_argument( - '-rm', - '--request-memory', - type=int, - default=REQUEST_MEM, - help=f'Memory requested for experiment in Gi, default: {REQUEST_MEM} Gi.') - parser.add_argument( - '-i', - '--local-introspector', - action='store_true', - help='If set will use a local version of fuzz introspector\'s webapp') - parser.add_argument( - '-ns', - '--num-samples', - type=int, - default=NUM_SAMPLES, - help=f'The number of samples to request from LLM, default: {NUM_SAMPLES}') - parser.add_argument( - '-nf', - '--llm-fix-limit', - type=int, - default=NUM_FIXES, - help=f'The number of fixes to request from LLM, default: {NUM_FIXES}') - parser.add_argument( - '-vt', - '--vary-temperature', - type=bool, - default=VARY_TEMPERATURE, - help=('Use different temperatures for each sample, default: ' - f'{VARY_TEMPERATURE}')) - parser.add_argument('-mr', - '--max-round', - type=int, - default=MAX_ROUND, - help=f'Max trial round for agents, default: {MAX_ROUND}.') - parser.add_argument('-ag', - '--agent', - action='store_true', - default=False, - help='Enables agent enhancement.') - parser.add_argument('-lg', - '--large', - action='store_true', - default=False, - help=('(Use sparingly) Do a large experiment with ' - 'many more cores available.')) - parser.add_argument('-rd', - '--redirect-outs', - action='store_true', - default=False, - help='Redirects experiments stdout/stderr to file') - - # Allow piping arbitrary args to run_all_experiments.py - args, additional_args = parser.parse_known_args(cmd) - args.additional_args = additional_args - - assert os.path.isfile( - args.gke_template), (f'GKE template does not exist: {args.gke_template}') - - # Construct experiment name and save it under args for simplicity. - args.experiment_name = f'{args.pr_id}' - if args.name_suffix: - args.experiment_name = f'{args.experiment_name}-{args.name_suffix}' - - # Use Chat model by default in agent-enhance experiments. - if args.agent and args.llm == LLM_NAME: - args.llm = LLM_CHAT_NAME - - if not args.llm_locations: - args.llm_locations = VERTEX_AI_LOCATIONS.get(args.llm, - DEFAULT_VERTEX_AI_LOCATION) - - if args.large: - args.location = LARGE_LOCATION - args.cluster = LARGE_CLUSTER - args.request_cpus = LARGE_REQUEST_CPU - args.request_memory = LARGE_REQUEST_MEM - args.gke_template = LARGE_TEMPLATE_PATH - - if (args.max_round == 100 and - any(args.name_suffix.startswith(suffix) for suffix in ['ascc-', 'dgk-'])): - args.max_round = 10 - - if args.additional_args: - logging.info("Additional args: %s", args.additional_args) - - return args - - -def _remove_existing_job_bucket(gke_job_name: str, bucket_link: str, - bucket_gs_link: str): - """Removes existing GKE job and gcloud bucket.""" - logging.info('Deleting GKE job: %s', gke_job_name) - del_job = sp.run(['kubectl', 'delete', 'job', gke_job_name], - stdin=sp.DEVNULL, - stdout=sp.PIPE, - stderr=sp.PIPE, - check=False) - if del_job.returncode: - stdout = del_job.stdout.decode('utf-8') - stderr = del_job.stderr.decode('utf-8') - if 'Error from server (NotFound)' in stderr: - logging.warning(stderr) - else: - logging.error('Failed to delete GKE job: %s.', gke_job_name) - logging.error('STDOUT:\n %s', stdout) - logging.error('STDERR:\n %s', stderr) - sys.exit(1) - - # Wait for 5 seconds to ensure job is deleted and not writing to bucket. - time.sleep(5) - - logging.info('Deleting gcloud bucket: %s', bucket_gs_link) - del_bucket = sp.run(['gsutil', '-m', 'rm', '-r', bucket_gs_link], - stdin=sp.DEVNULL, - stdout=sp.PIPE, - stderr=sp.PIPE, - check=False) - if del_bucket.returncode: - logging.error('Failed to rm gcloud bucket directory:\n %s', bucket_link) - logging.error('STDOUT:\n %s', del_bucket.stdout.decode('utf-8')) - logging.error('STDERR:\n %s', del_bucket.stderr.decode('utf-8')) + """Parses the command line arguments.""" + parser = argparse.ArgumentParser( + description="Requests a GKE experiment with the given PR ID from OSS-Fuzz-Gen." + ) + parser.add_argument( + "-c", + "--cluster", + type=str, + default=DEFAULT_CLUSTER, + help=f"The cluster name to run GKE jobs, default: {DEFAULT_CLUSTER}", + ) + parser.add_argument( + "-l", + "--location", + type=str, + default=DEFAULT_LOCATION, + help=f"The cluster location to run GKE jobs, default: {DEFAULT_LOCATION}", + ) + parser.add_argument( + "-t", + "--gke-template", + type=str, + default=TEMPLATE_PATH, + help=f"The template to request GKE job, default: {TEMPLATE_PATH}", + ) + parser.add_argument( + "-p", + "--pr-id", + type=int, + required=True, + help="The PR ID from OSS-Fuzz-Gen. Wait until the CI finishes building.", + ) + parser.add_argument( + "-n", + "--name-suffix", + required=True, + type=str, + help=( + "Experiment name suffix (e.g., your name), this will be used in " + "GKE job and result report." + ), + ) + parser.add_argument( + "-b", + "--benchmark-set", + type=str, + default=BENCHMARK_SET, + help=f"Experiment benchmark set, default: {BENCHMARK_SET}.", + ) + parser.add_argument( + "-m", + "--llm", + type=str, + default=LLM_NAME, + help=f"Large Language Model name, default: {LLM_NAME}.", + ) + parser.add_argument( + "-ll", + "--llm-locations", + type=str, + help=( + "Comma-separated list of locations where the LLM is available. " + "If not provided, default locations will be used based on the LLM." + ), + ) + parser.add_argument( + "-d", + "--delay", + type=int, + default=EXP_DELAY, + help=( + "Delay each benchmark experiment by N seconds, default: " f"{EXP_DELAY}." + ), + ) + parser.add_argument( + "-to", + "--fuzzing-timeout", + type=int, + default=FUZZING_TIMEOUT, + help=f"Fuzzing timeout in seconds, default: {FUZZING_TIMEOUT} seconds.", + ) + parser.add_argument( + "-rc", + "--request-cpus", + type=int, + default=REQUEST_CPU, + help=f"CPU requested for experiment, default: {REQUEST_CPU}.", + ) + parser.add_argument( + "-rm", + "--request-memory", + type=int, + default=REQUEST_MEM, + help=f"Memory requested for experiment in Gi, default: {REQUEST_MEM} Gi.", + ) + parser.add_argument( + "-i", + "--local-introspector", + action="store_true", + help="If set will use a local version of fuzz introspector's webapp", + ) + parser.add_argument( + "-ns", + "--num-samples", + type=int, + default=NUM_SAMPLES, + help=f"The number of samples to request from LLM, default: {NUM_SAMPLES}", + ) + parser.add_argument( + "-nf", + "--llm-fix-limit", + type=int, + default=NUM_FIXES, + help=f"The number of fixes to request from LLM, default: {NUM_FIXES}", + ) + parser.add_argument( + "-vt", + "--vary-temperature", + type=bool, + default=VARY_TEMPERATURE, + help=( + "Use different temperatures for each sample, default: " + f"{VARY_TEMPERATURE}" + ), + ) + parser.add_argument( + "-mr", + "--max-round", + type=int, + default=MAX_ROUND, + help=f"Max trial round for agents, default: {MAX_ROUND}.", + ) + parser.add_argument( + "-ag", + "--agent", + action="store_true", + default=False, + help="Enables agent enhancement.", + ) + parser.add_argument( + "-lg", + "--large", + action="store_true", + default=False, + help=( + "(Use sparingly) Do a large experiment with " "many more cores available." + ), + ) + parser.add_argument( + "-rd", + "--redirect-outs", + action="store_true", + default=False, + help="Redirects experiments stdout/stderr to file", + ) + + # Allow piping arbitrary args to run_all_experiments.py + args, additional_args = parser.parse_known_args(cmd) + args.additional_args = additional_args + + assert os.path.isfile( + args.gke_template + ), f"GKE template does not exist: {args.gke_template}" + + # Construct experiment name and save it under args for simplicity. + args.experiment_name = f"{args.pr_id}" + if args.name_suffix: + args.experiment_name = f"{args.experiment_name}-{args.name_suffix}" + + # Use Chat model by default in agent-enhance experiments. + if args.agent and args.llm == LLM_NAME: + args.llm = LLM_CHAT_NAME + + if not args.llm_locations: + args.llm_locations = VERTEX_AI_LOCATIONS.get( + args.llm, DEFAULT_VERTEX_AI_LOCATION + ) + + if args.large: + args.location = LARGE_LOCATION + args.cluster = LARGE_CLUSTER + args.request_cpus = LARGE_REQUEST_CPU + args.request_memory = LARGE_REQUEST_MEM + args.gke_template = LARGE_TEMPLATE_PATH + + if args.max_round == 100 and any( + args.name_suffix.startswith(suffix) for suffix in ["ascc-", "dgk-"] + ): + args.max_round = 10 + + if args.additional_args: + logging.info("Additional args: %s", args.additional_args) + + return args + + +def _remove_existing_job_bucket( + gke_job_name: str, bucket_link: str, bucket_gs_link: str +): + """Removes existing GKE job and gcloud bucket.""" + logging.info("Deleting GKE job: %s", gke_job_name) + del_job = sp.run( + ["kubectl", "delete", "job", gke_job_name], + stdin=sp.DEVNULL, + stdout=sp.PIPE, + stderr=sp.PIPE, + check=False, + ) + if del_job.returncode: + stdout = del_job.stdout.decode("utf-8") + stderr = del_job.stderr.decode("utf-8") + if "Error from server (NotFound)" in stderr: + logging.warning(stderr) + else: + logging.error("Failed to delete GKE job: %s.", gke_job_name) + logging.error("STDOUT:\n %s", stdout) + logging.error("STDERR:\n %s", stderr) + sys.exit(1) + + # Wait for 5 seconds to ensure job is deleted and not writing to bucket. + time.sleep(5) + + logging.info("Deleting gcloud bucket: %s", bucket_gs_link) + del_bucket = sp.run( + ["gsutil", "-m", "rm", "-r", bucket_gs_link], + stdin=sp.DEVNULL, + stdout=sp.PIPE, + stderr=sp.PIPE, + check=False, + ) + if del_bucket.returncode: + logging.error("Failed to rm gcloud bucket directory:\n %s", bucket_link) + logging.error("STDOUT:\n %s", del_bucket.stdout.decode("utf-8")) + logging.error("STDERR:\n %s", del_bucket.stderr.decode("utf-8")) def _prepare_experiment_info(args: argparse.Namespace) -> tuple[str, str, str]: - """ - Prepares and logs the key experiment information for easier accesses. - """ - # GKE job name. - gke_job_name = f'ofg-pr-{args.experiment_name}' - - # GKE job link. - gke_job_link = f'{JOB_LINK_PREFIX}/ofg-pr-{args.experiment_name}' - gke_job_link = gke_job_link.replace('', args.location) - gke_job_link = gke_job_link.replace('', args.cluster) - - # PR link. - ofg_pr_link = f'{PR_LINK_PREFIX}/{args.pr_id}' - - # Report link. - report_link = ( - f'{REPORT_LINK_PREFIX}/{datetime.now().strftime("%Y-%m-%d")}-' - f'{args.pr_id}-{args.name_suffix}-{args.benchmark_set}/index.html') - - # Bucket links. - bucket_link = (f'{BUCKET_LINK_PREFIX}/{datetime.now().strftime("%Y-%m-%d")}-' - f'{args.pr_id}-{args.name_suffix}-{args.benchmark_set}') - bucket_gs_link = ( - f'{BUCKET_GS_LINK_PREFIX}/{datetime.now().strftime("%Y-%m-%d")}-' - f'{args.pr_id}-{args.name_suffix}-{args.benchmark_set}') - - logging.info( - 'FORCE mode enable, will first remove existing GKE job and bucket.') - - logging.info( - 'Requesting a GKE experiment named %s:\nPR: %s\nJOB: %s\nREPORT: %s\n' - 'BUCKET: %s\nBUCKET GS: `%s`\n', - gke_job_name, - ofg_pr_link, - gke_job_link, - report_link, - bucket_link, - bucket_gs_link, - ) - return gke_job_name, bucket_link, bucket_gs_link + """ + Prepares and logs the key experiment information for easier accesses. + """ + # GKE job name. + gke_job_name = f"ofg-pr-{args.experiment_name}" + + # GKE job link. + gke_job_link = f"{JOB_LINK_PREFIX}/ofg-pr-{args.experiment_name}" + gke_job_link = gke_job_link.replace("", args.location) + gke_job_link = gke_job_link.replace("", args.cluster) + + # PR link. + ofg_pr_link = f"{PR_LINK_PREFIX}/{args.pr_id}" + + # Report link. + report_link = ( + f'{REPORT_LINK_PREFIX}/{datetime.now().strftime("%Y-%m-%d")}-' + f"{args.pr_id}-{args.name_suffix}-{args.benchmark_set}/index.html" + ) + + # Bucket links. + bucket_link = ( + f'{BUCKET_LINK_PREFIX}/{datetime.now().strftime("%Y-%m-%d")}-' + f"{args.pr_id}-{args.name_suffix}-{args.benchmark_set}" + ) + bucket_gs_link = ( + f'{BUCKET_GS_LINK_PREFIX}/{datetime.now().strftime("%Y-%m-%d")}-' + f"{args.pr_id}-{args.name_suffix}-{args.benchmark_set}" + ) + + logging.info("FORCE mode enable, will first remove existing GKE job and bucket.") + + logging.info( + "Requesting a GKE experiment named %s:\nPR: %s\nJOB: %s\nREPORT: %s\n" + "BUCKET: %s\nBUCKET GS: `%s`\n", + gke_job_name, + ofg_pr_link, + gke_job_link, + report_link, + bucket_link, + bucket_gs_link, + ) + return gke_job_name, bucket_link, bucket_gs_link def _get_gke_credential(args: argparse.Namespace): - """Authenticates gcloud account.""" - try: - sp.run([ - 'gcloud', - 'container', - 'clusters', - 'get-credentials', - args.cluster, - '--location', - args.location, - ], - check=False) - except Exception as e: - logging.error('Failed to authenticate gcloud: %s', e) + """Authenticates gcloud account.""" + try: + sp.run( + [ + "gcloud", + "container", + "clusters", + "get-credentials", + args.cluster, + "--location", + args.location, + ], + check=False, + ) + except Exception as e: + logging.error("Failed to authenticate gcloud: %s", e) def _fill_template(args: argparse.Namespace) -> str: - """Fills the GKE template with |args| and returns the result YAML path.""" - exp_env_vars = os.environ.copy() - exp_env_vars['PR_ID'] = str(args.pr_id) - exp_env_vars['GKE_EXP_BENCHMARK'] = args.benchmark_set - exp_env_vars['GKE_EXP_LLM'] = args.llm - exp_env_vars['GKE_EXP_VERTEX_AI_LOCATIONS'] = args.llm_locations - exp_env_vars['GKE_EXP_DELAY'] = args.delay - exp_env_vars['GKE_EXP_FUZZING_TIMEOUT'] = str(args.fuzzing_timeout) - exp_env_vars['GKE_EXP_NAME'] = args.experiment_name - exp_env_vars['GKE_EXP_REQ_CPU'] = args.request_cpus - exp_env_vars['GKE_EXP_REQ_MEM'] = f'{args.request_memory}Gi' - exp_env_vars[ - 'GKE_EXP_LOCAL_INTROSPECTOR'] = f'{args.local_introspector}'.lower() - exp_env_vars['GKE_EXP_NUM_SAMPLES'] = f'{args.num_samples}' - exp_env_vars['GKE_EXP_LLM_FIX_LIMIT'] = f'{args.llm_fix_limit}' - exp_env_vars['GKE_EXP_VARY_TEMPERATURE'] = f'{args.vary_temperature}'.lower() - exp_env_vars['GKE_EXP_AGENT'] = f'{args.agent}'.lower() - exp_env_vars['GKE_REDIRECT_OUTS'] = f'{args.redirect_outs}'.lower() - exp_env_vars['GKE_EXP_MAX_ROUND'] = args.max_round - - # Add additional args as a space-separated string - exp_env_vars['GKE_EXP_ADDITIONAL_ARGS'] = ' '.join(args.additional_args) - - with open(args.gke_template, 'r') as file: - yaml_template = file.read() - - substituted_content = Template(yaml_template).safe_substitute(exp_env_vars) - substituted_file_path = f'{os.path.splitext(args.gke_template)[0]}-sub.yaml' - - with open(substituted_file_path, 'w') as substituted_file: - substituted_file.write(substituted_content) - - return substituted_file_path + """Fills the GKE template with |args| and returns the result YAML path.""" + exp_env_vars = os.environ.copy() + exp_env_vars["PR_ID"] = str(args.pr_id) + exp_env_vars["GKE_EXP_BENCHMARK"] = args.benchmark_set + exp_env_vars["GKE_EXP_LLM"] = args.llm + exp_env_vars["GKE_EXP_VERTEX_AI_LOCATIONS"] = args.llm_locations + exp_env_vars["GKE_EXP_DELAY"] = args.delay + exp_env_vars["GKE_EXP_FUZZING_TIMEOUT"] = str(args.fuzzing_timeout) + exp_env_vars["GKE_EXP_NAME"] = args.experiment_name + exp_env_vars["GKE_EXP_REQ_CPU"] = args.request_cpus + exp_env_vars["GKE_EXP_REQ_MEM"] = f"{args.request_memory}Gi" + exp_env_vars["GKE_EXP_LOCAL_INTROSPECTOR"] = f"{args.local_introspector}".lower() + exp_env_vars["GKE_EXP_NUM_SAMPLES"] = f"{args.num_samples}" + exp_env_vars["GKE_EXP_LLM_FIX_LIMIT"] = f"{args.llm_fix_limit}" + exp_env_vars["GKE_EXP_VARY_TEMPERATURE"] = f"{args.vary_temperature}".lower() + exp_env_vars["GKE_EXP_AGENT"] = f"{args.agent}".lower() + exp_env_vars["GKE_REDIRECT_OUTS"] = f"{args.redirect_outs}".lower() + exp_env_vars["GKE_EXP_MAX_ROUND"] = args.max_round + + # Add additional args as a space-separated string + exp_env_vars["GKE_EXP_ADDITIONAL_ARGS"] = " ".join(args.additional_args) + + with open(args.gke_template, "r") as file: + yaml_template = file.read() + + substituted_content = Template(yaml_template).safe_substitute(exp_env_vars) + substituted_file_path = f"{os.path.splitext(args.gke_template)[0]}-sub.yaml" + + with open(substituted_file_path, "w") as substituted_file: + substituted_file.write(substituted_content) + + return substituted_file_path def _request_experiment(substituted_file_path: str): - """Requests an GKE experiment with |args| settings.""" - sp.run(['kubectl', 'create', '-f', substituted_file_path], check=True) + """Requests an GKE experiment with |args| settings.""" + sp.run(["kubectl", "create", "-f", substituted_file_path], check=True) def main(cmd=None): - """The main function.""" - args = _parse_args(cmd) - gke_job_name, bucket_link, bucket_gs_link = _prepare_experiment_info(args) - _get_gke_credential(args) - _remove_existing_job_bucket(gke_job_name, bucket_link, bucket_gs_link) - _request_experiment(_fill_template(args)) + """The main function.""" + args = _parse_args(cmd) + gke_job_name, bucket_link, bucket_gs_link = _prepare_experiment_info(args) + _get_gke_credential(args) + _remove_existing_job_bucket(gke_job_name, bucket_link, bucket_gs_link) + _request_experiment(_fill_template(args)) if __name__ == "__main__": - sys.exit(main()) + sys.exit(main()) diff --git a/common/cloud_builder.py b/common/cloud_builder.py index 79847da5a4..7f45ce6532 100644 --- a/common/cloud_builder.py +++ b/common/cloud_builder.py @@ -34,459 +34,508 @@ from agent.base_agent import BaseAgent from results import Result, RunResult -OF_REPO = 'https://github.com/google/oss-fuzz.git' +OF_REPO = "https://github.com/google/oss-fuzz.git" OFG_ROOT_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) -REGION = os.getenv('CLOUD_BUILD_LOCATION', 'us-west2') +REGION = os.getenv("CLOUD_BUILD_LOCATION", "us-west2") REGIONAL_CLIENT_OPTIONS = google.api_core.client_options.ClientOptions( - api_endpoint=f'https://{REGION}-cloudbuild.googleapis.com/') + api_endpoint=f"https://{REGION}-cloudbuild.googleapis.com/" +) _CHAT_HISTORY_PREFIX_PATTERN = r'^Step\s+#(\d+)\s+-\s+"agent-step":\s+' -_CHAT_HISTORY_START_MARKER = '' +_CHAT_HISTORY_START_MARKER = "" class CloudBuilder: - """A worker to execute llm-agents workflow in Google Cloud Build, providing a - scalable and distributed alternative to local executions: - - Request, monitor, and manage Google Cloud Build jobs. - - Execute agent in the cloud environment, replicating the local conditions. - - Transfer data and results between local and cloud environment. - """ - - def __init__(self, args: argparse.Namespace) -> None: - self.tags = ['ofg', 'agent', args.cloud_experiment_name] - self.exp_args = args - self.credentials, self.project_id = default() - assert self.project_id, 'Cloud experiment requires a Google cloud project.' - assert hasattr( - self.credentials, - 'refresh'), ('Cloud experiment requires a service account email') - assert hasattr(self.credentials, 'service_account_email'), ( - 'Cloud experiment requires a service account email') - - try: - # TODO(dongge): Understand why this crashes in local experiments. - self.credentials.refresh(Request()) # type: ignore - except: - pass - self.bucket_name = args.cloud_experiment_bucket - self.bucket = storage.Client().bucket(self.bucket_name) - - # pylint: disable=no-member - self.builds = cloud_build( - 'cloudbuild', - 'v1', - credentials=self.credentials, - cache_discovery=False, - client_options=REGIONAL_CLIENT_OPTIONS).projects().builds() - self.storage_client = storage.Client(credentials=self.credentials) - - def _upload_files(self, archive_name: str, target_dir: str, - files_to_upload: list[str]) -> str: - """Archive and upload files to GCS.""" - valid_files = [] - for f in files_to_upload: - file_path = os.path.join(target_dir, f) - if os.path.exists(file_path): - valid_files.append(f) - else: - logging.error("File does not exist: %s", file_path) - - valid_files.sort() - - with tempfile.TemporaryDirectory() as tmpdirname: - archive_path = os.path.join(tmpdirname, archive_name) - tar_command = ['tar', '-czf', archive_path] + valid_files - logging.error("Archive path: %s (exists: %s)", archive_path, - os.path.exists(archive_path)) - logging.error("Tar command: %s", ' '.join(tar_command)) - - try: - result = subprocess.run(tar_command, - cwd=target_dir, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True) - logging.error("subprocess stdout:\n%s", result.stdout) - logging.error("subprocess stderr:\n%s", result.stderr) - except subprocess.CalledProcessError as e: - logging.error("Tar command failed with return code %d", e.returncode) - logging.error("stdout:\n%s", e.stdout) - logging.error("stderr:\n%s", e.stderr) - raise - - if os.path.exists(archive_path): - logging.info("Successfully created archive: %s", archive_path) - else: - logging.error("Failed to create archive: %s", archive_path) - return self._upload_to_gcs(archive_path) - - def _upload_to_gcs(self, local_file_path: str) -> str: - """Uploads a file to Google Cloud Storage.""" - dest_file_name = os.path.basename(local_file_path) - self.bucket.blob(dest_file_name).upload_from_filename(local_file_path) - bucket_file_url = f'gs://{self.bucket_name}/{dest_file_name}' - logging.info('Uploaded %s to %s', local_file_path, bucket_file_url) - return bucket_file_url - - def _prepare_and_upload_archive(self, result_history: list[Result]) -> str: - """Archives and uploads local OFG repo to cloud build.""" - dir_files = set( - os.path.relpath(os.path.join(root, file)) - for root, _, files in os.walk(OFG_ROOT_DIR) - for file in files) - git_files = set( - subprocess.check_output(['git', 'ls-files'], - cwd=OFG_ROOT_DIR, - text=True).splitlines()) - - # Separate out any files in data-dir - tmp_files = [] - for gf in git_files: - if 'data-dir' in gf: - continue - tmp_files.append(gf) - git_files = set(tmp_files) - result_files = set( - os.path.relpath(os.path.join(root, file)) - for root, _, files in os.walk(result_history[-1].work_dirs.base) - for file in files) - file_to_upload = list((dir_files & git_files) | result_files) - - return self._upload_files(f'ofg-repo-{uuid.uuid4().hex}.tar.gz', - OFG_ROOT_DIR, file_to_upload) - - def _upload_oss_fuzz_data(self) -> str: - """Archives and uploads OSS_FUZZ_DATA_DIR, if any.""" - oss_fuzz_data_dir = os.getenv('OSS_FUZZ_DATA_DIR') - if not oss_fuzz_data_dir: - return '' - - # Upload all the files in the oss_fuzz_data_dir, _upload_files - # calls tar as the target_dir as cwd. - files_to_upload = ['.'] - return self._upload_files(f'oss-fuzz-{uuid.uuid4().hex}.tar.gz', - oss_fuzz_data_dir, files_to_upload) - - def _upload_fi_oss_fuzz_data(self) -> str: - """Upload data used by OFG from scratch.""" - data_dir = '/experiment/data-dir' - if not os.path.isdir(data_dir): - return '' - - # Upload all the files in the oss_fuzz_data_dir, _upload_files - # calls tar as the target_dir as cwd. - files_to_upload = ['.'] - return self._upload_files(f'data-dir-{uuid.uuid4().hex}.tar.gz', data_dir, - files_to_upload) - - def _request_cloud_build(self, ofg_repo_url: str, agent_dill_url: str, - results_dill_url: str, artifact_url: str, - artifact_path: str, oss_fuzz_data_url: str, - data_dir_url: str, new_result_filename: str) -> str: - """Requests Cloud Build to execute the operation.""" - - # Used for injecting additional OSS-Fuzz project integrations not in - # upstream OSS-Fuzz. - oss_fuzz_data_dir = '' - data_env_set = 'OSS_FUZZ_DATA_DIR_NOT_SET=1' - if oss_fuzz_data_url: - oss_fuzz_data_dir = '/workspace/oss-fuzz-data' - data_env_set = 'OSS_FUZZ_DATA_DIR=/workspace/oss-fuzz-data' - - target_data_dir = '' - if data_dir_url: - target_data_dir = '/workspace/data-dir' - - cloud_build_config = { - 'steps': [ - # Step 1: Download the dill and artifact files from GCS bucket. - { - 'name': 'bash', - 'dir': '/workspace', - 'args': ['-c', 'mkdir -p dills'] - }, - { - 'name': 'gcr.io/cloud-builders/gsutil', - 'dir': '/workspace', - 'args': ['cp', agent_dill_url, 'dills/agent.pkl'] - }, - { - 'name': 'gcr.io/cloud-builders/gsutil', - 'dir': '/workspace', - 'args': ['cp', results_dill_url, 'dills/result_history.pkl'] - }, - { - 'name': 'gcr.io/cloud-builders/gsutil', - 'entrypoint': 'bash', - 'args': [ - '-c', - f'mkdir -p /workspace/host/{os.path.dirname(artifact_path)}' - ], - 'allowFailure': True, - }, - { - 'name': 'gcr.io/cloud-builders/gsutil', - 'dir': '/workspace', - 'args': [ - 'cp', artifact_url, f'/workspace/host/{artifact_path}' - ], - 'allowFailure': True, - }, - # Step 2: Prepare OFG and OF repos. - { - 'name': - 'gcr.io/cloud-builders/gsutil', - 'entrypoint': - 'bash', - 'args': [ - '-c', f'gsutil cp {ofg_repo_url} /tmp/ofg-repo.tar.gz && ' - 'mkdir /workspace/ofg && ' - f'tar -xzf /tmp/ofg-repo.tar.gz -C /workspace/ofg' - ] - }, - # Step 3: Prepare agent base image. - { - 'name': 'gcr.io/cloud-builders/docker', - 'args': [ - 'build', '.', '-t', - ('us-central1-docker.pkg.dev/oss-fuzz/oss-fuzz-gen/' - 'agent-image'), '-f', 'Dockerfile.cloudbuild-agent' - ], - 'dir': '/workspace/ofg/', - }, - # Step 4: Prepare OSS-Fuzz repo. - { - 'name': 'gcr.io/cloud-builders/gsutil', - 'entrypoint': 'bash', - 'args': [ - '-c', f'test -n "{oss_fuzz_data_url}" && ' - f'gsutil cp {oss_fuzz_data_url} ' - '/tmp/oss-fuzz-data.tar.gz && ' - f'mkdir {oss_fuzz_data_dir} && ' - f'tar -xzf /tmp/oss-fuzz-data.tar.gz -C {oss_fuzz_data_dir}' - ], - 'allowFailure': True, - }, - { - 'name': 'gcr.io/cloud-builders/gsutil', - 'entrypoint': 'bash', - 'args': [ - '-c', f'test -n "{data_dir_url}" && ' - f'gsutil cp {data_dir_url} /tmp/data-dir.tar.gz && ' - f'mkdir {target_data_dir} && ' - f'tar -xzf /tmp/data-dir.tar.gz -C {target_data_dir}' - ], - 'allowFailure': True, - }, - { - 'name': - 'gcr.io/cloud-builders/docker', - 'dir': - '/workspace/ofg/', - 'args': [ - 'run', '--rm', '-v', '/workspace:/workspace', '-e', - data_env_set, - ('us-central1-docker.pkg.dev/oss-fuzz/oss-fuzz-gen/' - 'agent-image'), 'python3.11', '-c', - 'import os; from experiment import oss_fuzz_checkout; ' - 'oss_fuzz_checkout.clone_oss_fuzz("oss-fuzz"); ' - 'oss_fuzz_checkout.postprocess_oss_fuzz(); ' - ], - }, - # Step 5: Run the Python script with the dill files. - { - 'id': - 'agent-step', - 'name': - 'gcr.io/cloud-builders/docker', - 'args': [ - 'run', - '--rm', - '-v', - '/workspace:/workspace', - '-v', - '/workspace/host/experiment:/experiment', - '-v', - '/var/run/docker.sock:/var/run/docker.sock', - '-e', - 'VERTEX_AI_LOCATIONS=' + - os.getenv("VERTEX_AI_LOCATIONS", ""), - '--network=cloudbuild', - # Built from this repo's `Dockerfile.cloudbuild-agent`. - ('us-central1-docker.pkg.dev/oss-fuzz/oss-fuzz-gen/' - 'agent-image'), - 'python3.11', - '-m', - 'agent.base_agent', - '--agent', - '/workspace/dills/agent.pkl', - '--result-history', - '/workspace/dills/result_history.pkl', - '--result-new', - '/workspace/dills/new_result.pkl' - ], - }, - # Step 6: Upload the result to GCS bucket - { - 'name': 'bash', - 'dir': '/workspace', - 'args': ['ls', '/workspace/dills/'] - }, - { - 'name': - 'gcr.io/cloud-builders/gsutil', - 'dir': - '/workspace', - 'args': [ - 'cp', '/workspace/dills/new_result.pkl', - f'gs://{self.bucket_name}/{new_result_filename}' - ] - } - ], - 'tags': self.tags, - 'timeout': '10800s', # 3 hours - 'logsBucket': f'gs://{self.bucket_name}', - 'serviceAccount': - f'projects/{self.project_id}/serviceAccounts/' - f'{self.credentials.service_account_email}' # type: ignore - } - pool_name = os.getenv('GCB_BUILDPOOL_NAME') - if pool_name: - cloud_build_config.setdefault('options', {})['pool'] = {'name': pool_name} - logging.info(cloud_build_config) - - # Convert to YAML string and submit the Cloud Build request - build_info = self.builds.create(projectId=self.project_id, - body=cloud_build_config).execute() - build_id = build_info.get('metadata', {}).get('build', {}).get('id', '') - - logging.info('Created Cloud Build ID %s at %s', build_id, REGION) - return build_id - - def _wait_for_build(self, build_id: str) -> str: - """Wait for a GCB build.""" - prev_status = status = None - while status in [None, 'WORKING', 'QUEUED']: - try: - status = self.builds.get(projectId=self.project_id, - id=build_id).execute().get('status') - if status != prev_status: - logging.info('Cloud Build %s Status: %s', build_id, status) - prev_status = status - except (googleapiclient.errors.HttpError, BrokenPipeError) as e: - logging.warning('Failed to check cloud build status %s: %s', build_id, - e) - time.sleep(60) # Avoid rate limiting. - return status or '' - - def _cancel_build(self, build_id: str) -> None: - """Cancel a GCB build""" - self.builds.cancel(projectId=self.project_id, id=build_id).execute() - - def _extract_chat_history(self, full_log: str) -> str: - """Extracts the agent chat history from cloud build log.""" - in_chat = False - chat_history = [] - for log_line in full_log.splitlines(): - if not re.match(_CHAT_HISTORY_PREFIX_PATTERN, log_line): - continue - if _CHAT_HISTORY_START_MARKER in log_line: - in_chat = True - if in_chat: - stripped_line = re.sub(_CHAT_HISTORY_PREFIX_PATTERN, '', log_line) - chat_history.append(stripped_line) - return '\n'.join(chat_history) - - def _get_build_log(self, build_id: str) -> str: - """Downloads the build log""" - log_file_uri = f'log-{build_id}.txt' - try: - bucket = self.storage_client.bucket(self.bucket_name) - blob = bucket.blob(log_file_uri) - log_content = self._extract_chat_history(blob.download_as_text()) - logging.warning(log_content) - return log_content - except NotFound as e: - logging.error('Cloud build log %s not found: %s', log_file_uri, e) - return f'Cloud build log {log_file_uri} not found: {e}.' - - def _download_from_gcs(self, destination_file_name: str) -> None: - """Downloads the result file from GCS.""" - source_blob_name = os.path.basename(destination_file_name) - blob = self.bucket.blob(source_blob_name) - blob.download_to_filename(destination_file_name) - logging.info('Downloaded %s to %s', source_blob_name, destination_file_name) - - def run(self, agent: BaseAgent, result_history: list[Result], - dill_dir: str) -> Any: - """Runs agent on cloud build.""" - # Step 0: Add task-specific tags. - # TODO(dongge): More tags, e.g., benchmark name. - self.tags += [ - str(agent), - str(result_history[-1].benchmark.project), - str(result_history[-1].benchmark.function_name), - str(result_history[-1].trial) - ] - # Step1: Generate dill files. - agent_dill = utils.serialize_to_dill( - agent, os.path.join(dill_dir, f'{uuid.uuid4().hex}.pkl')) - results_dill = utils.serialize_to_dill( - result_history, os.path.join(dill_dir, f'{uuid.uuid4().hex}.pkl')) - # TODO(dongge): Encrypt dill files? - - # Step 2: Upload OFG repo and dill files to GCS. - ofg_url = self._prepare_and_upload_archive(result_history) - agent_url = self._upload_to_gcs(agent_dill) - results_url = self._upload_to_gcs(results_dill) - artifact_url = '' - artifact_path = '' - if isinstance(result_history[-1], RunResult): - artifact_path = result_history[-1].artifact_path - if artifact_path: - logging.info('Found artifact_path: %s in RunResult.', artifact_path) - artifact_url = self._upload_to_gcs(artifact_path) - logging.info('Uploaded artifact to %s', artifact_url) - else: - logging.error('No artifact_path found in RunResult.') - oss_fuzz_data_url = self._upload_oss_fuzz_data() - data_dir_url = self._upload_fi_oss_fuzz_data() - - # Step 3: Request Cloud Build. - new_result_filename = f'{uuid.uuid4().hex}.pkl' - build_id = self._request_cloud_build(ofg_url, agent_url, results_url, - artifact_url, artifact_path, - oss_fuzz_data_url, data_dir_url, - new_result_filename) - - # Step 4: Download new result dill. - cloud_build_log = '' - new_result_dill = os.path.join(dill_dir, new_result_filename) - try: - cloud_build_final_status = self._wait_for_build(build_id) - if cloud_build_final_status == 'SUCCESS': - self._download_from_gcs(new_result_dill) - else: - logging.error('Cloud build %s failed with status: %s', build_id, - cloud_build_final_status) - cloud_build_log += (f'Cloud build {build_id} failed with status: ' - f'{cloud_build_final_status}.\n') - except (KeyboardInterrupt, SystemExit) as e: - self._cancel_build(build_id) - logging.error('Cloud build %s cancled: %s', build_id, e) - cloud_build_log += f'Cloud build {build_id} cancled: {e}.\n' - - cloud_build_log += self._get_build_log(build_id) - - # Step 5: Deserialize dilld file. - result = utils.deserialize_from_dill(new_result_dill) - if not result: - cloud_build_log += f'Failed to deserialize from dill {new_result_dill}.\n' - last_result = result_history[-1] - result = Result(benchmark=last_result.benchmark, - trial=last_result.trial, - work_dirs=last_result.work_dirs, - author=agent) - result.chat_history = {agent.name: cloud_build_log} - - return result + """A worker to execute llm-agents workflow in Google Cloud Build, providing a + scalable and distributed alternative to local executions: + - Request, monitor, and manage Google Cloud Build jobs. + - Execute agent in the cloud environment, replicating the local conditions. + - Transfer data and results between local and cloud environment. + """ + + def __init__(self, args: argparse.Namespace) -> None: + self.tags = ["ofg", "agent", args.cloud_experiment_name] + self.exp_args = args + self.credentials, self.project_id = default() + assert self.project_id, "Cloud experiment requires a Google cloud project." + assert hasattr( + self.credentials, "refresh" + ), "Cloud experiment requires a service account email" + assert hasattr( + self.credentials, "service_account_email" + ), "Cloud experiment requires a service account email" + + try: + # TODO(dongge): Understand why this crashes in local experiments. + self.credentials.refresh(Request()) # type: ignore + except: + pass + self.bucket_name = args.cloud_experiment_bucket + self.bucket = storage.Client().bucket(self.bucket_name) + + # pylint: disable=no-member + self.builds = ( + cloud_build( + "cloudbuild", + "v1", + credentials=self.credentials, + cache_discovery=False, + client_options=REGIONAL_CLIENT_OPTIONS, + ) + .projects() + .builds() + ) + self.storage_client = storage.Client(credentials=self.credentials) + + def _upload_files( + self, archive_name: str, target_dir: str, files_to_upload: list[str] + ) -> str: + """Archive and upload files to GCS.""" + valid_files = [] + for f in files_to_upload: + file_path = os.path.join(target_dir, f) + if os.path.exists(file_path): + valid_files.append(f) + else: + logging.error("File does not exist: %s", file_path) + + valid_files.sort() + + with tempfile.TemporaryDirectory() as tmpdirname: + archive_path = os.path.join(tmpdirname, archive_name) + tar_command = ["tar", "-czf", archive_path] + valid_files + logging.error( + "Archive path: %s (exists: %s)", + archive_path, + os.path.exists(archive_path), + ) + logging.error("Tar command: %s", " ".join(tar_command)) + + try: + result = subprocess.run( + tar_command, + cwd=target_dir, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + logging.error("subprocess stdout:\n%s", result.stdout) + logging.error("subprocess stderr:\n%s", result.stderr) + except subprocess.CalledProcessError as e: + logging.error("Tar command failed with return code %d", e.returncode) + logging.error("stdout:\n%s", e.stdout) + logging.error("stderr:\n%s", e.stderr) + raise + + if os.path.exists(archive_path): + logging.info("Successfully created archive: %s", archive_path) + else: + logging.error("Failed to create archive: %s", archive_path) + return self._upload_to_gcs(archive_path) + + def _upload_to_gcs(self, local_file_path: str) -> str: + """Uploads a file to Google Cloud Storage.""" + dest_file_name = os.path.basename(local_file_path) + self.bucket.blob(dest_file_name).upload_from_filename(local_file_path) + bucket_file_url = f"gs://{self.bucket_name}/{dest_file_name}" + logging.info("Uploaded %s to %s", local_file_path, bucket_file_url) + return bucket_file_url + + def _prepare_and_upload_archive(self, result_history: list[Result]) -> str: + """Archives and uploads local OFG repo to cloud build.""" + dir_files = set( + os.path.relpath(os.path.join(root, file)) + for root, _, files in os.walk(OFG_ROOT_DIR) + for file in files + ) + git_files = set( + subprocess.check_output( + ["git", "ls-files"], cwd=OFG_ROOT_DIR, text=True + ).splitlines() + ) + + # Separate out any files in data-dir + tmp_files = [] + for gf in git_files: + if "data-dir" in gf: + continue + tmp_files.append(gf) + git_files = set(tmp_files) + result_files = set( + os.path.relpath(os.path.join(root, file)) + for root, _, files in os.walk(result_history[-1].work_dirs.base) + for file in files + ) + file_to_upload = list((dir_files & git_files) | result_files) + + return self._upload_files( + f"ofg-repo-{uuid.uuid4().hex}.tar.gz", OFG_ROOT_DIR, file_to_upload + ) + + def _upload_oss_fuzz_data(self) -> str: + """Archives and uploads OSS_FUZZ_DATA_DIR, if any.""" + oss_fuzz_data_dir = os.getenv("OSS_FUZZ_DATA_DIR") + if not oss_fuzz_data_dir: + return "" + + # Upload all the files in the oss_fuzz_data_dir, _upload_files + # calls tar as the target_dir as cwd. + files_to_upload = ["."] + return self._upload_files( + f"oss-fuzz-{uuid.uuid4().hex}.tar.gz", oss_fuzz_data_dir, files_to_upload + ) + + def _upload_fi_oss_fuzz_data(self) -> str: + """Upload data used by OFG from scratch.""" + data_dir = "/experiment/data-dir" + if not os.path.isdir(data_dir): + return "" + + # Upload all the files in the oss_fuzz_data_dir, _upload_files + # calls tar as the target_dir as cwd. + files_to_upload = ["."] + return self._upload_files( + f"data-dir-{uuid.uuid4().hex}.tar.gz", data_dir, files_to_upload + ) + + def _request_cloud_build( + self, + ofg_repo_url: str, + agent_dill_url: str, + results_dill_url: str, + artifact_url: str, + artifact_path: str, + oss_fuzz_data_url: str, + data_dir_url: str, + new_result_filename: str, + ) -> str: + """Requests Cloud Build to execute the operation.""" + + # Used for injecting additional OSS-Fuzz project integrations not in + # upstream OSS-Fuzz. + oss_fuzz_data_dir = "" + data_env_set = "OSS_FUZZ_DATA_DIR_NOT_SET=1" + if oss_fuzz_data_url: + oss_fuzz_data_dir = "/workspace/oss-fuzz-data" + data_env_set = "OSS_FUZZ_DATA_DIR=/workspace/oss-fuzz-data" + + target_data_dir = "" + if data_dir_url: + target_data_dir = "/workspace/data-dir" + + cloud_build_config = { + "steps": [ + # Step 1: Download the dill and artifact files from GCS bucket. + {"name": "bash", "dir": "/workspace", "args": ["-c", "mkdir -p dills"]}, + { + "name": "gcr.io/cloud-builders/gsutil", + "dir": "/workspace", + "args": ["cp", agent_dill_url, "dills/agent.pkl"], + }, + { + "name": "gcr.io/cloud-builders/gsutil", + "dir": "/workspace", + "args": ["cp", results_dill_url, "dills/result_history.pkl"], + }, + { + "name": "gcr.io/cloud-builders/gsutil", + "entrypoint": "bash", + "args": [ + "-c", + f"mkdir -p /workspace/host/{os.path.dirname(artifact_path)}", + ], + "allowFailure": True, + }, + { + "name": "gcr.io/cloud-builders/gsutil", + "dir": "/workspace", + "args": ["cp", artifact_url, f"/workspace/host/{artifact_path}"], + "allowFailure": True, + }, + # Step 2: Prepare OFG and OF repos. + { + "name": "gcr.io/cloud-builders/gsutil", + "entrypoint": "bash", + "args": [ + "-c", + f"gsutil cp {ofg_repo_url} /tmp/ofg-repo.tar.gz && " + "mkdir /workspace/ofg && " + f"tar -xzf /tmp/ofg-repo.tar.gz -C /workspace/ofg", + ], + }, + # Step 3: Prepare agent base image. + { + "name": "gcr.io/cloud-builders/docker", + "args": [ + "build", + ".", + "-t", + ( + "us-central1-docker.pkg.dev/oss-fuzz/oss-fuzz-gen/" + "agent-image" + ), + "-f", + "Dockerfile.cloudbuild-agent", + ], + "dir": "/workspace/ofg/", + }, + # Step 4: Prepare OSS-Fuzz repo. + { + "name": "gcr.io/cloud-builders/gsutil", + "entrypoint": "bash", + "args": [ + "-c", + f'test -n "{oss_fuzz_data_url}" && ' + f"gsutil cp {oss_fuzz_data_url} " + "/tmp/oss-fuzz-data.tar.gz && " + f"mkdir {oss_fuzz_data_dir} && " + f"tar -xzf /tmp/oss-fuzz-data.tar.gz -C {oss_fuzz_data_dir}", + ], + "allowFailure": True, + }, + { + "name": "gcr.io/cloud-builders/gsutil", + "entrypoint": "bash", + "args": [ + "-c", + f'test -n "{data_dir_url}" && ' + f"gsutil cp {data_dir_url} /tmp/data-dir.tar.gz && " + f"mkdir {target_data_dir} && " + f"tar -xzf /tmp/data-dir.tar.gz -C {target_data_dir}", + ], + "allowFailure": True, + }, + { + "name": "gcr.io/cloud-builders/docker", + "dir": "/workspace/ofg/", + "args": [ + "run", + "--rm", + "-v", + "/workspace:/workspace", + "-e", + data_env_set, + ( + "us-central1-docker.pkg.dev/oss-fuzz/oss-fuzz-gen/" + "agent-image" + ), + "python3.11", + "-c", + "import os; from experiment import oss_fuzz_checkout; " + 'oss_fuzz_checkout.clone_oss_fuzz("oss-fuzz"); ' + "oss_fuzz_checkout.postprocess_oss_fuzz(); ", + ], + }, + # Step 5: Run the Python script with the dill files. + { + "id": "agent-step", + "name": "gcr.io/cloud-builders/docker", + "args": [ + "run", + "--rm", + "-v", + "/workspace:/workspace", + "-v", + "/workspace/host/experiment:/experiment", + "-v", + "/var/run/docker.sock:/var/run/docker.sock", + "-e", + "VERTEX_AI_LOCATIONS=" + os.getenv("VERTEX_AI_LOCATIONS", ""), + "--network=cloudbuild", + # Built from this repo's `Dockerfile.cloudbuild-agent`. + ( + "us-central1-docker.pkg.dev/oss-fuzz/oss-fuzz-gen/" + "agent-image" + ), + "python3.11", + "-m", + "agent.base_agent", + "--agent", + "/workspace/dills/agent.pkl", + "--result-history", + "/workspace/dills/result_history.pkl", + "--result-new", + "/workspace/dills/new_result.pkl", + ], + }, + # Step 6: Upload the result to GCS bucket + { + "name": "bash", + "dir": "/workspace", + "args": ["ls", "/workspace/dills/"], + }, + { + "name": "gcr.io/cloud-builders/gsutil", + "dir": "/workspace", + "args": [ + "cp", + "/workspace/dills/new_result.pkl", + f"gs://{self.bucket_name}/{new_result_filename}", + ], + }, + ], + "tags": self.tags, + "timeout": "10800s", # 3 hours + "logsBucket": f"gs://{self.bucket_name}", + "serviceAccount": f"projects/{self.project_id}/serviceAccounts/" + f"{self.credentials.service_account_email}", # type: ignore + } + pool_name = os.getenv("GCB_BUILDPOOL_NAME") + if pool_name: + cloud_build_config.setdefault("options", {})["pool"] = {"name": pool_name} + logging.info(cloud_build_config) + + # Convert to YAML string and submit the Cloud Build request + build_info = self.builds.create( + projectId=self.project_id, body=cloud_build_config + ).execute() + build_id = build_info.get("metadata", {}).get("build", {}).get("id", "") + + logging.info("Created Cloud Build ID %s at %s", build_id, REGION) + return build_id + + def _wait_for_build(self, build_id: str) -> str: + """Wait for a GCB build.""" + prev_status = status = None + while status in [None, "WORKING", "QUEUED"]: + try: + status = ( + self.builds.get(projectId=self.project_id, id=build_id) + .execute() + .get("status") + ) + if status != prev_status: + logging.info("Cloud Build %s Status: %s", build_id, status) + prev_status = status + except (googleapiclient.errors.HttpError, BrokenPipeError) as e: + logging.warning( + "Failed to check cloud build status %s: %s", build_id, e + ) + time.sleep(60) # Avoid rate limiting. + return status or "" + + def _cancel_build(self, build_id: str) -> None: + """Cancel a GCB build""" + self.builds.cancel(projectId=self.project_id, id=build_id).execute() + + def _extract_chat_history(self, full_log: str) -> str: + """Extracts the agent chat history from cloud build log.""" + in_chat = False + chat_history = [] + for log_line in full_log.splitlines(): + if not re.match(_CHAT_HISTORY_PREFIX_PATTERN, log_line): + continue + if _CHAT_HISTORY_START_MARKER in log_line: + in_chat = True + if in_chat: + stripped_line = re.sub(_CHAT_HISTORY_PREFIX_PATTERN, "", log_line) + chat_history.append(stripped_line) + return "\n".join(chat_history) + + def _get_build_log(self, build_id: str) -> str: + """Downloads the build log""" + log_file_uri = f"log-{build_id}.txt" + try: + bucket = self.storage_client.bucket(self.bucket_name) + blob = bucket.blob(log_file_uri) + log_content = self._extract_chat_history(blob.download_as_text()) + logging.warning(log_content) + return log_content + except NotFound as e: + logging.error("Cloud build log %s not found: %s", log_file_uri, e) + return f"Cloud build log {log_file_uri} not found: {e}." + + def _download_from_gcs(self, destination_file_name: str) -> None: + """Downloads the result file from GCS.""" + source_blob_name = os.path.basename(destination_file_name) + blob = self.bucket.blob(source_blob_name) + blob.download_to_filename(destination_file_name) + logging.info("Downloaded %s to %s", source_blob_name, destination_file_name) + + def run(self, agent: BaseAgent, result_history: list[Result], dill_dir: str) -> Any: + """Runs agent on cloud build.""" + # Step 0: Add task-specific tags. + # TODO(dongge): More tags, e.g., benchmark name. + self.tags += [ + str(agent), + str(result_history[-1].benchmark.project), + str(result_history[-1].benchmark.function_name), + str(result_history[-1].trial), + ] + # Step1: Generate dill files. + agent_dill = utils.serialize_to_dill( + agent, os.path.join(dill_dir, f"{uuid.uuid4().hex}.pkl") + ) + results_dill = utils.serialize_to_dill( + result_history, os.path.join(dill_dir, f"{uuid.uuid4().hex}.pkl") + ) + # TODO(dongge): Encrypt dill files? + + # Step 2: Upload OFG repo and dill files to GCS. + ofg_url = self._prepare_and_upload_archive(result_history) + agent_url = self._upload_to_gcs(agent_dill) + results_url = self._upload_to_gcs(results_dill) + artifact_url = "" + artifact_path = "" + if isinstance(result_history[-1], RunResult): + artifact_path = result_history[-1].artifact_path + if artifact_path: + logging.info("Found artifact_path: %s in RunResult.", artifact_path) + artifact_url = self._upload_to_gcs(artifact_path) + logging.info("Uploaded artifact to %s", artifact_url) + else: + logging.error("No artifact_path found in RunResult.") + oss_fuzz_data_url = self._upload_oss_fuzz_data() + data_dir_url = self._upload_fi_oss_fuzz_data() + + # Step 3: Request Cloud Build. + new_result_filename = f"{uuid.uuid4().hex}.pkl" + build_id = self._request_cloud_build( + ofg_url, + agent_url, + results_url, + artifact_url, + artifact_path, + oss_fuzz_data_url, + data_dir_url, + new_result_filename, + ) + + # Step 4: Download new result dill. + cloud_build_log = "" + new_result_dill = os.path.join(dill_dir, new_result_filename) + try: + cloud_build_final_status = self._wait_for_build(build_id) + if cloud_build_final_status == "SUCCESS": + self._download_from_gcs(new_result_dill) + else: + logging.error( + "Cloud build %s failed with status: %s", + build_id, + cloud_build_final_status, + ) + cloud_build_log += ( + f"Cloud build {build_id} failed with status: " + f"{cloud_build_final_status}.\n" + ) + except (KeyboardInterrupt, SystemExit) as e: + self._cancel_build(build_id) + logging.error("Cloud build %s cancled: %s", build_id, e) + cloud_build_log += f"Cloud build {build_id} cancled: {e}.\n" + + cloud_build_log += self._get_build_log(build_id) + + # Step 5: Deserialize dilld file. + result = utils.deserialize_from_dill(new_result_dill) + if not result: + cloud_build_log += f"Failed to deserialize from dill {new_result_dill}.\n" + last_result = result_history[-1] + result = Result( + benchmark=last_result.benchmark, + trial=last_result.trial, + work_dirs=last_result.work_dirs, + author=agent, + ) + result.chat_history = {agent.name: cloud_build_log} + + return result diff --git a/data_prep/introspector.py b/data_prep/introspector.py index e36fa0a1fa..2aa7eabfe8 100755 --- a/data_prep/introspector.py +++ b/data_prep/introspector.py @@ -35,1139 +35,1188 @@ logger = logging.getLogger(__name__) -T = TypeVar('T', str, list, dict, int) # Generic type. +T = TypeVar("T", str, list, dict, int) # Generic type. TIMEOUT = 45 MAX_RETRY = 5 -BENCHMARK_ROOT: str = './benchmark-sets' -BENCHMARK_DIR: str = f'{BENCHMARK_ROOT}/comparison' -GENERATED_BENCHMARK: str = 'generated-benchmark-' +BENCHMARK_ROOT: str = "./benchmark-sets" +BENCHMARK_DIR: str = f"{BENCHMARK_ROOT}/comparison" +GENERATED_BENCHMARK: str = "generated-benchmark-" -USE_FI_TO_GET_TARGETS = bool(int(os.getenv('OSS_FI_TO_GET_TARGETS', '1'))) +USE_FI_TO_GET_TARGETS = bool(int(os.getenv("OSS_FI_TO_GET_TARGETS", "1"))) # By default exclude static functions when identifying fuzz target candidates # to generate benchmarks. ORACLE_AVOID_STATIC_FUNCTIONS = bool( - int(os.getenv('OSS_FUZZ_AVOID_STATIC_FUNCTIONS', '1'))) + int(os.getenv("OSS_FUZZ_AVOID_STATIC_FUNCTIONS", "1")) +) ORACLE_ONLY_REFERENCED_FUNCTIONS = bool( - int(os.getenv('OSS_FUZZ_ONLY_REFERENCED_FUNCTIONS', '0'))) + int(os.getenv("OSS_FUZZ_ONLY_REFERENCED_FUNCTIONS", "0")) +) ORACLE_ONLY_FUNCTIONS_WITH_HEADER_DECLARATIONS = bool( - int(os.getenv('OSS_FUZZ_ONLY_FUNCS_WITH_HEADER_DECLARATION', '1'))) - -DEFAULT_INTROSPECTOR_ENDPOINT = 'https://introspector.oss-fuzz.com/api' -INTROSPECTOR_ENDPOINT = '' -INTROSPECTOR_CFG = '' -INTROSPECTOR_ORACLE_FAR_REACH = '' -INTROSPECTOR_ORACLE_KEYWORD = '' -INTROSPECTOR_ORACLE_EASY_PARAMS = '' -INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES = '' -INTROSPECTOR_ORACLE_OPTIMAL = '' -INTROSPECTOR_ORACLE_ALL_TESTS = '' -INTROSPECTOR_FUNCTION_SOURCE = '' -INTROSPECTOR_PROJECT_SOURCE = '' -INTROSPECTOR_XREF = '' -INTROSPECTOR_TYPE = '' -INTROSPECTOR_FUNC_SIG = '' -INTROSPECTOR_ADDR_TYPE = '' -INTROSPECTOR_ALL_HEADER_FILES = '' -INTROSPECTOR_ALL_FUNC_TYPES = '' -INTROSPECTOR_ALL_TYPE_DEFINITION = '' -INTROSPECTOR_TEST_SOURCE = '' -INTROSPECTOR_HARNESS_SOURCE_AND_EXEC = '' -INTROSPECTOR_LANGUAGE_STATS = '' -INTROSPECTOR_GET_TARGET_FUNCTION = '' -INTROSPECTOR_CHECK_MACRO = '' - -INTROSPECTOR_HEADERS_FOR_FUNC = '' -INTROSPECTOR_SAMPLE_XREFS = '' -INTROSPECTOR_ALL_JVM_SOURCE_PATH = '' -INTROSPECTOR_FUNCTION_WITH_MATCHING_RETURN_TYPE = '' -INTROSPECTOR_JVM_PROPERTIES = '' -INTROSPECTOR_JVM_PUBLIC_CLASSES = '' + int(os.getenv("OSS_FUZZ_ONLY_FUNCS_WITH_HEADER_DECLARATION", "1")) +) + +DEFAULT_INTROSPECTOR_ENDPOINT = "https://introspector.oss-fuzz.com/api" +INTROSPECTOR_ENDPOINT = "" +INTROSPECTOR_CFG = "" +INTROSPECTOR_ORACLE_FAR_REACH = "" +INTROSPECTOR_ORACLE_KEYWORD = "" +INTROSPECTOR_ORACLE_EASY_PARAMS = "" +INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES = "" +INTROSPECTOR_ORACLE_OPTIMAL = "" +INTROSPECTOR_ORACLE_ALL_TESTS = "" +INTROSPECTOR_FUNCTION_SOURCE = "" +INTROSPECTOR_PROJECT_SOURCE = "" +INTROSPECTOR_XREF = "" +INTROSPECTOR_TYPE = "" +INTROSPECTOR_FUNC_SIG = "" +INTROSPECTOR_ADDR_TYPE = "" +INTROSPECTOR_ALL_HEADER_FILES = "" +INTROSPECTOR_ALL_FUNC_TYPES = "" +INTROSPECTOR_ALL_TYPE_DEFINITION = "" +INTROSPECTOR_TEST_SOURCE = "" +INTROSPECTOR_HARNESS_SOURCE_AND_EXEC = "" +INTROSPECTOR_LANGUAGE_STATS = "" +INTROSPECTOR_GET_TARGET_FUNCTION = "" +INTROSPECTOR_CHECK_MACRO = "" + +INTROSPECTOR_HEADERS_FOR_FUNC = "" +INTROSPECTOR_SAMPLE_XREFS = "" +INTROSPECTOR_ALL_JVM_SOURCE_PATH = "" +INTROSPECTOR_FUNCTION_WITH_MATCHING_RETURN_TYPE = "" +INTROSPECTOR_JVM_PROPERTIES = "" +INTROSPECTOR_JVM_PUBLIC_CLASSES = "" def get_oracle_dict() -> Dict[str, Any]: - """Returns the oracles available to identify targets.""" - # Do this in a function to allow for forward-declaration of functions below. - oracle_dict = { - 'far-reach-low-coverage': query_introspector_for_far_reach_low_cov, - 'low-cov-with-fuzz-keyword': query_introspector_for_keyword_targets, - 'easy-params-far-reach': query_introspector_for_easy_param_targets, - 'optimal-targets': query_introspector_for_optimal_targets, - 'test-migration': query_introspector_for_tests, - 'all-public-candidates': query_introspector_all_public_candidates, - } - return oracle_dict + """Returns the oracles available to identify targets.""" + # Do this in a function to allow for forward-declaration of functions below. + oracle_dict = { + "far-reach-low-coverage": query_introspector_for_far_reach_low_cov, + "low-cov-with-fuzz-keyword": query_introspector_for_keyword_targets, + "easy-params-far-reach": query_introspector_for_easy_param_targets, + "optimal-targets": query_introspector_for_optimal_targets, + "test-migration": query_introspector_for_tests, + "all-public-candidates": query_introspector_all_public_candidates, + } + return oracle_dict def set_introspector_endpoints(endpoint): - """Sets URLs for Fuzz Introspector endpoints to local or remote endpoints.""" - global INTROSPECTOR_ENDPOINT, INTROSPECTOR_CFG, INTROSPECTOR_FUNC_SIG, \ - INTROSPECTOR_FUNCTION_SOURCE, INTROSPECTOR_PROJECT_SOURCE, \ - INTROSPECTOR_XREF, INTROSPECTOR_TYPE, INTROSPECTOR_ORACLE_FAR_REACH, \ - INTROSPECTOR_ORACLE_KEYWORD, INTROSPECTOR_ADDR_TYPE, \ - INTROSPECTOR_ALL_HEADER_FILES, INTROSPECTOR_ALL_FUNC_TYPES, \ - INTROSPECTOR_SAMPLE_XREFS, INTROSPECTOR_ORACLE_EASY_PARAMS, \ - INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES, \ - INTROSPECTOR_ALL_JVM_SOURCE_PATH, INTROSPECTOR_ORACLE_OPTIMAL, \ - INTROSPECTOR_HEADERS_FOR_FUNC, \ - INTROSPECTOR_FUNCTION_WITH_MATCHING_RETURN_TYPE, \ - INTROSPECTOR_ORACLE_ALL_TESTS, INTROSPECTOR_JVM_PROPERTIES, \ - INTROSPECTOR_TEST_SOURCE, INTROSPECTOR_HARNESS_SOURCE_AND_EXEC, \ - INTROSPECTOR_JVM_PUBLIC_CLASSES, INTROSPECTOR_LANGUAGE_STATS, \ - INTROSPECTOR_GET_TARGET_FUNCTION, INTROSPECTOR_ALL_TYPE_DEFINITION, \ - INTROSPECTOR_CHECK_MACRO - - INTROSPECTOR_ENDPOINT = endpoint - - INTROSPECTOR_CFG = f'{INTROSPECTOR_ENDPOINT}/annotated-cfg' - INTROSPECTOR_ORACLE_FAR_REACH = ( - f'{INTROSPECTOR_ENDPOINT}/far-reach-but-low-coverage') - INTROSPECTOR_ORACLE_KEYWORD = ( - f'{INTROSPECTOR_ENDPOINT}/far-reach-low-cov-fuzz-keyword') - INTROSPECTOR_ORACLE_EASY_PARAMS = ( - f'{INTROSPECTOR_ENDPOINT}/easy-params-far-reach') - INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES = ( - f'{INTROSPECTOR_ENDPOINT}/all-public-candidates') - INTROSPECTOR_ORACLE_OPTIMAL = f'{INTROSPECTOR_ENDPOINT}/optimal-targets' - INTROSPECTOR_FUNCTION_SOURCE = f'{INTROSPECTOR_ENDPOINT}/function-source-code' - INTROSPECTOR_PROJECT_SOURCE = f'{INTROSPECTOR_ENDPOINT}/project-source-code' - INTROSPECTOR_TEST_SOURCE = f'{INTROSPECTOR_ENDPOINT}/project-test-code' - INTROSPECTOR_XREF = f'{INTROSPECTOR_ENDPOINT}/all-cross-references' - INTROSPECTOR_TYPE = f'{INTROSPECTOR_ENDPOINT}/type-info' - INTROSPECTOR_FUNC_SIG = f'{INTROSPECTOR_ENDPOINT}/function-signature' - INTROSPECTOR_ADDR_TYPE = ( - f'{INTROSPECTOR_ENDPOINT}/addr-to-recursive-dwarf-info') - INTROSPECTOR_ALL_HEADER_FILES = f'{INTROSPECTOR_ENDPOINT}/all-header-files' - INTROSPECTOR_ALL_FUNC_TYPES = f'{INTROSPECTOR_ENDPOINT}/func-debug-types' - INTROSPECTOR_ALL_TYPE_DEFINITION = ( - f'{INTROSPECTOR_ENDPOINT}/full-type-definition') - INTROSPECTOR_HEADERS_FOR_FUNC = ( - f'{INTROSPECTOR_ENDPOINT}/get-header-files-needed-for-function') - INTROSPECTOR_SAMPLE_XREFS = ( - f'{INTROSPECTOR_ENDPOINT}/sample-cross-references') - INTROSPECTOR_ALL_JVM_SOURCE_PATH = ( - f'{INTROSPECTOR_ENDPOINT}/all-project-source-files') - INTROSPECTOR_FUNCTION_WITH_MATCHING_RETURN_TYPE = ( - f'{INTROSPECTOR_ENDPOINT}/function-with-matching-return-type') - INTROSPECTOR_ORACLE_ALL_TESTS = f'{INTROSPECTOR_ENDPOINT}/project-tests' - INTROSPECTOR_JVM_PROPERTIES = f'{INTROSPECTOR_ENDPOINT}/jvm-method-properties' - INTROSPECTOR_HARNESS_SOURCE_AND_EXEC = ( - f'{INTROSPECTOR_ENDPOINT}/harness-source-and-executable') - INTROSPECTOR_JVM_PUBLIC_CLASSES = ( - f'{INTROSPECTOR_ENDPOINT}/all-public-classes') - INTROSPECTOR_LANGUAGE_STATS = ( - f'{INTROSPECTOR_ENDPOINT}/database-language-stats') - INTROSPECTOR_GET_TARGET_FUNCTION = ( - f'{INTROSPECTOR_ENDPOINT}/get-target-function') - INTROSPECTOR_CHECK_MACRO = f'{INTROSPECTOR_ENDPOINT}/check_macro' + """Sets URLs for Fuzz Introspector endpoints to local or remote endpoints.""" + global INTROSPECTOR_ENDPOINT, INTROSPECTOR_CFG, INTROSPECTOR_FUNC_SIG, INTROSPECTOR_FUNCTION_SOURCE, INTROSPECTOR_PROJECT_SOURCE, INTROSPECTOR_XREF, INTROSPECTOR_TYPE, INTROSPECTOR_ORACLE_FAR_REACH, INTROSPECTOR_ORACLE_KEYWORD, INTROSPECTOR_ADDR_TYPE, INTROSPECTOR_ALL_HEADER_FILES, INTROSPECTOR_ALL_FUNC_TYPES, INTROSPECTOR_SAMPLE_XREFS, INTROSPECTOR_ORACLE_EASY_PARAMS, INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES, INTROSPECTOR_ALL_JVM_SOURCE_PATH, INTROSPECTOR_ORACLE_OPTIMAL, INTROSPECTOR_HEADERS_FOR_FUNC, INTROSPECTOR_FUNCTION_WITH_MATCHING_RETURN_TYPE, INTROSPECTOR_ORACLE_ALL_TESTS, INTROSPECTOR_JVM_PROPERTIES, INTROSPECTOR_TEST_SOURCE, INTROSPECTOR_HARNESS_SOURCE_AND_EXEC, INTROSPECTOR_JVM_PUBLIC_CLASSES, INTROSPECTOR_LANGUAGE_STATS, INTROSPECTOR_GET_TARGET_FUNCTION, INTROSPECTOR_ALL_TYPE_DEFINITION, INTROSPECTOR_CHECK_MACRO + + INTROSPECTOR_ENDPOINT = endpoint + + INTROSPECTOR_CFG = f"{INTROSPECTOR_ENDPOINT}/annotated-cfg" + INTROSPECTOR_ORACLE_FAR_REACH = ( + f"{INTROSPECTOR_ENDPOINT}/far-reach-but-low-coverage" + ) + INTROSPECTOR_ORACLE_KEYWORD = ( + f"{INTROSPECTOR_ENDPOINT}/far-reach-low-cov-fuzz-keyword" + ) + INTROSPECTOR_ORACLE_EASY_PARAMS = f"{INTROSPECTOR_ENDPOINT}/easy-params-far-reach" + INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES = ( + f"{INTROSPECTOR_ENDPOINT}/all-public-candidates" + ) + INTROSPECTOR_ORACLE_OPTIMAL = f"{INTROSPECTOR_ENDPOINT}/optimal-targets" + INTROSPECTOR_FUNCTION_SOURCE = f"{INTROSPECTOR_ENDPOINT}/function-source-code" + INTROSPECTOR_PROJECT_SOURCE = f"{INTROSPECTOR_ENDPOINT}/project-source-code" + INTROSPECTOR_TEST_SOURCE = f"{INTROSPECTOR_ENDPOINT}/project-test-code" + INTROSPECTOR_XREF = f"{INTROSPECTOR_ENDPOINT}/all-cross-references" + INTROSPECTOR_TYPE = f"{INTROSPECTOR_ENDPOINT}/type-info" + INTROSPECTOR_FUNC_SIG = f"{INTROSPECTOR_ENDPOINT}/function-signature" + INTROSPECTOR_ADDR_TYPE = f"{INTROSPECTOR_ENDPOINT}/addr-to-recursive-dwarf-info" + INTROSPECTOR_ALL_HEADER_FILES = f"{INTROSPECTOR_ENDPOINT}/all-header-files" + INTROSPECTOR_ALL_FUNC_TYPES = f"{INTROSPECTOR_ENDPOINT}/func-debug-types" + INTROSPECTOR_ALL_TYPE_DEFINITION = f"{INTROSPECTOR_ENDPOINT}/full-type-definition" + INTROSPECTOR_HEADERS_FOR_FUNC = ( + f"{INTROSPECTOR_ENDPOINT}/get-header-files-needed-for-function" + ) + INTROSPECTOR_SAMPLE_XREFS = f"{INTROSPECTOR_ENDPOINT}/sample-cross-references" + INTROSPECTOR_ALL_JVM_SOURCE_PATH = ( + f"{INTROSPECTOR_ENDPOINT}/all-project-source-files" + ) + INTROSPECTOR_FUNCTION_WITH_MATCHING_RETURN_TYPE = ( + f"{INTROSPECTOR_ENDPOINT}/function-with-matching-return-type" + ) + INTROSPECTOR_ORACLE_ALL_TESTS = f"{INTROSPECTOR_ENDPOINT}/project-tests" + INTROSPECTOR_JVM_PROPERTIES = f"{INTROSPECTOR_ENDPOINT}/jvm-method-properties" + INTROSPECTOR_HARNESS_SOURCE_AND_EXEC = ( + f"{INTROSPECTOR_ENDPOINT}/harness-source-and-executable" + ) + INTROSPECTOR_JVM_PUBLIC_CLASSES = f"{INTROSPECTOR_ENDPOINT}/all-public-classes" + INTROSPECTOR_LANGUAGE_STATS = f"{INTROSPECTOR_ENDPOINT}/database-language-stats" + INTROSPECTOR_GET_TARGET_FUNCTION = f"{INTROSPECTOR_ENDPOINT}/get-target-function" + INTROSPECTOR_CHECK_MACRO = f"{INTROSPECTOR_ENDPOINT}/check_macro" def _construct_url(api: str, params: dict) -> str: - """Constructs an encoded url for the |api| with |params|.""" - return api + '?' + urlencode(params) + """Constructs an encoded url for the |api| with |params|.""" + return api + "?" + urlencode(params) def _query_introspector(api: str, params: dict) -> Optional[requests.Response]: - """Queries FuzzIntrospector API and returns the json payload, - returns an empty dict if unable to get data.""" - for attempt_num in range(1, MAX_RETRY + 1): - try: - resp = requests.get(api, params, timeout=TIMEOUT) - if not resp.ok: - logger.error( - 'Failed to get data from FI:\n' - '%s\n' - '-----------Response received------------\n' - '%s\n' - '------------End of response-------------', resp.url, - resp.content.decode('utf-8').strip()) - break - return resp - except requests.exceptions.Timeout as err: - if attempt_num == MAX_RETRY: - logger.error( - 'Failed to get data from FI due to timeout, max retry exceeded:\n' - '%s\n' - 'Error: %s', _construct_url(api, params), err) - break - delay = 5 * 2**attempt_num + random.randint(1, 10) - logger.warning( - 'Failed to get data from FI due to timeout on attempt %d:\n' - '%s\n' - 'retry in %ds...', attempt_num, _construct_url(api, params), delay) - time.sleep(delay) - except requests.exceptions.RequestException as err: - logger.error( - 'Failed to get data from FI due to unexpected error:\n' - '%s\n' - 'Error: %s', _construct_url(api, params), err) - break - - return None - - -def _get_data(resp: Optional[requests.Response], key: str, - default_value: T) -> T: - """Gets the value specified by |key| from a Request |resp|.""" - if not resp: - return default_value + """Queries FuzzIntrospector API and returns the json payload, + returns an empty dict if unable to get data.""" + for attempt_num in range(1, MAX_RETRY + 1): + try: + resp = requests.get(api, params, timeout=TIMEOUT) + if not resp.ok: + logger.error( + "Failed to get data from FI:\n" + "%s\n" + "-----------Response received------------\n" + "%s\n" + "------------End of response-------------", + resp.url, + resp.content.decode("utf-8").strip(), + ) + break + return resp + except requests.exceptions.Timeout as err: + if attempt_num == MAX_RETRY: + logger.error( + "Failed to get data from FI due to timeout, max retry exceeded:\n" + "%s\n" + "Error: %s", + _construct_url(api, params), + err, + ) + break + delay = 5 * 2**attempt_num + random.randint(1, 10) + logger.warning( + "Failed to get data from FI due to timeout on attempt %d:\n" + "%s\n" + "retry in %ds...", + attempt_num, + _construct_url(api, params), + delay, + ) + time.sleep(delay) + except requests.exceptions.RequestException as err: + logger.error( + "Failed to get data from FI due to unexpected error:\n" + "%s\n" + "Error: %s", + _construct_url(api, params), + err, + ) + break - try: - data = resp.json() - except requests.exceptions.InvalidJSONError: - logger.error( - 'Unable to parse response from FI:\n' - '%s\n' - '-----------Response received------------\n' - '%s\n' - '------------End of response-------------', resp.url, - resp.content.decode('utf-8').strip()) - return default_value - - # To handle the case that some FI query could return empty list, - # empty dict or boolean value False - content = data.get(key) - if content or key in data.keys(): - return content + return None - logger.error('Failed to get %s from FI:\n' - '%s\n' - '%s', key, resp.url, data) - return default_value +def _get_data(resp: Optional[requests.Response], key: str, default_value: T) -> T: + """Gets the value specified by |key| from a Request |resp|.""" + if not resp: + return default_value -def query_introspector_for_tests(project: str) -> list[str]: - """Gets the list of test files in the target project.""" - resp = _query_introspector(INTROSPECTOR_ORACLE_ALL_TESTS, { - 'project': project, - }) - return _get_data(resp, 'test-file-list', []) + try: + data = resp.json() + except requests.exceptions.InvalidJSONError: + logger.error( + "Unable to parse response from FI:\n" + "%s\n" + "-----------Response received------------\n" + "%s\n" + "------------End of response-------------", + resp.url, + resp.content.decode("utf-8").strip(), + ) + return default_value + + # To handle the case that some FI query could return empty list, + # empty dict or boolean value False + content = data.get(key) + if content or key in data.keys(): + return content + + logger.error("Failed to get %s from FI:\n" "%s\n" "%s", key, resp.url, data) + return default_value -def query_introspector_for_harness_intrinsics( - project: str) -> list[dict[str, str]]: - """Gets the list of test files in the target project.""" - resp = _query_introspector(INTROSPECTOR_HARNESS_SOURCE_AND_EXEC, { - 'project': project, - }) - return _get_data(resp, 'pairs', []) +def query_introspector_for_tests(project: str) -> list[str]: + """Gets the list of test files in the target project.""" + resp = _query_introspector( + INTROSPECTOR_ORACLE_ALL_TESTS, + { + "project": project, + }, + ) + return _get_data(resp, "test-file-list", []) + + +def query_introspector_for_harness_intrinsics(project: str) -> list[dict[str, str]]: + """Gets the list of test files in the target project.""" + resp = _query_introspector( + INTROSPECTOR_HARNESS_SOURCE_AND_EXEC, + { + "project": project, + }, + ) + return _get_data(resp, "pairs", []) def query_introspector_oracle(project: str, oracle_api: str) -> list[dict]: - """Queries a fuzz target oracle API from Fuzz Introspector.""" - resp = _query_introspector( - oracle_api, { - 'project': - project, - 'exclude-static-functions': - ORACLE_AVOID_STATIC_FUNCTIONS, - 'only-referenced-functions': - ORACLE_ONLY_REFERENCED_FUNCTIONS, - 'only-with-header-file-declaration': - ORACLE_ONLY_FUNCTIONS_WITH_HEADER_DECLARATIONS, - }) - return _get_data(resp, 'functions', []) + """Queries a fuzz target oracle API from Fuzz Introspector.""" + resp = _query_introspector( + oracle_api, + { + "project": project, + "exclude-static-functions": ORACLE_AVOID_STATIC_FUNCTIONS, + "only-referenced-functions": ORACLE_ONLY_REFERENCED_FUNCTIONS, + "only-with-header-file-declaration": ORACLE_ONLY_FUNCTIONS_WITH_HEADER_DECLARATIONS, + }, + ) + return _get_data(resp, "functions", []) def query_introspector_for_optimal_targets(project: str) -> list[dict]: - """Queries Fuzz Introspector for optimal target analysis.""" - return query_introspector_oracle(project, INTROSPECTOR_ORACLE_OPTIMAL) + """Queries Fuzz Introspector for optimal target analysis.""" + return query_introspector_oracle(project, INTROSPECTOR_ORACLE_OPTIMAL) def query_introspector_for_keyword_targets(project: str) -> list[dict]: - """Queries FuzzIntrospector for targets with interesting fuzz keywords.""" - return query_introspector_oracle(project, INTROSPECTOR_ORACLE_KEYWORD) + """Queries FuzzIntrospector for targets with interesting fuzz keywords.""" + return query_introspector_oracle(project, INTROSPECTOR_ORACLE_KEYWORD) def query_introspector_for_easy_param_targets(project: str) -> list[dict]: - """Queries Fuzz Introspector for targets that have fuzzer-friendly params, - such as data buffers.""" - return query_introspector_oracle(project, INTROSPECTOR_ORACLE_EASY_PARAMS) + """Queries Fuzz Introspector for targets that have fuzzer-friendly params, + such as data buffers.""" + return query_introspector_oracle(project, INTROSPECTOR_ORACLE_EASY_PARAMS) def query_introspector_all_public_candidates(project: str) -> list[dict]: - """Queries Fuzz Introspector for all public accessible function or - constructor candidates. - """ - return query_introspector_oracle(project, - INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES) + """Queries Fuzz Introspector for all public accessible function or + constructor candidates. + """ + return query_introspector_oracle(project, INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES) def query_introspector_for_targets(project, target_oracle) -> list[Dict]: - """Queries introspector for target functions.""" - query_func = get_oracle_dict().get(target_oracle, None) - if not query_func: - logger.error('No such oracle "%s"', target_oracle) - sys.exit(1) - return query_func(project) + """Queries introspector for target functions.""" + query_func = get_oracle_dict().get(target_oracle, None) + if not query_func: + logger.error('No such oracle "%s"', target_oracle) + sys.exit(1) + return query_func(project) def query_introspector_cfg(project: str) -> dict: - """Queries FuzzIntrospector API for CFG.""" - resp = _query_introspector(INTROSPECTOR_CFG, {'project': project}) - return _get_data(resp, 'project', {}) + """Queries FuzzIntrospector API for CFG.""" + resp = _query_introspector(INTROSPECTOR_CFG, {"project": project}) + return _get_data(resp, "project", {}) def query_introspector_source_file_path(project: str, func_sig: str) -> str: - """Queries FuzzIntrospector API for file path of |func_sig|.""" - resp = _query_introspector(INTROSPECTOR_FUNCTION_SOURCE, { - 'project': project, - 'function_signature': func_sig - }) - return _get_data(resp, 'filepath', '') + """Queries FuzzIntrospector API for file path of |func_sig|.""" + resp = _query_introspector( + INTROSPECTOR_FUNCTION_SOURCE, + {"project": project, "function_signature": func_sig}, + ) + return _get_data(resp, "filepath", "") def query_introspector_function_source(project: str, func_sig: str) -> str: - """Queries FuzzIntrospector API for source code of |func_sig|.""" - resp = _query_introspector(INTROSPECTOR_FUNCTION_SOURCE, { - 'project': project, - 'function_signature': func_sig - }) - return _get_data(resp, 'source', '') + """Queries FuzzIntrospector API for source code of |func_sig|.""" + resp = _query_introspector( + INTROSPECTOR_FUNCTION_SOURCE, + {"project": project, "function_signature": func_sig}, + ) + return _get_data(resp, "source", "") def query_introspector_function_line(project: str, func_sig: str) -> list: - """Queries FuzzIntrospector API for source line of |func_sig|.""" - resp = _query_introspector(INTROSPECTOR_FUNCTION_SOURCE, { - 'project': project, - 'function_signature': func_sig - }) - return [_get_data(resp, 'src_begin', 0), _get_data(resp, 'src_end', 0)] + """Queries FuzzIntrospector API for source line of |func_sig|.""" + resp = _query_introspector( + INTROSPECTOR_FUNCTION_SOURCE, + {"project": project, "function_signature": func_sig}, + ) + return [_get_data(resp, "src_begin", 0), _get_data(resp, "src_end", 0)] def query_introspector_function_props(project: str, func_sig: str) -> dict: - """Queries FuzzIntrospector API for additional properties of |func_sig|.""" - resp = _query_introspector(INTROSPECTOR_JVM_PROPERTIES, { - 'project': project, - 'function_signature': func_sig - }) - return { - 'exceptions': _get_data(resp, 'exceptions', []), - 'is-jvm-static': _get_data(resp, 'is-jvm-static', False), - 'need-close': _get_data(resp, 'need-close', False) - } + """Queries FuzzIntrospector API for additional properties of |func_sig|.""" + resp = _query_introspector( + INTROSPECTOR_JVM_PROPERTIES, + {"project": project, "function_signature": func_sig}, + ) + return { + "exceptions": _get_data(resp, "exceptions", []), + "is-jvm-static": _get_data(resp, "is-jvm-static", False), + "need-close": _get_data(resp, "need-close", False), + } def query_introspector_public_classes(project: str) -> list[str]: - """Queries FuzzIntrospector API for all public classes of |project|.""" - resp = _query_introspector(INTROSPECTOR_JVM_PUBLIC_CLASSES, - {'project': project}) - return _get_data(resp, 'classes', []) + """Queries FuzzIntrospector API for all public classes of |project|.""" + resp = _query_introspector(INTROSPECTOR_JVM_PUBLIC_CLASSES, {"project": project}) + return _get_data(resp, "classes", []) -def query_introspector_source_code(project: str, - filepath: str, - begin_line: int = 0, - end_line: int = 10000) -> str: - """Queries FuzzIntrospector API for source code of a +def query_introspector_source_code( + project: str, filepath: str, begin_line: int = 0, end_line: int = 10000 +) -> str: + """Queries FuzzIntrospector API for source code of a file |filepath| between |begin_line| and |end_line|.""" - resp = _query_introspector( - INTROSPECTOR_PROJECT_SOURCE, { - 'project': project, - 'filepath': filepath, - 'begin_line': begin_line, - 'end_line': end_line, - }) + resp = _query_introspector( + INTROSPECTOR_PROJECT_SOURCE, + { + "project": project, + "filepath": filepath, + "begin_line": begin_line, + "end_line": end_line, + }, + ) - return _get_data(resp, 'source_code', '') + return _get_data(resp, "source_code", "") def query_introspector_test_source(project: str, filepath: str) -> str: - """Queries the source code of a test file from.""" - resp = _query_introspector(INTROSPECTOR_TEST_SOURCE, { - 'project': project, - 'filepath': filepath - }) - return _get_data(resp, 'source_code', '') + """Queries the source code of a test file from.""" + resp = _query_introspector( + INTROSPECTOR_TEST_SOURCE, {"project": project, "filepath": filepath} + ) + return _get_data(resp, "source_code", "") def query_introspector_header_files(project: str) -> List[str]: - """Queries for the header files used in a given project.""" - resp = _query_introspector(INTROSPECTOR_ALL_HEADER_FILES, - {'project': project}) - all_header_files = _get_data(resp, 'all-header-files', []) - return all_header_files + """Queries for the header files used in a given project.""" + resp = _query_introspector(INTROSPECTOR_ALL_HEADER_FILES, {"project": project}) + all_header_files = _get_data(resp, "all-header-files", []) + return all_header_files def query_introspector_sample_xrefs(project: str, func_sig: str) -> List[str]: - """Queries for sample references in the form of source code.""" - resp = _query_introspector(INTROSPECTOR_SAMPLE_XREFS, { - 'project': project, - 'function_signature': func_sig - }) - return _get_data(resp, 'source-code-refs', []) + """Queries for sample references in the form of source code.""" + resp = _query_introspector( + INTROSPECTOR_SAMPLE_XREFS, {"project": project, "function_signature": func_sig} + ) + return _get_data(resp, "source-code-refs", []) def query_introspector_jvm_source_path(project: str) -> List[str]: - """Queries for all java source paths of a given project.""" - resp = _query_introspector(INTROSPECTOR_ALL_JVM_SOURCE_PATH, - {'project': project}) - return _get_data(resp, 'src_path', []) + """Queries for all java source paths of a given project.""" + resp = _query_introspector(INTROSPECTOR_ALL_JVM_SOURCE_PATH, {"project": project}) + return _get_data(resp, "src_path", []) def query_introspector_matching_function_constructor_type( - project: str, return_type: str, is_function: bool) -> List[Dict[str, Any]]: - """Queries for all functions or all constructors that returns a given type - in a given project.""" - simple_types_should_not_process = [ - 'byte', 'char', 'boolean', 'short', 'long', 'int', 'float', 'double', - 'void', 'java.lang.String', 'java.lang.CharSequence' - ] - if return_type in simple_types_should_not_process: - # Avoid querying introspector for simple object types as this API is - # not meant to be used for creating simple object. - return [] - - resp = _query_introspector(INTROSPECTOR_FUNCTION_WITH_MATCHING_RETURN_TYPE, { - 'project': project, - 'return-type': return_type - }) - - if is_function: - return _get_data(resp, 'functions', []) - - return _get_data(resp, 'constructors', []) - - -def query_introspector_header_files_to_include(project: str, - func_sig: str) -> List[str]: - """Queries Fuzz Introspector header files where a function is likely - declared.""" - resp = _query_introspector(INTROSPECTOR_HEADERS_FOR_FUNC, { - 'project': project, - 'function_signature': func_sig - }) - arg_types = _get_data(resp, 'headers-to-include', []) - return arg_types - - -def query_introspector_function_debug_arg_types(project: str, - func_sig: str) -> List[str]: - """Queries FuzzIntrospector function arguments extracted by way of debug - info.""" - resp = _query_introspector(INTROSPECTOR_ALL_FUNC_TYPES, { - 'project': project, - 'function_signature': func_sig - }) - arg_types = _get_data(resp, 'arg-types', []) - return arg_types + project: str, return_type: str, is_function: bool +) -> List[Dict[str, Any]]: + """Queries for all functions or all constructors that returns a given type + in a given project.""" + simple_types_should_not_process = [ + "byte", + "char", + "boolean", + "short", + "long", + "int", + "float", + "double", + "void", + "java.lang.String", + "java.lang.CharSequence", + ] + if return_type in simple_types_should_not_process: + # Avoid querying introspector for simple object types as this API is + # not meant to be used for creating simple object. + return [] + + resp = _query_introspector( + INTROSPECTOR_FUNCTION_WITH_MATCHING_RETURN_TYPE, + {"project": project, "return-type": return_type}, + ) + + if is_function: + return _get_data(resp, "functions", []) + + return _get_data(resp, "constructors", []) + + +def query_introspector_header_files_to_include( + project: str, func_sig: str +) -> List[str]: + """Queries Fuzz Introspector header files where a function is likely + declared.""" + resp = _query_introspector( + INTROSPECTOR_HEADERS_FOR_FUNC, + {"project": project, "function_signature": func_sig}, + ) + arg_types = _get_data(resp, "headers-to-include", []) + return arg_types + + +def query_introspector_function_debug_arg_types( + project: str, func_sig: str +) -> List[str]: + """Queries FuzzIntrospector function arguments extracted by way of debug + info.""" + resp = _query_introspector( + INTROSPECTOR_ALL_FUNC_TYPES, + {"project": project, "function_signature": func_sig}, + ) + arg_types = _get_data(resp, "arg-types", []) + return arg_types def query_introspector_type_definition(project: str) -> List[dict]: - """Queries FuzzIntrospector for a full list of custom type definition - including, union, struct, typedef, enum and macro definition.""" - resp = _query_introspector(INTROSPECTOR_ALL_TYPE_DEFINITION, { - 'project': project, - }) - result = _get_data(resp, 'project', {}) - return result.get('typedef_list', []) - - -def query_introspector_macro_block(project: str, - source_path: str, - line_start: int = 0, - line_end: int = 99999) -> List[dict]: - """Queries FuzzIntrospector for a full list of custom type definition - including, union, struct, typedef, enum and macro definition.""" - resp = _query_introspector( - INTROSPECTOR_CHECK_MACRO, { - 'project': project, - 'source': source_path, - 'start': line_start, - 'end': line_end - }) - result = _get_data(resp, 'project', {}) - return result.get('macro_block_info', []) - - -def query_introspector_cross_references(project: str, - func_sig: str) -> list[str]: - """Queries FuzzIntrospector API for source code of functions - which reference |func_sig|.""" - resp = _query_introspector(INTROSPECTOR_XREF, { - 'project': project, - 'function_signature': func_sig - }) - call_sites = _get_data(resp, 'callsites', []) - - xref_source = [] - for cs in call_sites: - name = cs.get('src_func') - sig = query_introspector_function_signature(project, name) - source = query_introspector_function_source(project, sig) - xref_source.append(source) - return xref_source + """Queries FuzzIntrospector for a full list of custom type definition + including, union, struct, typedef, enum and macro definition.""" + resp = _query_introspector( + INTROSPECTOR_ALL_TYPE_DEFINITION, + { + "project": project, + }, + ) + result = _get_data(resp, "project", {}) + return result.get("typedef_list", []) + + +def query_introspector_macro_block( + project: str, source_path: str, line_start: int = 0, line_end: int = 99999 +) -> List[dict]: + """Queries FuzzIntrospector for a full list of custom type definition + including, union, struct, typedef, enum and macro definition.""" + resp = _query_introspector( + INTROSPECTOR_CHECK_MACRO, + { + "project": project, + "source": source_path, + "start": line_start, + "end": line_end, + }, + ) + result = _get_data(resp, "project", {}) + return result.get("macro_block_info", []) + + +def query_introspector_cross_references(project: str, func_sig: str) -> list[str]: + """Queries FuzzIntrospector API for source code of functions + which reference |func_sig|.""" + resp = _query_introspector( + INTROSPECTOR_XREF, {"project": project, "function_signature": func_sig} + ) + call_sites = _get_data(resp, "callsites", []) + + xref_source = [] + for cs in call_sites: + name = cs.get("src_func") + sig = query_introspector_function_signature(project, name) + source = query_introspector_function_source(project, sig) + xref_source.append(source) + return xref_source def query_introspector_language_stats() -> dict: - """Queries introspector for language stats""" + """Queries introspector for language stats""" - resp = _query_introspector(INTROSPECTOR_LANGUAGE_STATS, {}) - return _get_data(resp, 'stats', {}) + resp = _query_introspector(INTROSPECTOR_LANGUAGE_STATS, {}) + return _get_data(resp, "stats", {}) def query_introspector_type_info(project: str, type_name: str) -> list[dict]: - """Queries FuzzIntrospector API for information of |type_name|.""" - resp = _query_introspector(INTROSPECTOR_TYPE, { - 'project': project, - 'type_name': type_name - }) - return _get_data(resp, 'type_data', []) + """Queries FuzzIntrospector API for information of |type_name|.""" + resp = _query_introspector( + INTROSPECTOR_TYPE, {"project": project, "type_name": type_name} + ) + return _get_data(resp, "type_data", []) -def query_introspector_function_signature(project: str, - function_name: str) -> str: - """Queries FuzzIntrospector API for signature of |function_name|.""" - resp = _query_introspector(INTROSPECTOR_FUNC_SIG, { - 'project': project, - 'function': function_name - }) - return _get_data(resp, 'signature', '') +def query_introspector_function_signature(project: str, function_name: str) -> str: + """Queries FuzzIntrospector API for signature of |function_name|.""" + resp = _query_introspector( + INTROSPECTOR_FUNC_SIG, {"project": project, "function": function_name} + ) + return _get_data(resp, "signature", "") def query_introspector_addr_type_info(project: str, addr: str) -> str: - """Queries FuzzIntrospector API for type information for a type - identified by its address used during compilation.""" - resp = _query_introspector(INTROSPECTOR_ADDR_TYPE, { - 'project': project, - 'addr': addr - }) + """Queries FuzzIntrospector API for type information for a type + identified by its address used during compilation.""" + resp = _query_introspector( + INTROSPECTOR_ADDR_TYPE, {"project": project, "addr": addr} + ) - return _get_data(resp, 'dwarf-map', '') + return _get_data(resp, "dwarf-map", "") def get_next_generated_benchmarks_dir() -> str: - """Retuns the next folder to be used for generated benchmarks.""" - max_idx = -1 - # When generating benchmarks dynamically sometimes we may not have a - # benchmark folder, as the command will be run from an arbitrary directory. - # Create the benchmark folder if this is the case. - if not os.path.isdir(BENCHMARK_ROOT): - os.makedirs(BENCHMARK_ROOT) - for benchmark_folder in os.listdir(BENCHMARK_ROOT): - try: - max_idx = max(max_idx, - int(benchmark_folder.replace(GENERATED_BENCHMARK, ''))) - except (ValueError, TypeError) as _: - pass - max_idx += 1 - return os.path.join(BENCHMARK_ROOT, f'{GENERATED_BENCHMARK}{max_idx}') + """Retuns the next folder to be used for generated benchmarks.""" + max_idx = -1 + # When generating benchmarks dynamically sometimes we may not have a + # benchmark folder, as the command will be run from an arbitrary directory. + # Create the benchmark folder if this is the case. + if not os.path.isdir(BENCHMARK_ROOT): + os.makedirs(BENCHMARK_ROOT) + for benchmark_folder in os.listdir(BENCHMARK_ROOT): + try: + max_idx = max( + max_idx, int(benchmark_folder.replace(GENERATED_BENCHMARK, "")) + ) + except (ValueError, TypeError) as _: + pass + max_idx += 1 + return os.path.join(BENCHMARK_ROOT, f"{GENERATED_BENCHMARK}{max_idx}") def query_introspector_target_function(project: str, function: str) -> dict: - resp = _query_introspector(INTROSPECTOR_GET_TARGET_FUNCTION, { - 'project': project, - 'function': function - }) + resp = _query_introspector( + INTROSPECTOR_GET_TARGET_FUNCTION, {"project": project, "function": function} + ) - return _get_data(resp, 'function', {}) + return _get_data(resp, "function", {}) def query_introspector_for_far_reach_low_cov(project): - functions = query_introspector_oracle(project, INTROSPECTOR_ORACLE_FAR_REACH) - return functions + functions = query_introspector_oracle(project, INTROSPECTOR_ORACLE_FAR_REACH) + return functions def demangle(name: str) -> str: - return subprocess.run(['c++filt', name], - check=True, - capture_output=True, - stdin=subprocess.DEVNULL, - text=True).stdout.strip() + return subprocess.run( + ["c++filt", name], + check=True, + capture_output=True, + stdin=subprocess.DEVNULL, + text=True, + ).stdout.strip() def clean_type(name: str) -> str: - """Fix comment function type mistakes from FuzzIntrospector.""" - if name == 'N/A': - # Seems to be a bug in introspector: - # https://github.com/ossf/fuzz-introspector/issues/1188 - return 'bool ' - - name = name.replace('struct.', 'struct ') - name = name.replace('class.', '') - name = name.replace('__1::basic_', '') - name = name.replace('__1::', '') - # Introspector sometimes includes numeric suffixes to struct names. - name = re.sub(r'(\.\d+)+(\s*\*)$', r'\2', name) - name.strip() - return name + """Fix comment function type mistakes from FuzzIntrospector.""" + if name == "N/A": + # Seems to be a bug in introspector: + # https://github.com/ossf/fuzz-introspector/issues/1188 + return "bool " + + name = name.replace("struct.", "struct ") + name = name.replace("class.", "") + name = name.replace("__1::basic_", "") + name = name.replace("__1::", "") + # Introspector sometimes includes numeric suffixes to struct names. + name = re.sub(r"(\.\d+)+(\s*\*)$", r"\2", name) + name.strip() + return name def _get_raw_return_type(function: dict, project: str) -> str: - """Returns the raw function type.""" - return_type = function.get('return-type') or function.get('return_type', '') - if not return_type: - logger.error( - 'Missing return type in project: %s\n' - ' raw_function_name: %s', project, - get_raw_function_name(function, project)) - return return_type + """Returns the raw function type.""" + return_type = function.get("return-type") or function.get("return_type", "") + if not return_type: + logger.error( + "Missing return type in project: %s\n" " raw_function_name: %s", + project, + get_raw_function_name(function, project), + ) + return return_type def _get_clean_return_type(function: dict, project: str) -> str: - """Returns the cleaned function type.""" - raw_return_type = _get_raw_return_type(function, project).strip() - if raw_return_type == 'N/A': - # Bug in introspector: Unable to distinguish between bool and void right - # now. More likely to be void for function return arguments. - return 'void' - return clean_type(raw_return_type) + """Returns the cleaned function type.""" + raw_return_type = _get_raw_return_type(function, project).strip() + if raw_return_type == "N/A": + # Bug in introspector: Unable to distinguish between bool and void right + # now. More likely to be void for function return arguments. + return "void" + return clean_type(raw_return_type) def get_raw_function_name(function: dict, project: str) -> str: - """Returns the raw function name.""" - raw_name = (function.get('raw-function-name') or - function.get('raw_function_name', '')) - if not raw_name: - logger.error('No raw function name in project: %s for function: %s', - project, function) - return raw_name + """Returns the raw function name.""" + raw_name = function.get("raw-function-name") or function.get( + "raw_function_name", "" + ) + if not raw_name: + logger.error( + "No raw function name in project: %s for function: %s", project, function + ) + return raw_name def _get_clean_arg_types(function: dict, project: str) -> list[str]: - """Returns the cleaned function argument types.""" - raw_arg_types = (function.get('arg-types') or - function.get('function_arguments', [])) - if not raw_arg_types: - logger.error( - 'Missing argument types in project: %s\n' - ' raw_function_name: %s', project, - get_raw_function_name(function, project)) - return [clean_type(arg_type) for arg_type in raw_arg_types] + """Returns the cleaned function argument types.""" + raw_arg_types = function.get("arg-types") or function.get("function_arguments", []) + if not raw_arg_types: + logger.error( + "Missing argument types in project: %s\n" " raw_function_name: %s", + project, + get_raw_function_name(function, project), + ) + return [clean_type(arg_type) for arg_type in raw_arg_types] def _get_arg_count(function: dict) -> int: - """Count the number of arguments for this function.""" - raw_arg_types = (function.get('arg-types') or - function.get('function_arguments', [])) - return len(raw_arg_types) + """Count the number of arguments for this function.""" + raw_arg_types = function.get("arg-types") or function.get("function_arguments", []) + return len(raw_arg_types) def _get_arg_names(function: dict, project: str, language: str) -> list[str]: - """Returns the function argument names.""" - if language == 'jvm': - # The fuzz-introspector front end of JVM projects cannot get the original - # argument name. Thus the argument name here uses arg{Count} as arugment - # name reference. - jvm_args = _get_clean_arg_types(function, project) - arg_names = [f'arg{i}' for i in range(len(jvm_args))] - else: - arg_names = (function.get('arg-names') or - function.get('function_argument_names', [])) - if not arg_names: - logger.error( - 'Missing argument names in project: %s\n' - ' raw_function_name: %s', project, - get_raw_function_name(function, project)) - return arg_names + """Returns the function argument names.""" + if language == "jvm": + # The fuzz-introspector front end of JVM projects cannot get the original + # argument name. Thus the argument name here uses arg{Count} as arugment + # name reference. + jvm_args = _get_clean_arg_types(function, project) + arg_names = [f"arg{i}" for i in range(len(jvm_args))] + else: + arg_names = function.get("arg-names") or function.get( + "function_argument_names", [] + ) + if not arg_names: + logger.error( + "Missing argument names in project: %s\n" " raw_function_name: %s", + project, + get_raw_function_name(function, project), + ) + return arg_names def get_function_signature(function: dict, project: str) -> str: - """Returns the function signature.""" - function_signature = function.get('function_signature', '') - if function_signature == "N/A": - # For JVM projects, the full function signature are the raw function name - return get_raw_function_name(function, project) - if not function_signature: - logger.error( - 'Missing function signature in project: %s\n' - ' raw_function_name: %s', project, - get_raw_function_name(function, project)) - return function_signature + """Returns the function signature.""" + function_signature = function.get("function_signature", "") + if function_signature == "N/A": + # For JVM projects, the full function signature are the raw function name + return get_raw_function_name(function, project) + if not function_signature: + logger.error( + "Missing function signature in project: %s\n" " raw_function_name: %s", + project, + get_raw_function_name(function, project), + ) + return function_signature # TODO(dongge): Remove this function when FI fixes it. def _parse_type_from_raw_tagged_type(tagged_type: str, language: str) -> str: - """Returns type name from |tagged_type| such as struct.TypeA""" - # Assume: Types do not contain dot(.). - # (ascchan): This assumption is wrong on Java projects because - # most full qulified classes name of Java projects have dot(.) to - # identify the package name of the classes. Thus for Java projects, - # this action needed to be skipped until this function is removed. - if language == 'jvm': - return tagged_type - return tagged_type.split('.')[-1] - - -def _group_function_params(param_types: list[str], param_names: list[str], - language: str) -> list[dict[str, str]]: - """Groups the type and name of each parameter.""" - return [{ - 'type': _parse_type_from_raw_tagged_type(param_type, language), - 'name': param_name - } for param_type, param_name in zip(param_types, param_names)] - - -def _select_top_functions_from_oracle(project: str, limit: int, - target_oracle: str, - target_oracles: list[str]) -> OrderedDict: - """Selects the top |limit| functions from |target_oracle|.""" - if target_oracle not in target_oracles or target_oracle == 'test-migration': - return OrderedDict() - - logger.info('Extracting functions using oracle %s.', target_oracle) - functions = query_introspector_for_targets(project, target_oracle)[:limit] - - return OrderedDict((func['function_signature'], func) for func in functions) - - -def _combine_functions(a: list[str], b: list[str], c: list[str], - limit: int) -> list[str]: - """Combines functions from three oracles. Prioritize on a, but include one of - b and c if any.""" - head = a[:limit - 2] - b_in_head = any(i in b for i in head) - c_in_head = any(i in c for i in head) - # Result contains items from b and c and is long enough. - if b_in_head and c_in_head and len(a) >= limit: - return a - - all_functions = a + b + c - - if b_in_head or not b: - add_from_b = [] - else: - add_from_b = [i for i in a[3:] if i in b] - add_from_b = [add_from_b[0]] if add_from_b else [b[0]] - - if c_in_head or not c: - add_from_c = [] - else: - add_from_c = [i for i in a[3:] if i in c] - add_from_c = [add_from_c[0]] if add_from_c else [c[0]] - - combined = set(head + add_from_b + add_from_c) - # Result contains items from b and c, append more until long enough. - for func in all_functions: - if len(combined) >= limit: - continue - combined.add(func) - return list(combined) - - -def _select_functions_from_jvm_oracles(project: str, limit: int, - target_oracles: list[str]) -> list[dict]: - """Selects functions from oracles designated for jvm projects, with - jvm-public-candidates as the prioritised oracle""" - all_functions = OrderedDict() - - if 'jvm-public-candidates' in target_oracles: - # JPC is the primary oracle for JVM projects. If it does exist, all other - # oracles are ignored because the results from all other oracles are subsets - # of the results from JPC oracle for JVM projects. - target_oracles = ['jvm-public-candidates'] - - for target_oracle in target_oracles: - tmp_functions = _select_top_functions_from_oracle(project, limit, - target_oracle, - target_oracles) - all_functions.update(tmp_functions) - - return list(all_functions.values())[:limit] - - -def _select_functions_from_oracles(project: str, limit: int, - target_oracles: list[str]) -> list[dict]: - """Selects function-under-test from oracles.""" - all_functions = OrderedDict() - frlc_targets = _select_top_functions_from_oracle(project, limit, - 'far-reach-low-coverage', - target_oracles) - # FRLC is the primary oracle. If it does not exist, follow oracle order and - # deduplicate. - if not frlc_targets: - for target_oracle in target_oracles: - tmp_functions = _select_top_functions_from_oracle(project, limit, - target_oracle, - target_oracles) - all_functions.update(tmp_functions) + """Returns type name from |tagged_type| such as struct.TypeA""" + # Assume: Types do not contain dot(.). + # (ascchan): This assumption is wrong on Java projects because + # most full qulified classes name of Java projects have dot(.) to + # identify the package name of the classes. Thus for Java projects, + # this action needed to be skipped until this function is removed. + if language == "jvm": + return tagged_type + return tagged_type.split(".")[-1] + + +def _group_function_params( + param_types: list[str], param_names: list[str], language: str +) -> list[dict[str, str]]: + """Groups the type and name of each parameter.""" + return [ + { + "type": _parse_type_from_raw_tagged_type(param_type, language), + "name": param_name, + } + for param_type, param_name in zip(param_types, param_names) + ] - return list(all_functions.values())[:limit] - # Selection rule: Prioritize on far-reach-low-coverage, but include one of - # optimal-targets, easy-params-far-reach if any. - all_functions.update(frlc_targets) +def _select_top_functions_from_oracle( + project: str, limit: int, target_oracle: str, target_oracles: list[str] +) -> OrderedDict: + """Selects the top |limit| functions from |target_oracle|.""" + if target_oracle not in target_oracles or target_oracle == "test-migration": + return OrderedDict() + + logger.info("Extracting functions using oracle %s.", target_oracle) + functions = query_introspector_for_targets(project, target_oracle)[:limit] + + return OrderedDict((func["function_signature"], func) for func in functions) + + +def _combine_functions( + a: list[str], b: list[str], c: list[str], limit: int +) -> list[str]: + """Combines functions from three oracles. Prioritize on a, but include one of + b and c if any.""" + head = a[: limit - 2] + b_in_head = any(i in b for i in head) + c_in_head = any(i in c for i in head) + # Result contains items from b and c and is long enough. + if b_in_head and c_in_head and len(a) >= limit: + return a + + all_functions = a + b + c + + if b_in_head or not b: + add_from_b = [] + else: + add_from_b = [i for i in a[3:] if i in b] + add_from_b = [add_from_b[0]] if add_from_b else [b[0]] + + if c_in_head or not c: + add_from_c = [] + else: + add_from_c = [i for i in a[3:] if i in c] + add_from_c = [add_from_c[0]] if add_from_c else [c[0]] + + combined = set(head + add_from_b + add_from_c) + # Result contains items from b and c, append more until long enough. + for func in all_functions: + if len(combined) >= limit: + continue + combined.add(func) + return list(combined) + + +def _select_functions_from_jvm_oracles( + project: str, limit: int, target_oracles: list[str] +) -> list[dict]: + """Selects functions from oracles designated for jvm projects, with + jvm-public-candidates as the prioritised oracle""" + all_functions = OrderedDict() + + if "jvm-public-candidates" in target_oracles: + # JPC is the primary oracle for JVM projects. If it does exist, all other + # oracles are ignored because the results from all other oracles are subsets + # of the results from JPC oracle for JVM projects. + target_oracles = ["jvm-public-candidates"] - epfr_targets = _select_top_functions_from_oracle(project, limit, - 'easy-params-far-reach', - target_oracles) - all_functions.update(epfr_targets) + for target_oracle in target_oracles: + tmp_functions = _select_top_functions_from_oracle( + project, limit, target_oracle, target_oracles + ) + all_functions.update(tmp_functions) - ot_targets = _select_top_functions_from_oracle(project, limit, - 'optimal-targets', - target_oracles) - all_functions.update(ot_targets) + return list(all_functions.values())[:limit] - selected_singatures = _combine_functions(list(frlc_targets.keys()), - list(epfr_targets.keys()), - list(ot_targets.keys()), limit) - return [all_functions[func] for func in selected_singatures] +def _select_functions_from_oracles( + project: str, limit: int, target_oracles: list[str] +) -> list[dict]: + """Selects function-under-test from oracles.""" + all_functions = OrderedDict() + frlc_targets = _select_top_functions_from_oracle( + project, limit, "far-reach-low-coverage", target_oracles + ) + # FRLC is the primary oracle. If it does not exist, follow oracle order and + # deduplicate. + if not frlc_targets: + for target_oracle in target_oracles: + tmp_functions = _select_top_functions_from_oracle( + project, limit, target_oracle, target_oracles + ) + all_functions.update(tmp_functions) + + return list(all_functions.values())[:limit] + + # Selection rule: Prioritize on far-reach-low-coverage, but include one of + # optimal-targets, easy-params-far-reach if any. + all_functions.update(frlc_targets) + + epfr_targets = _select_top_functions_from_oracle( + project, limit, "easy-params-far-reach", target_oracles + ) + all_functions.update(epfr_targets) + + ot_targets = _select_top_functions_from_oracle( + project, limit, "optimal-targets", target_oracles + ) + all_functions.update(ot_targets) + + selected_singatures = _combine_functions( + list(frlc_targets.keys()), + list(epfr_targets.keys()), + list(ot_targets.keys()), + limit, + ) + + return [all_functions[func] for func in selected_singatures] def _get_harness_intrinsics( - project, - filenames, - language='') -> tuple[Optional[str], Optional[str], Dict[str, str]]: - """Returns a harness source path and executable from a given project.""" - if USE_FI_TO_GET_TARGETS and language != 'jvm' and language != 'python': - harnesses = query_introspector_for_harness_intrinsics(project) - if not harnesses: - logger.error('No harness/source pairs found in project.') - return None, None, {} - - harness_dict = harnesses[0] - harness = harness_dict['source'] - target_name = harness_dict['executable'] - interesting_files = {} - else: - harnesses, interesting_files = project_src.search_source( - project, filenames, language) - harness = pick_one(harnesses) - if not harness: - logger.error('No fuzz target found in project %s.', project) - return None, None, {} - target_name = get_target_name(project, harness) - - logger.info('Fuzz target file found for project %s: %s', project, harness) - logger.info('Fuzz target binary found for project %s: %s', project, - target_name) - - return harness, target_name, interesting_files + project, filenames, language="" +) -> tuple[Optional[str], Optional[str], Dict[str, str]]: + """Returns a harness source path and executable from a given project.""" + if USE_FI_TO_GET_TARGETS and language != "jvm" and language != "python": + harnesses = query_introspector_for_harness_intrinsics(project) + if not harnesses: + logger.error("No harness/source pairs found in project.") + return None, None, {} + + harness_dict = harnesses[0] + harness = harness_dict["source"] + target_name = harness_dict["executable"] + interesting_files = {} + else: + harnesses, interesting_files = project_src.search_source( + project, filenames, language + ) + harness = pick_one(harnesses) + if not harness: + logger.error("No fuzz target found in project %s.", project) + return None, None, {} + target_name = get_target_name(project, harness) + + logger.info("Fuzz target file found for project %s: %s", project, harness) + logger.info("Fuzz target binary found for project %s: %s", project, target_name) + + return harness, target_name, interesting_files def populate_benchmarks_using_test_migration( - project: str, language: str, limit: int) -> list[benchmarklib.Benchmark]: - """Populates benchmarks using tests for test-to-harness conversion.""" - - harness, target_name, _ = _get_harness_intrinsics(project, [], language) - if not harness: - return [] - - logger.info('Using harness path %s', harness) - potential_benchmarks = [] - test_files = query_introspector_for_tests(project) - for test_file in test_files: - potential_benchmarks.append( - benchmarklib.Benchmark(benchmark_id='cli', - project=project, - language=language, - function_signature='test-file', - function_name='test-file', - return_type='test', - params=[], - target_path=harness, - preferred_target_name=target_name, - test_file_path=test_file)) - return potential_benchmarks[:limit] + project: str, language: str, limit: int +) -> list[benchmarklib.Benchmark]: + """Populates benchmarks using tests for test-to-harness conversion.""" + harness, target_name, _ = _get_harness_intrinsics(project, [], language) + if not harness: + return [] + + logger.info("Using harness path %s", harness) + potential_benchmarks = [] + test_files = query_introspector_for_tests(project) + for test_file in test_files: + potential_benchmarks.append( + benchmarklib.Benchmark( + benchmark_id="cli", + project=project, + language=language, + function_signature="test-file", + function_name="test-file", + return_type="test", + params=[], + target_path=harness, + preferred_target_name=target_name, + test_file_path=test_file, + ) + ) + return potential_benchmarks[:limit] -def generate_benchmark_for_targeted_function(project: str, function_name: str): - """generates a benchmark for a single function.""" - function_dict = query_introspector_target_function(project, function_name) - project_lang = oss_fuzz_checkout.get_project_language(project) - - harness, target_name, _ = _get_harness_intrinsics(project, [], project_lang) - if not harness: - return '' - target_benchmarks = [ - benchmarklib.Benchmark( - benchmark_id='cli', - project=project, - language=project_lang, - function_signature=function_dict.get('function_signature', ''), - function_name=get_raw_function_name(function_dict, project), - return_type=_get_clean_return_type(function_dict, project), - params=_group_function_params( - _get_clean_arg_types(function_dict, project), - _get_arg_names(function_dict, project, project_lang), - project_lang), - target_path=harness, - preferred_target_name=target_name, - function_dict=function_dict) - ] - - benchmark_dir = get_next_generated_benchmarks_dir() - os.makedirs(benchmark_dir) - benchmarklib.Benchmark.to_yaml(target_benchmarks, outdir=benchmark_dir) - return benchmark_dir - - -def populate_benchmarks_using_introspector(project: str, language: str, - limit: int, - target_oracles: List[str]): - """Populates benchmark YAML files from the data from FuzzIntrospector.""" - - potential_benchmarks = [] - for target_oracle in target_oracles: - if 'test-migration' in target_oracle: - potential_benchmarks.extend( - populate_benchmarks_using_test_migration(project, language, limit)) - - if language == 'jvm': - functions = _select_functions_from_jvm_oracles(project, limit, - target_oracles) - else: - functions = _select_functions_from_oracles(project, limit, target_oracles) - - if not functions: - return potential_benchmarks - if language == 'jvm': - filenames = [ - f'{function["function_filename"].split("$")[0].replace(".", "/")}.java' - for function in functions - ] - elif language == 'python': - filenames = [ - (f'{function["function_filename"].replace("...", "").replace(".", "/")}' - '.py') for function in functions - ] - else: - filenames = [ - os.path.basename(function['function_filename']) - for function in functions - ] +def generate_benchmark_for_targeted_function(project: str, function_name: str): + """generates a benchmark for a single function.""" + function_dict = query_introspector_target_function(project, function_name) + project_lang = oss_fuzz_checkout.get_project_language(project) - harness, target_name, interesting = _get_harness_intrinsics( - project, filenames, language) - if not harness: - return [] - - for function in functions: - if _get_arg_count(function) == 0: - # Skipping functions / methods that does not take in any arguments. - # Those functions / methods are not fuzz-worthy. - continue - - filename = os.path.basename(function['function_filename']) - - if language == 'python': - if filename.startswith('...'): - # Filename of python fuzzers always starts with ... - # Skipping them - continue - if _get_arg_count(function) == 1 and _get_arg_names( - function, project, language)[0] == 'self': - # If a python function has only 1 arugment and the argument name - # is 'self', it means that it is an instance function with no - # arguments. Thus skipping it. - continue - - elif language == 'jvm': - # Retrieve list of source file from introspector - src_path_list = query_introspector_jvm_source_path(project) - if src_path_list: - # For all JVM projects, the full class name is stored in the filename - # field. The full class name includes the package of the class and that - # forms part of the directory pattern of the source file that is needed - # for checking. For example, the source file of class a.b.c.d is always - # stored as /a/b/c/d.java - if filename.endswith('.java'): - src_file = filename - else: - src_file = f'{filename.replace(".", "/")}.java' - - if not any(src_path.endswith(src_file) for src_path in src_path_list): - logger.error('error: %s %s', filename, interesting.keys()) - continue - - elif (language not in ['rust'] and interesting and - filename not in [os.path.basename(i) for i in interesting.keys()]): - # TODO: Bazel messes up paths to include "/proc/self/cwd/..." - logger.error('error: %s %s', filename, interesting.keys()) - continue - - function_signature = get_function_signature(function, project) - if not function_signature: - continue - logger.info('Function signature to fuzz: %s', function_signature) - potential_benchmarks.append( + harness, target_name, _ = _get_harness_intrinsics(project, [], project_lang) + if not harness: + return "" + target_benchmarks = [ benchmarklib.Benchmark( - benchmark_id='cli', + benchmark_id="cli", project=project, - language=language, - function_signature=function_signature, - function_name=get_raw_function_name(function, project), - return_type=_get_clean_return_type(function, project), + language=project_lang, + function_signature=function_dict.get("function_signature", ""), + function_name=get_raw_function_name(function_dict, project), + return_type=_get_clean_return_type(function_dict, project), params=_group_function_params( - _get_clean_arg_types(function, project), - _get_arg_names(function, project, language), language), + _get_clean_arg_types(function_dict, project), + _get_arg_names(function_dict, project, project_lang), + project_lang, + ), target_path=harness, preferred_target_name=target_name, - function_dict=function)) + function_dict=function_dict, + ) + ] + + benchmark_dir = get_next_generated_benchmarks_dir() + os.makedirs(benchmark_dir) + benchmarklib.Benchmark.to_yaml(target_benchmarks, outdir=benchmark_dir) + return benchmark_dir - if len(potential_benchmarks) >= (limit * len(target_oracles)): - break - logger.info('Length of potential targets: %d', len(potential_benchmarks)) - return potential_benchmarks +def populate_benchmarks_using_introspector( + project: str, language: str, limit: int, target_oracles: List[str] +): + """Populates benchmark YAML files from the data from FuzzIntrospector.""" + + potential_benchmarks = [] + for target_oracle in target_oracles: + if "test-migration" in target_oracle: + potential_benchmarks.extend( + populate_benchmarks_using_test_migration(project, language, limit) + ) + + if language == "jvm": + functions = _select_functions_from_jvm_oracles(project, limit, target_oracles) + else: + functions = _select_functions_from_oracles(project, limit, target_oracles) + + if not functions: + return potential_benchmarks + + if language == "jvm": + filenames = [ + f'{function["function_filename"].split("$")[0].replace(".", "/")}.java' + for function in functions + ] + elif language == "python": + filenames = [ + ( + f'{function["function_filename"].replace("...", "").replace(".", "/")}' + ".py" + ) + for function in functions + ] + else: + filenames = [ + os.path.basename(function["function_filename"]) for function in functions + ] + + harness, target_name, interesting = _get_harness_intrinsics( + project, filenames, language + ) + if not harness: + return [] + + for function in functions: + if _get_arg_count(function) == 0: + # Skipping functions / methods that does not take in any arguments. + # Those functions / methods are not fuzz-worthy. + continue + + filename = os.path.basename(function["function_filename"]) + + if language == "python": + if filename.startswith("..."): + # Filename of python fuzzers always starts with ... + # Skipping them + continue + if ( + _get_arg_count(function) == 1 + and _get_arg_names(function, project, language)[0] == "self" + ): + # If a python function has only 1 arugment and the argument name + # is 'self', it means that it is an instance function with no + # arguments. Thus skipping it. + continue + + elif language == "jvm": + # Retrieve list of source file from introspector + src_path_list = query_introspector_jvm_source_path(project) + if src_path_list: + # For all JVM projects, the full class name is stored in the filename + # field. The full class name includes the package of the class and that + # forms part of the directory pattern of the source file that is needed + # for checking. For example, the source file of class a.b.c.d is always + # stored as /a/b/c/d.java + if filename.endswith(".java"): + src_file = filename + else: + src_file = f'{filename.replace(".", "/")}.java' + + if not any(src_path.endswith(src_file) for src_path in src_path_list): + logger.error("error: %s %s", filename, interesting.keys()) + continue + + elif ( + language not in ["rust"] + and interesting + and filename not in [os.path.basename(i) for i in interesting.keys()] + ): + # TODO: Bazel messes up paths to include "/proc/self/cwd/..." + logger.error("error: %s %s", filename, interesting.keys()) + continue + + function_signature = get_function_signature(function, project) + if not function_signature: + continue + logger.info("Function signature to fuzz: %s", function_signature) + potential_benchmarks.append( + benchmarklib.Benchmark( + benchmark_id="cli", + project=project, + language=language, + function_signature=function_signature, + function_name=get_raw_function_name(function, project), + return_type=_get_clean_return_type(function, project), + params=_group_function_params( + _get_clean_arg_types(function, project), + _get_arg_names(function, project, language), + language, + ), + target_path=harness, + preferred_target_name=target_name, + function_dict=function, + ) + ) + + if len(potential_benchmarks) >= (limit * len(target_oracles)): + break + logger.info("Length of potential targets: %d", len(potential_benchmarks)) + + return potential_benchmarks def pick_one(d: dict): - if not d: - return None - return list(d.keys())[0] + if not d: + return None + return list(d.keys())[0] def get_target_name(project_name: str, harness: str) -> Optional[str]: - """Gets the matching target name.""" - summary = query_introspector_cfg(project_name) - for annotated in summary.get('annotated_cfg', []): - if annotated['source_file'] == harness: - return annotated['fuzzer_name'] + """Gets the matching target name.""" + summary = query_introspector_cfg(project_name) + for annotated in summary.get("annotated_cfg", []): + if annotated["source_file"] == harness: + return annotated["fuzzer_name"] - return None + return None ##### Helper logic for downloading fuzz introspector reports. # Download introspector report. def _identify_latest_report(project_name: str): - """Returns the latest summary in the FuzzIntrospector bucket.""" - client = storage.Client.create_anonymous_client() - bucket = client.get_bucket('oss-fuzz-introspector') - blobs = bucket.list_blobs(prefix=project_name) - summaries = sorted( - [blob.name for blob in blobs if blob.name.endswith('summary.json')]) - if summaries: - return ('https://storage.googleapis.com/oss-fuzz-introspector/' - f'{summaries[-1]}') - logger.error('Error: %s has no summary.', project_name) - return None + """Returns the latest summary in the FuzzIntrospector bucket.""" + client = storage.Client.create_anonymous_client() + bucket = client.get_bucket("oss-fuzz-introspector") + blobs = bucket.list_blobs(prefix=project_name) + summaries = sorted( + [blob.name for blob in blobs if blob.name.endswith("summary.json")] + ) + if summaries: + return ( + "https://storage.googleapis.com/oss-fuzz-introspector/" f"{summaries[-1]}" + ) + logger.error("Error: %s has no summary.", project_name) + return None def _extract_introspector_report(project_name): - """Queries and extracts FuzzIntrospector report data of |project_name|.""" - project_url = _identify_latest_report(project_name) - if not project_url: - return None - # Read the introspector artifact. - try: - raw_introspector_json_request = requests.get(project_url, timeout=10) - introspector_report = json.loads(raw_introspector_json_request.text) - except: - return None - return introspector_report + """Queries and extracts FuzzIntrospector report data of |project_name|.""" + project_url = _identify_latest_report(project_name) + if not project_url: + return None + # Read the introspector artifact. + try: + raw_introspector_json_request = requests.get(project_url, timeout=10) + introspector_report = json.loads(raw_introspector_json_request.text) + except: + return None + return introspector_report def _contains_function(funcs: List[Dict], target_func: Dict): - """Returns True if |funcs| contains |target_func|, vice versa.""" - key_fields = ['function-name', 'source-file', 'return-type', 'arg-list'] - for func in funcs: - if all(func.get(field) == target_func.get(field) for field in key_fields): - return True - return False + """Returns True if |funcs| contains |target_func|, vice versa.""" + key_fields = ["function-name", "source-file", "return-type", "arg-list"] + for func in funcs: + if all(func.get(field) == target_func.get(field) for field in key_fields): + return True + return False def _postprocess_function(target_func: dict, project_name: str): - """Post-processes target function.""" - target_func['return-type'] = _get_clean_return_type(target_func, project_name) - target_func['function-name'] = demangle(target_func['function-name']) + """Post-processes target function.""" + target_func["return-type"] = _get_clean_return_type(target_func, project_name) + target_func["function-name"] = demangle(target_func["function-name"]) def get_project_funcs(project_name: str) -> Dict[str, List[Dict]]: - """Fetches the latest fuzz targets and function signatures of |project_name| + """Fetches the latest fuzz targets and function signatures of |project_name| from FuzzIntrospector.""" - introspector_json_report = _extract_introspector_report(project_name) - if introspector_json_report is None: - logger.error('No fuzz introspector report is found.') - return {} - - if introspector_json_report.get('analyses') is None: - logger.error('Error: introspector_json_report has no "analyses"') - return {} - if introspector_json_report.get('analyses').get('AnnotatedCFG') is None: - logger.error( - 'Error: introspector_json_report["analyses"] has no "AnnotatedCFG"') - return {} - - # Group functions by target files. - annotated_cfg = introspector_json_report.get('analyses').get('AnnotatedCFG') - fuzz_target_funcs = {} - for fuzzer in annotated_cfg: - for target_func in annotated_cfg[fuzzer]['destinations']: - # Remove functions where there are no source file, e.g. libc functions - if target_func['source-file'] == '': - continue - - # Group functions by fuzz target source code file, because there may - # be multiple functions in the same fuzz target file. - fuzz_target_file = annotated_cfg[fuzzer]['src_file'] - if fuzz_target_file not in fuzz_target_funcs: - fuzz_target_funcs[fuzz_target_file] = [] - if _contains_function(fuzz_target_funcs[fuzz_target_file], target_func): - continue - _postprocess_function(target_func, project_name) - fuzz_target_funcs[fuzz_target_file].append(target_func) - - # Sort functions in each target file by their complexity. - # Assume the most complex functions are the ones under test, - # put them at the beginning. - for file, funcs in fuzz_target_funcs.items(): - fuzz_target_funcs[file] = sorted( - funcs, key=lambda f: f.get('cyclomatic-complexity'), reverse=True) - return fuzz_target_funcs + introspector_json_report = _extract_introspector_report(project_name) + if introspector_json_report is None: + logger.error("No fuzz introspector report is found.") + return {} + + if introspector_json_report.get("analyses") is None: + logger.error('Error: introspector_json_report has no "analyses"') + return {} + if introspector_json_report.get("analyses").get("AnnotatedCFG") is None: + logger.error( + 'Error: introspector_json_report["analyses"] has no "AnnotatedCFG"' + ) + return {} + + # Group functions by target files. + annotated_cfg = introspector_json_report.get("analyses").get("AnnotatedCFG") + fuzz_target_funcs = {} + for fuzzer in annotated_cfg: + for target_func in annotated_cfg[fuzzer]["destinations"]: + # Remove functions where there are no source file, e.g. libc functions + if target_func["source-file"] == "": + continue + + # Group functions by fuzz target source code file, because there may + # be multiple functions in the same fuzz target file. + fuzz_target_file = annotated_cfg[fuzzer]["src_file"] + if fuzz_target_file not in fuzz_target_funcs: + fuzz_target_funcs[fuzz_target_file] = [] + if _contains_function(fuzz_target_funcs[fuzz_target_file], target_func): + continue + _postprocess_function(target_func, project_name) + fuzz_target_funcs[fuzz_target_file].append(target_func) + + # Sort functions in each target file by their complexity. + # Assume the most complex functions are the ones under test, + # put them at the beginning. + for file, funcs in fuzz_target_funcs.items(): + fuzz_target_funcs[file] = sorted( + funcs, key=lambda f: f.get("cyclomatic-complexity"), reverse=True + ) + return fuzz_target_funcs def _parse_arguments() -> argparse.Namespace: - """Parses command line args.""" - parser = argparse.ArgumentParser( - description='Parse arguments to generate benchmarks.') - - parser.add_argument('project', help='Name of the project.', type=str) - parser.add_argument('-m', - '--max-functions', - type=int, - default=3, - help='Number of benchmarks to generate.') - parser.add_argument('-o', - '--out', - type=str, - default='', - help='Output directory.') - parser.add_argument('-e', - '--endpoint', - type=str, - default=DEFAULT_INTROSPECTOR_ENDPOINT, - help='Fuzz Introspecor API endpoint.') - parser.add_argument('-t', - '--target-oracle', - type=str, - nargs='+', - default=['optimal-targets', 'far-reach-low-coverage'], - help='Oracles used to determine interesting targets.') - - return parser.parse_args() + """Parses command line args.""" + parser = argparse.ArgumentParser( + description="Parse arguments to generate benchmarks." + ) + + parser.add_argument("project", help="Name of the project.", type=str) + parser.add_argument( + "-m", + "--max-functions", + type=int, + default=3, + help="Number of benchmarks to generate.", + ) + parser.add_argument("-o", "--out", type=str, default="", help="Output directory.") + parser.add_argument( + "-e", + "--endpoint", + type=str, + default=DEFAULT_INTROSPECTOR_ENDPOINT, + help="Fuzz Introspecor API endpoint.", + ) + parser.add_argument( + "-t", + "--target-oracle", + type=str, + nargs="+", + default=["optimal-targets", "far-reach-low-coverage"], + help="Oracles used to determine interesting targets.", + ) + + return parser.parse_args() # Set default endpoint. set_introspector_endpoints(DEFAULT_INTROSPECTOR_ENDPOINT) -if __name__ == '__main__': - logging.basicConfig(level=logging.INFO) - - args = _parse_arguments() - if args.out: - os.makedirs(args.out, exist_ok=True) - - set_introspector_endpoints(args.endpoint) - - try: - oss_fuzz_checkout.clone_oss_fuzz() - oss_fuzz_checkout.postprocess_oss_fuzz() - except subprocess.CalledProcessError as e: - logger.error('Failed to prepare OSS-Fuzz directory for project %s: %s', - args.project, e) - cur_project_language = oss_fuzz_checkout.get_project_language(args.project) - benchmarks = populate_benchmarks_using_introspector(args.project, - cur_project_language, - args.max_functions, - args.target_oracle) - if benchmarks: - benchmarklib.Benchmark.to_yaml(benchmarks, outdir=args.out) - else: - logger.error('Nothing found for %s', args.project) - sys.exit(1) +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + args = _parse_arguments() + if args.out: + os.makedirs(args.out, exist_ok=True) + + set_introspector_endpoints(args.endpoint) + + try: + oss_fuzz_checkout.clone_oss_fuzz() + oss_fuzz_checkout.postprocess_oss_fuzz() + except subprocess.CalledProcessError as e: + logger.error( + "Failed to prepare OSS-Fuzz directory for project %s: %s", args.project, e + ) + cur_project_language = oss_fuzz_checkout.get_project_language(args.project) + benchmarks = populate_benchmarks_using_introspector( + args.project, cur_project_language, args.max_functions, args.target_oracle + ) + if benchmarks: + benchmarklib.Benchmark.to_yaml(benchmarks, outdir=args.out) + else: + logger.error("Nothing found for %s", args.project) + sys.exit(1) diff --git a/data_prep/parse_training_data.py b/data_prep/parse_training_data.py index 8d1bcdab57..a7290e0b0e 100755 --- a/data_prep/parse_training_data.py +++ b/data_prep/parse_training_data.py @@ -37,337 +37,368 @@ logger = logging.getLogger(__name__) STORAGE_CLIENT = storage.Client() -FUZZ_TARGET_FIXING_DIR_PATTERN = r'\d+-F\d+' +FUZZ_TARGET_FIXING_DIR_PATTERN = r"\d+-F\d+" class Benchmark: - """The result directory of a benchmark.""" - - def __init__(self, benchmark_dir: str) -> None: - self.benchmark_dir = os.path.abspath(benchmark_dir) - self.benchmark = os.path.basename(benchmark_dir).replace('output-', '', 1) - - @property - def prompt(self) -> str: - """Returns the prompt used by the benchmark.""" - prompt_path = os.path.join(self.benchmark_dir, 'prompt.txt') - if not os.path.isfile(prompt_path): - logger.warning('Prompt does not exist: %s', prompt_path) - return '' - with open(prompt_path) as prompt_file: - return prompt_file.read() - - def _get_code_fixing_dirs(self, fixed_target_dir): - """Gets the directories for fixing fuzz targets.""" - return [ - item for item in os.listdir(fixed_target_dir) - if (os.path.isdir(os.path.join(fixed_target_dir, item)) and - re.match(FUZZ_TARGET_FIXING_DIR_PATTERN, item)) - ] - - @property - def targets(self) -> Dict[str, List[str]]: - """Returns the generated targets of a benchmark in a directory, mapping - the instance ID to a list of targets generated and fixed by LLM.""" - all_targets = {} - raw_target_dir = os.path.join(self.benchmark_dir, 'raw_targets') - if not os.path.isdir(raw_target_dir): - logger.warning('Raw target dir does not exist: %s', raw_target_dir) - return {} - raw_targets = [ - instance for instance in os.listdir(raw_target_dir) - if not instance.endswith('rawoutput') - ] - for instance in raw_targets: - raw_target_path = os.path.join(raw_target_dir, instance) - with open(raw_target_path) as target_file: - all_targets[os.path.splitext(instance)[0]] = [target_file.read()] - - fixed_target_dir = os.path.join(self.benchmark_dir, 'fixed_targets') - if not os.path.isdir(fixed_target_dir): - logger.warning('Fixed target dir does not exist: %s', fixed_target_dir) - return {} - fix_dirs = self._get_code_fixing_dirs(fixed_target_dir) - for fix_dir in sorted(fix_dirs): - instance, _ = fix_dir.split('-F') - code_path = [ - os.path.join(fixed_target_dir, fix_dir, f) - for f in os.listdir(os.path.join(fixed_target_dir, fix_dir)) - if not (f == 'prompt.txt' and f.endswith('rawoutput')) - ][0] - with open(code_path) as code_file: - fixed_code = code_file.read() - if not all_targets.get(instance): - logger.warning('Benchmark instance does not exist: %s - %s', - self.benchmark_dir, instance) - continue - all_targets[instance].append(fixed_code) - return all_targets - - @property - def status(self) -> Dict[str, Dict[str, Any]]: - """Returns the status of all instances of the benchmark, mapping the - instance ID to its status JSON.""" - all_status = {} - status_dir = os.path.join(self.benchmark_dir, 'status') - if not os.path.isdir(status_dir): - logger.warning('Status dir does not exist: %s', status_dir) - return {} - for instance in os.listdir(status_dir): - status_json_path = os.path.join(status_dir, instance, 'result.json') - if not os.path.isfile(status_json_path): - logger.info('Missing result JSON of benchmark instance: %s - %s', - self.benchmark, instance) - continue - with open(status_json_path) as file: - try: - all_status[instance] = json.load(file) - except Exception as e: - logger.warning(e) - logger.warning(status_json_path) - - return all_status - - @property - def is_valid_benchmark(self) -> bool: - """Checks if this has a valid benchmark directory.""" - path = self.benchmark_dir - expected_components = [ - 'raw_targets', 'status', 'fixed_targets', 'prompt.txt' - ] - return all( - os.path.exists(os.path.join(path, component)) - for component in expected_components) - - @staticmethod - def final_score(stat: Dict[str, Any], coverage: bool) -> float: - """Evaluates the final score of a benchmark instance.""" - return stat.get('line_coverage_diff', 0.0) if coverage else float( - stat.get('compiles', 0.0)) - - def organize_group_pointwise(self, - coverage: bool = False - ) -> List[Dict[str, str | List[float]]]: - """Organizes grouped pointwise training data for reward model.""" - data = [] - all_targets = self.targets - prompt = self.prompt - for instance, stat in self.status.items(): - targets = all_targets.get(instance, []) - if not targets: - continue - scores = [0.0] * (len(targets) - 1) + [self.final_score(stat, coverage)] - datum = { - 'prompt': prompt, - 'target': targets, - 'score': [scores], - } - data.append(datum) - return data - - def organize_ungroup_pointwise(self, - coverage: bool = False - ) -> List[Dict[str, str | float]]: - """Organizes ungrouped pointwise training data for reward model.""" - data = [] - all_targets = self.targets - prompt = self.prompt - for instance, stat in self.status.items(): - targets = all_targets.get(instance, []) - data.extend([{ - 'prompt': prompt, - 'target': target, - 'score': 0.0 - } for target in targets[:-1]]) - data.append({ - 'prompt': prompt, - 'target': targets[-1], - 'score': self.final_score(stat, coverage) - }) - return data - - def organize_data(self, coverage: bool, group: bool) -> List[Dict[str, Any]]: - """Organizes benchmark result into training data in the required format.""" - if group: - return self.organize_group_pointwise(coverage) - return self.organize_ungroup_pointwise(coverage) - - def save_json(self, coverage: bool, group: bool, save_dir: str): - """Saves the training data into a JSON file.""" - data = self.organize_data(coverage, group) - coverage_str = 'cov' if coverage else 'build' - group_str = 'group' if group else 'ungroup' - data_filename = (f'{self.benchmark}.{len(data)}.{coverage_str}.{group_str}' - f'.json') - data_filapath = os.path.join(save_dir, data_filename) - with open(data_filapath, 'w') as file: - json.dump(data, file, indent=4) - logger.info('Saved to: %s', data_filapath) + """The result directory of a benchmark.""" + + def __init__(self, benchmark_dir: str) -> None: + self.benchmark_dir = os.path.abspath(benchmark_dir) + self.benchmark = os.path.basename(benchmark_dir).replace("output-", "", 1) + + @property + def prompt(self) -> str: + """Returns the prompt used by the benchmark.""" + prompt_path = os.path.join(self.benchmark_dir, "prompt.txt") + if not os.path.isfile(prompt_path): + logger.warning("Prompt does not exist: %s", prompt_path) + return "" + with open(prompt_path) as prompt_file: + return prompt_file.read() + + def _get_code_fixing_dirs(self, fixed_target_dir): + """Gets the directories for fixing fuzz targets.""" + return [ + item + for item in os.listdir(fixed_target_dir) + if ( + os.path.isdir(os.path.join(fixed_target_dir, item)) + and re.match(FUZZ_TARGET_FIXING_DIR_PATTERN, item) + ) + ] + + @property + def targets(self) -> Dict[str, List[str]]: + """Returns the generated targets of a benchmark in a directory, mapping + the instance ID to a list of targets generated and fixed by LLM.""" + all_targets = {} + raw_target_dir = os.path.join(self.benchmark_dir, "raw_targets") + if not os.path.isdir(raw_target_dir): + logger.warning("Raw target dir does not exist: %s", raw_target_dir) + return {} + raw_targets = [ + instance + for instance in os.listdir(raw_target_dir) + if not instance.endswith("rawoutput") + ] + for instance in raw_targets: + raw_target_path = os.path.join(raw_target_dir, instance) + with open(raw_target_path) as target_file: + all_targets[os.path.splitext(instance)[0]] = [target_file.read()] + + fixed_target_dir = os.path.join(self.benchmark_dir, "fixed_targets") + if not os.path.isdir(fixed_target_dir): + logger.warning("Fixed target dir does not exist: %s", fixed_target_dir) + return {} + fix_dirs = self._get_code_fixing_dirs(fixed_target_dir) + for fix_dir in sorted(fix_dirs): + instance, _ = fix_dir.split("-F") + code_path = [ + os.path.join(fixed_target_dir, fix_dir, f) + for f in os.listdir(os.path.join(fixed_target_dir, fix_dir)) + if not (f == "prompt.txt" and f.endswith("rawoutput")) + ][0] + with open(code_path) as code_file: + fixed_code = code_file.read() + if not all_targets.get(instance): + logger.warning( + "Benchmark instance does not exist: %s - %s", + self.benchmark_dir, + instance, + ) + continue + all_targets[instance].append(fixed_code) + return all_targets + + @property + def status(self) -> Dict[str, Dict[str, Any]]: + """Returns the status of all instances of the benchmark, mapping the + instance ID to its status JSON.""" + all_status = {} + status_dir = os.path.join(self.benchmark_dir, "status") + if not os.path.isdir(status_dir): + logger.warning("Status dir does not exist: %s", status_dir) + return {} + for instance in os.listdir(status_dir): + status_json_path = os.path.join(status_dir, instance, "result.json") + if not os.path.isfile(status_json_path): + logger.info( + "Missing result JSON of benchmark instance: %s - %s", + self.benchmark, + instance, + ) + continue + with open(status_json_path) as file: + try: + all_status[instance] = json.load(file) + except Exception as e: + logger.warning(e) + logger.warning(status_json_path) + + return all_status + + @property + def is_valid_benchmark(self) -> bool: + """Checks if this has a valid benchmark directory.""" + path = self.benchmark_dir + expected_components = ["raw_targets", "status", "fixed_targets", "prompt.txt"] + return all( + os.path.exists(os.path.join(path, component)) + for component in expected_components + ) + + @staticmethod + def final_score(stat: Dict[str, Any], coverage: bool) -> float: + """Evaluates the final score of a benchmark instance.""" + return ( + stat.get("line_coverage_diff", 0.0) + if coverage + else float(stat.get("compiles", 0.0)) + ) + + def organize_group_pointwise( + self, coverage: bool = False + ) -> List[Dict[str, str | List[float]]]: + """Organizes grouped pointwise training data for reward model.""" + data = [] + all_targets = self.targets + prompt = self.prompt + for instance, stat in self.status.items(): + targets = all_targets.get(instance, []) + if not targets: + continue + scores = [0.0] * (len(targets) - 1) + [self.final_score(stat, coverage)] + datum = { + "prompt": prompt, + "target": targets, + "score": [scores], + } + data.append(datum) + return data + + def organize_ungroup_pointwise( + self, coverage: bool = False + ) -> List[Dict[str, str | float]]: + """Organizes ungrouped pointwise training data for reward model.""" + data = [] + all_targets = self.targets + prompt = self.prompt + for instance, stat in self.status.items(): + targets = all_targets.get(instance, []) + data.extend( + [ + {"prompt": prompt, "target": target, "score": 0.0} + for target in targets[:-1] + ] + ) + data.append( + { + "prompt": prompt, + "target": targets[-1], + "score": self.final_score(stat, coverage), + } + ) + return data + + def organize_data(self, coverage: bool, group: bool) -> List[Dict[str, Any]]: + """Organizes benchmark result into training data in the required format.""" + if group: + return self.organize_group_pointwise(coverage) + return self.organize_ungroup_pointwise(coverage) + + def save_json(self, coverage: bool, group: bool, save_dir: str): + """Saves the training data into a JSON file.""" + data = self.organize_data(coverage, group) + coverage_str = "cov" if coverage else "build" + group_str = "group" if group else "ungroup" + data_filename = ( + f"{self.benchmark}.{len(data)}.{coverage_str}.{group_str}" f".json" + ) + data_filapath = os.path.join(save_dir, data_filename) + with open(data_filapath, "w") as file: + json.dump(data, file, indent=4) + logger.info("Saved to: %s", data_filapath) class Experiment: - """The directory of an experiment, containing benchmark result directories.""" - - def __init__(self, experiment_dir: str, bucket_uri: str = '') -> None: - # The local result directory. The directory from bucket_uri will be - # downloaded here if this directory does not contain experiment results. - self.experiment = experiment_dir - # The gcloud bucket result directory uri. It can be an empty string if - # experiment_dir already contains experiment results. - self.bucket_uri = bucket_uri - self.benchmarks = [] - - if bucket_uri: - _download_files(experiment_dir, bucket_uri) - for benchmark_dir in os.listdir(experiment_dir): - benchmark_dir_path = os.path.join(experiment_dir, benchmark_dir) - benchmark = Benchmark(benchmark_dir_path) - if benchmark.is_valid_benchmark: - self.benchmarks.append(benchmark) - - def organize_data(self, coverage: bool, group: bool) -> List[Dict[str, Any]]: - """Organizes experiment result into training data in the required format.""" - data = [] - for benchmark in self.benchmarks: - data.extend(benchmark.organize_data(coverage, group)) - return data - - def save_json(self, coverage: bool, group: bool, save_dir: str) -> None: - """Saves the training data into a JSON file.""" - data = self.organize_data(coverage, group) - group_str = 'group' if group else 'ungroup' - coverage_str = 'cov' if coverage else 'build' - data_filename = (f'{self.experiment}.{len(data)}.{coverage_str}.{group_str}' - f'.json') - data_filapath = os.path.join(save_dir, data_filename) - with open(data_filapath, 'w') as file: - json.dump(data, file, indent=4) - logger.info('Saved to: %s', data_filapath) + """The directory of an experiment, containing benchmark result directories.""" + + def __init__(self, experiment_dir: str, bucket_uri: str = "") -> None: + # The local result directory. The directory from bucket_uri will be + # downloaded here if this directory does not contain experiment results. + self.experiment = experiment_dir + # The gcloud bucket result directory uri. It can be an empty string if + # experiment_dir already contains experiment results. + self.bucket_uri = bucket_uri + self.benchmarks = [] + + if bucket_uri: + _download_files(experiment_dir, bucket_uri) + for benchmark_dir in os.listdir(experiment_dir): + benchmark_dir_path = os.path.join(experiment_dir, benchmark_dir) + benchmark = Benchmark(benchmark_dir_path) + if benchmark.is_valid_benchmark: + self.benchmarks.append(benchmark) + + def organize_data(self, coverage: bool, group: bool) -> List[Dict[str, Any]]: + """Organizes experiment result into training data in the required format.""" + data = [] + for benchmark in self.benchmarks: + data.extend(benchmark.organize_data(coverage, group)) + return data + + def save_json(self, coverage: bool, group: bool, save_dir: str) -> None: + """Saves the training data into a JSON file.""" + data = self.organize_data(coverage, group) + group_str = "group" if group else "ungroup" + coverage_str = "cov" if coverage else "build" + data_filename = ( + f"{self.experiment}.{len(data)}.{coverage_str}.{group_str}" f".json" + ) + data_filapath = os.path.join(save_dir, data_filename) + with open(data_filapath, "w") as file: + json.dump(data, file, indent=4) + logger.info("Saved to: %s", data_filapath) def _parse_gcs_uri(bucket_uri: str) -> tuple[str, str]: - """Parses the bucket name and directory prefix from |bucket_uri|.""" - bucket_name = bucket_uri.removeprefix('gs://').split('/')[0] - directory_prefix = bucket_uri.removeprefix(f'gs://{bucket_name}/') - return bucket_name, directory_prefix + """Parses the bucket name and directory prefix from |bucket_uri|.""" + bucket_name = bucket_uri.removeprefix("gs://").split("/")[0] + directory_prefix = bucket_uri.removeprefix(f"gs://{bucket_name}/") + return bucket_name, directory_prefix def _download_files(experiment_dir: str, bucket_uri: str) -> None: - """ - Downloads files in |bucket_uri| to |experiment_dir| and preserve their paths. - """ - bucket_name, directory_prefix = _parse_gcs_uri(bucket_uri) - bucket = STORAGE_CLIENT.bucket(bucket_name) - blobs = bucket.list_blobs(prefix=directory_prefix) - blobs_num = len(list(blobs)) - # Download blobs in parallel - blobs = bucket.list_blobs(prefix=directory_prefix) - with ThreadPoolExecutor(max_workers=40) as executor: - for i, blob in enumerate(blobs): - logger.info('%d / %d', i, blobs_num) - executor.submit(_download_file, blob, experiment_dir) + """ + Downloads files in |bucket_uri| to |experiment_dir| and preserve their paths. + """ + bucket_name, directory_prefix = _parse_gcs_uri(bucket_uri) + bucket = STORAGE_CLIENT.bucket(bucket_name) + blobs = bucket.list_blobs(prefix=directory_prefix) + blobs_num = len(list(blobs)) + # Download blobs in parallel + blobs = bucket.list_blobs(prefix=directory_prefix) + with ThreadPoolExecutor(max_workers=40) as executor: + for i, blob in enumerate(blobs): + logger.info("%d / %d", i, blobs_num) + executor.submit(_download_file, blob, experiment_dir) def _download_file(file_blob: storage.Blob, local_dir: str) -> None: - """ - Downloads a file from |file_blob| and preserve its path after |bucket_dir|. - """ - if not file_blob.name: - logger.warning('Blob has no name: %s', file_blob) - return - if any( - file_blob.name.endswith(suffix) - for suffix in ['.rawoutput', '.log', 'log.txt']): - return - local_path = os.path.join(local_dir, file_blob.name) - os.makedirs(os.path.dirname(local_path), exist_ok=True) - file_blob.download_to_filename(local_path) + """ + Downloads a file from |file_blob| and preserve its path after |bucket_dir|. + """ + if not file_blob.name: + logger.warning("Blob has no name: %s", file_blob) + return + if any( + file_blob.name.endswith(suffix) for suffix in [".rawoutput", ".log", "log.txt"] + ): + return + local_path = os.path.join(local_dir, file_blob.name) + os.makedirs(os.path.dirname(local_path), exist_ok=True) + file_blob.download_to_filename(local_path) def _validate_bucket(bucket_uri: str) -> bool: - """Checks if the |directory_uri| is local or from a bucket.""" - # Assume we will only use gs:// links for simplicity in directory operations. - return bucket_uri.startswith('gs://') + """Checks if the |directory_uri| is local or from a bucket.""" + # Assume we will only use gs:// links for simplicity in directory operations. + return bucket_uri.startswith("gs://") def _parse_args() -> argparse.Namespace: - """Handles command-line arguments.""" - parser = argparse.ArgumentParser( - description="Parse benchmark data from an HTML file.") - parser.add_argument( - '--coverage', - '-c', - action='store_true', - help=('Use percentage code coverage instead of Boolean build status as ' - 'benchmark score.')) - parser.add_argument('--group', - '-g', - action='store_true', - help='Group targets by their prompt.') - parser.add_argument('--benchmark-dir', - '-b', - type=str, - default='', - help='Path to the benchmark result directory.') - parser.add_argument( - '--experiment-dir', - '-e', - type=str, - default='', - help=('Path to the experiment result directory. When --bucket-uri is ' - 'provided, the bucket directory will be downloaded to this ' - 'directory.')) - parser.add_argument( - '--bucket-uri', - '-u', - help=('URI to the experiment result bucket directory. The bucket ' - 'directory will be downloaded to local --experiment-dir.')) - parser.add_argument('--save-dir', - '-s', - type=str, - default='', - help='Path to the directory for saving json result.') - args = parser.parse_args() - - if args.benchmark_dir: - args.benchmark_dir = args.benchmark_dir.rstrip('/') - if args.experiment_dir: - args.experiment_dir = args.experiment_dir.rstrip('/') - - assert bool(args.benchmark_dir) != bool(args.experiment_dir), ( - 'Need exactly one directory of a benchmark or an experiment.') - - result_dir = args.benchmark_dir or args.experiment_dir - assert os.path.isdir(result_dir), ( - f'{result_dir} needs to be an existing directory.') - - if args.bucket_uri: - assert _validate_bucket(args.bucket_uri), ( - f'{args.bucket_uri} is an invalid bucket directory URL.') - assert not os.path.isdir(args.benchmark_dir), ( - 'Downloading bucket directory will overwrite existing local dir ' - f'{args.benchmark_dir}') - - if args.save_dir: - os.makedirs(args.save_dir, exist_ok=True) - return args + """Handles command-line arguments.""" + parser = argparse.ArgumentParser( + description="Parse benchmark data from an HTML file." + ) + parser.add_argument( + "--coverage", + "-c", + action="store_true", + help=( + "Use percentage code coverage instead of Boolean build status as " + "benchmark score." + ), + ) + parser.add_argument( + "--group", "-g", action="store_true", help="Group targets by their prompt." + ) + parser.add_argument( + "--benchmark-dir", + "-b", + type=str, + default="", + help="Path to the benchmark result directory.", + ) + parser.add_argument( + "--experiment-dir", + "-e", + type=str, + default="", + help=( + "Path to the experiment result directory. When --bucket-uri is " + "provided, the bucket directory will be downloaded to this " + "directory." + ), + ) + parser.add_argument( + "--bucket-uri", + "-u", + help=( + "URI to the experiment result bucket directory. The bucket " + "directory will be downloaded to local --experiment-dir." + ), + ) + parser.add_argument( + "--save-dir", + "-s", + type=str, + default="", + help="Path to the directory for saving json result.", + ) + args = parser.parse_args() + + if args.benchmark_dir: + args.benchmark_dir = args.benchmark_dir.rstrip("/") + if args.experiment_dir: + args.experiment_dir = args.experiment_dir.rstrip("/") + + assert bool(args.benchmark_dir) != bool( + args.experiment_dir + ), "Need exactly one directory of a benchmark or an experiment." + + result_dir = args.benchmark_dir or args.experiment_dir + assert os.path.isdir(result_dir), f"{result_dir} needs to be an existing directory." + + if args.bucket_uri: + assert _validate_bucket( + args.bucket_uri + ), f"{args.bucket_uri} is an invalid bucket directory URL." + assert not os.path.isdir(args.benchmark_dir), ( + "Downloading bucket directory will overwrite existing local dir " + f"{args.benchmark_dir}" + ) + + if args.save_dir: + os.makedirs(args.save_dir, exist_ok=True) + return args def main() -> int: - """Main function to and initiate the parsing process.""" - args = _parse_args() - if args.benchmark_dir: - result = Benchmark(args.benchmark_dir) - if not result.is_valid_benchmark: - logger.info( - 'Invalid benchmark directory provided, missing necessary file.') - elif args.experiment_dir: - result = Experiment(args.experiment_dir, args.bucket_uri) - else: - return 1 - result.save_json(args.coverage, args.group, args.save_dir) - return 0 + """Main function to and initiate the parsing process.""" + args = _parse_args() + if args.benchmark_dir: + result = Benchmark(args.benchmark_dir) + if not result.is_valid_benchmark: + logger.info("Invalid benchmark directory provided, missing necessary file.") + elif args.experiment_dir: + result = Experiment(args.experiment_dir, args.bucket_uri) + else: + return 1 + result.save_json(args.coverage, args.group, args.save_dir) + return 0 if __name__ == "__main__": - sys.exit(main()) + sys.exit(main()) diff --git a/data_prep/project_context/context_introspector.py b/data_prep/project_context/context_introspector.py index 2f24f71b3e..ac8e08ab14 100644 --- a/data_prep/project_context/context_introspector.py +++ b/data_prep/project_context/context_introspector.py @@ -24,310 +24,355 @@ logger = logging.getLogger(__name__) -COMPLEX_TYPES = ['const', 'enum', 'struct', 'union', 'volatile'] +COMPLEX_TYPES = ["const", "enum", "struct", "union", "volatile"] class ContextRetriever: - """Class to retrieve context from introspector for - better prompt generation.""" - - def __init__(self, benchmark: benchmarklib.Benchmark): - """Constructor.""" - self._benchmark = benchmark - - def _get_embeddable_declaration(self) -> str: - """Retrieves declaration by language. Attach extern C if needed.""" - lang = self._benchmark.language.lower() - sig = self._benchmark.function_signature + ';' - - if self._benchmark.needs_extern: - return 'extern "C" ' + sig - - if lang != 'c++': - logging.warning('Unsupported decl - Lang: %s Project: %s', lang, - self._benchmark.project) - - return sig - - def _get_nested_item(self, element: dict, *path: str) -> Any: - """Safely retrieve a nested item from a dictionary without - throwing an error. Logs whenever an item can not be found - with a given key.""" - nested_item = element - - for key in path: - next_nested_item = nested_item.get(key, '') - if not next_nested_item: - logging.warning('Missing item "%s" in object: %s', key, nested_item) - nested_item = next_nested_item - - return nested_item - - def _get_source_line(self, item: dict) -> int: - return int(self._get_nested_item(item, 'source', 'source_line')) - - def _get_source_file(self, item: dict) -> str: - return self._get_nested_item(item, 'source', 'source_file') - - def _get_files_to_include(self) -> list[str]: - """Retrieves files to include. - These files are found from the source files for complex types seen - in the function declaration.""" - types = [] - files = set() - types.append(self._clean_type(self._benchmark.return_type)) - - params = self._benchmark.params - - for param in params: - cleaned_type = self._clean_type(param['type']) - if cleaned_type: - types.append(cleaned_type) - - for current_type in types: - info_list = introspector.query_introspector_type_info( - self._benchmark.project, current_type) - if not info_list: - logging.warning('Could not retrieve info for project: %s type: %s', - self._benchmark.project, current_type) - continue - - for info in info_list: + """Class to retrieve context from introspector for + better prompt generation.""" + + def __init__(self, benchmark: benchmarklib.Benchmark): + """Constructor.""" + self._benchmark = benchmark + + def _get_embeddable_declaration(self) -> str: + """Retrieves declaration by language. Attach extern C if needed.""" + lang = self._benchmark.language.lower() + sig = self._benchmark.function_signature + ";" + + if self._benchmark.needs_extern: + return 'extern "C" ' + sig + + if lang != "c++": + logging.warning( + "Unsupported decl - Lang: %s Project: %s", lang, self._benchmark.project + ) + + return sig + + def _get_nested_item(self, element: dict, *path: str) -> Any: + """Safely retrieve a nested item from a dictionary without + throwing an error. Logs whenever an item can not be found + with a given key.""" + nested_item = element + + for key in path: + next_nested_item = nested_item.get(key, "") + if not next_nested_item: + logging.warning('Missing item "%s" in object: %s', key, nested_item) + nested_item = next_nested_item + + return nested_item + + def _get_source_line(self, item: dict) -> int: + return int(self._get_nested_item(item, "source", "source_line")) + + def _get_source_file(self, item: dict) -> str: + return self._get_nested_item(item, "source", "source_file") + + def _get_files_to_include(self) -> list[str]: + """Retrieves files to include. + These files are found from the source files for complex types seen + in the function declaration.""" + types = [] + files = set() + types.append(self._clean_type(self._benchmark.return_type)) + + params = self._benchmark.params + + for param in params: + cleaned_type = self._clean_type(param["type"]) + if cleaned_type: + types.append(cleaned_type) + + for current_type in types: + info_list = introspector.query_introspector_type_info( + self._benchmark.project, current_type + ) + if not info_list: + logging.warning( + "Could not retrieve info for project: %s type: %s", + self._benchmark.project, + current_type, + ) + continue + + for info in info_list: + include_file = self._get_source_file(info) + include_file = os.path.normpath(include_file) + include_base = os.path.basename(include_file) + + # Ensure include_file is a file. + if not include_base or "." not in include_base: + logging.warning( + "File %s found as a source path for project: %s", + include_file, + self._benchmark.project, + ) + continue + # Ensure it is a header file (suffix starting with .h). + if include_base.endswith((".h", ".hxx", ".hpp")): + logging.warning( + "File found with unexpected suffix %s for project: %s", + include_file, + self._benchmark.project, + ) + continue + # Remove "system" header files. + # Assuming header files under /usr/ are irrelevant. + if include_file.startswith("/usr/"): + logging.debug("Header file removed: %s", include_file) + continue + # TODO: Dynamically adjust path prefixes + # (e.g. based on existing fuzz targets). + + files.add(include_file) + + return list(files) + + def _clean_type(self, type_name: str) -> str: + """Cleans a type so that it can be fetched from FI.""" + if not type_name: + return type_name + + if "*" in type_name: + type_name = type_name.replace("*", "") + + type_tokens = type_name.split(" ") + + # Could be a trailing space after the pointer is removed + if "" in type_tokens: + type_tokens.remove("") + + for complex_type in COMPLEX_TYPES: + if complex_type in type_tokens: + type_tokens.remove(complex_type) + + # If there is more than a single token + # we probably do not care about querying for the type (?) + # E.g. unsigned [...], long [...], short [...], ... + # as they're most likely builtin. + if len(type_tokens) > 1: + logging.debug("Tokens: %s", type_tokens) + return "" + + return type_tokens[0] + + def _get_function_implementation(self) -> str: + """Queries FI for the source code of function being fuzzed.""" + project = self._benchmark.project + func_sig = self._benchmark.function_signature + function_source = introspector.query_introspector_function_source( + project, func_sig + ) + + if not function_source: + logging.warning( + "Could not retrieve function source for project: %s " + "function_signature: %s", + project, + func_sig, + ) + + return function_source + + def _get_xrefs_to_function(self) -> list[str]: + """Queries FI for function being fuzzed.""" + project = self._benchmark.project + func_sig = self._benchmark.function_signature + xrefs = introspector.query_introspector_cross_references(project, func_sig) + + if not xrefs: + logging.warning( + "Could not retrieve xrefs for project: %s " "function_signature: %s", + project, + func_sig, + ) + return xrefs + + def get_context_info(self) -> dict: + """Retrieves contextual information and stores them in a dictionary.""" + xrefs = self._get_xrefs_to_function() + func_source = self._get_function_implementation() + files = self._get_files_to_include() + decl = self._get_embeddable_declaration() + header = self.get_prefixed_header_file() + + context_info = { + "xrefs": xrefs, + "func_source": func_source, + "files": files, + "decl": decl, + "header": header, + } + + logging.info("Context: %s", context_info) + + return context_info + + def _concat_info_lines(self, info: dict) -> str: + """Concatenates source code lines based on |info|.""" include_file = self._get_source_file(info) - include_file = os.path.normpath(include_file) - include_base = os.path.basename(include_file) - - # Ensure include_file is a file. - if not include_base or '.' not in include_base: - logging.warning('File %s found as a source path for project: %s', - include_file, self._benchmark.project) - continue - # Ensure it is a header file (suffix starting with .h). - if include_base.endswith(('.h', '.hxx', '.hpp')): - logging.warning( - 'File found with unexpected suffix %s for project: %s', - include_file, self._benchmark.project) - continue - # Remove "system" header files. - # Assuming header files under /usr/ are irrelevant. - if include_file.startswith('/usr/'): - logging.debug('Header file removed: %s', include_file) - continue - # TODO: Dynamically adjust path prefixes - # (e.g. based on existing fuzz targets). - - files.add(include_file) - - return list(files) - - def _clean_type(self, type_name: str) -> str: - """Cleans a type so that it can be fetched from FI.""" - if not type_name: - return type_name - - if '*' in type_name: - type_name = type_name.replace('*', '') - - type_tokens = type_name.split(' ') - - # Could be a trailing space after the pointer is removed - if '' in type_tokens: - type_tokens.remove('') - - for complex_type in COMPLEX_TYPES: - if complex_type in type_tokens: - type_tokens.remove(complex_type) - - # If there is more than a single token - # we probably do not care about querying for the type (?) - # E.g. unsigned [...], long [...], short [...], ... - # as they're most likely builtin. - if len(type_tokens) > 1: - logging.debug('Tokens: %s', type_tokens) - return '' - - return type_tokens[0] - - def _get_function_implementation(self) -> str: - """Queries FI for the source code of function being fuzzed.""" - project = self._benchmark.project - func_sig = self._benchmark.function_signature - function_source = introspector.query_introspector_function_source( - project, func_sig) - - if not function_source: - logging.warning( - 'Could not retrieve function source for project: %s ' - 'function_signature: %s', project, func_sig) - - return function_source - - def _get_xrefs_to_function(self) -> list[str]: - """Queries FI for function being fuzzed.""" - project = self._benchmark.project - func_sig = self._benchmark.function_signature - xrefs = introspector.query_introspector_cross_references(project, func_sig) - - if not xrefs: - logging.warning( - 'Could not retrieve xrefs for project: %s ' - 'function_signature: %s', project, func_sig) - return xrefs - - def get_context_info(self) -> dict: - """Retrieves contextual information and stores them in a dictionary.""" - xrefs = self._get_xrefs_to_function() - func_source = self._get_function_implementation() - files = self._get_files_to_include() - decl = self._get_embeddable_declaration() - header = self.get_prefixed_header_file() - - context_info = { - 'xrefs': xrefs, - 'func_source': func_source, - 'files': files, - 'decl': decl, - 'header': header, - } - - logging.info('Context: %s', context_info) - - return context_info - - def _concat_info_lines(self, info: dict) -> str: - """Concatenates source code lines based on |info|.""" - include_file = self._get_source_file(info) - include_lines = sorted([self._get_source_line(info)] + [ - self._get_source_line(element) for element in info.get('elements', []) - ]) - - # Add the next line after the last element. - return introspector.query_introspector_source_code(self._benchmark.project, - include_file, - include_lines[0], - include_lines[-1] + 1) - - def get_type_def(self, type_name: str) -> str: - """Retrieves the source code definitions for the given |type_name|.""" - type_names = [self._clean_type(type_name)] - considered_types = [] - type_def = '' - - while type_names: - # Breath-first is more suitable for prompting. - current_type = type_names.pop(0) - info_list = introspector.query_introspector_type_info( - self._benchmark.project, current_type) - if not info_list: - logging.warning('Could not type info for project: %s type: %s', - self._benchmark.project, current_type) - continue - - for info in info_list: - type_def += self._concat_info_lines(info) + '\n' - considered_types.append(current_type) - - # Retrieve nested unseen types. - new_type_type = info.get('type') - new_type_name = info.get('name') - if (new_type_type and new_type_type in COMPLEX_TYPES and - new_type_name and new_type_name not in considered_types): - type_names.append(new_type_name) - - return type_def - - def get_same_header_file_paths(self, wrong_file: str) -> list[str]: - """Retrieves path of header files with the same name as |wrong_name|.""" - wrong_file_name = os.path.splitext(os.path.basename(wrong_file)) - header_list = introspector.query_introspector_header_files( - self._benchmark.project) - - candidate_headers = [] - for header in header_list: - correct_file_name = os.path.splitext(os.path.basename(header)) - if wrong_file_name == correct_file_name: - candidate_headers.append(os.path.normpath(header)) - - return candidate_headers[:5] - - def get_similar_header_file_paths(self, wrong_file: str) -> list[str]: - """Retrieves and finds 5 header file names closest to |wrong_name|.""" - header_list = introspector.query_introspector_header_files( - self._benchmark.project) - candidate_header_scores = { - header: - SequenceMatcher(lambda x: x in ['_', '/', '-', '.'], wrong_file, - header).ratio() for header in header_list - } - candidate_headers = sorted(candidate_header_scores, - key=lambda x: candidate_header_scores[x], - reverse=True) - return [os.path.normpath(header) for header in candidate_headers[:5]] - - def _get_header_files_to_include(self, func_sig: str) -> Optional[str]: - """Retrieves the header file of the function signature.""" - header_file = introspector.query_introspector_header_files_to_include( - self._benchmark.project, func_sig) - return header_file[0] if header_file else None - - def _get_target_function_file_path(self) -> str: - """Retrieves the header/source file of the function under test.""" - # Step 1: Find a header file from the default API. - header_file = self._get_header_files_to_include( - self._benchmark.function_signature) - if header_file: - return header_file - - # Step 2: Find a header file that shares the same name as the source file. - # TODO: Make this more robust, e.g., when header file and base file do not - # share the same basename. - source_file = introspector.query_introspector_source_file_path( - self._benchmark.project, self._benchmark.function_signature) - source_file_base, _ = os.path.splitext(os.path.basename(source_file)) - header_list = introspector.query_introspector_header_files( - self._benchmark.project) - candidate_headers = [ - header for header in header_list - if os.path.basename(header).startswith(source_file_base) - ] - if candidate_headers: - return candidate_headers[0] - - # Step 3: Use the source file If it does not have a same-name-header. - return source_file - - def get_prefixed_header_file(self, func_sig: str = '') -> Optional[str]: - """Retrieves the header_file with `extern "C"` if needed.""" - if func_sig: - header_file = self._get_header_files_to_include(func_sig) - else: - header_file = self._get_target_function_file_path() - - if not header_file: - return None - include_statement = f'#include "{os.path.normpath(header_file)}"' - return (f'extern "C" {{\n{include_statement}\n}}' - if self._benchmark.needs_extern else include_statement) - - def get_prefixed_header_file_by_name(self, func_name: str) -> Optional[str]: - """Retrieves the header file based on function name with `extern "C"` if - needed.""" - func_sig = introspector.query_introspector_function_signature( - self._benchmark.project, func_name) - return self.get_prefixed_header_file(func_sig) - - def get_prefixed_source_file(self, - function_signature: str = '') -> Optional[str]: - """Retrieves the source file with `extern "C"` if needed.""" - if function_signature: - source_file = introspector.query_introspector_source_file_path( - self._benchmark.project, function_signature) - else: - source_file = introspector.query_introspector_source_file_path( - self._benchmark.project, self._benchmark.function_signature) - if not source_file: - return None - - include_statement = f'#include "{os.path.normpath(source_file)}"' - return (f'extern "C" {{\n{include_statement}\n}}' - if self._benchmark.needs_extern else include_statement) + include_lines = sorted( + [self._get_source_line(info)] + + [self._get_source_line(element) for element in info.get("elements", [])] + ) + + # Add the next line after the last element. + return introspector.query_introspector_source_code( + self._benchmark.project, + include_file, + include_lines[0], + include_lines[-1] + 1, + ) + + def get_type_def(self, type_name: str) -> str: + """Retrieves the source code definitions for the given |type_name|.""" + type_names = [self._clean_type(type_name)] + considered_types = [] + type_def = "" + + while type_names: + # Breath-first is more suitable for prompting. + current_type = type_names.pop(0) + info_list = introspector.query_introspector_type_info( + self._benchmark.project, current_type + ) + if not info_list: + logging.warning( + "Could not type info for project: %s type: %s", + self._benchmark.project, + current_type, + ) + continue + + for info in info_list: + type_def += self._concat_info_lines(info) + "\n" + considered_types.append(current_type) + + # Retrieve nested unseen types. + new_type_type = info.get("type") + new_type_name = info.get("name") + if ( + new_type_type + and new_type_type in COMPLEX_TYPES + and new_type_name + and new_type_name not in considered_types + ): + type_names.append(new_type_name) + + return type_def + + def get_same_header_file_paths(self, wrong_file: str) -> list[str]: + """Retrieves path of header files with the same name as |wrong_name|.""" + wrong_file_name = os.path.splitext(os.path.basename(wrong_file)) + header_list = introspector.query_introspector_header_files( + self._benchmark.project + ) + + candidate_headers = [] + for header in header_list: + correct_file_name = os.path.splitext(os.path.basename(header)) + if wrong_file_name == correct_file_name: + candidate_headers.append(os.path.normpath(header)) + + return candidate_headers[:5] + + def get_similar_header_file_paths(self, wrong_file: str) -> list[str]: + """Retrieves and finds 5 header file names closest to |wrong_name|.""" + header_list = introspector.query_introspector_header_files( + self._benchmark.project + ) + candidate_header_scores = { + header: SequenceMatcher( + lambda x: x in ["_", "/", "-", "."], wrong_file, header + ).ratio() + for header in header_list + } + candidate_headers = sorted( + candidate_header_scores, + key=lambda x: candidate_header_scores[x], + reverse=True, + ) + return [os.path.normpath(header) for header in candidate_headers[:5]] + + def _get_header_files_to_include(self, func_sig: str) -> Optional[str]: + """Retrieves the header file of the function signature.""" + header_file = introspector.query_introspector_header_files_to_include( + self._benchmark.project, func_sig + ) + return header_file[0] if header_file else None + + def _get_target_function_file_path(self) -> str: + """Retrieves the header/source file of the function under test.""" + # Step 1: Find a header file from the default API. + header_file = self._get_header_files_to_include( + self._benchmark.function_signature + ) + if header_file: + return header_file + + # Step 2: Find a header file that shares the same name as the source file. + # TODO: Make this more robust, e.g., when header file and base file do not + # share the same basename. + source_file = introspector.query_introspector_source_file_path( + self._benchmark.project, self._benchmark.function_signature + ) + source_file_base, _ = os.path.splitext(os.path.basename(source_file)) + header_list = introspector.query_introspector_header_files( + self._benchmark.project + ) + candidate_headers = [ + header + for header in header_list + if os.path.basename(header).startswith(source_file_base) + ] + if candidate_headers: + return candidate_headers[0] + + # Step 3: Use the source file If it does not have a same-name-header. + return source_file + + def get_prefixed_header_file(self, func_sig: str = "") -> Optional[str]: + """Retrieves the header_file with `extern "C"` if needed.""" + if func_sig: + header_file = self._get_header_files_to_include(func_sig) + else: + header_file = self._get_target_function_file_path() + + if not header_file: + return None + include_statement = f'#include "{os.path.normpath(header_file)}"' + return ( + f'extern "C" {{\n{include_statement}\n}}' + if self._benchmark.needs_extern + else include_statement + ) + + def get_prefixed_header_file_by_name(self, func_name: str) -> Optional[str]: + """Retrieves the header file based on function name with `extern "C"` if + needed.""" + func_sig = introspector.query_introspector_function_signature( + self._benchmark.project, func_name + ) + return self.get_prefixed_header_file(func_sig) + + def get_prefixed_source_file(self, function_signature: str = "") -> Optional[str]: + """Retrieves the source file with `extern "C"` if needed.""" + if function_signature: + source_file = introspector.query_introspector_source_file_path( + self._benchmark.project, function_signature + ) + else: + source_file = introspector.query_introspector_source_file_path( + self._benchmark.project, self._benchmark.function_signature + ) + if not source_file: + return None + + include_statement = f'#include "{os.path.normpath(source_file)}"' + return ( + f'extern "C" {{\n{include_statement}\n}}' + if self._benchmark.needs_extern + else include_statement + ) diff --git a/data_prep/project_src.py b/data_prep/project_src.py index e827763f42..6e549bccb8 100755 --- a/data_prep/project_src.py +++ b/data_prep/project_src.py @@ -29,379 +29,397 @@ logger = logging.getLogger(__name__) -SEARCH_IGNORE_DIRS = ['aflplusplus', 'fuzztest', 'honggfuzz', 'libfuzzer'] -SEARCH_EXTS = ['.c', '.cc', '.cpp', '.cxx', '.c++'] +SEARCH_IGNORE_DIRS = ["aflplusplus", "fuzztest", "honggfuzz", "libfuzzer"] +SEARCH_EXTS = [".c", ".cc", ".cpp", ".cxx", ".c++"] -def _read_harness(src_file: str, encoding_error_handling: str = 'replace'): - """Reads content of a harness |src_file| and handles encoding error.""" - with open(src_file, encoding='utf-8', errors=encoding_error_handling) as fp: - try: - content = fp.read() - except Exception as e: - raise type(e)(f'Failed to decode fuzz target {src_file} with ' - f'{encoding_error_handling}.') - return content +def _read_harness(src_file: str, encoding_error_handling: str = "replace"): + """Reads content of a harness |src_file| and handles encoding error.""" + with open(src_file, encoding="utf-8", errors=encoding_error_handling) as fp: + try: + content = fp.read() + except Exception as e: + raise type(e)( + f"Failed to decode fuzz target {src_file} with " + f"{encoding_error_handling}." + ) + return content def _format_source(src_file: str) -> str: - """Runs Clang format and returns formatted code.""" - # Need to install clang-format, e.g., apt install clang-format. - cmd = ['clang-format', '-style={ColumnLimit: 1000}', '-i', src_file] - timeout_seconds = 60 - try: - result = sp.run(cmd, - check=True, - capture_output=True, - stdin=sp.DEVNULL, - timeout=timeout_seconds) - except sp.TimeoutExpired: - logger.debug( - 'Could not format in %d seconds: %s', - timeout_seconds, - src_file, - ) - except Exception as e: - logger.debug('Failed to format %s: %s', src_file, e) - else: - if result.returncode: - logger.warning('Failed to format %s:', src_file) - logger.warning('STDOUT: %s', result.stdout) - logger.warning('STDERR: %s', result.stderr) - if os.path.isfile(src_file): - return _read_harness(src_file) or _read_harness(src_file, 'ignore') or '' - logger.warning('Failed to find file: %s', src_file) + """Runs Clang format and returns formatted code.""" + # Need to install clang-format, e.g., apt install clang-format. + cmd = ["clang-format", "-style={ColumnLimit: 1000}", "-i", src_file] + timeout_seconds = 60 + try: + result = sp.run( + cmd, + check=True, + capture_output=True, + stdin=sp.DEVNULL, + timeout=timeout_seconds, + ) + except sp.TimeoutExpired: + logger.debug( + "Could not format in %d seconds: %s", + timeout_seconds, + src_file, + ) + except Exception as e: + logger.debug("Failed to format %s: %s", src_file, e) + else: + if result.returncode: + logger.warning("Failed to format %s:", src_file) + logger.warning("STDOUT: %s", result.stdout) + logger.warning("STDERR: %s", result.stderr) + if os.path.isfile(src_file): + return _read_harness(src_file) or _read_harness(src_file, "ignore") or "" + logger.warning("Failed to find file: %s", src_file) - return '' + return "" def _get_interesting_file(src_file: str, out: str) -> tuple[str, str]: - """Returns the path name and content of |src_file|""" - short_path = src_file[len(out):] - content = _format_source(src_file) - if not content: - return '', '' - return short_path, content + """Returns the path name and content of |src_file|""" + short_path = src_file[len(out) :] + content = _format_source(src_file) + if not content: + return "", "" + return short_path, content def _get_harness(src_file: str, out: str, language: str) -> tuple[str, str]: - """Returns the path name and content of harness.""" + """Returns the path name and content of harness.""" - content = _format_source(src_file) + content = _format_source(src_file) - if language.lower() in {'c++', 'c' - } and 'int LLVMFuzzerTestOneInput' not in content: - return '', '' - if language.lower( - ) == 'jvm' and 'static void fuzzerTestOneInput' not in content: - return '', '' - if language.lower() == 'python' and 'atheris.Fuzz()' not in content: - return '', '' - if language.lower() == 'rust' and 'fuzz_target!' not in content: - return '', '' + if language.lower() in {"c++", "c"} and "int LLVMFuzzerTestOneInput" not in content: + return "", "" + if language.lower() == "jvm" and "static void fuzzerTestOneInput" not in content: + return "", "" + if language.lower() == "python" and "atheris.Fuzz()" not in content: + return "", "" + if language.lower() == "rust" and "fuzz_target!" not in content: + return "", "" - short_path = src_file[len(out):] - return short_path, content + short_path = src_file[len(out) :] + return short_path, content def _build_project_local_docker(project: str): - """Builds the project with OSS-Fuzz.""" - helper_path = os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, 'infra', - 'helper.py') - command = [ - 'python3', helper_path, 'build_image', '--cache', '--no-pull', project - ] - logger.info('Building project image: %s', ' '.join(command)) - result = sp.run(command, - stdout=sp.PIPE, - stderr=sp.STDOUT, - stdin=sp.DEVNULL, - check=False) - if result.returncode: - logger.error('Failed to build OSS-Fuzz image for %s:', project) - logger.error('return code %d: %s', result.returncode, result.stdout) - raise Exception('Failed to build OSS-Fuzz image for {project}') - logger.info('Done building image.') - - -def _copy_project_src(project: str, - out: str, - cloud_experiment_bucket: str = '', - language: str = 'c++'): - """Copies /|src| from cloud if bucket is available or from local image.""" - if cloud_experiment_bucket: - logger.info( - 'Retrieving human-written fuzz targets of %s from Google Cloud Build.', - project) - bucket_dirname = _build_project_on_cloud(project, cloud_experiment_bucket) - _copy_project_src_from_cloud(bucket_dirname, out, cloud_experiment_bucket) - else: - logger.info( - 'Retrieving human-written fuzz targets of %s from local Docker build.', - project) - _build_project_local_docker(project) - _copy_project_src_from_local(project, out, language) + """Builds the project with OSS-Fuzz.""" + helper_path = os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, "infra", "helper.py") + command = ["python3", helper_path, "build_image", "--cache", "--no-pull", project] + logger.info("Building project image: %s", " ".join(command)) + result = sp.run( + command, stdout=sp.PIPE, stderr=sp.STDOUT, stdin=sp.DEVNULL, check=False + ) + if result.returncode: + logger.error("Failed to build OSS-Fuzz image for %s:", project) + logger.error("return code %d: %s", result.returncode, result.stdout) + raise Exception("Failed to build OSS-Fuzz image for {project}") + logger.info("Done building image.") + + +def _copy_project_src( + project: str, out: str, cloud_experiment_bucket: str = "", language: str = "c++" +): + """Copies /|src| from cloud if bucket is available or from local image.""" + if cloud_experiment_bucket: + logger.info( + "Retrieving human-written fuzz targets of %s from Google Cloud Build.", + project, + ) + bucket_dirname = _build_project_on_cloud(project, cloud_experiment_bucket) + _copy_project_src_from_cloud(bucket_dirname, out, cloud_experiment_bucket) + else: + logger.info( + "Retrieving human-written fuzz targets of %s from local Docker build.", + project, + ) + _build_project_local_docker(project) + _copy_project_src_from_local(project, out, language) def _build_project_on_cloud(project: str, cloud_experiment_bucket: str) -> str: - """Builds project image on cloud and copies /src.""" - # project => cloud_experiment_name - uid = project + '-' + str(uuid.uuid4()) - search_regex = '-o'.join(f' -name "*{ext}" ' for ext in SEARCH_EXTS) - ignore_regex = ' '.join( - f'! -path "/src/{bad_dir}/*"' for bad_dir in SEARCH_IGNORE_DIRS) - cp_command = (f'find /src \\({search_regex}\\) {ignore_regex} ' - '-exec cp --parents {} /workspace/out/ \\;') - cloud_build_command = [ - f'./{oss_fuzz_checkout.VENV_DIR}/bin/python3', - 'infra/build/functions/project_experiment.py', - f'--project={project}', - f'--command={cp_command}', - f"--upload_output=gs://{cloud_experiment_bucket}/{uid}", - f'--experiment_name={uid}', - ] - cloud_build_result = sp.run(cloud_build_command, - capture_output=True, - stdin=sp.DEVNULL, - text=True, - check=False, - cwd=oss_fuzz_checkout.OSS_FUZZ_DIR) - if (cloud_build_result.returncode or - 'failed: step exited with non-zero status' in cloud_build_result.stdout): - logger.error('Failed to upload /src/ in OSS-Fuzz image of %s:', project) - logger.error('STDOUT: %s', cloud_build_result.stdout) - logger.error('STDERR: %s', cloud_build_result.stderr) - raise Exception( - f'Failed to run cloud build command: {" ".join(cloud_build_command)}') - - return uid - - -def _copy_project_src_from_cloud(bucket_dirname: str, out: str, - cloud_experiment_bucket: str): - """Copies /src from |bucket_dirname|.""" - storage_client = storage.Client() - bucket = storage_client.bucket(cloud_experiment_bucket) - blobs = bucket.list_blobs(prefix=bucket_dirname) - # Download each file in the directory - for blob in blobs: - # Ignore directories - if blob.name.endswith('/'): - continue - # Create a local path that mirrors the structure in the bucket. - relative_path = blob.name[len(bucket_dirname) + 1:] - local_file_path = os.path.join(out, 'src', relative_path) - # Create local directories if they don't exist - local_dir = os.path.dirname(local_file_path) - os.makedirs(local_dir, exist_ok=True) - - # Download the file - blob.download_to_filename(local_file_path) - logger.info('Downloaded %s to %s', blob.name, local_file_path) - blob.delete() - logger.info('Deleted %s from the bucket.', blob.name) + """Builds project image on cloud and copies /src.""" + # project => cloud_experiment_name + uid = project + "-" + str(uuid.uuid4()) + search_regex = "-o".join(f' -name "*{ext}" ' for ext in SEARCH_EXTS) + ignore_regex = " ".join( + f'! -path "/src/{bad_dir}/*"' for bad_dir in SEARCH_IGNORE_DIRS + ) + cp_command = ( + f"find /src \\({search_regex}\\) {ignore_regex} " + "-exec cp --parents {} /workspace/out/ \\;" + ) + cloud_build_command = [ + f"./{oss_fuzz_checkout.VENV_DIR}/bin/python3", + "infra/build/functions/project_experiment.py", + f"--project={project}", + f"--command={cp_command}", + f"--upload_output=gs://{cloud_experiment_bucket}/{uid}", + f"--experiment_name={uid}", + ] + cloud_build_result = sp.run( + cloud_build_command, + capture_output=True, + stdin=sp.DEVNULL, + text=True, + check=False, + cwd=oss_fuzz_checkout.OSS_FUZZ_DIR, + ) + if ( + cloud_build_result.returncode + or "failed: step exited with non-zero status" in cloud_build_result.stdout + ): + logger.error("Failed to upload /src/ in OSS-Fuzz image of %s:", project) + logger.error("STDOUT: %s", cloud_build_result.stdout) + logger.error("STDERR: %s", cloud_build_result.stderr) + raise Exception( + f'Failed to run cloud build command: {" ".join(cloud_build_command)}' + ) + + return uid + + +def _copy_project_src_from_cloud( + bucket_dirname: str, out: str, cloud_experiment_bucket: str +): + """Copies /src from |bucket_dirname|.""" + storage_client = storage.Client() + bucket = storage_client.bucket(cloud_experiment_bucket) + blobs = bucket.list_blobs(prefix=bucket_dirname) + # Download each file in the directory + for blob in blobs: + # Ignore directories + if blob.name.endswith("/"): + continue + # Create a local path that mirrors the structure in the bucket. + relative_path = blob.name[len(bucket_dirname) + 1 :] + local_file_path = os.path.join(out, "src", relative_path) + # Create local directories if they don't exist + local_dir = os.path.dirname(local_file_path) + os.makedirs(local_dir, exist_ok=True) + + # Download the file + blob.download_to_filename(local_file_path) + logger.info("Downloaded %s to %s", blob.name, local_file_path) + blob.delete() + logger.info("Deleted %s from the bucket.", blob.name) def _copy_project_src_from_local(project: str, out: str, language: str): - """Runs the project's OSS-Fuzz image to copy /|src| to /|out|.""" - timestamp = time.time() - run_container = [ - 'docker', - 'run', - '-d', - '--rm', - '--shm-size=2g', - '--platform', - 'linux/amd64', - '-e', - 'FUZZING_ENGINE=libfuzzer', - '-e' - 'SANITIZER=address', - '-e', - 'ARCHITECTURE=x86_64', - '-e', - f'PROJECT_NAME={project}', - '-e', - 'HELPER=True', - '-e', - f'FUZZING_LANGUAGE={language}', - '--name', - f'{project}-container-{timestamp}', - f'gcr.io/oss-fuzz/{project}', - ] - result = sp.run(run_container, - capture_output=True, - stdin=sp.DEVNULL, - check=False) - if result.returncode and 'Conflict' in str(result.stderr): - # When running in multi-threading environment, the timestamp suffix - # for the container name maybe the same and cause Conflict error. - # Sleep for random seconds from 1 to 30 and retry to avoid conflict. - logger.warning('Failed to run OSS-Fuzz on %s, retry in 1~30 sec', project) - time.sleep(random.randint(1, 30)) - - # Update timestamp suffix and re run + """Runs the project's OSS-Fuzz image to copy /|src| to /|out|.""" timestamp = time.time() - run_container[-2] = f'{project}-container-{timestamp}' - result = sp.run(run_container, - capture_output=True, - stdin=sp.DEVNULL, - check=False) - - if result.returncode: - # Still fail from conclict or other errors - logger.error('Failed to run OSS-Fuzz image of %s:', project) - logger.error('STDOUT: %s', result.stdout) - logger.error('STDERR: %s', result.stderr) - raise Exception(f'Failed to run docker command: {" ".join(run_container)}') - - try: - copy_src = ['docker', 'cp', f'{project}-container-{timestamp}:/src', out] - result = sp.run(copy_src, - capture_output=True, - stdin=sp.DEVNULL, - check=False) - if result.returncode: - logger.error('Failed to copy /src from OSS-Fuzz image of %s:', project) - logger.error('STDOUT: %s', result.stdout) - logger.error('STDERR: %s', result.stderr) - raise Exception(f'Failed to run docker command: {" ".join(copy_src)}') - logger.info('Done copying %s /src to %s.', project, out) - finally: - # Shut down the container that was just started. - result = sp.run( - ['docker', 'container', 'stop', f'{project}-container-{timestamp}'], - capture_output=True, - stdin=sp.DEVNULL, - check=False) + run_container = [ + "docker", + "run", + "-d", + "--rm", + "--shm-size=2g", + "--platform", + "linux/amd64", + "-e", + "FUZZING_ENGINE=libfuzzer", + "-e" "SANITIZER=address", + "-e", + "ARCHITECTURE=x86_64", + "-e", + f"PROJECT_NAME={project}", + "-e", + "HELPER=True", + "-e", + f"FUZZING_LANGUAGE={language}", + "--name", + f"{project}-container-{timestamp}", + f"gcr.io/oss-fuzz/{project}", + ] + result = sp.run(run_container, capture_output=True, stdin=sp.DEVNULL, check=False) + if result.returncode and "Conflict" in str(result.stderr): + # When running in multi-threading environment, the timestamp suffix + # for the container name maybe the same and cause Conflict error. + # Sleep for random seconds from 1 to 30 and retry to avoid conflict. + logger.warning("Failed to run OSS-Fuzz on %s, retry in 1~30 sec", project) + time.sleep(random.randint(1, 30)) + + # Update timestamp suffix and re run + timestamp = time.time() + run_container[-2] = f"{project}-container-{timestamp}" + result = sp.run( + run_container, capture_output=True, stdin=sp.DEVNULL, check=False + ) + if result.returncode: - logger.error('Failed to stop container image: %s-container', project) - logger.error('STDOUT: %s', result.stdout) - logger.error('STDERR: %s', result.stderr) - - -def _identify_fuzz_targets(out: str, interesting_filenames: list[str], - language: str) -> tuple[list[str], list[str]]: - """ - Identifies fuzz target file contents and |interesting_filenames| in |out|. - """ - logger.debug('len(interesting_filenames): %d', len(interesting_filenames)) - - interesting_filepaths = [] - potential_harnesses = [] - - for root, _, filenames in os.walk(out): - is_bad = False - for ignore_dir in SEARCH_IGNORE_DIRS: - # Exclude engine source. - if f'out/src/{ignore_dir}' in root: - is_bad = True - break - if is_bad: - continue - for filename in filenames: - if not benchmark.get_file_type(filename): - continue - path = os.path.join(root, filename) - if language == 'jvm': - # For JVM - if path.endswith(tuple(interesting_filenames)): - interesting_filepaths.append(path) - if path.endswith('.java'): - potential_harnesses.append(path) - elif language == 'python': - # For Python - if path.endswith(tuple(interesting_filenames)): - interesting_filepaths.append(path) - if path.endswith('.py'): - potential_harnesses.append(path) - elif language == 'rust': - # For Rust - if path.endswith(tuple(interesting_filenames)): - interesting_filepaths.append(path) - if path.endswith('.rs'): - potential_harnesses.append(path) - else: - # For C/C++ - short_path = path[len(out):] - if short_path in interesting_filenames: - interesting_filepaths.append(path) - # TODO(dongge): Figure out why the path does not match Bazel projects. - if os.path.basename(short_path) in interesting_filenames: - interesting_filepaths.append(path) - - if any(path.endswith(suffix) for suffix in SEARCH_EXTS): - potential_harnesses.append(path) - - return potential_harnesses, interesting_filepaths - - -def _parse_fuzz_targets(project: str, out: str, potential_harnesses: list[str], - interesting_filepaths: list[str], - language: str) -> tuple[dict[str, str], dict[str, str]]: - """ - Parses fuzz target file contents and |interesting_filenames| in |out|. - """ - interesting_files = {} - for src_file in interesting_filepaths: - short_path, content = _get_interesting_file(src_file, out) - if short_path == content == '': - continue - interesting_files[short_path] = content - - fuzz_targets = {} - for harness in potential_harnesses: - short_path, content = _get_harness(harness, out, language) - if short_path == content == '': - continue - fuzz_targets[short_path] = content - # Sometimes you will get /src/$DEPENDENCY/$FUZZER (e.g. /src/cJSON when - # fuzzing mosquitto). OSS-Fuzz is too popular. - pruned = {k: v for k, v in fuzz_targets.items() if project in k} - fuzz_targets = pruned or fuzz_targets - - return fuzz_targets, interesting_files + # Still fail from conclict or other errors + logger.error("Failed to run OSS-Fuzz image of %s:", project) + logger.error("STDOUT: %s", result.stdout) + logger.error("STDERR: %s", result.stderr) + raise Exception(f'Failed to run docker command: {" ".join(run_container)}') + + try: + copy_src = ["docker", "cp", f"{project}-container-{timestamp}:/src", out] + result = sp.run(copy_src, capture_output=True, stdin=sp.DEVNULL, check=False) + if result.returncode: + logger.error("Failed to copy /src from OSS-Fuzz image of %s:", project) + logger.error("STDOUT: %s", result.stdout) + logger.error("STDERR: %s", result.stderr) + raise Exception(f'Failed to run docker command: {" ".join(copy_src)}') + logger.info("Done copying %s /src to %s.", project, out) + finally: + # Shut down the container that was just started. + result = sp.run( + ["docker", "container", "stop", f"{project}-container-{timestamp}"], + capture_output=True, + stdin=sp.DEVNULL, + check=False, + ) + if result.returncode: + logger.error("Failed to stop container image: %s-container", project) + logger.error("STDOUT: %s", result.stdout) + logger.error("STDERR: %s", result.stderr) + + +def _identify_fuzz_targets( + out: str, interesting_filenames: list[str], language: str +) -> tuple[list[str], list[str]]: + """ + Identifies fuzz target file contents and |interesting_filenames| in |out|. + """ + logger.debug("len(interesting_filenames): %d", len(interesting_filenames)) + + interesting_filepaths = [] + potential_harnesses = [] + + for root, _, filenames in os.walk(out): + is_bad = False + for ignore_dir in SEARCH_IGNORE_DIRS: + # Exclude engine source. + if f"out/src/{ignore_dir}" in root: + is_bad = True + break + if is_bad: + continue + for filename in filenames: + if not benchmark.get_file_type(filename): + continue + path = os.path.join(root, filename) + if language == "jvm": + # For JVM + if path.endswith(tuple(interesting_filenames)): + interesting_filepaths.append(path) + if path.endswith(".java"): + potential_harnesses.append(path) + elif language == "python": + # For Python + if path.endswith(tuple(interesting_filenames)): + interesting_filepaths.append(path) + if path.endswith(".py"): + potential_harnesses.append(path) + elif language == "rust": + # For Rust + if path.endswith(tuple(interesting_filenames)): + interesting_filepaths.append(path) + if path.endswith(".rs"): + potential_harnesses.append(path) + else: + # For C/C++ + short_path = path[len(out) :] + if short_path in interesting_filenames: + interesting_filepaths.append(path) + # TODO(dongge): Figure out why the path does not match Bazel projects. + if os.path.basename(short_path) in interesting_filenames: + interesting_filepaths.append(path) + + if any(path.endswith(suffix) for suffix in SEARCH_EXTS): + potential_harnesses.append(path) + + return potential_harnesses, interesting_filepaths + + +def _parse_fuzz_targets( + project: str, + out: str, + potential_harnesses: list[str], + interesting_filepaths: list[str], + language: str, +) -> tuple[dict[str, str], dict[str, str]]: + """ + Parses fuzz target file contents and |interesting_filenames| in |out|. + """ + interesting_files = {} + for src_file in interesting_filepaths: + short_path, content = _get_interesting_file(src_file, out) + if short_path == content == "": + continue + interesting_files[short_path] = content + + fuzz_targets = {} + for harness in potential_harnesses: + short_path, content = _get_harness(harness, out, language) + if short_path == content == "": + continue + fuzz_targets[short_path] = content + # Sometimes you will get /src/$DEPENDENCY/$FUZZER (e.g. /src/cJSON when + # fuzzing mosquitto). OSS-Fuzz is too popular. + pruned = {k: v for k, v in fuzz_targets.items() if project in k} + fuzz_targets = pruned or fuzz_targets + + return fuzz_targets, interesting_files def _copy_fuzz_targets(harness_path: str, dest_dir: str, project: str): - """Copies the harness from |harness_path| to ./|dest_dir|/|project|/.""" - if not dest_dir: - return - dest_dir = os.path.join(dest_dir, project) - os.makedirs(dest_dir, exist_ok=True) - command = ['cp', harness_path, dest_dir] - result = sp.run(command, capture_output=True, stdin=sp.DEVNULL, check=True) - if result.returncode: - logger.error('Failed to copy harness from %s to %s: %s %s.', harness_path, - dest_dir, result.stdout, result.stderr) - raise Exception(f'Failed to copy harness from {harness_path} to {dest_dir}', - harness_path, dest_dir) - - logger.info('Retrieved fuzz targets from %s:\n %s', project, - '\n '.join(os.listdir(dest_dir))) + """Copies the harness from |harness_path| to ./|dest_dir|/|project|/.""" + if not dest_dir: + return + dest_dir = os.path.join(dest_dir, project) + os.makedirs(dest_dir, exist_ok=True) + command = ["cp", harness_path, dest_dir] + result = sp.run(command, capture_output=True, stdin=sp.DEVNULL, check=True) + if result.returncode: + logger.error( + "Failed to copy harness from %s to %s: %s %s.", + harness_path, + dest_dir, + result.stdout, + result.stderr, + ) + raise Exception( + f"Failed to copy harness from {harness_path} to {dest_dir}", + harness_path, + dest_dir, + ) + + logger.info( + "Retrieved fuzz targets from %s:\n %s", + project, + "\n ".join(os.listdir(dest_dir)), + ) def search_source( project: str, interesting_filenames: list, language: str, - result_dir: str = '', - cloud_experiment_bucket: str = '', + result_dir: str = "", + cloud_experiment_bucket: str = "", ) -> tuple[Dict[str, str], Dict[str, str]]: - """Searches source code of the target OSS-Fuzz project for the files listed + """Searches source code of the target OSS-Fuzz project for the files listed in |interesting_filenames|. Returns a dictionary of fuzz targets (path: contents) and a dictionary of interesting files to their contents.""" - with tempfile.TemporaryDirectory() as temp_dir: - out = os.path.join(temp_dir, 'out') - os.makedirs(out) - - _copy_project_src(project, out, cloud_experiment_bucket, language) - - potential_harnesses, interesting_filepaths = _identify_fuzz_targets( - out, interesting_filenames, language) - fuzz_targets, interesting_files = _parse_fuzz_targets( - project, out, potential_harnesses, interesting_filepaths, language) - - for short_path in fuzz_targets.keys(): - _copy_fuzz_targets(os.path.join(out, short_path[1:]), result_dir, project) - return fuzz_targets, interesting_files + with tempfile.TemporaryDirectory() as temp_dir: + out = os.path.join(temp_dir, "out") + os.makedirs(out) + + _copy_project_src(project, out, cloud_experiment_bucket, language) + + potential_harnesses, interesting_filepaths = _identify_fuzz_targets( + out, interesting_filenames, language + ) + fuzz_targets, interesting_files = _parse_fuzz_targets( + project, out, potential_harnesses, interesting_filepaths, language + ) + + for short_path in fuzz_targets.keys(): + _copy_fuzz_targets(os.path.join(out, short_path[1:]), result_dir, project) + return fuzz_targets, interesting_files diff --git a/data_prep/project_targets.py b/data_prep/project_targets.py index f5b67246b8..1d8826f47f 100755 --- a/data_prep/project_targets.py +++ b/data_prep/project_targets.py @@ -33,342 +33,356 @@ logger = logging.getLogger(__name__) -OSS_FUZZ_EXP_BUCKET = 'oss-fuzz-llm-public' +OSS_FUZZ_EXP_BUCKET = "oss-fuzz-llm-public" # TODO(dongge): Use tmp dir. -OSS_FUZZ_PATH = os.path.join(os.path.dirname(__file__), '..', 'oss-fuzz') +OSS_FUZZ_PATH = os.path.join(os.path.dirname(__file__), "..", "oss-fuzz") def _get_fuzz_target_dir(project_name: str) -> str: - """Returns the directory that contains the fuzz targets of |project_name|. - """ - data_dir = os.path.abspath( - os.path.join(os.path.dirname(__file__), '..', 'oss-fuzz-data')) - fuzz_target_dir = os.path.join(data_dir, 'fuzz_targets') - os.makedirs(fuzz_target_dir, exist_ok=True) - - project_fuzz_target_dir = os.path.join(fuzz_target_dir, project_name) - os.makedirs(project_fuzz_target_dir, exist_ok=True) - - storage_client = storage.Client.create_anonymous_client() - bucket = storage_client.bucket(OSS_FUZZ_EXP_BUCKET) - project_prefix = os.path.join('human_written_targets', project_name) - blobs = bucket.list_blobs(prefix=project_prefix) - for blob in blobs: - file_relpath = blob.name.replace(f'{project_prefix}/', '') - filedir = os.path.dirname( - os.path.join(project_fuzz_target_dir, file_relpath)) - os.makedirs(filedir, exist_ok=True) - blob.download_to_filename( - os.path.join(project_fuzz_target_dir, file_relpath)) - - return project_fuzz_target_dir - - -def _match_target_path_content(target_paths: List[str], - fuzz_target_dir: str) -> Dict[str, str]: - """Returns a dictionary with |target_paths| as keys and its file content - from |fuzz_target_dir| as values.""" - path_contents = {} - # Walk through the directory - for dirpath, _, filenames in os.walk(fuzz_target_dir): - for filename in filenames: - # Compute the relative file path - relative_path = os.path.relpath(os.path.join(dirpath, filename), - fuzz_target_dir) - for target_path in target_paths: - if os.path.basename(target_path) != os.path.basename(relative_path): - continue - - file_path = os.path.join(fuzz_target_dir, relative_path) - with open(file_path) as file: - content = file.read() - path_contents[target_path] = filter_target_lines(content) - - return path_contents + """Returns the directory that contains the fuzz targets of |project_name|.""" + data_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "oss-fuzz-data") + ) + fuzz_target_dir = os.path.join(data_dir, "fuzz_targets") + os.makedirs(fuzz_target_dir, exist_ok=True) + + project_fuzz_target_dir = os.path.join(fuzz_target_dir, project_name) + os.makedirs(project_fuzz_target_dir, exist_ok=True) + + storage_client = storage.Client.create_anonymous_client() + bucket = storage_client.bucket(OSS_FUZZ_EXP_BUCKET) + project_prefix = os.path.join("human_written_targets", project_name) + blobs = bucket.list_blobs(prefix=project_prefix) + for blob in blobs: + file_relpath = blob.name.replace(f"{project_prefix}/", "") + filedir = os.path.dirname(os.path.join(project_fuzz_target_dir, file_relpath)) + os.makedirs(filedir, exist_ok=True) + blob.download_to_filename(os.path.join(project_fuzz_target_dir, file_relpath)) + + return project_fuzz_target_dir + + +def _match_target_path_content( + target_paths: List[str], fuzz_target_dir: str +) -> Dict[str, str]: + """Returns a dictionary with |target_paths| as keys and its file content + from |fuzz_target_dir| as values.""" + path_contents = {} + # Walk through the directory + for dirpath, _, filenames in os.walk(fuzz_target_dir): + for filename in filenames: + # Compute the relative file path + relative_path = os.path.relpath( + os.path.join(dirpath, filename), fuzz_target_dir + ) + for target_path in target_paths: + if os.path.basename(target_path) != os.path.basename(relative_path): + continue + + file_path = os.path.join(fuzz_target_dir, relative_path) + with open(file_path) as file: + content = file.read() + path_contents[target_path] = filter_target_lines(content) + + return path_contents def _bucket_match_target_content_signatures( - target_funcs: Dict[str, List[Dict]], fuzz_target_dir: str, - project_name: str) -> Dict[str, List[str]]: - """Returns a list of dictionary with function signatures as keys and + target_funcs: Dict[str, List[Dict]], fuzz_target_dir: str, project_name: str +) -> Dict[str, List[str]]: + """Returns a list of dictionary with function signatures as keys and its fuzz target content as values.""" - if not target_funcs: - logger.info('Error: No fuzz target functions available.') - return {} - if not os.path.isdir(fuzz_target_dir): - logger.info( - 'Error: Fuzz target directory does not exist ({fuzz_target_dir})') - return {} - - target_path_contents = _match_target_path_content(list(target_funcs.keys()), - fuzz_target_dir) - target_content_signature_dict = {} - for target_path, functions in target_funcs.items(): - content = target_path_contents.get(target_path) - # Some projects' `target_path` is different from the actual - # path in container, due to relocation in build process. - # E.g., target_path is /src/hiredis/format_command_fuzzer.c, different - # from the actual path /src/hiredis/fuzzing/format_command_fuzzer.c in - # https://storage.googleapis.com/oss-fuzz-introspector/hiredis/inspector-report/20240120/summary.json - if not content: - adjusted_target_paths = [ - t_path for t_path in target_path_contents - if os.path.basename(t_path) == os.path.basename(target_path) - ] - if adjusted_target_paths: - adjusted_target_path = adjusted_target_paths[0] - content = target_path_contents.get(adjusted_target_path) - if not content: - return {} - if content not in target_content_signature_dict: - target_content_signature_dict[content] = [] - - signatures = [ - introspector.query_introspector_function_signature( + if not target_funcs: + logger.info("Error: No fuzz target functions available.") + return {} + if not os.path.isdir(fuzz_target_dir): + logger.info("Error: Fuzz target directory does not exist ({fuzz_target_dir})") + return {} + + target_path_contents = _match_target_path_content( + list(target_funcs.keys()), fuzz_target_dir + ) + target_content_signature_dict = {} + for target_path, functions in target_funcs.items(): + content = target_path_contents.get(target_path) + # Some projects' `target_path` is different from the actual + # path in container, due to relocation in build process. + # E.g., target_path is /src/hiredis/format_command_fuzzer.c, different + # from the actual path /src/hiredis/fuzzing/format_command_fuzzer.c in + # https://storage.googleapis.com/oss-fuzz-introspector/hiredis/inspector-report/20240120/summary.json + if not content: + adjusted_target_paths = [ + t_path + for t_path in target_path_contents + if os.path.basename(t_path) == os.path.basename(target_path) + ] + if adjusted_target_paths: + adjusted_target_path = adjusted_target_paths[0] + content = target_path_contents.get(adjusted_target_path) + if not content: + return {} + if content not in target_content_signature_dict: + target_content_signature_dict[content] = [] + + signatures = [ + introspector.query_introspector_function_signature( + project_name, + introspector.get_raw_function_name(func_info, project_name), + ) + for func_info in functions + ] + target_content_signature_dict[content].extend(signatures) + + return target_content_signature_dict + + +def generate_data( + project_name: str, + language: str, + sig_per_target: int = 1, + max_samples: int = 1, + cloud_experiment_bucket: str = "", +): + """Generates project-specific fuzz targets examples.""" + target_funcs = introspector.get_project_funcs(project_name) + project_fuzz_target_dir = _get_fuzz_target_dir(project_name) + target_content_signature_dict = _bucket_match_target_content_signatures( + target_funcs, project_fuzz_target_dir, project_name + ) + + if target_content_signature_dict: + logger.info( + "Downloaded human-written fuzz targets of %s from Google Cloud Bucket: " + "%s", project_name, - introspector.get_raw_function_name(func_info, project_name)) - for func_info in functions - ] - target_content_signature_dict[content].extend(signatures) - - return target_content_signature_dict - - -def generate_data(project_name: str, - language: str, - sig_per_target: int = 1, - max_samples: int = 1, - cloud_experiment_bucket: str = ''): - """Generates project-specific fuzz targets examples.""" - target_funcs = introspector.get_project_funcs(project_name) - project_fuzz_target_dir = _get_fuzz_target_dir(project_name) - target_content_signature_dict = _bucket_match_target_content_signatures( - target_funcs, project_fuzz_target_dir, project_name) - - if target_content_signature_dict: - logger.info( - 'Downloaded human-written fuzz targets of %s from Google Cloud Bucket: ' - '%s', project_name, OSS_FUZZ_EXP_BUCKET) - else: - logger.info( - 'Failed to download human-written fuzz target of %s from Google Cloud ' - 'Bucket: %s.', project_name, OSS_FUZZ_EXP_BUCKET) - logger.info('Will try to build from Google Cloud or local docker image.') - target_content_signature_dict = _match_target_content_signatures( - target_funcs, project_name, language, cloud_experiment_bucket) - if not target_content_signature_dict: - return [] - - # Ensures the most complex fuzz target is always at the end. - contents = sorted(target_content_signature_dict.keys(), key=len) - sig_contents = [] - for i in range(sig_per_target): - for content in contents: - sigs = target_content_signature_dict.get(content, []) - if i >= len(sigs): - continue - sig_contents.append([sigs[i], content]) - - return sig_contents[-max_samples:] + OSS_FUZZ_EXP_BUCKET, + ) + else: + logger.info( + "Failed to download human-written fuzz target of %s from Google Cloud " + "Bucket: %s.", + project_name, + OSS_FUZZ_EXP_BUCKET, + ) + logger.info("Will try to build from Google Cloud or local docker image.") + target_content_signature_dict = _match_target_content_signatures( + target_funcs, project_name, language, cloud_experiment_bucket + ) + if not target_content_signature_dict: + return [] + + # Ensures the most complex fuzz target is always at the end. + contents = sorted(target_content_signature_dict.keys(), key=len) + sig_contents = [] + for i in range(sig_per_target): + for content in contents: + sigs = target_content_signature_dict.get(content, []) + if i >= len(sigs): + continue + sig_contents.append([sigs[i], content]) + + return sig_contents[-max_samples:] def _remove_header_comments(code: str) -> str: - """Removes comments and empty lines in the code.""" - # Remove multi-line comments. - multi_line_comment = re.compile(r'/\*.*?\*/', re.DOTALL) - code = re.sub(multi_line_comment, '', code) + """Removes comments and empty lines in the code.""" + # Remove multi-line comments. + multi_line_comment = re.compile(r"/\*.*?\*/", re.DOTALL) + code = re.sub(multi_line_comment, "", code) - # Remove single-line comments. - single_line_comment = re.compile(r'(?:^|\s+)//.*\n') - code = re.sub(single_line_comment, '\n', code) + # Remove single-line comments. + single_line_comment = re.compile(r"(?:^|\s+)//.*\n") + code = re.sub(single_line_comment, "\n", code) - # Remove empty lines. - empty_line = re.compile(r'\n+\s*\n+') - code = re.sub(empty_line, '\n', code) + # Remove empty lines. + empty_line = re.compile(r"\n+\s*\n+") + code = re.sub(empty_line, "\n", code) - # Trim all newlines and spaces. - code.lstrip('\n ') - code.rstrip('\n ') - return code + # Trim all newlines and spaces. + code.lstrip("\n ") + code.rstrip("\n ") + return code def _remove_header(code: str) -> str: - """Removes header comments (e.g. copyright) only before the first #include. - """ - # Split the code at the first #include. - parts = code.split('#include', 1) - header = parts[0] - content = '#include' + parts[1] if len(parts) > 1 else '' - return _remove_header_comments(header) + content + """Removes header comments (e.g. copyright) only before the first #include.""" + # Split the code at the first #include. + parts = code.split("#include", 1) + header = parts[0] + content = "#include" + parts[1] if len(parts) > 1 else "" + return _remove_header_comments(header) + content def filter_target_lines(target_content: str) -> str: - """Remove non-interesting lines in the target_content.""" - target_content = _remove_header(target_content) - return target_content + """Remove non-interesting lines in the target_content.""" + target_content = _remove_header(target_content) + return target_content def _match_target_content_signatures( target_funcs: Dict[str, List[Dict]], project_name: str, language: str, - cloud_experiment_bucket: str = '') -> Dict[str, List[str]]: - """Returns a list of dictionary with function signatures as keys and + cloud_experiment_bucket: str = "", +) -> Dict[str, List[str]]: + """Returns a list of dictionary with function signatures as keys and its fuzz target content as values.""" - if not target_funcs: - logger.info('Error: No fuzz target functions available.') - return {} - - source_content = project_src.search_source( - project_name, [], - language, - cloud_experiment_bucket=cloud_experiment_bucket) - - if not source_content[0]: - logger.info('Error: No fuzz target found for project %s', project_name) - return {} - - target_path_contents = source_content[0] - - target_content_signature_dict = {} - for target_path, functions in target_funcs.items(): - content = target_path_contents.get(target_path) - # Some projects' `target_path` is different from the actual - # path in container, due to relocation in build process. - # E.g., target_path is /src/hiredis/format_command_fuzzer.c, - # from the actual path /src/hiredis/fuzzing/format_command_fuzzer.c in - # https://storage.googleapis.com/oss-fuzz-introspector/hiredis/inspector-report/20240120/summary.json - if not content: - adjusted_target_paths = [ - t_path for t_path in target_path_contents - if os.path.basename(t_path) == os.path.basename(target_path) - ] - if adjusted_target_paths: - adjusted_target_path = adjusted_target_paths[0] - content = target_path_contents.get(adjusted_target_path) - if not content: - return {} - if content not in target_content_signature_dict: - target_content_signature_dict[content] = [] - - signatures = [ - introspector.query_introspector_function_signature( - project_name, - introspector.get_raw_function_name(func_info, project_name)) - for func_info in functions - ] - target_content_signature_dict[content].extend(signatures) - - return target_content_signature_dict + if not target_funcs: + logger.info("Error: No fuzz target functions available.") + return {} + + source_content = project_src.search_source( + project_name, [], language, cloud_experiment_bucket=cloud_experiment_bucket + ) + + if not source_content[0]: + logger.info("Error: No fuzz target found for project %s", project_name) + return {} + + target_path_contents = source_content[0] + + target_content_signature_dict = {} + for target_path, functions in target_funcs.items(): + content = target_path_contents.get(target_path) + # Some projects' `target_path` is different from the actual + # path in container, due to relocation in build process. + # E.g., target_path is /src/hiredis/format_command_fuzzer.c, + # from the actual path /src/hiredis/fuzzing/format_command_fuzzer.c in + # https://storage.googleapis.com/oss-fuzz-introspector/hiredis/inspector-report/20240120/summary.json + if not content: + adjusted_target_paths = [ + t_path + for t_path in target_path_contents + if os.path.basename(t_path) == os.path.basename(target_path) + ] + if adjusted_target_paths: + adjusted_target_path = adjusted_target_paths[0] + content = target_path_contents.get(adjusted_target_path) + if not content: + return {} + if content not in target_content_signature_dict: + target_content_signature_dict[content] = [] + + signatures = [ + introspector.query_introspector_function_signature( + project_name, + introspector.get_raw_function_name(func_info, project_name), + ) + for func_info in functions + ] + target_content_signature_dict[content].extend(signatures) + + return target_content_signature_dict def _parse_arguments(): - """Parses command line args.""" - parser = argparse.ArgumentParser( - description='Parse project-related arguments') - - # project_name argument - parser.add_argument('-p', - '--project-name', - type=str, - required=True, - help='Name of the project') - - # result_path argument - parser.add_argument('-r', - '--result-path', - type=str, - help='Path to store the results') - - # number of signatures per target argument - parser.add_argument( - '-n', - '--num-signature-per-target', - type=int, - default=1, - help='Number of signatures per fuzz target (default is 1 if unspecified)') - - # maximum number of samples per project argument - parser.add_argument( - '-m', - '--max-samples', - type=int, - default=0, - help='Maximum number of samples per project (default is 0 if unspecified)' - ) - - # number of threads argument - parser.add_argument('-t', - '--num-threads', - type=int, - default=4, - help='Number of threads to use') - - parser.add_argument('-cb', - '--cloud-experiment-bucket', - type=str, - default='', - help='A gcloud bucket to store experiment files.') - - parser.add_argument('-l', - '--language', - type=str, - default='c++', - help='Language of projects.') - - parsed_args = parser.parse_args() - if not parsed_args.result_path: - parsed_args.result_path = f'{parsed_args.project_name}.json' - return parsed_args - - -def _generate_project_training_data(project_name: str, - sig_per_target, - max_samples, - language, - cloud_experiment_bucket: str = ''): - """Generate project training data.""" - try: - return generate_data(project_name, language, sig_per_target, max_samples, - cloud_experiment_bucket) - except Exception as e: - logger.info('Project %s failed:\n%s', project_name, e) - return None + """Parses command line args.""" + parser = argparse.ArgumentParser(description="Parse project-related arguments") + + # project_name argument + parser.add_argument( + "-p", "--project-name", type=str, required=True, help="Name of the project" + ) + + # result_path argument + parser.add_argument( + "-r", "--result-path", type=str, help="Path to store the results" + ) + + # number of signatures per target argument + parser.add_argument( + "-n", + "--num-signature-per-target", + type=int, + default=1, + help="Number of signatures per fuzz target (default is 1 if unspecified)", + ) + + # maximum number of samples per project argument + parser.add_argument( + "-m", + "--max-samples", + type=int, + default=0, + help="Maximum number of samples per project (default is 0 if unspecified)", + ) + + # number of threads argument + parser.add_argument( + "-t", "--num-threads", type=int, default=4, help="Number of threads to use" + ) + + parser.add_argument( + "-cb", + "--cloud-experiment-bucket", + type=str, + default="", + help="A gcloud bucket to store experiment files.", + ) + + parser.add_argument( + "-l", "--language", type=str, default="c++", help="Language of projects." + ) + + parsed_args = parser.parse_args() + if not parsed_args.result_path: + parsed_args.result_path = f"{parsed_args.project_name}.json" + return parsed_args + + +def _generate_project_training_data( + project_name: str, + sig_per_target, + max_samples, + language, + cloud_experiment_bucket: str = "", +): + """Generate project training data.""" + try: + return generate_data( + project_name, language, sig_per_target, max_samples, cloud_experiment_bucket + ) + except Exception as e: + logger.info("Project %s failed:\n%s", project_name, e) + return None def main(): - args = _parse_arguments() - project_name = args.project_name - result_path = args.result_path - sig_per_target = args.num_signature_per_target - max_samples = args.max_samples - num_threads = args.num_threads - - all_projects = [] - if project_name == 'all': - all_projects = oss_fuzz_checkout.list_c_cpp_projects() - else: - all_projects = [project_name] - - training_data = [] - configs = [[ - project, - sig_per_target, - max_samples, - args.language, - args.cloud_experiment_bucket, - ] for project in all_projects] - with ThreadPool(num_threads) as p: - for data in p.starmap(_generate_project_training_data, configs): - if data is None: - continue - training_data.extend(data) - - result_name, result_ext = os.path.splitext(result_path) - result_path = f'{result_name}_{len(training_data)}{result_ext}' - with open(result_path, 'w+') as file: - json.dump(training_data, file, indent=4) - - -if __name__ == '__main__': - sys.exit(main()) + args = _parse_arguments() + project_name = args.project_name + result_path = args.result_path + sig_per_target = args.num_signature_per_target + max_samples = args.max_samples + num_threads = args.num_threads + + all_projects = [] + if project_name == "all": + all_projects = oss_fuzz_checkout.list_c_cpp_projects() + else: + all_projects = [project_name] + + training_data = [] + configs = [ + [ + project, + sig_per_target, + max_samples, + args.language, + args.cloud_experiment_bucket, + ] + for project in all_projects + ] + with ThreadPool(num_threads) as p: + for data in p.starmap(_generate_project_training_data, configs): + if data is None: + continue + training_data.extend(data) + + result_name, result_ext = os.path.splitext(result_path) + result_path = f"{result_name}_{len(training_data)}{result_ext}" + with open(result_path, "w+") as file: + json.dump(training_data, file, indent=4) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/data_prep/target_collector.py b/data_prep/target_collector.py index 15df4424cd..155851018d 100755 --- a/data_prep/target_collector.py +++ b/data_prep/target_collector.py @@ -32,39 +32,42 @@ def _extract_introspector_report(project_name, date_str): - project_url = ('https://storage.googleapis.com/oss-fuzz-introspector/' - f'{project_name}/inspector-report/{date_str}/summary.json') - # Read the introspector artifact. - try: - raw_introspector_json_request = requests.get(project_url, timeout=10) - introspector_report = json.loads(raw_introspector_json_request.text) - except: - return None - return introspector_report + project_url = ( + "https://storage.googleapis.com/oss-fuzz-introspector/" + f"{project_name}/inspector-report/{date_str}/summary.json" + ) + # Read the introspector artifact. + try: + raw_introspector_json_request = requests.get(project_url, timeout=10) + introspector_report = json.loads(raw_introspector_json_request.text) + except: + return None + return introspector_report def _get_targets(project_name: str) -> Set[str]: - """Fetches the latest fuzz targets and function signatures of |project_name| + """Fetches the latest fuzz targets and function signatures of |project_name| from FuzzIntrospector.""" - yesterday = datetime.date.today() - datetime.timedelta(days=2) - introspector_json_report = _extract_introspector_report( - project_name, yesterday.strftime('%Y%m%d')) - if introspector_json_report is None: - logger.info('Error: No fuzz introspector report is found.') - return set() + yesterday = datetime.date.today() - datetime.timedelta(days=2) + introspector_json_report = _extract_introspector_report( + project_name, yesterday.strftime("%Y%m%d") + ) + if introspector_json_report is None: + logger.info("Error: No fuzz introspector report is found.") + return set() - annotated_cfg = introspector_json_report['analyses']['AnnotatedCFG'] - return set(annotated_cfg[fuzzer]['src_file'] for fuzzer in annotated_cfg) + annotated_cfg = introspector_json_report["analyses"]["AnnotatedCFG"] + return set(annotated_cfg[fuzzer]["src_file"] for fuzzer in annotated_cfg) def main() -> None: - """Installs tools, gets signatures, and writes them to the result file.""" - project_name = sys.argv[1] - targets = _get_targets(project_name) - os.makedirs(f'/work/out/{project_name}', exist_ok=True) - for target in targets: - shutil.copy(target, f'/work/out/{project_name}') + """Installs tools, gets signatures, and writes them to the result file.""" + project_name = sys.argv[1] + targets = _get_targets(project_name) + os.makedirs(f"/work/out/{project_name}", exist_ok=True) + for target in targets: + shutil.copy(target, f"/work/out/{project_name}") -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/experiment/benchmark.py b/experiment/benchmark.py index cc74d20b09..b94a5fe991 100644 --- a/experiment/benchmark.py +++ b/experiment/benchmark.py @@ -25,272 +25,286 @@ class FileType(Enum): - """File types of target files.""" - C = 'C' - CPP = 'C++' - JAVA = 'Java' - NONE = '' + """File types of target files.""" + + C = "C" + CPP = "C++" + JAVA = "Java" + NONE = "" # Define a custom representer for quoting strings def quoted_string_presenter(dumper, data): - if '\n' in data: - return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') - return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='"') + if "\n" in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style='"') class Benchmark: - """Represents a benchmark.""" - - @classmethod - def to_yaml(cls, - benchmarks: list[Benchmark], - outdir: str = './', - out_basename: str = ''): - """Converts and saves selected fields of a benchmark to a YAML file.""" - # Register the custom representer - yaml.add_representer(str, quoted_string_presenter) - result: dict[str, Any] = { - 'project': benchmarks[0].project, - 'language': benchmarks[0].language, - 'target_path': benchmarks[0].target_path, - 'target_name': benchmarks[0].target_name, - } - for benchmark in benchmarks: - if benchmark.test_file_path: - if 'test_files' not in result: - result['test_files'] = [] - result['test_files'].append( - {'test_file_path': benchmark.test_file_path}) - else: - if 'functions' not in result: - result['functions'] = [] - result['functions'].append({ - 'signature': benchmark.function_signature, - 'name': benchmark.function_name, - 'return_type': benchmark.return_type, - 'params': benchmark.params - }) - - if not out_basename: - out_basename = f'{benchmarks[0].project}.yaml' - with open(os.path.join(outdir, out_basename), 'w') as file: - yaml.dump(result, file, default_flow_style=False, width=sys.maxsize) - - @classmethod - def from_yaml(cls, benchmark_path: str) -> List: - """Constructs a benchmark based on a yaml file.""" - benchmarks = [] - with open(benchmark_path, 'r') as benchmark_file: - data = yaml.safe_load(benchmark_file) - if not data: - return [] - - project_name = data.get('project', '') - use_context = data.get('use_context', False) - use_project_examples = data.get('use_project_examples', True) - cppify_headers = data.get('cppify_headers', False) - commit = data.get('commit') - functions = data.get('functions', []) - - test_files = data.get('test_files', []) - if test_files: - for test_file in test_files: - max_len = os.pathconf('/', 'PC_NAME_MAX') - len('output-') - test_file_path = test_file.get('test_file_path') - normalized_test_path = test_file_path.replace('/', '_').replace( - '.', '_').replace('-', '_') - truncated_id = f'{project_name}-{normalized_test_path}'[:max_len] - - benchmarks.append( - cls( - truncated_id.lower(), - data['project'], - data['language'], - '', - '', - '', - [], - data['target_path'], - data.get('target_name', ''), - test_file_path=test_file_path, - )) - - if functions: - # function type benchmark - for function in functions: - # Long raw_function_names (particularly for c++ projects) may exceed - # filesystem limits on file path/name length when creating WorkDir. - max_len = os.pathconf('/', 'PC_NAME_MAX') - len('output-') - # Docker tag name cannot exceed 127 characters, and will be suffixed by - # '-experiment'. - docker_name_len = 127 - len('-03-experiment') - max_len = min(max_len, docker_name_len) - truncated_id = f'{project_name}-{function.get("name")}'[:max_len] - benchmarks.append( - cls(truncated_id.lower(), - data['project'], - data['language'], - function.get('signature'), - function.get('name'), - function.get('return_type'), - function.get('params'), - data['target_path'], - data.get('target_name'), - use_project_examples=use_project_examples, - cppify_headers=cppify_headers, - commit=commit, - use_context=use_context, - function_dict=function)) - - return benchmarks - - def __init__(self, - benchmark_id: str, - project: str, - language: str, - function_signature: str, - function_name: str, - return_type: str, - params: list[dict[str, str]], - target_path: str, - preferred_target_name: Optional[str] = None, - use_project_examples=True, - cppify_headers=False, - use_context=False, - commit=None, - function_dict: Optional[dict] = None, - test_file_path: str = ''): - self.id = benchmark_id - self.project = project - self.language = language - self.function_signature = function_signature - self.function_name = function_name - self.return_type = return_type - self.params = params - self.function_dict = function_dict - self.target_path = target_path - self._preferred_target_name = preferred_target_name - self.use_project_examples = use_project_examples - self.use_context = use_context - self.cppify_headers = cppify_headers - self.commit = commit - self.test_file_path = test_file_path - - if self.language == 'jvm': - # For java projects, in order to differentiate between overloaded methods - # the full signature is being used as function_name. The full signature - # is following the format of: - # [() - # The benchmark id uses the function_signature directly and is used as - # the name of the result directory. In order to avoid confusion in the - # directory name remove special characters in the id coming from the - # function signature. Additional special characters exist for - # constructors which will be shown as because constructors do not - # have names. - self.function_signature = self.function_name - self.id = self.id.replace('<', '').replace('>', '') - self.id = self.id.replace('[', '').replace(']', '') - self.id = self.id.replace('(', '_').replace(')', '').replace(',', '_') - - if self.language == 'python': - # For python projects, classes and methods name could begins with - # underscore character. This could affect the benchmark_id and cause - # OSS-Fuzz build failed if dot and underscore character is put together. - # Special handling of benchmark_id is needed to avoid this situation. - # For example, zipp._difference in zip project will have benchmark id of - # zipp-zipp._difference and the pattern '._' cause OSS-Fuzz failed to - # recognise the project name and needed to be replaced by - # zipp-zipp.difference. - self.id = self.id.replace('._', '.') - - if self.language == 'rust': - # For rust projects, double colon (::) is sometime used to identify - # crate, impl or trait name of a function. This could affect the - # benchmark_id and cause OSS-Fuzz build failed. - # Special handling of benchmark_id is needed to avoid this situation. - self.id = self.id.replace('::', '-') - - def __repr__(self): - return (f'Benchmark') - - @property - def target_name(self): - """Returns target_name if it is defined, + """Represents a benchmark.""" + + @classmethod + def to_yaml( + cls, benchmarks: list[Benchmark], outdir: str = "./", out_basename: str = "" + ): + """Converts and saves selected fields of a benchmark to a YAML file.""" + # Register the custom representer + yaml.add_representer(str, quoted_string_presenter) + result: dict[str, Any] = { + "project": benchmarks[0].project, + "language": benchmarks[0].language, + "target_path": benchmarks[0].target_path, + "target_name": benchmarks[0].target_name, + } + for benchmark in benchmarks: + if benchmark.test_file_path: + if "test_files" not in result: + result["test_files"] = [] + result["test_files"].append( + {"test_file_path": benchmark.test_file_path} + ) + else: + if "functions" not in result: + result["functions"] = [] + result["functions"].append( + { + "signature": benchmark.function_signature, + "name": benchmark.function_name, + "return_type": benchmark.return_type, + "params": benchmark.params, + } + ) + + if not out_basename: + out_basename = f"{benchmarks[0].project}.yaml" + with open(os.path.join(outdir, out_basename), "w") as file: + yaml.dump(result, file, default_flow_style=False, width=sys.maxsize) + + @classmethod + def from_yaml(cls, benchmark_path: str) -> List: + """Constructs a benchmark based on a yaml file.""" + benchmarks = [] + with open(benchmark_path, "r") as benchmark_file: + data = yaml.safe_load(benchmark_file) + if not data: + return [] + + project_name = data.get("project", "") + use_context = data.get("use_context", False) + use_project_examples = data.get("use_project_examples", True) + cppify_headers = data.get("cppify_headers", False) + commit = data.get("commit") + functions = data.get("functions", []) + + test_files = data.get("test_files", []) + if test_files: + for test_file in test_files: + max_len = os.pathconf("/", "PC_NAME_MAX") - len("output-") + test_file_path = test_file.get("test_file_path") + normalized_test_path = ( + test_file_path.replace("/", "_").replace(".", "_").replace("-", "_") + ) + truncated_id = f"{project_name}-{normalized_test_path}"[:max_len] + + benchmarks.append( + cls( + truncated_id.lower(), + data["project"], + data["language"], + "", + "", + "", + [], + data["target_path"], + data.get("target_name", ""), + test_file_path=test_file_path, + ) + ) + + if functions: + # function type benchmark + for function in functions: + # Long raw_function_names (particularly for c++ projects) may exceed + # filesystem limits on file path/name length when creating WorkDir. + max_len = os.pathconf("/", "PC_NAME_MAX") - len("output-") + # Docker tag name cannot exceed 127 characters, and will be suffixed by + # '-experiment'. + docker_name_len = 127 - len("-03-experiment") + max_len = min(max_len, docker_name_len) + truncated_id = f'{project_name}-{function.get("name")}'[:max_len] + benchmarks.append( + cls( + truncated_id.lower(), + data["project"], + data["language"], + function.get("signature"), + function.get("name"), + function.get("return_type"), + function.get("params"), + data["target_path"], + data.get("target_name"), + use_project_examples=use_project_examples, + cppify_headers=cppify_headers, + commit=commit, + use_context=use_context, + function_dict=function, + ) + ) + + return benchmarks + + def __init__( + self, + benchmark_id: str, + project: str, + language: str, + function_signature: str, + function_name: str, + return_type: str, + params: list[dict[str, str]], + target_path: str, + preferred_target_name: Optional[str] = None, + use_project_examples=True, + cppify_headers=False, + use_context=False, + commit=None, + function_dict: Optional[dict] = None, + test_file_path: str = "", + ): + self.id = benchmark_id + self.project = project + self.language = language + self.function_signature = function_signature + self.function_name = function_name + self.return_type = return_type + self.params = params + self.function_dict = function_dict + self.target_path = target_path + self._preferred_target_name = preferred_target_name + self.use_project_examples = use_project_examples + self.use_context = use_context + self.cppify_headers = cppify_headers + self.commit = commit + self.test_file_path = test_file_path + + if self.language == "jvm": + # For java projects, in order to differentiate between overloaded methods + # the full signature is being used as function_name. The full signature + # is following the format of: + # [() + # The benchmark id uses the function_signature directly and is used as + # the name of the result directory. In order to avoid confusion in the + # directory name remove special characters in the id coming from the + # function signature. Additional special characters exist for + # constructors which will be shown as because constructors do not + # have names. + self.function_signature = self.function_name + self.id = self.id.replace("<", "").replace(">", "") + self.id = self.id.replace("[", "").replace("]", "") + self.id = self.id.replace("(", "_").replace(")", "").replace(",", "_") + + if self.language == "python": + # For python projects, classes and methods name could begins with + # underscore character. This could affect the benchmark_id and cause + # OSS-Fuzz build failed if dot and underscore character is put together. + # Special handling of benchmark_id is needed to avoid this situation. + # For example, zipp._difference in zip project will have benchmark id of + # zipp-zipp._difference and the pattern '._' cause OSS-Fuzz failed to + # recognise the project name and needed to be replaced by + # zipp-zipp.difference. + self.id = self.id.replace("._", ".") + + if self.language == "rust": + # For rust projects, double colon (::) is sometime used to identify + # crate, impl or trait name of a function. This could affect the + # benchmark_id and cause OSS-Fuzz build failed. + # Special handling of benchmark_id is needed to avoid this situation. + self.id = self.id.replace("::", "-") + + def __repr__(self): + return ( + f"Benchmark" + ) + + @property + def target_name(self): + """Returns target_name if it is defined, otherwise use the basename of the target path.""" - return (self._preferred_target_name or - os.path.splitext(os.path.basename(self.target_path))[0]) - - @property - def file_type(self) -> FileType: - """Returns the file type of the benchmark.""" - return get_file_type(self.target_path) - - @property - def is_c_target(self) -> bool: - """Validates if the project is written in C.""" - return self.file_type.value.lower() == 'c' - - @property - def is_cpp_target(self) -> bool: - """Validates if the project is written in C++.""" - return self.file_type.value.lower() == 'c++' - - @property - def is_java_target(self) -> bool: - """Validates if the project is written in Java.""" - return self.file_type.value.lower() == 'java' - - @property - def is_c_project(self) -> bool: - """Validates if the project is written in C.""" - return self.language.lower() == 'c' - - @property - def is_cpp_project(self) -> bool: - """Validates if the project is written in C++.""" - return self.language.lower() == 'c++' - - @property - def is_java_project(self) -> bool: - """Validates if the project is written in Java.""" - return self.language.lower() == 'jvm' - - @property - def needs_extern(self) -> bool: - """Checks if it is C++ fuzz target for a C project, which needs `extern`.""" - return self.is_cpp_target and self.is_c_project + return ( + self._preferred_target_name + or os.path.splitext(os.path.basename(self.target_path))[0] + ) + + @property + def file_type(self) -> FileType: + """Returns the file type of the benchmark.""" + return get_file_type(self.target_path) + + @property + def is_c_target(self) -> bool: + """Validates if the project is written in C.""" + return self.file_type.value.lower() == "c" + + @property + def is_cpp_target(self) -> bool: + """Validates if the project is written in C++.""" + return self.file_type.value.lower() == "c++" + + @property + def is_java_target(self) -> bool: + """Validates if the project is written in Java.""" + return self.file_type.value.lower() == "java" + + @property + def is_c_project(self) -> bool: + """Validates if the project is written in C.""" + return self.language.lower() == "c" + + @property + def is_cpp_project(self) -> bool: + """Validates if the project is written in C++.""" + return self.language.lower() == "c++" + + @property + def is_java_project(self) -> bool: + """Validates if the project is written in Java.""" + return self.language.lower() == "jvm" + + @property + def needs_extern(self) -> bool: + """Checks if it is C++ fuzz target for a C project, which needs `extern`.""" + return self.is_cpp_target and self.is_c_project def get_file_type(file_path: str) -> FileType: - """Returns the file type based on the extension of |file_name|.""" - if file_path.endswith('.c'): - return FileType.C - cpp_extensions = ['.cc', '.cpp', '.cxx', '.c++', '.h', '.hpp'] - if any(file_path.endswith(ext) for ext in cpp_extensions): - return FileType.CPP - if file_path.endswith('.java'): - return FileType.JAVA - return FileType.NONE + """Returns the file type based on the extension of |file_name|.""" + if file_path.endswith(".c"): + return FileType.C + cpp_extensions = [".cc", ".cpp", ".cxx", ".c++", ".h", ".hpp"] + if any(file_path.endswith(ext) for ext in cpp_extensions): + return FileType.CPP + if file_path.endswith(".java"): + return FileType.JAVA + return FileType.NONE def is_c_file(file_path: str) -> bool: - """Validates if |file_path| is a C file by its extension.""" - return get_file_type(file_path) == FileType.C + """Validates if |file_path| is a C file by its extension.""" + return get_file_type(file_path) == FileType.C def is_cpp_file(file_path: str) -> bool: - """Validates if |file_path| is a C++ file by its extension.""" - return get_file_type(file_path) == FileType.CPP + """Validates if |file_path| is a C++ file by its extension.""" + return get_file_type(file_path) == FileType.CPP def is_java_file(file_path: str) -> bool: - """Validates if |file_path| is a Java file by its extension.""" - return get_file_type(file_path) == FileType.JAVA + """Validates if |file_path| is a Java file by its extension.""" + return get_file_type(file_path) == FileType.JAVA diff --git a/experiment/builder_runner.py b/experiment/builder_runner.py index 671dca0546..177462d433 100644 --- a/experiment/builder_runner.py +++ b/experiment/builder_runner.py @@ -46,1100 +46,1328 @@ CLOUD_EXP_MAX_ATTEMPT = 5 LIBFUZZER_MODULES_LOADED_REGEX = re.compile( - r'^INFO:\s+Loaded\s+\d+\s+(modules|PC tables)\s+\((\d+)\s+.*\).*') -LIBFUZZER_COV_REGEX = re.compile(r'.*cov: (\d+) ft:') -LIBFUZZER_CRASH_TYPE_REGEX = re.compile(r'.*Test unit written to.*') -LIBFUZZER_COV_LINE_PREFIX = re.compile(r'^#(\d+)') -LIBFUZZER_STACK_FRAME_LINE_PREFIX = re.compile(r'^\s+#\d+') -CRASH_EXCLUSIONS = re.compile(r'.*(slow-unit-|timeout-|leak-|oom-).*') -CRASH_STACK_WITH_SOURCE_INFO = re.compile(r'in.*:\d+:\d+$') - -LIBFUZZER_LOG_STACK_FRAME_LLVM = '/src/llvm-project/compiler-rt' -LIBFUZZER_LOG_STACK_FRAME_LLVM2 = '/work/llvm-stage2/projects/compiler-rt' -LIBFUZZER_LOG_STACK_FRAME_CPP = '/usr/local/bin/../include/c++' + r"^INFO:\s+Loaded\s+\d+\s+(modules|PC tables)\s+\((\d+)\s+.*\).*" +) +LIBFUZZER_COV_REGEX = re.compile(r".*cov: (\d+) ft:") +LIBFUZZER_CRASH_TYPE_REGEX = re.compile(r".*Test unit written to.*") +LIBFUZZER_COV_LINE_PREFIX = re.compile(r"^#(\d+)") +LIBFUZZER_STACK_FRAME_LINE_PREFIX = re.compile(r"^\s+#\d+") +CRASH_EXCLUSIONS = re.compile(r".*(slow-unit-|timeout-|leak-|oom-).*") +CRASH_STACK_WITH_SOURCE_INFO = re.compile(r"in.*:\d+:\d+$") + +LIBFUZZER_LOG_STACK_FRAME_LLVM = "/src/llvm-project/compiler-rt" +LIBFUZZER_LOG_STACK_FRAME_LLVM2 = "/work/llvm-stage2/projects/compiler-rt" +LIBFUZZER_LOG_STACK_FRAME_CPP = "/usr/local/bin/../include/c++" EARLY_FUZZING_ROUND_THRESHOLD = 3 -ParseResult = namedtuple('ParseResult', [ - 'cov_pcs', 'total_pcs', 'crashes', 'crash_info', 'artifact_name', - 'semantic_check_result' -]) +ParseResult = namedtuple( + "ParseResult", + [ + "cov_pcs", + "total_pcs", + "crashes", + "crash_info", + "artifact_name", + "semantic_check_result", + ], +) @dataclasses.dataclass class BuildResult: - """Results of compilation & link.""" + """Results of compilation & link.""" - succeeded: bool = False - errors: list[str] = dataclasses.field(default_factory=list) - log_path: str = '' + succeeded: bool = False + errors: list[str] = dataclasses.field(default_factory=list) + log_path: str = "" - def to_dict(self): - return dataclasses.asdict(self) + def to_dict(self): + return dataclasses.asdict(self) @dataclasses.dataclass class RunResult: - """Checked results of conducting short-term fuzzing.""" - - succeeded: bool = False - coverage_summary: dict = dataclasses.field(default_factory=dict) - coverage: Optional[textcov.Textcov] = None - log_path: str = '' - corpus_path: str = '' - coverage_report_path: str = '' - reproducer_path: str = '' - artifact_path: str = '' - artifact_name: str = '' - sanitizer: str = '' - cov_pcs: int = 0 - total_pcs: int = 0 - crashes: bool = False - crash_info: str = '' - triage: str = TriageResult.NOT_APPLICABLE - semantic_check: SemanticCheckResult = SemanticCheckResult( - SemanticCheckResult.NOT_APPLICABLE) - - def to_dict(self): - return dataclasses.asdict(self) + """Checked results of conducting short-term fuzzing.""" + + succeeded: bool = False + coverage_summary: dict = dataclasses.field(default_factory=dict) + coverage: Optional[textcov.Textcov] = None + log_path: str = "" + corpus_path: str = "" + coverage_report_path: str = "" + reproducer_path: str = "" + artifact_path: str = "" + artifact_name: str = "" + sanitizer: str = "" + cov_pcs: int = 0 + total_pcs: int = 0 + crashes: bool = False + crash_info: str = "" + triage: str = TriageResult.NOT_APPLICABLE + semantic_check: SemanticCheckResult = SemanticCheckResult( + SemanticCheckResult.NOT_APPLICABLE + ) + + def to_dict(self): + return dataclasses.asdict(self) class BuilderRunner: - """Builder and runner.""" - - # Regex for extract function name. - FUNC_NAME = re.compile(r'(?:^|\s|\b)([\w:]+::)*(\w+)(?:<[^>]*>)?(?=\(|$)') - # Regex for extract line number, - LINE_NUMBER = re.compile(r':(\d+):') - - def __init__(self, - benchmark: Benchmark, - work_dirs: WorkDirs, - run_timeout: int = RUN_TIMEOUT, - fixer_model_name: str = DefaultModel.name): - self.benchmark = benchmark - self.work_dirs = work_dirs - self.run_timeout = run_timeout - self.fixer_model_name = fixer_model_name - - def _libfuzzer_args(self) -> list[str]: - return [ - '-print_final_stats=1', - f'-max_total_time={self.run_timeout}', - # Without this flag, libFuzzer only consider short inputs in short - # experiments, which lowers the coverage for quick performance tests. - '-len_control=0', - # Timeout per testcase. - '-timeout=30', - '-detect_leaks=0', - ] - - def _get_minimum_func_name(self, func_sig: str) -> str: - """Extracts the minimum function name from function signature, - without name space, return type, params, templates.""" - pattern = (r'(?:[a-zA-Z_]\w*::)*([a-zA-Z_]\w*|operator[^(\s]*)(?:\s*<.*>)?' - r'\s*\(') - match = re.search(pattern, func_sig) - if not match: - return func_sig - - function_name = match.group(1).strip() - return function_name.removeprefix('operator') - - def _contains_target_jvm_method(self, target_path: str) -> bool: - """Validates if the LLM-generated code contains the target jvm methods.""" - signature = self.benchmark.function_signature - - # For test to harness approach, the target signature does not - # exist, no need to do this pre-check - if not signature or not '].' in signature: - return True - - with open(target_path) as generated_code_file: - code = generated_code_file.read() - - # This regex is used to identify legitimate Java variable names - # or instance method calls (which could return a needed variable). - # This is necessary because the method name of a Java method also - # includes its parameter list in order to distinguish between - # overloaded methods. Thus it need to use the regex to identify - # if there are method calls with unknown variable names that match - # the target method. - base_arg_regex = r'[\s]*[a-zA-Z_$][a-zA-Z_$0-9(),.]*' - name = signature.split('].')[1].split('(')[0] - arg_count = len(signature.split('(')[1].split(')')[0].split(',')) - - if '' in name: - # Always return true for Java constructors because it is not possible - # to match all possible ways to call the constructors - return True - - pattern = rf'({name}\({", ".join([base_arg_regex] * arg_count)}\))' - match = re.search(pattern, ''.join(code.splitlines()).replace(' ', '')) - - return bool(match) - - def _contains_target_function(self, target_path: str) -> bool: - """Validates if the LLM-generated code contains the target function.""" - with open(target_path) as generated_code_file: - generated_code = generated_code_file.read() - - min_func_name = self._get_minimum_func_name( - self.benchmark.function_signature) - - return min_func_name in generated_code - - def _contains_target_python_function(self, target_path: str) -> bool: - """Validates if the LLM-generated code contains the target function for - python projects.""" - with open(target_path) as generated_code_file: - generated_code = generated_code_file.read() - - min_func_name = self.benchmark.function_signature.rsplit('.', 1)[-1] - - return min_func_name in generated_code - - def _contains_target_rust_function(self, target_path: str) -> bool: - """Validates if the LLM-generated code contains the target function for - rust projects.""" - with open(target_path) as generated_code_file: - generated_code = generated_code_file.read() - - min_func_name = self._get_minimum_func_name( - self.benchmark.function_signature) - - # Retrieve function name only with crate, triat, impl or mod tag - min_func_name = min_func_name.rsplit('::', 1)[-1] - min_func_name = min_func_name.rsplit('.', 1)[-1] - - return min_func_name in generated_code - - def _pre_build_check(self, target_path: str, - build_result: BuildResult) -> bool: - """Checks the generated target before building and running it.""" - # No need to build the fuzz target if it does not contain the target - # function. - if self.benchmark.language == 'jvm': - result = self._contains_target_jvm_method(target_path) - elif self.benchmark.language == 'python': - result = self._contains_target_python_function(target_path) - elif self.benchmark.language == 'rust': - result = self._contains_target_rust_function(target_path) - else: - # C/C++ pre-build check is done in agents. - return True - - if not result: - build_result.errors = [ - (f'The target function `{self.benchmark.function_signature}`' - ' was not called by the fuzz target ' - '`LLVMFuzzerTestOneInput`.' - 'YOU MUST CALL FUNCTION ' - f'`{self.benchmark.function_signature}` INSIDE FUNCTION ' - '`LLVMFuzzerTestOneInput`.') - ] - logger.warning('Missing target function: %s does not contain %s', - target_path, self.benchmark.function_signature) - - return result - - def _parse_stacks_from_libfuzzer_logs(self, - lines: list[str]) -> list[list[str]]: - """Parses stack traces from libFuzzer logs.""" - # TODO (dongge): Use stack parsing from ClusterFuzz. - # There can have over one thread stack in a log. - stacks = [] - - # A stack -> a sequence of stack frame lines. - stack, stack_parsing = [], False - for line in lines: - is_stack_frame_line = LIBFUZZER_STACK_FRAME_LINE_PREFIX.match( - line) is not None - if (not stack_parsing) and is_stack_frame_line: - # First line. - stack_parsing = True - stack = [line.strip()] - elif stack_parsing and is_stack_frame_line: - # Middle line(s). - stack.append(line.strip()) - elif stack_parsing and (not is_stack_frame_line): - # Last line. - stack_parsing = False - stacks.append(stack) - - # Last stack. - if stack_parsing: - stacks.append(stack) - - return stacks - - def _parse_func_from_stacks(self, project_name: str, - stacks: list[list[str]]) -> dict: - """Parses project functions from stack traces.""" - func_info = defaultdict(set) - - for stack in stacks: - for line in stack: - # Use 3 spaces to divide each line of crash info into four parts. - # Only parse the fourth part, which includes the function name, - # file path, and line number. - parts = line.split(' ', 3) - if len(parts) < 4: - continue - func_and_file_path = parts[3] - if project_name not in func_and_file_path: - continue - func_name, _, file_path = func_and_file_path.partition(' /') - if func_name == 'LLVMFuzzerTestOneInput': - line_match = self.LINE_NUMBER.search(file_path) - if line_match: - line_number = int(line_match.group(1)) - func_info[func_name].add(line_number) - else: - logger.warning('Failed to parse line number from %s in project %s', - func_name, project_name) - break - if project_name in file_path: - func_match = self.FUNC_NAME.search(func_name) - line_match = self.LINE_NUMBER.search(file_path) - if func_match and line_match: - func_name = func_match.group(2) - line_number = int(line_match.group(1)) - func_info[func_name].add(line_number) - else: + """Builder and runner.""" + + # Regex for extract function name. + FUNC_NAME = re.compile(r"(?:^|\s|\b)([\w:]+::)*(\w+)(?:<[^>]*>)?(?=\(|$)") + # Regex for extract line number, + LINE_NUMBER = re.compile(r":(\d+):") + + def __init__( + self, + benchmark: Benchmark, + work_dirs: WorkDirs, + run_timeout: int = RUN_TIMEOUT, + fixer_model_name: str = DefaultModel.name, + ): + self.benchmark = benchmark + self.work_dirs = work_dirs + self.run_timeout = run_timeout + self.fixer_model_name = fixer_model_name + + def _libfuzzer_args(self) -> list[str]: + return [ + "-print_final_stats=1", + f"-max_total_time={self.run_timeout}", + # Without this flag, libFuzzer only consider short inputs in short + # experiments, which lowers the coverage for quick performance tests. + "-len_control=0", + # Timeout per testcase. + "-timeout=30", + "-detect_leaks=0", + ] + + def _get_minimum_func_name(self, func_sig: str) -> str: + """Extracts the minimum function name from function signature, + without name space, return type, params, templates.""" + pattern = ( + r"(?:[a-zA-Z_]\w*::)*([a-zA-Z_]\w*|operator[^(\s]*)(?:\s*<.*>)?" r"\s*\(" + ) + match = re.search(pattern, func_sig) + if not match: + return func_sig + + function_name = match.group(1).strip() + return function_name.removeprefix("operator") + + def _contains_target_jvm_method(self, target_path: str) -> bool: + """Validates if the LLM-generated code contains the target jvm methods.""" + signature = self.benchmark.function_signature + + # For test to harness approach, the target signature does not + # exist, no need to do this pre-check + if not signature or not "]." in signature: + return True + + with open(target_path) as generated_code_file: + code = generated_code_file.read() + + # This regex is used to identify legitimate Java variable names + # or instance method calls (which could return a needed variable). + # This is necessary because the method name of a Java method also + # includes its parameter list in order to distinguish between + # overloaded methods. Thus it need to use the regex to identify + # if there are method calls with unknown variable names that match + # the target method. + base_arg_regex = r"[\s]*[a-zA-Z_$][a-zA-Z_$0-9(),.]*" + name = signature.split("].")[1].split("(")[0] + arg_count = len(signature.split("(")[1].split(")")[0].split(",")) + + if "" in name: + # Always return true for Java constructors because it is not possible + # to match all possible ways to call the constructors + return True + + pattern = rf'({name}\({", ".join([base_arg_regex] * arg_count)}\))' + match = re.search(pattern, "".join(code.splitlines()).replace(" ", "")) + + return bool(match) + + def _contains_target_function(self, target_path: str) -> bool: + """Validates if the LLM-generated code contains the target function.""" + with open(target_path) as generated_code_file: + generated_code = generated_code_file.read() + + min_func_name = self._get_minimum_func_name(self.benchmark.function_signature) + + return min_func_name in generated_code + + def _contains_target_python_function(self, target_path: str) -> bool: + """Validates if the LLM-generated code contains the target function for + python projects.""" + with open(target_path) as generated_code_file: + generated_code = generated_code_file.read() + + min_func_name = self.benchmark.function_signature.rsplit(".", 1)[-1] + + return min_func_name in generated_code + + def _contains_target_rust_function(self, target_path: str) -> bool: + """Validates if the LLM-generated code contains the target function for + rust projects.""" + with open(target_path) as generated_code_file: + generated_code = generated_code_file.read() + + min_func_name = self._get_minimum_func_name(self.benchmark.function_signature) + + # Retrieve function name only with crate, triat, impl or mod tag + min_func_name = min_func_name.rsplit("::", 1)[-1] + min_func_name = min_func_name.rsplit(".", 1)[-1] + + return min_func_name in generated_code + + def _pre_build_check(self, target_path: str, build_result: BuildResult) -> bool: + """Checks the generated target before building and running it.""" + # No need to build the fuzz target if it does not contain the target + # function. + if self.benchmark.language == "jvm": + result = self._contains_target_jvm_method(target_path) + elif self.benchmark.language == "python": + result = self._contains_target_python_function(target_path) + elif self.benchmark.language == "rust": + result = self._contains_target_rust_function(target_path) + else: + # C/C++ pre-build check is done in agents. + return True + + if not result: + build_result.errors = [ + ( + f"The target function `{self.benchmark.function_signature}`" + " was not called by the fuzz target " + "`LLVMFuzzerTestOneInput`." + "YOU MUST CALL FUNCTION " + f"`{self.benchmark.function_signature}` INSIDE FUNCTION " + "`LLVMFuzzerTestOneInput`." + ) + ] logger.warning( - 'Failed to parse function name from %s in project %s', - func_name, project_name) - - return func_info - - def _parse_fuzz_cov_info_from_libfuzzer_logs( - self, - lines: list[str]) -> tuple[Optional[int], Optional[int], Optional[int]]: - """Parses cov of INITED & DONE, and round number from libFuzzer logs.""" - initcov, donecov, lastround = None, None, None - - for line in lines: - if line.startswith('#'): - # Parses cov line to get the round number. - match = LIBFUZZER_COV_LINE_PREFIX.match(line) - roundno = int(match.group(1)) if match else None - - if roundno is not None: - lastround = roundno - if 'INITED' in line and 'cov: ' in line: - initcov = int(line.split('cov: ')[1].split(' ft:')[0]) - elif 'DONE' in line and 'cov: ' in line: - donecov = int(line.split('cov: ')[1].split(' ft:')[0]) - - return initcov, donecov, lastround - - def _stack_func_is_of_testing_project(self, stack_frame: str) -> bool: - return (bool(CRASH_STACK_WITH_SOURCE_INFO.match(stack_frame)) and - LIBFUZZER_LOG_STACK_FRAME_LLVM not in stack_frame and - LIBFUZZER_LOG_STACK_FRAME_LLVM2 not in stack_frame and - LIBFUZZER_LOG_STACK_FRAME_CPP not in stack_frame) - - def _parse_libfuzzer_logs(self, - log_handle, - project_name: str, - check_cov_increase: bool = True) -> ParseResult: - """Parses libFuzzer logs.""" - lines = None - try: - fuzzlog = log_handle.read(-1) - # Some crashes can mess up the libfuzzer output and raise decode error. - fuzzlog = fuzzlog.decode('utf-8', errors='ignore') - lines = fuzzlog.split('\n') - except MemoryError as e: - # Some logs from abnormal fuzz targets are too large to be parsed. - logger.error('%s is too large to parse: %s', log_handle.name, e) - return ParseResult(0, 0, False, '', '', - SemanticCheckResult(SemanticCheckResult.LOG_MESS_UP)) - - cov_pcs, total_pcs, crashes = 0, 0, False - - for line in lines: - m = LIBFUZZER_MODULES_LOADED_REGEX.match(line) - if m: - total_pcs = int(m.group(2)) - continue - - m = LIBFUZZER_COV_REGEX.match(line) - if m: - cov_pcs = int(m.group(1)) - continue - - m = LIBFUZZER_CRASH_TYPE_REGEX.match(line) - if m and not CRASH_EXCLUSIONS.match(line): - # TODO(@happy-qop): Handling oom, slow cases in semantic checks & fix. - crashes = True - continue - - initcov, donecov, lastround = self._parse_fuzz_cov_info_from_libfuzzer_logs( - lines) - - # NOTE: Crashes from incorrect fuzz targets will not be counted finally. - - if crashes: - symptom = SemanticCheckResult.extract_symptom(fuzzlog) - crash_stacks = self._parse_stacks_from_libfuzzer_logs(lines) - crash_func = self._parse_func_from_stacks(project_name, crash_stacks) - crash_info = SemanticCheckResult.extract_crash_info(fuzzlog) - artifact_name = SemanticCheckResult.extract_artifact_name(fuzzlog) - - # FP case 1: Common fuzz target errors. - # Null-deref, normally indicating inadequate parameter initialization or - # wrong function usage. - if symptom == 'null-deref': - return ParseResult( - cov_pcs, total_pcs, True, crash_info, artifact_name, - SemanticCheckResult(SemanticCheckResult.NULL_DEREF, symptom, - crash_stacks, crash_func)) + "Missing target function: %s does not contain %s", + target_path, + self.benchmark.function_signature, + ) + + return result + + def _parse_stacks_from_libfuzzer_logs(self, lines: list[str]) -> list[list[str]]: + """Parses stack traces from libFuzzer logs.""" + # TODO (dongge): Use stack parsing from ClusterFuzz. + # There can have over one thread stack in a log. + stacks = [] + + # A stack -> a sequence of stack frame lines. + stack, stack_parsing = [], False + for line in lines: + is_stack_frame_line = ( + LIBFUZZER_STACK_FRAME_LINE_PREFIX.match(line) is not None + ) + if (not stack_parsing) and is_stack_frame_line: + # First line. + stack_parsing = True + stack = [line.strip()] + elif stack_parsing and is_stack_frame_line: + # Middle line(s). + stack.append(line.strip()) + elif stack_parsing and (not is_stack_frame_line): + # Last line. + stack_parsing = False + stacks.append(stack) + + # Last stack. + if stack_parsing: + stacks.append(stack) + + return stacks + + def _parse_func_from_stacks( + self, project_name: str, stacks: list[list[str]] + ) -> dict: + """Parses project functions from stack traces.""" + func_info = defaultdict(set) + + for stack in stacks: + for line in stack: + # Use 3 spaces to divide each line of crash info into four parts. + # Only parse the fourth part, which includes the function name, + # file path, and line number. + parts = line.split(" ", 3) + if len(parts) < 4: + continue + func_and_file_path = parts[3] + if project_name not in func_and_file_path: + continue + func_name, _, file_path = func_and_file_path.partition(" /") + if func_name == "LLVMFuzzerTestOneInput": + line_match = self.LINE_NUMBER.search(file_path) + if line_match: + line_number = int(line_match.group(1)) + func_info[func_name].add(line_number) + else: + logger.warning( + "Failed to parse line number from %s in project %s", + func_name, + project_name, + ) + break + if project_name in file_path: + func_match = self.FUNC_NAME.search(func_name) + line_match = self.LINE_NUMBER.search(file_path) + if func_match and line_match: + func_name = func_match.group(2) + line_number = int(line_match.group(1)) + func_info[func_name].add(line_number) + else: + logger.warning( + "Failed to parse function name from %s in project %s", + func_name, + project_name, + ) + + return func_info + + def _parse_fuzz_cov_info_from_libfuzzer_logs( + self, lines: list[str] + ) -> tuple[Optional[int], Optional[int], Optional[int]]: + """Parses cov of INITED & DONE, and round number from libFuzzer logs.""" + initcov, donecov, lastround = None, None, None + + for line in lines: + if line.startswith("#"): + # Parses cov line to get the round number. + match = LIBFUZZER_COV_LINE_PREFIX.match(line) + roundno = int(match.group(1)) if match else None + + if roundno is not None: + lastround = roundno + if "INITED" in line and "cov: " in line: + initcov = int(line.split("cov: ")[1].split(" ft:")[0]) + elif "DONE" in line and "cov: " in line: + donecov = int(line.split("cov: ")[1].split(" ft:")[0]) + + return initcov, donecov, lastround + + def _stack_func_is_of_testing_project(self, stack_frame: str) -> bool: + return ( + bool(CRASH_STACK_WITH_SOURCE_INFO.match(stack_frame)) + and LIBFUZZER_LOG_STACK_FRAME_LLVM not in stack_frame + and LIBFUZZER_LOG_STACK_FRAME_LLVM2 not in stack_frame + and LIBFUZZER_LOG_STACK_FRAME_CPP not in stack_frame + ) + + def _parse_libfuzzer_logs( + self, log_handle, project_name: str, check_cov_increase: bool = True + ) -> ParseResult: + """Parses libFuzzer logs.""" + lines = None + try: + fuzzlog = log_handle.read(-1) + # Some crashes can mess up the libfuzzer output and raise decode error. + fuzzlog = fuzzlog.decode("utf-8", errors="ignore") + lines = fuzzlog.split("\n") + except MemoryError as e: + # Some logs from abnormal fuzz targets are too large to be parsed. + logger.error("%s is too large to parse: %s", log_handle.name, e) + return ParseResult( + 0, + 0, + False, + "", + "", + SemanticCheckResult(SemanticCheckResult.LOG_MESS_UP), + ) + + cov_pcs, total_pcs, crashes = 0, 0, False + + for line in lines: + m = LIBFUZZER_MODULES_LOADED_REGEX.match(line) + if m: + total_pcs = int(m.group(2)) + continue + + m = LIBFUZZER_COV_REGEX.match(line) + if m: + cov_pcs = int(m.group(1)) + continue + + m = LIBFUZZER_CRASH_TYPE_REGEX.match(line) + if m and not CRASH_EXCLUSIONS.match(line): + # TODO(@happy-qop): Handling oom, slow cases in semantic checks & fix. + crashes = True + continue + + initcov, donecov, lastround = self._parse_fuzz_cov_info_from_libfuzzer_logs( + lines + ) + + # NOTE: Crashes from incorrect fuzz targets will not be counted finally. + + if crashes: + symptom = SemanticCheckResult.extract_symptom(fuzzlog) + crash_stacks = self._parse_stacks_from_libfuzzer_logs(lines) + crash_func = self._parse_func_from_stacks(project_name, crash_stacks) + crash_info = SemanticCheckResult.extract_crash_info(fuzzlog) + artifact_name = SemanticCheckResult.extract_artifact_name(fuzzlog) + + # FP case 1: Common fuzz target errors. + # Null-deref, normally indicating inadequate parameter initialization or + # wrong function usage. + if symptom == "null-deref": + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + artifact_name, + SemanticCheckResult( + SemanticCheckResult.NULL_DEREF, + symptom, + crash_stacks, + crash_func, + ), + ) + + # Signal, normally indicating assertion failure due to inadequate + # parameter initialization or wrong function usage. + if symptom == "signal": + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + artifact_name, + SemanticCheckResult( + SemanticCheckResult.SIGNAL, symptom, crash_stacks, crash_func + ), + ) + + # Exit, normally indicating the fuzz target exited in a controlled manner, + # blocking its bug discovery. + if symptom.endswith("fuzz target exited"): + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + artifact_name, + SemanticCheckResult( + SemanticCheckResult.EXIT, symptom, crash_stacks, crash_func + ), + ) + + # Fuzz target modified constants. + if symptom.endswith("fuzz target overwrites its const input"): + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + artifact_name, + SemanticCheckResult( + SemanticCheckResult.OVERWRITE_CONST, + symptom, + crash_stacks, + crash_func, + ), + ) + + # OOM, normally indicating malloc's parameter is too large, e.g., because + # of using parameter `size`. + # TODO(dongge): Refine this, 1) Merge this with the other oom case found + # from reproducer name; 2) Capture the actual number in (malloc(\d+)). + if "out-of-memory" in symptom or "out of memory" in symptom: + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + artifact_name, + SemanticCheckResult( + SemanticCheckResult.FP_OOM, symptom, crash_stacks, crash_func + ), + ) + + # FP case 2: fuzz target crashes at init or first few rounds. + if lastround is None or lastround <= EARLY_FUZZING_ROUND_THRESHOLD: + # No cov line has been identified or only INITED round has been passed. + # This is very likely the false positive cases. + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + artifact_name, + SemanticCheckResult( + SemanticCheckResult.FP_NEAR_INIT_CRASH, + symptom, + crash_stacks, + crash_func, + ), + ) + + # FP case 3: no func in 1st thread stack belongs to testing proj. + if len(crash_stacks) > 0: + first_stack = crash_stacks[0] + for stack_frame in first_stack: + if self._stack_func_is_of_testing_project(stack_frame): + if "LLVMFuzzerTestOneInput" in stack_frame: + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + artifact_name, + SemanticCheckResult( + SemanticCheckResult.FP_TARGET_CRASH, + symptom, + crash_stacks, + crash_func, + ), + ) + break + + return ParseResult( + cov_pcs, + total_pcs, + True, + crash_info, + artifact_name, + SemanticCheckResult( + SemanticCheckResult.NO_SEMANTIC_ERR, + symptom, + crash_stacks, + crash_func, + ), + ) + + if check_cov_increase and initcov == donecov and lastround is not None: + # Another error fuzz target case: no cov increase. + # A special case is initcov == donecov == None, which indicates no + # interesting inputs were found. This may happen if the target rejected + # all inputs we tried. + return ParseResult( + cov_pcs, + total_pcs, + False, + "", + "", + SemanticCheckResult(SemanticCheckResult.NO_COV_INCREASE), + ) - # Signal, normally indicating assertion failure due to inadequate - # parameter initialization or wrong function usage. - if symptom == 'signal': return ParseResult( - cov_pcs, total_pcs, True, crash_info, artifact_name, - SemanticCheckResult(SemanticCheckResult.SIGNAL, symptom, - crash_stacks, crash_func)) + cov_pcs, + total_pcs, + crashes, + "", + "", + SemanticCheckResult(SemanticCheckResult.NO_SEMANTIC_ERR), + ) + + def _copy_crash_file( + self, outdir: str, artifact_dir: str, run_result: RunResult + ) -> None: + """Copies the first crash file to the artifact directory.""" + # Only consider testcases starting with 'crash-' + crash_files = [ + f + for f in os.listdir(outdir) + if f.startswith("crash-") and os.path.isfile(os.path.join(outdir, f)) + ] + if len(crash_files) != 0: + crash_file = crash_files[0] + src = os.path.join(outdir, crash_file) + dst = os.path.join(artifact_dir, crash_file) + run_result.artifact_path = dst + shutil.copy2(src, dst) + logger.info("Copied crash file %s to %s", crash_file, artifact_dir) + + def build_and_run( + self, + generated_project: str, + target_path: str, + iteration: int, + language: str, + cloud_build_tags: Optional[list[str]] = None, + trial: int = 0, + ) -> tuple[BuildResult, Optional[RunResult]]: + """Builds and runs the fuzz target for fuzzing.""" + del cloud_build_tags + build_result = BuildResult() + + if not self._pre_build_check(target_path, build_result): + logger.warning("Pre-build check failure: %s", build_result) + return build_result, None + + try: + return self.build_and_run_local( + generated_project, target_path, iteration, build_result, language, trial + ) + except Exception as err: + logger.warning( + "Error occurred when building and running fuzz target locally" + "(attempt %d) %s: %s", + iteration, + err, + traceback.format_exc(), + ) + raise err + + def build_and_run_local( + self, + generated_project: str, + target_path: str, + iteration: int, + build_result: BuildResult, + language: str, + trial: int = 0, + ) -> tuple[BuildResult, Optional[RunResult]]: + """Builds and runs the fuzz target locally for fuzzing.""" + project_name = self.benchmark.project + benchmark_target_name = os.path.basename(target_path) + project_target_name = os.path.basename(self.benchmark.target_path) + benchmark_log_path = self.work_dirs.build_logs_target( + benchmark_target_name, iteration, trial + ) + build_result.succeeded = self.build_target_local( + generated_project, benchmark_log_path + ) + if not build_result.succeeded: + errors = code_fixer.extract_error_message( + benchmark_log_path, project_target_name, language + ) + build_result.errors = errors + return build_result, None + + # TODO(Dongge): Split Builder and Runner: + # Make the rest lines in an independent function. + run_result = RunResult() + + run_log_path = os.path.join(self.work_dirs.run_logs, f"{trial:02d}.log") + self.run_target_local(generated_project, benchmark_target_name, run_log_path) + artifact_dir = self.work_dirs.artifact(benchmark_target_name, iteration, trial) + outdir = get_build_artifact_dir(generated_project, "out") + self._copy_crash_file(outdir, artifact_dir, run_result) + + run_result.coverage, run_result.coverage_summary = self.get_coverage_local( + generated_project, benchmark_target_name + ) + + run_result.log_path = run_log_path + + # Parse libfuzzer logs to get fuzz target runtime details. + with open(run_log_path, "rb") as f: + # In many case JVM/python projects won't have much cov + # difference in short running. Adding the flag for JVM/python + # projects to temporary skip the checking of coverage change. + # Also skipping for rust projects in initial implementation. + flag = not self.benchmark.language in ["jvm", "python", "rust"] + ( + run_result.cov_pcs, + run_result.total_pcs, + run_result.crashes, + run_result.crash_info, + run_result.artifact_name, + run_result.semantic_check, + ) = self._parse_libfuzzer_logs(f, project_name, flag) + + return build_result, run_result + + def run_target_local( + self, generated_project: str, benchmark_target_name: str, log_path: str + ): + """Runs a target in the fixed target directory.""" + # If target name is not overridden, use the basename of the target path + # in the Dockerfile. + logger.info("Running %s", generated_project) + corpus_dir = self.work_dirs.corpus(benchmark_target_name) + command = [ + "python3", + "infra/helper.py", + "run_fuzzer", + "--corpus-dir", + corpus_dir, + generated_project, + self.benchmark.target_name, + "--", + ] + self._libfuzzer_args() + + with open(log_path, "w") as f: + proc = sp.Popen( + command, + stdin=sp.DEVNULL, + stdout=f, + stderr=sp.STDOUT, + cwd=oss_fuzz_checkout.OSS_FUZZ_DIR, + ) + + # TODO(ochang): Handle the timeout exception. + try: + proc.wait(timeout=self.run_timeout + 5) + except sp.TimeoutExpired: + logger.info("%s timed out during fuzzing.", generated_project) + # Try continuing and parsing the logs even in case of timeout. + + if proc.returncode != 0: + logger.info("********** Failed to run %s. **********", generated_project) + else: + logger.info("Successfully run %s.", generated_project) + + def build_target_local( + self, generated_project: str, log_path: str, sanitizer: str = "address" + ) -> bool: + """Builds a target with OSS-Fuzz.""" + + logger.info("Building %s with %s", generated_project, sanitizer) + + if oss_fuzz_checkout.ENABLE_CACHING and oss_fuzz_checkout.is_image_cached( + self.benchmark.project, sanitizer + ): + logger.info("We should use cached instance.") + # Rewrite for caching. + oss_fuzz_checkout.rewrite_project_to_cached_project( + self.benchmark.project, generated_project, sanitizer + ) + + # Prepare build + oss_fuzz_checkout.prepare_build( + self.benchmark.project, sanitizer, generated_project + ) + + else: + logger.info("The project does not have any cache") + + # Build the image + command = [ + "docker", + "build", + "-t", + f"gcr.io/oss-fuzz/{generated_project}", + os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, "projects", generated_project), + ] + with open(log_path, "w+") as log_file: + try: + sp.run( + command, + cwd=oss_fuzz_checkout.OSS_FUZZ_DIR, + stdin=sp.DEVNULL, + stdout=log_file, + stderr=sp.STDOUT, + check=True, + ) + except sp.CalledProcessError as e: + logger.info("Failed to build image for %s: %s", generated_project, e) + return False + + outdir = get_build_artifact_dir(generated_project, "out") + workdir = get_build_artifact_dir(generated_project, "work") + command = [ + "docker", + "run", + "--rm", + "--privileged", + "--shm-size=2g", + "--platform", + "linux/amd64", + "-i", + "-e", + "FUZZING_ENGINE=libfuzzer", + "-e", + f"SANITIZER={sanitizer}", + "-e", + "ARCHITECTURE=x86_64", + "-e", + f"PROJECT_NAME={generated_project}", + "-e", + f"FUZZING_LANGUAGE={self.benchmark.language}", + "-v", + f"{outdir}:/out", + "-v", + f"{workdir}:/work", + ] + # Avoid permissions errors. + os.makedirs(outdir, exist_ok=True) + os.makedirs(workdir, exist_ok=True) + command.extend(["--entrypoint", "/bin/bash"]) + command.append(f"gcr.io/oss-fuzz/{generated_project}") + + pre_build_command = [] + post_build_command = [] + + # Cleanup mounted dirs. + pre_build_command.extend(["rm", "-rf", "/out/*", "/work/*", "&&"]) + + if self.benchmark.commit: + # TODO(metzman): Try to use build_specified_commit here. + for repo, commit in self.benchmark.commit.items(): + pre_build_command.extend( + [ + "git", + "-C", + repo, + "fetch", + "--unshallow", + "-f", + "||", + "true", + "&&", + ] + ) + pre_build_command.extend( + ["git", "-C", repo, "checkout", commit, "-f", "&&"] + ) + + post_build_command.extend(["&&", "chmod", "777", "-R", "/out/*"]) + + build_command = pre_build_command + ["compile"] + post_build_command + build_bash_command = ["-c", " ".join(build_command)] + command.extend(build_bash_command) + with open(log_path, "w+") as log_file: + try: + sp.run( + command, + cwd=oss_fuzz_checkout.OSS_FUZZ_DIR, + stdin=sp.DEVNULL, + stdout=log_file, + stderr=sp.STDOUT, + check=True, + ) + except sp.CalledProcessError: + logger.info( + "Failed to build fuzzer for %s with %s", + generated_project, + sanitizer, + ) + return False + + logger.info( + "Successfully build fuzzer for %s with %s", generated_project, sanitizer + ) + return True - # Exit, normally indicating the fuzz target exited in a controlled manner, - # blocking its bug discovery. - if symptom.endswith('fuzz target exited'): - return ParseResult( - cov_pcs, total_pcs, True, crash_info, artifact_name, - SemanticCheckResult(SemanticCheckResult.EXIT, symptom, crash_stacks, - crash_func)) + def _get_coverage_text_filename(self, project_name: str) -> str: + """Get the filename of the text coverage file. This is language + dependent.""" + lang_to_textcov_basename = { + "jvm": "jacoco.xml", + "python": "all_cov.json", + "c++": f"{self.benchmark.target_name}.covreport", + "c": f"{self.benchmark.target_name}.covreport", + "rust": f"{self.benchmark.target_name}.covreport", + } + + return os.path.join( + get_build_artifact_dir(project_name, "out"), + "textcov_reports", + lang_to_textcov_basename[self.benchmark.language], + ) + + def _extract_local_textcoverage_data(self, project_name: str) -> textcov.Textcov: + """Returns the textcoverage from a local coverage run.""" + local_textcov_location = self._get_coverage_text_filename(project_name) + language_modes = { + "jvm": "r", + "python": "r", + "c": "rb", + "c++": "rb", + "rust": "rb", + } + with open( + local_textcov_location, language_modes.get(self.benchmark.language, "rb") + ) as f: + if self.benchmark.language == "jvm": + new_textcov = textcov.Textcov.from_jvm_file(f) + elif self.benchmark.language == "python": + new_textcov = textcov.Textcov.from_python_file(f) + elif self.benchmark.language == "rust": + new_textcov = textcov.Textcov.from_rust_file(f) + else: + target_basename = os.path.basename(self.benchmark.target_path) + new_textcov = textcov.Textcov.from_file( + f, + ignore_function_patterns=[ + # Don't include other functions defined in the target code. + re.compile(r"^" + re.escape(target_basename) + ":") + ], + ) + return new_textcov + + def get_coverage_local( + self, generated_project: str, benchmark_target_name: str + ) -> tuple[Optional[textcov.Textcov], Any]: + """Builds the generate project with coverage sanitizer, runs OSS-Fuzz + coverage extraction and then returns the generated coverage reports, in + the form of the text coverage as well as the summary.json.""" + sample_id = os.path.splitext(benchmark_target_name)[0] + log_path = os.path.join(self.work_dirs.build_logs, f"{sample_id}-coverage.log") + logger.info("Building project for coverage") + built_coverage = self.build_target_local( + generated_project, log_path, sanitizer="coverage" + ) + if not built_coverage: + logger.info("Failed to make coverage build for %s", generated_project) + return None, None + + logger.info("Extracting coverage") + corpus_dir = self.work_dirs.corpus(benchmark_target_name) + command = [ + "python3", + "infra/helper.py", + "coverage", + "--corpus-dir", + corpus_dir, + "--fuzz-target", + self.benchmark.target_name, + "--no-serve", + "--port", + "", + generated_project, + ] + + try: + sp.run( + command, + capture_output=True, + cwd=oss_fuzz_checkout.OSS_FUZZ_DIR, + stdin=sp.DEVNULL, + check=True, + ) + except sp.CalledProcessError as e: + logger.info( + "Failed to generate coverage for %s:\n%s\n%s", + generated_project, + e.stdout, + e.stderr, + ) + return None, None + + # Get the local text coverage, which includes the specific lines + # exercised in the target project. + local_textcov = self._extract_local_textcoverage_data(generated_project) + + # Copy the code coverage to a folder in the results directory so + # the coverage can be displayed in the result HTML page. + coverage_report = os.path.join( + get_build_artifact_dir(generated_project, "out"), "report" + ) + destination_coverage = self.work_dirs.code_coverage_report( + benchmark_target_name + ) + shutil.copytree(coverage_report, destination_coverage, dirs_exist_ok=True) + + textcov_dir = os.path.join( + get_build_artifact_dir(generated_project, "out"), "textcov_reports" + ) + dst_textcov = os.path.join( + self.work_dirs.code_coverage_report(benchmark_target_name), "textcov" + ) + shutil.copytree(textcov_dir, dst_textcov, dirs_exist_ok=True) + + coverage_summary = os.path.join( + get_build_artifact_dir(generated_project, "out"), + "report", + "linux", + "summary.json", + ) + with open(coverage_summary) as f: + coverage_summary = json.load(f) + + return local_textcov, coverage_summary - # Fuzz target modified constants. - if symptom.endswith('fuzz target overwrites its const input'): - return ParseResult( - cov_pcs, total_pcs, True, crash_info, artifact_name, - SemanticCheckResult(SemanticCheckResult.OVERWRITE_CONST, symptom, - crash_stacks, crash_func)) - - # OOM, normally indicating malloc's parameter is too large, e.g., because - # of using parameter `size`. - # TODO(dongge): Refine this, 1) Merge this with the other oom case found - # from reproducer name; 2) Capture the actual number in (malloc(\d+)). - if 'out-of-memory' in symptom or 'out of memory' in symptom: - return ParseResult( - cov_pcs, total_pcs, True, crash_info, artifact_name, - SemanticCheckResult(SemanticCheckResult.FP_OOM, symptom, - crash_stacks, crash_func)) - - # FP case 2: fuzz target crashes at init or first few rounds. - if lastround is None or lastround <= EARLY_FUZZING_ROUND_THRESHOLD: - # No cov line has been identified or only INITED round has been passed. - # This is very likely the false positive cases. - return ParseResult( - cov_pcs, total_pcs, True, crash_info, artifact_name, - SemanticCheckResult(SemanticCheckResult.FP_NEAR_INIT_CRASH, symptom, - crash_stacks, crash_func)) - - # FP case 3: no func in 1st thread stack belongs to testing proj. - if len(crash_stacks) > 0: - first_stack = crash_stacks[0] - for stack_frame in first_stack: - if self._stack_func_is_of_testing_project(stack_frame): - if 'LLVMFuzzerTestOneInput' in stack_frame: - return ParseResult( - cov_pcs, total_pcs, True, crash_info, artifact_name, - SemanticCheckResult(SemanticCheckResult.FP_TARGET_CRASH, - symptom, crash_stacks, crash_func)) - break - - return ParseResult( - cov_pcs, total_pcs, True, crash_info, artifact_name, - SemanticCheckResult(SemanticCheckResult.NO_SEMANTIC_ERR, symptom, - crash_stacks, crash_func)) - - if check_cov_increase and initcov == donecov and lastround is not None: - # Another error fuzz target case: no cov increase. - # A special case is initcov == donecov == None, which indicates no - # interesting inputs were found. This may happen if the target rejected - # all inputs we tried. - return ParseResult( - cov_pcs, total_pcs, False, '', '', - SemanticCheckResult(SemanticCheckResult.NO_COV_INCREASE)) - - return ParseResult(cov_pcs, total_pcs, crashes, '', '', - SemanticCheckResult(SemanticCheckResult.NO_SEMANTIC_ERR)) - - def _copy_crash_file(self, outdir: str, artifact_dir: str, - run_result: RunResult) -> None: - """Copies the first crash file to the artifact directory.""" - # Only consider testcases starting with 'crash-' - crash_files = [ - f for f in os.listdir(outdir) - if f.startswith('crash-') and os.path.isfile(os.path.join(outdir, f)) - ] - if len(crash_files) != 0: - crash_file = crash_files[0] - src = os.path.join(outdir, crash_file) - dst = os.path.join(artifact_dir, crash_file) - run_result.artifact_path = dst - shutil.copy2(src, dst) - logger.info('Copied crash file %s to %s', crash_file, artifact_dir) - - def build_and_run( - self, - generated_project: str, - target_path: str, - iteration: int, - language: str, - cloud_build_tags: Optional[list[str]] = None, - trial: int = 0, - ) -> tuple[BuildResult, Optional[RunResult]]: - """Builds and runs the fuzz target for fuzzing.""" - del cloud_build_tags - build_result = BuildResult() - - if not self._pre_build_check(target_path, build_result): - logger.warning('Pre-build check failure: %s', build_result) - return build_result, None - - try: - return self.build_and_run_local(generated_project, target_path, iteration, - build_result, language, trial) - except Exception as err: - logger.warning( - 'Error occurred when building and running fuzz target locally' - '(attempt %d) %s: %s', iteration, err, traceback.format_exc()) - raise err - - def build_and_run_local( - self, - generated_project: str, - target_path: str, - iteration: int, - build_result: BuildResult, - language: str, - trial: int = 0, - ) -> tuple[BuildResult, Optional[RunResult]]: - """Builds and runs the fuzz target locally for fuzzing.""" - project_name = self.benchmark.project - benchmark_target_name = os.path.basename(target_path) - project_target_name = os.path.basename(self.benchmark.target_path) - benchmark_log_path = self.work_dirs.build_logs_target( - benchmark_target_name, iteration, trial) - build_result.succeeded = self.build_target_local(generated_project, - benchmark_log_path) - if not build_result.succeeded: - errors = code_fixer.extract_error_message(benchmark_log_path, - project_target_name, language) - build_result.errors = errors - return build_result, None - - # TODO(Dongge): Split Builder and Runner: - # Make the rest lines in an independent function. - run_result = RunResult() - - run_log_path = os.path.join(self.work_dirs.run_logs, f'{trial:02d}.log') - self.run_target_local(generated_project, benchmark_target_name, - run_log_path) - artifact_dir = self.work_dirs.artifact(benchmark_target_name, iteration, - trial) - outdir = get_build_artifact_dir(generated_project, 'out') - self._copy_crash_file(outdir, artifact_dir, run_result) - - run_result.coverage, run_result.coverage_summary = (self.get_coverage_local( - generated_project, benchmark_target_name)) - - run_result.log_path = run_log_path - - # Parse libfuzzer logs to get fuzz target runtime details. - with open(run_log_path, 'rb') as f: - # In many case JVM/python projects won't have much cov - # difference in short running. Adding the flag for JVM/python - # projects to temporary skip the checking of coverage change. - # Also skipping for rust projects in initial implementation. - flag = not self.benchmark.language in ['jvm', 'python', 'rust'] - run_result.cov_pcs, run_result.total_pcs, \ - run_result.crashes, run_result.crash_info, \ - run_result.artifact_name, run_result.semantic_check = \ - self._parse_libfuzzer_logs(f, project_name, flag) - - return build_result, run_result - - def run_target_local(self, generated_project: str, benchmark_target_name: str, - log_path: str): - """Runs a target in the fixed target directory.""" - # If target name is not overridden, use the basename of the target path - # in the Dockerfile. - logger.info('Running %s', generated_project) - corpus_dir = self.work_dirs.corpus(benchmark_target_name) - command = [ - 'python3', 'infra/helper.py', 'run_fuzzer', '--corpus-dir', corpus_dir, - generated_project, self.benchmark.target_name, '--' - ] + self._libfuzzer_args() - - with open(log_path, 'w') as f: - proc = sp.Popen(command, - stdin=sp.DEVNULL, - stdout=f, - stderr=sp.STDOUT, - cwd=oss_fuzz_checkout.OSS_FUZZ_DIR) - - # TODO(ochang): Handle the timeout exception. - try: - proc.wait(timeout=self.run_timeout + 5) - except sp.TimeoutExpired: - logger.info('%s timed out during fuzzing.', generated_project) - # Try continuing and parsing the logs even in case of timeout. - - if proc.returncode != 0: - logger.info('********** Failed to run %s. **********', generated_project) - else: - logger.info('Successfully run %s.', generated_project) - - def build_target_local(self, - generated_project: str, - log_path: str, - sanitizer: str = 'address') -> bool: - """Builds a target with OSS-Fuzz.""" - - logger.info('Building %s with %s', generated_project, sanitizer) - - if oss_fuzz_checkout.ENABLE_CACHING and oss_fuzz_checkout.is_image_cached( - self.benchmark.project, sanitizer): - logger.info('We should use cached instance.') - # Rewrite for caching. - oss_fuzz_checkout.rewrite_project_to_cached_project( - self.benchmark.project, generated_project, sanitizer) - - # Prepare build - oss_fuzz_checkout.prepare_build(self.benchmark.project, sanitizer, - generated_project) - - else: - logger.info('The project does not have any cache') - - # Build the image - command = [ - 'docker', 'build', '-t', f'gcr.io/oss-fuzz/{generated_project}', - os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, 'projects', - generated_project) - ] - with open(log_path, 'w+') as log_file: - try: - sp.run(command, - cwd=oss_fuzz_checkout.OSS_FUZZ_DIR, - stdin=sp.DEVNULL, - stdout=log_file, - stderr=sp.STDOUT, - check=True) - except sp.CalledProcessError as e: - logger.info('Failed to build image for %s: %s', generated_project, e) - return False - outdir = get_build_artifact_dir(generated_project, 'out') - workdir = get_build_artifact_dir(generated_project, 'work') - command = [ - 'docker', - 'run', - '--rm', - '--privileged', - '--shm-size=2g', - '--platform', - 'linux/amd64', - '-i', - '-e', - 'FUZZING_ENGINE=libfuzzer', - '-e', - f'SANITIZER={sanitizer}', - '-e', - 'ARCHITECTURE=x86_64', - '-e', - f'PROJECT_NAME={generated_project}', - '-e', - f'FUZZING_LANGUAGE={self.benchmark.language}', - '-v', - f'{outdir}:/out', - '-v', - f'{workdir}:/work', - ] - # Avoid permissions errors. - os.makedirs(outdir, exist_ok=True) - os.makedirs(workdir, exist_ok=True) - command.extend(['--entrypoint', '/bin/bash']) - command.append(f'gcr.io/oss-fuzz/{generated_project}') - - pre_build_command = [] - post_build_command = [] - - # Cleanup mounted dirs. - pre_build_command.extend(['rm', '-rf', '/out/*', '/work/*', '&&']) - - if self.benchmark.commit: - # TODO(metzman): Try to use build_specified_commit here. - for repo, commit in self.benchmark.commit.items(): - pre_build_command.extend([ - 'git', '-C', repo, 'fetch', '--unshallow', '-f', '||', 'true', '&&' - ]) - pre_build_command.extend( - ['git', '-C', repo, 'checkout', commit, '-f', '&&']) - - post_build_command.extend(['&&', 'chmod', '777', '-R', '/out/*']) - - build_command = pre_build_command + ['compile'] + post_build_command - build_bash_command = ['-c', ' '.join(build_command)] - command.extend(build_bash_command) - with open(log_path, 'w+') as log_file: - try: - sp.run(command, - cwd=oss_fuzz_checkout.OSS_FUZZ_DIR, - stdin=sp.DEVNULL, - stdout=log_file, - stderr=sp.STDOUT, - check=True) - except sp.CalledProcessError: - logger.info('Failed to build fuzzer for %s with %s', generated_project, - sanitizer) +class CloudBuilderRunner(BuilderRunner): + """Cloud BuilderRunner.""" + + def __init__(self, *args, experiment_name: str, experiment_bucket: str, **kwargs): + self.experiment_name = experiment_name + self.experiment_bucket = experiment_bucket + super().__init__(*args, **kwargs) + + @staticmethod + def _run_with_retry_control(target_path: str, *args, **kwargs) -> bool: + """sp.run() with controllable retry and customized exponential backoff.""" + # List of (error_str, exp_backoff_func). + retryable_errors = [ + # As mentioned in pr #100. + ("RESOURCE_EXHAUSTED", lambda x: 5 * 2**x + random.randint(50, 90)), + # As mentioned in pr #151. + ( + "BrokenPipeError: [Errno 32] Broken pipe", + lambda x: 5 * 2**x + random.randint(1, 5), + ), + # Service Unavailable. + ("Service Unavailable", lambda x: 5 * 2**x + random.randint(1, 5)), + # Temp workaround for issue #12. + ( + "You do not currently have an active account selected", + lambda x: 5 * 2**x, + ), + # Workaround for issue #85. + ("gcloud crashed (OSError): unexpected end of data", lambda x: 5 * 2**x), + ] + + for attempt_id in range(1, CLOUD_EXP_MAX_ATTEMPT + 1): + try: + sp.run(*args, check=True, **kwargs) + return True + except sp.CalledProcessError as e: + # Replace \n for single log entry on cloud. + stdout = e.stdout.decode("utf-8").replace("\n", "\t") + stderr = e.stderr.decode("utf-8").replace("\n", "\t") + + delay = next( + ( + delay_f(attempt_id) + for err, delay_f in retryable_errors + if err in stdout + stderr + ), + 0, + ) + + if not delay or attempt_id == CLOUD_EXP_MAX_ATTEMPT: + logger.error( + "Failed to evaluate %s on cloud, attempt %d:\n%s\n%s", + os.path.realpath(target_path), + attempt_id, + stdout, + stderr, + ) + break + + logger.warning( + "Failed to evaluate %s on cloud, attempt %d, retry in %ds:\n" + "%s\n%s", + os.path.realpath(target_path), + attempt_id, + delay, + stdout, + stderr, + ) + time.sleep(delay) + logger.info("Evaluate %s on cloud.", os.path.realpath(target_path)) + return False - logger.info('Successfully build fuzzer for %s with %s', generated_project, - sanitizer) - return True - - def _get_coverage_text_filename(self, project_name: str) -> str: - """Get the filename of the text coverage file. This is language - dependent.""" - lang_to_textcov_basename = { - 'jvm': 'jacoco.xml', - 'python': 'all_cov.json', - 'c++': f'{self.benchmark.target_name}.covreport', - 'c': f'{self.benchmark.target_name}.covreport', - 'rust': f'{self.benchmark.target_name}.covreport', - } - - return os.path.join(get_build_artifact_dir(project_name, - 'out'), 'textcov_reports', - lang_to_textcov_basename[self.benchmark.language]) - - def _extract_local_textcoverage_data(self, - project_name: str) -> textcov.Textcov: - """Returns the textcoverage from a local coverage run.""" - local_textcov_location = self._get_coverage_text_filename(project_name) - language_modes = { - 'jvm': 'r', - 'python': 'r', - 'c': 'rb', - 'c++': 'rb', - 'rust': 'rb', - } - with open(local_textcov_location, - language_modes.get(self.benchmark.language, 'rb')) as f: - if self.benchmark.language == 'jvm': - new_textcov = textcov.Textcov.from_jvm_file(f) - elif self.benchmark.language == 'python': - new_textcov = textcov.Textcov.from_python_file(f) - elif self.benchmark.language == 'rust': - new_textcov = textcov.Textcov.from_rust_file(f) - else: - target_basename = os.path.basename(self.benchmark.target_path) - new_textcov = textcov.Textcov.from_file( - f, - ignore_function_patterns=[ - # Don't include other functions defined in the target code. - re.compile(r'^' + re.escape(target_basename) + ':') - ]) - return new_textcov - - def get_coverage_local( - self, generated_project: str, - benchmark_target_name: str) -> tuple[Optional[textcov.Textcov], Any]: - """Builds the generate project with coverage sanitizer, runs OSS-Fuzz - coverage extraction and then returns the generated coverage reports, in - the form of the text coverage as well as the summary.json.""" - sample_id = os.path.splitext(benchmark_target_name)[0] - log_path = os.path.join(self.work_dirs.build_logs, - f'{sample_id}-coverage.log') - logger.info('Building project for coverage') - built_coverage = self.build_target_local(generated_project, - log_path, - sanitizer='coverage') - if not built_coverage: - logger.info('Failed to make coverage build for %s', generated_project) - return None, None - - logger.info('Extracting coverage') - corpus_dir = self.work_dirs.corpus(benchmark_target_name) - command = [ - 'python3', - 'infra/helper.py', - 'coverage', - '--corpus-dir', - corpus_dir, - '--fuzz-target', - self.benchmark.target_name, - '--no-serve', - '--port', - '', - generated_project, - ] - - try: - sp.run(command, - capture_output=True, - cwd=oss_fuzz_checkout.OSS_FUZZ_DIR, - stdin=sp.DEVNULL, - check=True) - except sp.CalledProcessError as e: - logger.info('Failed to generate coverage for %s:\n%s\n%s', - generated_project, e.stdout, e.stderr) - return None, None - - # Get the local text coverage, which includes the specific lines - # exercised in the target project. - local_textcov = self._extract_local_textcoverage_data(generated_project) - - # Copy the code coverage to a folder in the results directory so - # the coverage can be displayed in the result HTML page. - coverage_report = os.path.join( - get_build_artifact_dir(generated_project, 'out'), 'report') - destination_coverage = self.work_dirs.code_coverage_report( - benchmark_target_name) - shutil.copytree(coverage_report, destination_coverage, dirs_exist_ok=True) - - textcov_dir = os.path.join(get_build_artifact_dir(generated_project, 'out'), - 'textcov_reports') - dst_textcov = os.path.join( - self.work_dirs.code_coverage_report(benchmark_target_name), 'textcov') - shutil.copytree(textcov_dir, dst_textcov, dirs_exist_ok=True) - - coverage_summary = os.path.join( - get_build_artifact_dir(generated_project, 'out'), 'report', 'linux', - 'summary.json') - with open(coverage_summary) as f: - coverage_summary = json.load(f) - - return local_textcov, coverage_summary + def build_and_run( + self, + generated_project: str, + target_path: str, + iteration: int, + language: str, + cloud_build_tags: Optional[list[str]] = None, + trial: int = 0, + ) -> tuple[BuildResult, Optional[RunResult]]: + """Builds and runs the fuzz target for fuzzing.""" + build_result = BuildResult() + if not self._pre_build_check(target_path, build_result): + logger.warning("Pre-build check failure: %s", build_result) + return build_result, None + + try: + return self.build_and_run_cloud( + generated_project, + target_path, + iteration, + build_result, + language, + cloud_build_tags, + trial, + ) + except Exception as err: + logger.warning( + "Error occurred when building and running fuzz target on cloud" + "(attempt %d) %s: %s", + iteration, + err, + traceback.format_exc(), + ) + traceback.print_exc() + raise err + + def build_and_run_cloud( + self, + generated_project: str, + target_path: str, + iteration: int, + build_result: BuildResult, + language: str, + cloud_build_tags: Optional[list[str]] = None, + trial: int = 0, + ) -> tuple[BuildResult, Optional[RunResult]]: + """Builds and runs the fuzz target locally for fuzzing.""" + logger.info("Evaluating %s on cloud.", os.path.realpath(target_path)) + + project_name = self.benchmark.project + + uid = self.experiment_name + str(uuid.uuid4()) + run_log_name = f"{uid}.run.log" + run_log_path = f"gs://{self.experiment_bucket}/{run_log_name}" + + build_log_name = f"{uid}.build.log" + build_log_path = f"gs://{self.experiment_bucket}/{build_log_name}" + + corpus_name = f"{uid}.corpus.zip" + corpus_path = f"gs://{self.experiment_bucket}/{corpus_name}" + + coverage_name = f"{uid}.coverage" + coverage_path = f"gs://{self.experiment_bucket}/{coverage_name}" + + reproducer_name = f"{uid}.reproducer" + reproducer_path = f"gs://{self.experiment_bucket}/{reproducer_name}" + + logger.info( + "Servie account key: %s", os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") + ) + command = [ + f"./{oss_fuzz_checkout.VENV_DIR}/bin/python3", + "infra/build/functions/target_experiment.py", + f"--project={generated_project}", + f"--target={self.benchmark.target_name}", + f"--upload_build_log={build_log_path}", + f"--upload_output_log={run_log_path}", + f"--upload_coverage={coverage_path}", + f"--upload_reproducer={reproducer_path}", + f"--upload_corpus={corpus_path}", + f"--experiment_name={self.experiment_name}", + f"--real_project={project_name}", + ] + + # TODO(dongge): Reenable caching when build script is not modified. + # Current caching is not applicable when OFG modifies the build script, + # There is no simple way to check if the build script has been modified, + # but this feature should be added later. + # and fails to build the project (particularly with coverage sanitizer). + # if oss_fuzz_checkout.ENABLE_CACHING and ( + # oss_fuzz_checkout.is_image_cached(project_name, 'address') and + # oss_fuzz_checkout.is_image_cached(project_name, 'coverage')): + # logger.info('Using cached image for %s', project_name) + # command.append('--use_cached_image') + + # # Overwrite the Dockerfile to be caching friendly + # # We hardcode 'address' here, but this is irrelevant and will be + # # overridden later via a Docker argument. + # oss_fuzz_checkout.rewrite_project_to_cached_project( + # project_name, generated_project, 'address') + # oss_fuzz_checkout.prepare_build(project_name, 'address', + # generated_project) + + if cloud_build_tags: + command += ["--tags"] + cloud_build_tags + command += ["--"] + self._libfuzzer_args() + + logger.info("Command: %s", command) + + if not self._run_with_retry_control( + os.path.realpath(target_path), command, cwd=oss_fuzz_checkout.OSS_FUZZ_DIR + ): + return build_result, None + + logger.info("Evaluated %s on cloud.", os.path.realpath(target_path)) + + storage_client = storage.Client() + bucket = storage_client.bucket(self.experiment_bucket) + + build_result.log_path = build_log_path + + generated_target_name = os.path.basename(target_path) + with open( + self.work_dirs.build_logs_target(generated_target_name, iteration, trial), + "wb", + ) as f: + blob = bucket.blob(build_log_name) + if blob.exists(): + logger.info( + "Downloading cloud build log of %s: %s to %s", + os.path.realpath(target_path), + build_log_name, + f, + ) + blob.download_to_file(f) + else: + logger.warning( + "Cannot find cloud build log of %s: %s", + os.path.realpath(target_path), + build_log_name, + ) + + # TODO(Dongge): Split Builder and Runner: + # Set build_result.succeeded based on existence of fuzz target binary. + # Separate the rest lines into an independent function. + run_log_path = os.path.join(self.work_dirs.run_logs, f"{trial:02d}.log") + with open(run_log_path, "wb") as f: + blob = bucket.blob(run_log_name) + if blob.exists(): + build_result.succeeded = True + logger.info( + "Downloading cloud run log of %s: %s to %s", + os.path.realpath(target_path), + run_log_name, + f, + ) + blob.download_to_file(f) + else: + logger.warning( + "Cannot find cloud run log of %s: %s", + os.path.realpath(target_path), + run_log_name, + ) + + if not build_result.succeeded: + errors = code_fixer.extract_error_message( + self.work_dirs.build_logs_target( + generated_target_name, iteration, trial + ), + os.path.basename(self.benchmark.target_path), + language, + ) + build_result.errors = errors + logger.info( + "Cloud evaluation of %s indicates a failure: %s", + os.path.realpath(target_path), + errors, + ) + return build_result, None + logger.info( + "Cloud evaluation of %s indicates a success.", os.path.realpath(target_path) + ) + + corpus_dir = self.work_dirs.corpus(generated_target_name) + with open(os.path.join(corpus_dir, "corpus.zip"), "wb") as f: + blob = bucket.blob(corpus_name) + if blob.exists(): + blob.download_to_file(f) + + run_result = RunResult( + corpus_path=corpus_path, + coverage_report_path=coverage_path, + reproducer_path=reproducer_path, + log_path=run_log_path, + ) + + blob = bucket.blob(f"{coverage_name}/report/linux/summary.json") + if blob.exists(): + # Download summary.json to our workdir. + cov_summary_folder = os.path.join( + self.work_dirs.code_coverage_report(generated_target_name), + "report/linux/", + ) + os.makedirs(cov_summary_folder, exist_ok=True) + coverage_summary_file = os.path.join(cov_summary_folder, "summary.json") + with open(coverage_summary_file, "wb") as f: + blob.download_to_file(f) + + # Load the coverage summary + with open(coverage_summary_file, "r") as f: + run_result.coverage_summary = json.load(f) + target_basename = os.path.basename(self.benchmark.target_path) -class CloudBuilderRunner(BuilderRunner): - """Cloud BuilderRunner.""" - - def __init__(self, *args, experiment_name: str, experiment_bucket: str, - **kwargs): - self.experiment_name = experiment_name - self.experiment_bucket = experiment_bucket - super().__init__(*args, **kwargs) - - @staticmethod - def _run_with_retry_control(target_path: str, *args, **kwargs) -> bool: - """sp.run() with controllable retry and customized exponential backoff.""" - # List of (error_str, exp_backoff_func). - retryable_errors = [ - # As mentioned in pr #100. - ('RESOURCE_EXHAUSTED', lambda x: 5 * 2**x + random.randint(50, 90)), - # As mentioned in pr #151. - ('BrokenPipeError: [Errno 32] Broken pipe', - lambda x: 5 * 2**x + random.randint(1, 5)), - # Service Unavailable. - ('Service Unavailable', lambda x: 5 * 2**x + random.randint(1, 5)), - # Temp workaround for issue #12. - ('You do not currently have an active account selected', - lambda x: 5 * 2**x), - # Workaround for issue #85. - ('gcloud crashed (OSError): unexpected end of data', lambda x: 5 * 2**x - ), - ] - - for attempt_id in range(1, CLOUD_EXP_MAX_ATTEMPT + 1): - try: - sp.run(*args, check=True, **kwargs) - return True - except sp.CalledProcessError as e: - # Replace \n for single log entry on cloud. - stdout = e.stdout.decode('utf-8').replace('\n', '\t') - stderr = e.stderr.decode('utf-8').replace('\n', '\t') - - delay = next((delay_f(attempt_id) - for err, delay_f in retryable_errors - if err in stdout + stderr), 0) - - if not delay or attempt_id == CLOUD_EXP_MAX_ATTEMPT: - logger.error('Failed to evaluate %s on cloud, attempt %d:\n%s\n%s', - os.path.realpath(target_path), attempt_id, stdout, - stderr) - break - - logger.warning( - 'Failed to evaluate %s on cloud, attempt %d, retry in %ds:\n' - '%s\n%s', os.path.realpath(target_path), attempt_id, delay, stdout, - stderr) - time.sleep(delay) - logger.info('Evaluate %s on cloud.', os.path.realpath(target_path)) - - return False - - def build_and_run( - self, - generated_project: str, - target_path: str, - iteration: int, - language: str, - cloud_build_tags: Optional[list[str]] = None, - trial: int = 0, - ) -> tuple[BuildResult, Optional[RunResult]]: - """Builds and runs the fuzz target for fuzzing.""" - build_result = BuildResult() - if not self._pre_build_check(target_path, build_result): - logger.warning('Pre-build check failure: %s', build_result) - return build_result, None - - try: - return self.build_and_run_cloud(generated_project, target_path, iteration, - build_result, language, cloud_build_tags, - trial) - except Exception as err: - logger.warning( - 'Error occurred when building and running fuzz target on cloud' - '(attempt %d) %s: %s', iteration, err, traceback.format_exc()) - traceback.print_exc() - raise err - - def build_and_run_cloud( - self, - generated_project: str, - target_path: str, - iteration: int, - build_result: BuildResult, - language: str, - cloud_build_tags: Optional[list[str]] = None, - trial: int = 0, - ) -> tuple[BuildResult, Optional[RunResult]]: - """Builds and runs the fuzz target locally for fuzzing.""" - logger.info('Evaluating %s on cloud.', os.path.realpath(target_path)) - - project_name = self.benchmark.project - - uid = self.experiment_name + str(uuid.uuid4()) - run_log_name = f'{uid}.run.log' - run_log_path = f'gs://{self.experiment_bucket}/{run_log_name}' - - build_log_name = f'{uid}.build.log' - build_log_path = f'gs://{self.experiment_bucket}/{build_log_name}' - - corpus_name = f'{uid}.corpus.zip' - corpus_path = f'gs://{self.experiment_bucket}/{corpus_name}' - - coverage_name = f'{uid}.coverage' - coverage_path = f'gs://{self.experiment_bucket}/{coverage_name}' - - reproducer_name = f'{uid}.reproducer' - reproducer_path = f'gs://{self.experiment_bucket}/{reproducer_name}' - - logger.info('Servie account key: %s', - os.environ.get('GOOGLE_APPLICATION_CREDENTIALS')) - command = [ - f'./{oss_fuzz_checkout.VENV_DIR}/bin/python3', - 'infra/build/functions/target_experiment.py', - f'--project={generated_project}', - f'--target={self.benchmark.target_name}', - f'--upload_build_log={build_log_path}', - f'--upload_output_log={run_log_path}', - f'--upload_coverage={coverage_path}', - f'--upload_reproducer={reproducer_path}', - f'--upload_corpus={corpus_path}', - f'--experiment_name={self.experiment_name}', - f'--real_project={project_name}', - ] - - # TODO(dongge): Reenable caching when build script is not modified. - # Current caching is not applicable when OFG modifies the build script, - # There is no simple way to check if the build script has been modified, - # but this feature should be added later. - # and fails to build the project (particularly with coverage sanitizer). - # if oss_fuzz_checkout.ENABLE_CACHING and ( - # oss_fuzz_checkout.is_image_cached(project_name, 'address') and - # oss_fuzz_checkout.is_image_cached(project_name, 'coverage')): - # logger.info('Using cached image for %s', project_name) - # command.append('--use_cached_image') - - # # Overwrite the Dockerfile to be caching friendly - # # We hardcode 'address' here, but this is irrelevant and will be - # # overridden later via a Docker argument. - # oss_fuzz_checkout.rewrite_project_to_cached_project( - # project_name, generated_project, 'address') - # oss_fuzz_checkout.prepare_build(project_name, 'address', - # generated_project) - - if cloud_build_tags: - command += ['--tags'] + cloud_build_tags - command += ['--'] + self._libfuzzer_args() - - logger.info('Command: %s', command) - - if not self._run_with_retry_control(os.path.realpath(target_path), - command, - cwd=oss_fuzz_checkout.OSS_FUZZ_DIR): - return build_result, None - - logger.info('Evaluated %s on cloud.', os.path.realpath(target_path)) - - storage_client = storage.Client() - bucket = storage_client.bucket(self.experiment_bucket) - - build_result.log_path = build_log_path - - generated_target_name = os.path.basename(target_path) - with open( - self.work_dirs.build_logs_target(generated_target_name, iteration, - trial), 'wb') as f: - blob = bucket.blob(build_log_name) - if blob.exists(): - logger.info('Downloading cloud build log of %s: %s to %s', - os.path.realpath(target_path), build_log_name, f) - blob.download_to_file(f) - else: - logger.warning('Cannot find cloud build log of %s: %s', - os.path.realpath(target_path), build_log_name) - - # TODO(Dongge): Split Builder and Runner: - # Set build_result.succeeded based on existence of fuzz target binary. - # Separate the rest lines into an independent function. - run_log_path = os.path.join(self.work_dirs.run_logs, f'{trial:02d}.log') - with open(run_log_path, 'wb') as f: - blob = bucket.blob(run_log_name) - if blob.exists(): - build_result.succeeded = True - logger.info('Downloading cloud run log of %s: %s to %s', - os.path.realpath(target_path), run_log_name, f) - blob.download_to_file(f) - else: - logger.warning('Cannot find cloud run log of %s: %s', - os.path.realpath(target_path), run_log_name) - - if not build_result.succeeded: - errors = code_fixer.extract_error_message( - self.work_dirs.build_logs_target(generated_target_name, iteration, - trial), - os.path.basename(self.benchmark.target_path), language) - build_result.errors = errors - logger.info('Cloud evaluation of %s indicates a failure: %s', - os.path.realpath(target_path), errors) - return build_result, None - logger.info('Cloud evaluation of %s indicates a success.', - os.path.realpath(target_path)) - - corpus_dir = self.work_dirs.corpus(generated_target_name) - with open(os.path.join(corpus_dir, 'corpus.zip'), 'wb') as f: - blob = bucket.blob(corpus_name) - if blob.exists(): - blob.download_to_file(f) - - run_result = RunResult(corpus_path=corpus_path, - coverage_report_path=coverage_path, - reproducer_path=reproducer_path, - log_path=run_log_path) - - blob = bucket.blob(f'{coverage_name}/report/linux/summary.json') - if blob.exists(): - # Download summary.json to our workdir. - cov_summary_folder = os.path.join( - self.work_dirs.code_coverage_report(generated_target_name), - 'report/linux/') - os.makedirs(cov_summary_folder, exist_ok=True) - coverage_summary_file = os.path.join(cov_summary_folder, 'summary.json') - with open(coverage_summary_file, 'wb') as f: - blob.download_to_file(f) - - # Load the coverage summary - with open(coverage_summary_file, 'r') as f: - run_result.coverage_summary = json.load(f) - - target_basename = os.path.basename(self.benchmark.target_path) - - # Load coverage reports. - textcov_blob_path = self._get_cloud_textcov_path(coverage_name) - if self.benchmark.language == 'jvm': - blob = bucket.blob(textcov_blob_path) - if blob.exists(): - with blob.open() as f: - run_result.coverage = textcov.Textcov.from_jvm_file(f) - self._copy_textcov_to_workdir(bucket, textcov_blob_path, - generated_target_name) - elif self.benchmark.language == 'python': - blob = bucket.blob(textcov_blob_path) - if blob.exists(): - with blob.open() as f: - run_result.coverage = textcov.Textcov.from_python_file(f) - self._copy_textcov_to_workdir(bucket, textcov_blob_path, - generated_target_name) - elif self.benchmark.language == 'rust': - blob = bucket.blob(textcov_blob_path) - if blob.exists(): - with blob.open() as f: - run_result.coverage = textcov.Textcov.from_rust_file(f) - self._copy_textcov_to_workdir(bucket, textcov_blob_path, - generated_target_name) - else: - # C/C++ - blob = bucket.blob(textcov_blob_path) - if blob.exists(): - with blob.open('rb') as f: - run_result.coverage = textcov.Textcov.from_file( - f, - ignore_function_patterns=[ - # Don't include other functions defined in the target code. - re.compile(r'^' + re.escape(target_basename) + ':') - ]) - self._copy_textcov_to_workdir(bucket, textcov_blob_path, - generated_target_name) - - # Parse libfuzzer logs to get fuzz target runtime details. - with open(run_log_path, 'rb') as f: - run_result.cov_pcs, run_result.total_pcs, \ - run_result.crashes, run_result.crash_info, \ - run_result.artifact_name, run_result.semantic_check = \ - self._parse_libfuzzer_logs(f, project_name) - - artifact_dir = self.work_dirs.artifact(generated_target_name, iteration, - trial) - blobs = list(bucket.list_blobs(prefix=f'{reproducer_name}/artifacts/')) - if blobs: - blob = blobs[0] - artifact_path = os.path.join(artifact_dir, os.path.basename(blob.name)) - # TOOD: Some try-catch here. - blob.download_to_filename(artifact_path) - run_result.artifact_path = artifact_path - else: - logger.warning('Cloud evaluation of %s failed to downlod artifact:%s', - os.path.realpath(target_path), - f'{reproducer_name}/artifacts/') - - return build_result, run_result - - def _copy_textcov_to_workdir(self, bucket, textcov_blob_path: str, - generated_target_name: str) -> None: - """Stores a given textcov blob into the workdir.""" - blob = bucket.blob(textcov_blob_path) - textcov_dir = os.path.join( - self.work_dirs.code_coverage_report(generated_target_name), 'textcov') - os.makedirs(textcov_dir, exist_ok=True) - textcov_dst = os.path.join(textcov_dir, os.path.basename(textcov_blob_path)) - with open(textcov_dst, 'wb') as f: - blob.download_to_file(f) - - def _get_cloud_textcov_path(self, coverage_name: str) -> str: - """Extracts textcov blob path for this benchmark.""" - if self.benchmark.language == 'jvm': - return f'{coverage_name}/textcov_reports/jacoco.xml' - if self.benchmark.language == 'python': - return f'{coverage_name}/textcov_reports/all_cov.json' - - # For C/C++/Rust - return (f'{coverage_name}/textcov_reports/{self.benchmark.target_name}' - '.covreport') + # Load coverage reports. + textcov_blob_path = self._get_cloud_textcov_path(coverage_name) + if self.benchmark.language == "jvm": + blob = bucket.blob(textcov_blob_path) + if blob.exists(): + with blob.open() as f: + run_result.coverage = textcov.Textcov.from_jvm_file(f) + self._copy_textcov_to_workdir( + bucket, textcov_blob_path, generated_target_name + ) + elif self.benchmark.language == "python": + blob = bucket.blob(textcov_blob_path) + if blob.exists(): + with blob.open() as f: + run_result.coverage = textcov.Textcov.from_python_file(f) + self._copy_textcov_to_workdir( + bucket, textcov_blob_path, generated_target_name + ) + elif self.benchmark.language == "rust": + blob = bucket.blob(textcov_blob_path) + if blob.exists(): + with blob.open() as f: + run_result.coverage = textcov.Textcov.from_rust_file(f) + self._copy_textcov_to_workdir( + bucket, textcov_blob_path, generated_target_name + ) + else: + # C/C++ + blob = bucket.blob(textcov_blob_path) + if blob.exists(): + with blob.open("rb") as f: + run_result.coverage = textcov.Textcov.from_file( + f, + ignore_function_patterns=[ + # Don't include other functions defined in the target code. + re.compile(r"^" + re.escape(target_basename) + ":") + ], + ) + self._copy_textcov_to_workdir( + bucket, textcov_blob_path, generated_target_name + ) + + # Parse libfuzzer logs to get fuzz target runtime details. + with open(run_log_path, "rb") as f: + ( + run_result.cov_pcs, + run_result.total_pcs, + run_result.crashes, + run_result.crash_info, + run_result.artifact_name, + run_result.semantic_check, + ) = self._parse_libfuzzer_logs(f, project_name) + + artifact_dir = self.work_dirs.artifact(generated_target_name, iteration, trial) + blobs = list(bucket.list_blobs(prefix=f"{reproducer_name}/artifacts/")) + if blobs: + blob = blobs[0] + artifact_path = os.path.join(artifact_dir, os.path.basename(blob.name)) + # TOOD: Some try-catch here. + blob.download_to_filename(artifact_path) + run_result.artifact_path = artifact_path + else: + logger.warning( + "Cloud evaluation of %s failed to downlod artifact:%s", + os.path.realpath(target_path), + f"{reproducer_name}/artifacts/", + ) + + return build_result, run_result + + def _copy_textcov_to_workdir( + self, bucket, textcov_blob_path: str, generated_target_name: str + ) -> None: + """Stores a given textcov blob into the workdir.""" + blob = bucket.blob(textcov_blob_path) + textcov_dir = os.path.join( + self.work_dirs.code_coverage_report(generated_target_name), "textcov" + ) + os.makedirs(textcov_dir, exist_ok=True) + textcov_dst = os.path.join(textcov_dir, os.path.basename(textcov_blob_path)) + with open(textcov_dst, "wb") as f: + blob.download_to_file(f) + + def _get_cloud_textcov_path(self, coverage_name: str) -> str: + """Extracts textcov blob path for this benchmark.""" + if self.benchmark.language == "jvm": + return f"{coverage_name}/textcov_reports/jacoco.xml" + if self.benchmark.language == "python": + return f"{coverage_name}/textcov_reports/all_cov.json" + + # For C/C++/Rust + return ( + f"{coverage_name}/textcov_reports/{self.benchmark.target_name}" ".covreport" + ) def get_build_artifact_dir(generated_project: str, build_artifact: str) -> str: - """ - Returns the |build_artifact| absolute directory path for |generated_project|. - """ - return os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, 'build', build_artifact, - generated_project) + """ + Returns the |build_artifact| absolute directory path for |generated_project|. + """ + return os.path.join( + oss_fuzz_checkout.OSS_FUZZ_DIR, "build", build_artifact, generated_project + ) diff --git a/experiment/evaluator.py b/experiment/evaluator.py index 6990279dfb..d6aa598697 100644 --- a/experiment/evaluator.py +++ b/experiment/evaluator.py @@ -34,604 +34,678 @@ logger = logging.getLogger(__name__) -LLM_FIX_LIMIT = int(os.getenv('LLM_FIX_LIMIT', '5')) -GENERATE_CORPUS = bool(os.getenv('LLM_GENERATE_CORPUS', '')) +LLM_FIX_LIMIT = int(os.getenv("LLM_FIX_LIMIT", "5")) +GENERATE_CORPUS = bool(os.getenv("LLM_GENERATE_CORPUS", "")) -OSS_FUZZ_COVERAGE_BUCKET = 'oss-fuzz-coverage' -OSS_FUZZ_INTROSPECTOR_BUCKET = 'oss-fuzz-introspector' +OSS_FUZZ_COVERAGE_BUCKET = "oss-fuzz-coverage" +OSS_FUZZ_INTROSPECTOR_BUCKET = "oss-fuzz-introspector" @dataclasses.dataclass class Result: - """Evaluation result.""" - finished: bool = False - compiles: bool = False - crashes: bool = False - coverage: float = 0.0 - line_coverage_diff: float = 0.0 - coverage_report_path: str = '' - reproducer_path: str = '' - # Grammatically correct but has false positive or no cov increase at all. - is_semantic_error: bool = False - semantic_error: str = '' - triage: str = '' - textcov_diff: textcov.Textcov = dataclasses.field( - default_factory=textcov.Textcov) - # Deprecated renamed fields. Keeping them for backward compatibility. - # TODO https://github.com/google/oss-fuzz-gen/issues/215 - is_driver_fuzz_err: bool = dataclasses.field(kw_only=True, default=False) - driver_fuzz_err: str = dataclasses.field(kw_only=True, default='') - compile_error: str = '' - compile_log: str = '' - - def __post_init__(self, *args, **kwargs): # pylint: disable=unused-argument - if self.is_driver_fuzz_err: - self.is_semantic_error = self.is_driver_fuzz_err - if self.driver_fuzz_err: - self.semantic_error = self.driver_fuzz_err - - def to_dict(self): - return dataclasses.asdict(self) + """Evaluation result.""" + + finished: bool = False + compiles: bool = False + crashes: bool = False + coverage: float = 0.0 + line_coverage_diff: float = 0.0 + coverage_report_path: str = "" + reproducer_path: str = "" + # Grammatically correct but has false positive or no cov increase at all. + is_semantic_error: bool = False + semantic_error: str = "" + triage: str = "" + textcov_diff: textcov.Textcov = dataclasses.field(default_factory=textcov.Textcov) + # Deprecated renamed fields. Keeping them for backward compatibility. + # TODO https://github.com/google/oss-fuzz-gen/issues/215 + is_driver_fuzz_err: bool = dataclasses.field(kw_only=True, default=False) + driver_fuzz_err: str = dataclasses.field(kw_only=True, default="") + compile_error: str = "" + compile_log: str = "" + + def __post_init__(self, *args, **kwargs): # pylint: disable=unused-argument + if self.is_driver_fuzz_err: + self.is_semantic_error = self.is_driver_fuzz_err + if self.driver_fuzz_err: + self.semantic_error = self.driver_fuzz_err + + def to_dict(self): + return dataclasses.asdict(self) def load_existing_textcov(project: str) -> textcov.Textcov: - """Loads existing textcovs.""" - storage_client = storage.Client.create_anonymous_client() - bucket = storage_client.bucket(OSS_FUZZ_COVERAGE_BUCKET) - blobs = storage_client.list_blobs(bucket, - prefix=f'{project}/textcov_reports/', - delimiter='/') - # Iterate through all blobs first to get the prefixes (i.e. "subdirectories"). - for blob in blobs: - continue - - if not blobs.prefixes: # type: ignore - # No existing coverage reports. - logger.info('No existing coverage report. Using empty.') - return textcov.Textcov() - - # Find the latest generated textcov date. - latest_dir = sorted(blobs.prefixes)[-1] # type: ignore - blobs = storage_client.list_blobs(bucket, prefix=latest_dir) - - # Download and merge them. - existing_textcov = textcov.Textcov() - for blob in blobs: - if not blob.name.endswith('.covreport'): - continue - - logger.info('Loading existing textcov from %s', blob.name) - with blob.open('rb') as f: - existing_textcov.merge(textcov.Textcov.from_file(f)) - - return existing_textcov + """Loads existing textcovs.""" + storage_client = storage.Client.create_anonymous_client() + bucket = storage_client.bucket(OSS_FUZZ_COVERAGE_BUCKET) + blobs = storage_client.list_blobs( + bucket, prefix=f"{project}/textcov_reports/", delimiter="/" + ) + # Iterate through all blobs first to get the prefixes (i.e. "subdirectories"). + for blob in blobs: + continue + + if not blobs.prefixes: # type: ignore + # No existing coverage reports. + logger.info("No existing coverage report. Using empty.") + return textcov.Textcov() + + # Find the latest generated textcov date. + latest_dir = sorted(blobs.prefixes)[-1] # type: ignore + blobs = storage_client.list_blobs(bucket, prefix=latest_dir) + + # Download and merge them. + existing_textcov = textcov.Textcov() + for blob in blobs: + if not blob.name.endswith(".covreport"): + continue + + logger.info("Loading existing textcov from %s", blob.name) + with blob.open("rb") as f: + existing_textcov.merge(textcov.Textcov.from_file(f)) + + return existing_textcov def load_existing_jvm_textcov(project: str) -> textcov.Textcov: - """Loads existing textcovs for JVM project.""" - storage_client = storage.Client.create_anonymous_client() - bucket = storage_client.bucket(OSS_FUZZ_COVERAGE_BUCKET) - blobs = storage_client.list_blobs(bucket, - prefix=f'{project}/reports/', - delimiter='/') - # Iterate through all blobs first to get the prefixes (i.e. "subdirectories"). - for blob in blobs: - continue - - if not blobs.prefixes: # type: ignore - # No existing coverage reports. - logger.info('No existing coverage report. Using empty.') - return textcov.Textcov() - - latest_dir = sorted(blobs.prefixes)[-1] # type: ignore - blob = bucket.blob(f'{latest_dir}linux/jacoco.xml') - logger.info('Loading existing jacoco.xml textcov from %s', blob.name) - with blob.open() as f: - return textcov.Textcov.from_jvm_file(f) + """Loads existing textcovs for JVM project.""" + storage_client = storage.Client.create_anonymous_client() + bucket = storage_client.bucket(OSS_FUZZ_COVERAGE_BUCKET) + blobs = storage_client.list_blobs( + bucket, prefix=f"{project}/reports/", delimiter="/" + ) + # Iterate through all blobs first to get the prefixes (i.e. "subdirectories"). + for blob in blobs: + continue + + if not blobs.prefixes: # type: ignore + # No existing coverage reports. + logger.info("No existing coverage report. Using empty.") + return textcov.Textcov() + + latest_dir = sorted(blobs.prefixes)[-1] # type: ignore + blob = bucket.blob(f"{latest_dir}linux/jacoco.xml") + logger.info("Loading existing jacoco.xml textcov from %s", blob.name) + with blob.open() as f: + return textcov.Textcov.from_jvm_file(f) def load_existing_python_textcov(project: str) -> textcov.Textcov: - """Loads existing textcovs for python project.""" - storage_client = storage.Client.create_anonymous_client() - bucket = storage_client.bucket(OSS_FUZZ_INTROSPECTOR_BUCKET) - blobs = storage_client.list_blobs(bucket, - prefix=f'{project}/inspector-report/', - delimiter='/') - # Iterate through all blobs first to get the prefixes (i.e. "subdirectories"). - for blob in blobs: - continue - - if not blobs.prefixes: # type: ignore - # No existing coverage reports. - logger.info('No existing coverage report. Using empty.') - return textcov.Textcov() - - latest_dir = sorted(blobs.prefixes)[-1] # type: ignore - blob = bucket.blob(f'{latest_dir}all_cov.json') - logger.info('Loading existing all_cov.json textcov from %s', blob.name) - with blob.open() as f: - return textcov.Textcov.from_python_file(f) + """Loads existing textcovs for python project.""" + storage_client = storage.Client.create_anonymous_client() + bucket = storage_client.bucket(OSS_FUZZ_INTROSPECTOR_BUCKET) + blobs = storage_client.list_blobs( + bucket, prefix=f"{project}/inspector-report/", delimiter="/" + ) + # Iterate through all blobs first to get the prefixes (i.e. "subdirectories"). + for blob in blobs: + continue + if not blobs.prefixes: # type: ignore + # No existing coverage reports. + logger.info("No existing coverage report. Using empty.") + return textcov.Textcov() -def load_existing_rust_textcov(project: str) -> textcov.Textcov: - """Loads existing textcovs for rust project.""" - storage_client = storage.Client.create_anonymous_client() - bucket = storage_client.bucket(OSS_FUZZ_INTROSPECTOR_BUCKET) - blobs = storage_client.list_blobs(bucket, - prefix=f'{project}/inspector-report/', - delimiter='/') - # Iterate through all blobs first to get the prefixes (i.e. "subdirectories"). - for blob in blobs: - continue - - if not blobs.prefixes: # type: ignore - # No existing coverage reports. - logger.info('No existing coverage report. Using empty.') - return textcov.Textcov() - - # Find the latest generated textcov date. - latest_dir = sorted(blobs.prefixes)[-1] # type: ignore - blobs = storage_client.list_blobs(bucket, prefix=latest_dir) - - # Download and merge them. - existing_textcov = textcov.Textcov() - for blob in blobs: - if not blob.name.endswith('.covreport'): - continue - - logger.info('Loading existing textcov from %s', blob.name) - with blob.open('rb') as f: - existing_textcov.merge(textcov.Textcov.from_rust_file(f)) - - return existing_textcov + latest_dir = sorted(blobs.prefixes)[-1] # type: ignore + blob = bucket.blob(f"{latest_dir}all_cov.json") + logger.info("Loading existing all_cov.json textcov from %s", blob.name) + with blob.open() as f: + return textcov.Textcov.from_python_file(f) -def load_existing_coverage_summary(project: str) -> dict: - """Load existing summary.json.""" - storage_client = storage.Client.create_anonymous_client() - bucket = storage_client.bucket(OSS_FUZZ_COVERAGE_BUCKET) - blobs = storage_client.list_blobs(bucket, - prefix=f'{project}/reports/', - delimiter='/') - # Iterate through all blobs first to get the prefixes (i.e. "subdirectories"). - for blob in blobs: - continue - - if not blobs.prefixes: # type: ignore - # No existing coverage reports. - logger.info('No existing coverage reports, using empty one.') - return {} - - latest_dir = sorted(blobs.prefixes)[-1] # type: ignore - blob = bucket.blob(f'{latest_dir}linux/summary.json') - logger.info('Loading existing summary.json from %s', blob.name) - with blob.open() as f: - return json.load(f) - - -def compute_total_lines_without_fuzz_targets(coverage_summary: dict, - fuzz_target_base_name: str) -> int: - """Counts the total number of lines excluding the fuzz target.""" - # TODO(dongge): Exclude all fuzz targets if there are multiple. - return sum([ - f['summary']['lines']['count'] - for f in coverage_summary['data'][0]['files'] - if fuzz_target_base_name not in f['filename'] - ]) +def load_existing_rust_textcov(project: str) -> textcov.Textcov: + """Loads existing textcovs for rust project.""" + storage_client = storage.Client.create_anonymous_client() + bucket = storage_client.bucket(OSS_FUZZ_INTROSPECTOR_BUCKET) + blobs = storage_client.list_blobs( + bucket, prefix=f"{project}/inspector-report/", delimiter="/" + ) + # Iterate through all blobs first to get the prefixes (i.e. "subdirectories"). + for blob in blobs: + continue + if not blobs.prefixes: # type: ignore + # No existing coverage reports. + logger.info("No existing coverage report. Using empty.") + return textcov.Textcov() -# TODO(Dongge): Make this universally available. -class _Logger: - """Log evaluation progress.""" + # Find the latest generated textcov date. + latest_dir = sorted(blobs.prefixes)[-1] # type: ignore + blobs = storage_client.list_blobs(bucket, prefix=latest_dir) - def __init__( - self, - status_path: str, - ): - self._log = open(os.path.join(status_path, 'log.txt'), 'w') - self._result_path = os.path.join(status_path, 'result.json') + # Download and merge them. + existing_textcov = textcov.Textcov() + for blob in blobs: + if not blob.name.endswith(".covreport"): + continue - def log(self, *args, **kwargs): - logger.info(*args, *kwargs) - print(*args, *kwargs, file=self._log) - self._log.flush() + logger.info("Loading existing textcov from %s", blob.name) + with blob.open("rb") as f: + existing_textcov.merge(textcov.Textcov.from_rust_file(f)) - def return_result(self, result: Result): - with open(self._result_path, 'w') as f: - json.dump(result.to_dict(), f) + return existing_textcov - return result + +def load_existing_coverage_summary(project: str) -> dict: + """Load existing summary.json.""" + storage_client = storage.Client.create_anonymous_client() + bucket = storage_client.bucket(OSS_FUZZ_COVERAGE_BUCKET) + blobs = storage_client.list_blobs( + bucket, prefix=f"{project}/reports/", delimiter="/" + ) + # Iterate through all blobs first to get the prefixes (i.e. "subdirectories"). + for blob in blobs: + continue + + if not blobs.prefixes: # type: ignore + # No existing coverage reports. + logger.info("No existing coverage reports, using empty one.") + return {} + + latest_dir = sorted(blobs.prefixes)[-1] # type: ignore + blob = bucket.blob(f"{latest_dir}linux/summary.json") + logger.info("Loading existing summary.json from %s", blob.name) + with blob.open() as f: + return json.load(f) + + +def compute_total_lines_without_fuzz_targets( + coverage_summary: dict, fuzz_target_base_name: str +) -> int: + """Counts the total number of lines excluding the fuzz target.""" + # TODO(dongge): Exclude all fuzz targets if there are multiple. + return sum( + [ + f["summary"]["lines"]["count"] + for f in coverage_summary["data"][0]["files"] + if fuzz_target_base_name not in f["filename"] + ] + ) -class Evaluator: - """Target evaluator.""" - - def __init__(self, runner: builder_runner.BuilderRunner, benchmark: Benchmark, - work_dirs: WorkDirs): - self.builder_runner = runner - self.benchmark = benchmark - self.work_dirs = work_dirs - - def build_log_path(self, generated_target_name: str, iteration: int, - trial: int): - return os.path.join( - self.work_dirs.run_logs, - f'{generated_target_name}-F{iteration}-{trial:02d}.log') - - def run_log_path(self, generated_target_name: str, trial: int): - return os.path.join(self.work_dirs.run_logs, - f'{generated_target_name}-{trial:02d}.log') - - @staticmethod - def create_ossfuzz_project(benchmark: Benchmark, - name: str, - target_file: str, - build_script_path: str = '') -> str: - """Creates an OSS-Fuzz project with the generated target. The new project - will replicate an existing project to |name| but replace its fuzz target - and build script with the new |target_file| and |build_script_path|.""" - logger.info('target file: %s', target_file) - generated_project_path = oss_fuzz_checkout.create_ossfuzz_project( - benchmark, name) - - # Copy generated fuzzers to generated_project_path - shutil.copyfile( - target_file, - os.path.join(generated_project_path, os.path.basename(target_file))) - - # Add additional statement in dockerfile to overwrite with generated fuzzer - with open(os.path.join(generated_project_path, 'Dockerfile'), 'a') as f: - f.write(f'\nCOPY {os.path.basename(target_file)} ' - f'{benchmark.target_path}\n') - - if not build_script_path or os.path.getsize(build_script_path) == 0: - return name - - # Copy generated build script to generated_project_path - shutil.copyfile( - build_script_path, - os.path.join(generated_project_path, - os.path.basename('agent-build.sh'))) - - # Add additional statement in dockerfile to overwrite with generated - # build script - with open(os.path.join(generated_project_path, 'Dockerfile'), 'a') as f: - f.write('\nRUN cp /src/build.sh /src/build.bk.sh\n') - with open(os.path.join(generated_project_path, 'Dockerfile'), 'a') as f: - f.write('\nCOPY agent-build.sh /src/build.sh\n') - - return name - - @staticmethod - def create_ossfuzz_project_with_lldb(benchmark: Benchmark, - name: str, - target_file: str, - run_result: results.RunResult, - build_script_path: str = '', - artifact_path: str = '') -> str: - """Creates an OSS-Fuzz project with the generated target and new dockerfile. - The new project will replicate an existing project |name| but replace its - fuzz target and build script with the new |target_file| and - |build_script_path| and modify its dockerfile.""" - Evaluator.create_ossfuzz_project(benchmark, name, target_file, - build_script_path) - generated_project_path = os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, - 'projects', name) - - shutil.copyfile( - artifact_path, - os.path.join(generated_project_path, os.path.basename(artifact_path))) - # Add additional statement in dockerfile to copy testcase, - # enable -g, install lldb and screen - with open(os.path.join(generated_project_path, 'Dockerfile'), 'a') as f: - f.write( - '\nRUN mkdir -p /artifact\n' - f'\nCOPY {os.path.basename(run_result.artifact_path)} /artifact/\n' - '\nENV CFLAGS="${CFLAGS} -g -O0"\n' - '\nENV CXXFLAGS="${CXXFLAGS} -g -O0"\n' - '\nRUN apt-get update\n' - '\nRUN apt-get install -y lldb\n' - '\nRUN apt-get install -y screen\n') - - return name - - def _fix_generated_fuzz_target(self, ai_binary: str, - generated_oss_fuzz_project: str, - target_path: str, iteration: int, - build_result: BuildResult, - run_result: Optional[RunResult], - dual_logger: _Logger, language: str): - """Fixes the generated fuzz target.""" - error_desc, errors = '', [] - if build_result.succeeded: - if language != 'jvm': - if run_result: - error_desc, errors = run_result.semantic_check.get_error_info() - else: - dual_logger.log(f'Warning: Build succeed but no run_result in ' - f'{generated_oss_fuzz_project}.') - else: - error_desc, errors = None, build_result.errors - - code_fixer.llm_fix(ai_binary, target_path, self.benchmark, iteration, - error_desc, errors, self.builder_runner.fixer_model_name, - language) - shutil.copyfile( - target_path, - os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, 'projects', - generated_oss_fuzz_project, os.path.basename(target_path))) - - def triage_crash( - self, - ai_binary: str, - generated_oss_fuzz_project: str, - driver_path: str, - run_result: RunResult, - dual_logger: _Logger, - ) -> str: - """Triages the crash.""" - if run_result.crash_info: - crash_info = run_result.crash_info - crash_func = run_result.semantic_check.crash_func - return crash_triager.llm_triage( - ai_binary, - driver_path, - self.benchmark, - crash_info, - crash_func, - self.builder_runner.fixer_model_name, - ) - - dual_logger.log(f'Warning: no crash info in {generated_oss_fuzz_project}.') - return TriageResult.NOT_APPLICABLE - - def extend_build_with_corpus(self, ai_binary, target_path, - generated_oss_fuzz_project): - """Extends an OSS-Fuzz project with corpus generated programmatically.""" - generated_project_path = os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, - 'projects', - generated_oss_fuzz_project) - generated_corp = corpus_generator.get_script( - ai_binary, self.builder_runner.fixer_model_name, target_path, - self.benchmark) - - corpus_generator_path = os.path.join(generated_project_path, 'corp_gen.py') - with open(corpus_generator_path, 'w') as f: - f.write(generated_corp) - - with open(os.path.join(generated_project_path, 'Dockerfile'), 'a') as f: - f.write('COPY corp_gen.py $SRC/corp_gen.py\n') - target_harness_file = os.path.basename(self.benchmark.target_path) - target_harness_file = os.path.splitext(target_harness_file)[0] - corpus_dst = '/src/generated-corpus/*' - with open(os.path.join(generated_project_path, 'build.sh'), 'a') as f: - f.write('\n# Generate a corpus for the modified harness.') - f.write('\nmkdir -p /src/generated-corpus') - f.write('\npushd /src/generated-corpus') - f.write('\npython3 $SRC/corp_gen.py') - f.write('\npopd') - f.write(f'\nzip $OUT/{target_harness_file}_seed_corpus.zip {corpus_dst}') - - def check_target(self, ai_binary, target_path: str) -> Result: - """Builds and runs a target.""" - generated_target_name = os.path.basename(target_path) - sample_id = os.path.splitext(generated_target_name)[0] - generated_oss_fuzz_project = f'{self.benchmark.id}-{sample_id}' - generated_oss_fuzz_project = oss_fuzz_checkout.rectify_docker_tag( - generated_oss_fuzz_project) - Evaluator.create_ossfuzz_project(self.benchmark, generated_oss_fuzz_project, - target_path) - - status_path = os.path.join(self.work_dirs.status, sample_id) - os.makedirs(status_path, exist_ok=True) - - dual_logger = _Logger(status_path) - - # Try building and running the new target. - - # TODO: Log build failure. - # TODO: Log run success/failure. - - if GENERATE_CORPUS: - self.extend_build_with_corpus(ai_binary, target_path, - generated_oss_fuzz_project) - - # Loop of evaluating and fixing fuzz target. - llm_fix_count = 0 - while True: - # 1. Evaluating generated driver. - try: - build_result, run_result = self.builder_runner.build_and_run( - generated_oss_fuzz_project, target_path, llm_fix_count, - self.benchmark.language) - except Exception as e: - dual_logger.log( - 'Exception occurred when building and running fuzz target ' - f'in attempt {llm_fix_count}: {e}') - build_result = BuildResult() - run_result = None - - # 2. Calculate coverage percentage and coverage diff - coverage_summary = None - total_lines = 0 - coverage_percent = 0.0 - coverage_diff = 0.0 - if run_result: - # Gets line coverage (diff) details. - coverage_summary = self._load_existing_coverage_summary() - - if self.benchmark.language in ['python', 'jvm'] and run_result.coverage: - # The Jacoco.xml coverage report used to generate summary.json on - # OSS-Fuzz for JVM projects does not trace the source file location. - # Thus the conversion may miss some classes because they are not - # present during coverage report generation. This fix gets the total - # line calculation from the jacoco.xml report of the current run - # directly and compares it with the total_lines retrieved from - # summary.json. Then the larger total_lines is used which is assumed - # to be more accurate. This is the same case for python project which - # the total line is determined from the all_cov.json file. - total_lines = run_result.coverage.total_lines - elif coverage_summary: - total_lines = compute_total_lines_without_fuzz_targets( - coverage_summary, generated_target_name) - else: - total_lines = 0 +# TODO(Dongge): Make this universally available. +class _Logger: + """Log evaluation progress.""" - if run_result.total_pcs: - coverage_percent = run_result.cov_pcs / run_result.total_pcs - else: - dual_logger.log( - f'Warning: total_pcs == 0 in {generated_oss_fuzz_project}.') - coverage_percent = 0.0 + def __init__( + self, + status_path: str, + ): + self._log = open(os.path.join(status_path, "log.txt"), "w") + self._result_path = os.path.join(status_path, "result.json") - existing_textcov = self.load_existing_textcov() - if run_result.coverage: - run_result.coverage.subtract_covered_lines(existing_textcov) + def log(self, *args, **kwargs): + logger.info(*args, *kwargs) + print(*args, *kwargs, file=self._log) + self._log.flush() - if total_lines and run_result.coverage: - coverage_diff = run_result.coverage.covered_lines / total_lines - else: - dual_logger.log( - f'Warning: total_lines == 0 in {generated_oss_fuzz_project}.') - coverage_diff = 0.0 - - if self.benchmark.language == 'jvm': - # For JVM, the generation is consider success if either is true - # 1) Build success and run crashed (expected for exceptions) - # 2) Build success, run success and coverage diff > 0 - gen_succ = build_result.succeeded and run_result - if gen_succ and run_result and run_result.succeeded: - gen_succ = gen_succ and (coverage_diff > 0) - else: - # Should not concern run_result.succeeded for generation otherwise - # it may make a good fuzz target bad. - # Should concern run_result.succeeded for analyzes to know semantic - # errors - gen_succ = build_result.succeeded - - if gen_succ or llm_fix_count >= LLM_FIX_LIMIT: - # Exit cond 1: successfully generate the fuzz target. - # Exit cond 2: fix limit is reached. - break - - # 2. Fixing generated driver - llm_fix_count += 1 - dual_logger.log(f'Fixing {target_path} with ' - f'{self.builder_runner.fixer_model_name}, ' - f'attempt {llm_fix_count}.') - try: - self._fix_generated_fuzz_target(ai_binary, generated_oss_fuzz_project, - target_path, llm_fix_count, - build_result, run_result, dual_logger, - self.benchmark.language) - except Exception as e: - dual_logger.log('Exception occurred when fixing fuzz target in attempt ' - f'{llm_fix_count}: {e}') - break - - # Logs and returns the result. - if not build_result.succeeded: - dual_logger.log(f'Failed to build {target_path} with ' - f'{self.builder_runner.fixer_model_name} in ' - f'{llm_fix_count} iterations of fixing.') - return dual_logger.return_result( - Result(False, - False, - False, - 0.0, - 0.0, - '', - '', - False, - SemanticCheckResult.NOT_APPLICABLE, - TriageResult.NOT_APPLICABLE, - compile_error=build_result.log_path, - compile_log=build_result.log_path)) - - dual_logger.log(f'Successfully built {target_path} with ' - f'{self.builder_runner.fixer_model_name} in ' - f'{llm_fix_count} iterations of fixing.') - - if not run_result: - dual_logger.log( - f'Warning: no run result in {generated_oss_fuzz_project}.') - return dual_logger.return_result( - Result(False, - True, - False, - 0.0, - 0.0, - '', - '', - False, - SemanticCheckResult.NOT_APPLICABLE, - TriageResult.NOT_APPLICABLE, - compile_error=build_result.log_path, - compile_log=build_result.log_path)) - - # Triage the crash with LLM - dual_logger.log(f'Triaging the crash related to {target_path} with ' - f'{self.builder_runner.fixer_model_name}.') - run_result.triage = self.triage_crash( - ai_binary, - generated_oss_fuzz_project, - target_path, - run_result, - dual_logger, - ) + def return_result(self, result: Result): + with open(self._result_path, "w") as f: + json.dump(result.to_dict(), f) - if run_result.coverage_summary is None or run_result.coverage is None: - dual_logger.log( - f'Warning: No cov info in run result of {generated_oss_fuzz_project}.' - ) - return dual_logger.return_result( - Result(False, - True, - run_result.crashes, - 0.0, - 0.0, - '', - '', - not run_result.succeeded, - run_result.semantic_check.type, - run_result.triage, - compile_error=build_result.log_path, - compile_log=build_result.log_path)) - - dual_logger.log( - f'Result for {generated_oss_fuzz_project}: ' - f'crashes={run_result.crashes}, coverage={coverage_percent} ' - f'({run_result.cov_pcs}/{run_result.total_pcs}), ' - f'coverage diff={coverage_diff} ' - f'({run_result.coverage.covered_lines}/{total_lines})') - return dual_logger.return_result( - Result(False, - True, - run_result.crashes, - coverage_percent, - coverage_diff, - run_result.coverage_report_path, - run_result.reproducer_path, - not run_result.succeeded, - run_result.semantic_check.type, - run_result.triage, - run_result.coverage, - compile_error=build_result.log_path, - compile_log=build_result.log_path)) - - def _load_existing_coverage_summary(self) -> dict: - """Load existing summary.json.""" - return load_existing_coverage_summary(self.benchmark.project) + return result - def load_existing_textcov(self) -> textcov.Textcov: - """Loads existing textcovs.""" - if self.benchmark.language == 'jvm': - return load_existing_jvm_textcov(self.benchmark.project) - if self.benchmark.language == 'python': - return load_existing_python_textcov(self.benchmark.project) +class Evaluator: + """Target evaluator.""" + + def __init__( + self, + runner: builder_runner.BuilderRunner, + benchmark: Benchmark, + work_dirs: WorkDirs, + ): + self.builder_runner = runner + self.benchmark = benchmark + self.work_dirs = work_dirs + + def build_log_path(self, generated_target_name: str, iteration: int, trial: int): + return os.path.join( + self.work_dirs.run_logs, + f"{generated_target_name}-F{iteration}-{trial:02d}.log", + ) + + def run_log_path(self, generated_target_name: str, trial: int): + return os.path.join( + self.work_dirs.run_logs, f"{generated_target_name}-{trial:02d}.log" + ) + + @staticmethod + def create_ossfuzz_project( + benchmark: Benchmark, name: str, target_file: str, build_script_path: str = "" + ) -> str: + """Creates an OSS-Fuzz project with the generated target. The new project + will replicate an existing project to |name| but replace its fuzz target + and build script with the new |target_file| and |build_script_path|.""" + logger.info("target file: %s", target_file) + generated_project_path = oss_fuzz_checkout.create_ossfuzz_project( + benchmark, name + ) + + # Copy generated fuzzers to generated_project_path + shutil.copyfile( + target_file, + os.path.join(generated_project_path, os.path.basename(target_file)), + ) + + # Add additional statement in dockerfile to overwrite with generated fuzzer + with open(os.path.join(generated_project_path, "Dockerfile"), "a") as f: + f.write( + f"\nCOPY {os.path.basename(target_file)} " f"{benchmark.target_path}\n" + ) + + if not build_script_path or os.path.getsize(build_script_path) == 0: + return name + + # Copy generated build script to generated_project_path + shutil.copyfile( + build_script_path, + os.path.join(generated_project_path, os.path.basename("agent-build.sh")), + ) + + # Add additional statement in dockerfile to overwrite with generated + # build script + with open(os.path.join(generated_project_path, "Dockerfile"), "a") as f: + f.write("\nRUN cp /src/build.sh /src/build.bk.sh\n") + with open(os.path.join(generated_project_path, "Dockerfile"), "a") as f: + f.write("\nCOPY agent-build.sh /src/build.sh\n") + + return name + + @staticmethod + def create_ossfuzz_project_with_lldb( + benchmark: Benchmark, + name: str, + target_file: str, + run_result: results.RunResult, + build_script_path: str = "", + artifact_path: str = "", + ) -> str: + """Creates an OSS-Fuzz project with the generated target and new dockerfile. + The new project will replicate an existing project |name| but replace its + fuzz target and build script with the new |target_file| and + |build_script_path| and modify its dockerfile.""" + Evaluator.create_ossfuzz_project( + benchmark, name, target_file, build_script_path + ) + generated_project_path = os.path.join( + oss_fuzz_checkout.OSS_FUZZ_DIR, "projects", name + ) + + shutil.copyfile( + artifact_path, + os.path.join(generated_project_path, os.path.basename(artifact_path)), + ) + # Add additional statement in dockerfile to copy testcase, + # enable -g, install lldb and screen + with open(os.path.join(generated_project_path, "Dockerfile"), "a") as f: + f.write( + "\nRUN mkdir -p /artifact\n" + f"\nCOPY {os.path.basename(run_result.artifact_path)} /artifact/\n" + '\nENV CFLAGS="${CFLAGS} -g -O0"\n' + '\nENV CXXFLAGS="${CXXFLAGS} -g -O0"\n' + "\nRUN apt-get update\n" + "\nRUN apt-get install -y lldb\n" + "\nRUN apt-get install -y screen\n" + ) + + return name + + def _fix_generated_fuzz_target( + self, + ai_binary: str, + generated_oss_fuzz_project: str, + target_path: str, + iteration: int, + build_result: BuildResult, + run_result: Optional[RunResult], + dual_logger: _Logger, + language: str, + ): + """Fixes the generated fuzz target.""" + error_desc, errors = "", [] + if build_result.succeeded: + if language != "jvm": + if run_result: + error_desc, errors = run_result.semantic_check.get_error_info() + else: + dual_logger.log( + f"Warning: Build succeed but no run_result in " + f"{generated_oss_fuzz_project}." + ) + else: + error_desc, errors = None, build_result.errors + + code_fixer.llm_fix( + ai_binary, + target_path, + self.benchmark, + iteration, + error_desc, + errors, + self.builder_runner.fixer_model_name, + language, + ) + shutil.copyfile( + target_path, + os.path.join( + oss_fuzz_checkout.OSS_FUZZ_DIR, + "projects", + generated_oss_fuzz_project, + os.path.basename(target_path), + ), + ) + + def triage_crash( + self, + ai_binary: str, + generated_oss_fuzz_project: str, + driver_path: str, + run_result: RunResult, + dual_logger: _Logger, + ) -> str: + """Triages the crash.""" + if run_result.crash_info: + crash_info = run_result.crash_info + crash_func = run_result.semantic_check.crash_func + return crash_triager.llm_triage( + ai_binary, + driver_path, + self.benchmark, + crash_info, + crash_func, + self.builder_runner.fixer_model_name, + ) + + dual_logger.log(f"Warning: no crash info in {generated_oss_fuzz_project}.") + return TriageResult.NOT_APPLICABLE + + def extend_build_with_corpus( + self, ai_binary, target_path, generated_oss_fuzz_project + ): + """Extends an OSS-Fuzz project with corpus generated programmatically.""" + generated_project_path = os.path.join( + oss_fuzz_checkout.OSS_FUZZ_DIR, "projects", generated_oss_fuzz_project + ) + generated_corp = corpus_generator.get_script( + ai_binary, self.builder_runner.fixer_model_name, target_path, self.benchmark + ) + + corpus_generator_path = os.path.join(generated_project_path, "corp_gen.py") + with open(corpus_generator_path, "w") as f: + f.write(generated_corp) + + with open(os.path.join(generated_project_path, "Dockerfile"), "a") as f: + f.write("COPY corp_gen.py $SRC/corp_gen.py\n") + target_harness_file = os.path.basename(self.benchmark.target_path) + target_harness_file = os.path.splitext(target_harness_file)[0] + corpus_dst = "/src/generated-corpus/*" + with open(os.path.join(generated_project_path, "build.sh"), "a") as f: + f.write("\n# Generate a corpus for the modified harness.") + f.write("\nmkdir -p /src/generated-corpus") + f.write("\npushd /src/generated-corpus") + f.write("\npython3 $SRC/corp_gen.py") + f.write("\npopd") + f.write(f"\nzip $OUT/{target_harness_file}_seed_corpus.zip {corpus_dst}") + + def check_target(self, ai_binary, target_path: str) -> Result: + """Builds and runs a target.""" + generated_target_name = os.path.basename(target_path) + sample_id = os.path.splitext(generated_target_name)[0] + generated_oss_fuzz_project = f"{self.benchmark.id}-{sample_id}" + generated_oss_fuzz_project = oss_fuzz_checkout.rectify_docker_tag( + generated_oss_fuzz_project + ) + Evaluator.create_ossfuzz_project( + self.benchmark, generated_oss_fuzz_project, target_path + ) + + status_path = os.path.join(self.work_dirs.status, sample_id) + os.makedirs(status_path, exist_ok=True) + + dual_logger = _Logger(status_path) + + # Try building and running the new target. + + # TODO: Log build failure. + # TODO: Log run success/failure. + + if GENERATE_CORPUS: + self.extend_build_with_corpus( + ai_binary, target_path, generated_oss_fuzz_project + ) + + # Loop of evaluating and fixing fuzz target. + llm_fix_count = 0 + while True: + # 1. Evaluating generated driver. + try: + build_result, run_result = self.builder_runner.build_and_run( + generated_oss_fuzz_project, + target_path, + llm_fix_count, + self.benchmark.language, + ) + except Exception as e: + dual_logger.log( + "Exception occurred when building and running fuzz target " + f"in attempt {llm_fix_count}: {e}" + ) + build_result = BuildResult() + run_result = None + + # 2. Calculate coverage percentage and coverage diff + coverage_summary = None + total_lines = 0 + coverage_percent = 0.0 + coverage_diff = 0.0 + if run_result: + # Gets line coverage (diff) details. + coverage_summary = self._load_existing_coverage_summary() + + if self.benchmark.language in ["python", "jvm"] and run_result.coverage: + # The Jacoco.xml coverage report used to generate summary.json on + # OSS-Fuzz for JVM projects does not trace the source file location. + # Thus the conversion may miss some classes because they are not + # present during coverage report generation. This fix gets the total + # line calculation from the jacoco.xml report of the current run + # directly and compares it with the total_lines retrieved from + # summary.json. Then the larger total_lines is used which is assumed + # to be more accurate. This is the same case for python project which + # the total line is determined from the all_cov.json file. + total_lines = run_result.coverage.total_lines + elif coverage_summary: + total_lines = compute_total_lines_without_fuzz_targets( + coverage_summary, generated_target_name + ) + else: + total_lines = 0 + + if run_result.total_pcs: + coverage_percent = run_result.cov_pcs / run_result.total_pcs + else: + dual_logger.log( + f"Warning: total_pcs == 0 in {generated_oss_fuzz_project}." + ) + coverage_percent = 0.0 + + existing_textcov = self.load_existing_textcov() + if run_result.coverage: + run_result.coverage.subtract_covered_lines(existing_textcov) + + if total_lines and run_result.coverage: + coverage_diff = run_result.coverage.covered_lines / total_lines + else: + dual_logger.log( + f"Warning: total_lines == 0 in {generated_oss_fuzz_project}." + ) + coverage_diff = 0.0 + + if self.benchmark.language == "jvm": + # For JVM, the generation is consider success if either is true + # 1) Build success and run crashed (expected for exceptions) + # 2) Build success, run success and coverage diff > 0 + gen_succ = build_result.succeeded and run_result + if gen_succ and run_result and run_result.succeeded: + gen_succ = gen_succ and (coverage_diff > 0) + else: + # Should not concern run_result.succeeded for generation otherwise + # it may make a good fuzz target bad. + # Should concern run_result.succeeded for analyzes to know semantic + # errors + gen_succ = build_result.succeeded + + if gen_succ or llm_fix_count >= LLM_FIX_LIMIT: + # Exit cond 1: successfully generate the fuzz target. + # Exit cond 2: fix limit is reached. + break + + # 2. Fixing generated driver + llm_fix_count += 1 + dual_logger.log( + f"Fixing {target_path} with " + f"{self.builder_runner.fixer_model_name}, " + f"attempt {llm_fix_count}." + ) + try: + self._fix_generated_fuzz_target( + ai_binary, + generated_oss_fuzz_project, + target_path, + llm_fix_count, + build_result, + run_result, + dual_logger, + self.benchmark.language, + ) + except Exception as e: + dual_logger.log( + "Exception occurred when fixing fuzz target in attempt " + f"{llm_fix_count}: {e}" + ) + break + + # Logs and returns the result. + if not build_result.succeeded: + dual_logger.log( + f"Failed to build {target_path} with " + f"{self.builder_runner.fixer_model_name} in " + f"{llm_fix_count} iterations of fixing." + ) + return dual_logger.return_result( + Result( + False, + False, + False, + 0.0, + 0.0, + "", + "", + False, + SemanticCheckResult.NOT_APPLICABLE, + TriageResult.NOT_APPLICABLE, + compile_error=build_result.log_path, + compile_log=build_result.log_path, + ) + ) - if self.benchmark.language == 'rust': - return load_existing_rust_textcov(self.benchmark.project) + dual_logger.log( + f"Successfully built {target_path} with " + f"{self.builder_runner.fixer_model_name} in " + f"{llm_fix_count} iterations of fixing." + ) + + if not run_result: + dual_logger.log(f"Warning: no run result in {generated_oss_fuzz_project}.") + return dual_logger.return_result( + Result( + False, + True, + False, + 0.0, + 0.0, + "", + "", + False, + SemanticCheckResult.NOT_APPLICABLE, + TriageResult.NOT_APPLICABLE, + compile_error=build_result.log_path, + compile_log=build_result.log_path, + ) + ) + + # Triage the crash with LLM + dual_logger.log( + f"Triaging the crash related to {target_path} with " + f"{self.builder_runner.fixer_model_name}." + ) + run_result.triage = self.triage_crash( + ai_binary, + generated_oss_fuzz_project, + target_path, + run_result, + dual_logger, + ) + + if run_result.coverage_summary is None or run_result.coverage is None: + dual_logger.log( + f"Warning: No cov info in run result of {generated_oss_fuzz_project}." + ) + return dual_logger.return_result( + Result( + False, + True, + run_result.crashes, + 0.0, + 0.0, + "", + "", + not run_result.succeeded, + run_result.semantic_check.type, + run_result.triage, + compile_error=build_result.log_path, + compile_log=build_result.log_path, + ) + ) - return load_existing_textcov(self.benchmark.project) + dual_logger.log( + f"Result for {generated_oss_fuzz_project}: " + f"crashes={run_result.crashes}, coverage={coverage_percent} " + f"({run_result.cov_pcs}/{run_result.total_pcs}), " + f"coverage diff={coverage_diff} " + f"({run_result.coverage.covered_lines}/{total_lines})" + ) + return dual_logger.return_result( + Result( + False, + True, + run_result.crashes, + coverage_percent, + coverage_diff, + run_result.coverage_report_path, + run_result.reproducer_path, + not run_result.succeeded, + run_result.semantic_check.type, + run_result.triage, + run_result.coverage, + compile_error=build_result.log_path, + compile_log=build_result.log_path, + ) + ) + + def _load_existing_coverage_summary(self) -> dict: + """Load existing summary.json.""" + return load_existing_coverage_summary(self.benchmark.project) + + def load_existing_textcov(self) -> textcov.Textcov: + """Loads existing textcovs.""" + if self.benchmark.language == "jvm": + return load_existing_jvm_textcov(self.benchmark.project) + + if self.benchmark.language == "python": + return load_existing_python_textcov(self.benchmark.project) + + if self.benchmark.language == "rust": + return load_existing_rust_textcov(self.benchmark.project) + + return load_existing_textcov(self.benchmark.project) diff --git a/experiment/fuzz_target_error.py b/experiment/fuzz_target_error.py index cf305a8a97..076f700247 100644 --- a/experiment/fuzz_target_error.py +++ b/experiment/fuzz_target_error.py @@ -22,169 +22,197 @@ class SemanticCheckResult: - """Fuzz target semantic check results.""" - NOT_APPLICABLE = '-' - NO_SEMANTIC_ERR = 'NO_SEMANTIC_ERR' - LOG_MESS_UP = 'LOG_MESS_UP' - FP_NEAR_INIT_CRASH = 'FP_NEAR_INIT_CRASH' - FP_TARGET_CRASH = 'FP_TARGET_CRASH' - FP_MEMLEAK = 'FP_MEMLEAK' - FP_OOM = 'FP_OOM' - FP_TIMEOUT = 'FP_TIMEOUT' - NO_COV_INCREASE = 'NO_COV_INCREASE' - NULL_DEREF = 'NULL_DEREF' - SIGNAL = 'SIGNAL' - EXIT = 'EXIT' - OVERWRITE_CONST = 'OVERWRITE_CONST' - - # Regex for extract crash symptoms. - # Matches over 18 types of ASAN errors symptoms - # e.g. ERROR: AddressSanitizer: attempting use-after-free on xxx - # e.g. ERROR: AddressSanitizer: attempting stack-overflow on xxx - # e.g. ERROR: AddressSanitizer: attempting negative-size-param on xxx - # Full list here: - # https://github.com/occia/fuzzdrivergpt/blob/35b0e957a61be8bd506017cda621a50e75f5acdb/validation/libVR.py#L466-L485. - SYMPTOM_ASAN = re.compile(r'ERROR: AddressSanitizer: (.*)\n') - # Matches 'ERROR: libFuzzer: timeout after xxx' - SYMPTOM_LIBFUZZER = re.compile(r'ERROR: libFuzzer: (.*)\n') - # E.g., matches 'SCARINESS: 10 (null-deref)' - SYMPTOM_SCARINESS = re.compile(r'SCARINESS:\s*\d+\s*\((.*)\)\n') - - # Regex for extract crash information. - INFO_CRASH = re.compile(r'ERROR: (.*?)(?=SUMMARY)', re.DOTALL) - # Regex for extract artifact file name. - ARTIFACT_NAME = re.compile(r'(?<=written to ./)crash-[\w]+') - - NO_COV_INCREASE_MSG_PREFIX = 'No code coverage increasement' - - @classmethod - def extract_symptom(cls, fuzzlog: str) -> str: - """Extracts crash symptom from fuzzing logs.""" - # Need to catch this before ASAN. - match = cls.SYMPTOM_SCARINESS.search(fuzzlog) - if match: - return match.group(1).strip() - - match = cls.SYMPTOM_ASAN.search(fuzzlog) - if match: - return f'ASAN-{match.group(0).strip()}' - - match = cls.SYMPTOM_LIBFUZZER.search(fuzzlog) - if match: - return f'libFuzzer-{match.group(0).strip()}' - - return '' - - @classmethod - def is_no_cov_increase_err(cls, error_desc: Optional[str]) -> bool: - return (error_desc is not None) and error_desc.startswith( - cls.NO_COV_INCREASE_MSG_PREFIX) - - @classmethod - def extract_crash_info(cls, fuzzlog: str) -> str: - """Extracts crash information from fuzzing logs.""" - match = cls.INFO_CRASH.search(fuzzlog) - if match: - return match.group(1) - - logging.warning('Failed to match crash information.') - return '' - - @classmethod - def extract_artifact_name(cls, fuzzlog: str) -> str: - """Extracts artifact file name from fuzzing logs.""" - match = cls.ARTIFACT_NAME.search(fuzzlog) - if match: - return match.group(0).strip() - - logging.warning('Failed to match artifact file name.') - return 'testcase' - - def __init__(self, - err_type: str, - crash_symptom: str = '', - crash_stacks: Optional[list[list[str]]] = None, - crash_func: Optional[dict] = None): - self.type = err_type - self.crash_symptom = crash_symptom - self.crash_stacks = crash_stacks if crash_stacks else [] - self.crash_func = crash_func if crash_func else {} - - def __repr__(self) -> str: - return (f'{self.__class__.__name__}' - f'({", ".join(f"{k}={v!r}" for k, v in vars(self).items())})') - - def _get_error_desc(self) -> str: - """Returns one sentence error description used in fix prompt.""" - if self.type == self.LOG_MESS_UP: - # TODO(happy-qop): Add detailed description for this error type. - return 'Overlong fuzzing log.' - if self.type == self.FP_NEAR_INIT_CRASH: - return (f'Fuzzing crashed immediately at runtime ({self.crash_symptom})' - ', indicating fuzz target code for invoking the function under' - ' test is incorrect or unrobust.') - if self.type == self.FP_TARGET_CRASH: - return (f'Fuzzing has crashes ({self.crash_symptom}) caused by fuzz ' - 'target code, indicating its usage for the function under ' - 'test is incorrect or unrobust.') - if self.type == self.FP_MEMLEAK: - return ('Memory leak detected, indicating some memory was not freed ' - 'by the fuzz target.') - if self.type == self.FP_OOM: - return ('Out-of-memory error detected, suggesting the fuzz target ' - 'incorrectly allocates too much memory or has a memory leak.') - if self.type == self.FP_TIMEOUT: - return ('Fuzz target timed out at runtime, indicating its usage for ' - 'the function under test is incorrect or unrobust.') - if self.type == self.NO_COV_INCREASE: - # TODO(dongge): Append the implementation of the function under test. - return (self.NO_COV_INCREASE_MSG_PREFIX + ', indicating the fuzz target' - ' ineffectively invokes the function under test.') - if self.type == self.NULL_DEREF: - return ('Accessing a null pointer, indicating improper parameter ' - 'initialization or incorrect function usages in the fuzz target.') - if self.type == self.SIGNAL: - return ('Abort with signal, indicating the fuzz target has violated some ' - 'assertion in the project, likely due to improper parameter ' - 'initialization or incorrect function usages.') - if self.type == self.EXIT: - return ('Fuzz target exited in a controlled manner without showing any ' - 'sign of memory corruption, likely due to the fuzz target is not ' - 'well designed to effectively find memory corruption ' - 'vulnerability in the function-under-test.') - if self.type == self.OVERWRITE_CONST: - return ('Fuzz target modified a const data. To fix this, ensure that all ' - 'input data passed to the fuzz target is treated as read-only ' - 'and not modified. Copy the input data to a separate buffer if ' - 'any modifications are necessary.') - - return '' - - def _get_error_detail(self) -> list[str]: - """Returns detailed error description used in fix prompt.""" - if self.type not in [ - self.FP_NEAR_INIT_CRASH, self.FP_TARGET_CRASH, self.FP_TIMEOUT - ]: - return [] - - detail = ['Crash stacks:'] - for index, stack in enumerate(self.crash_stacks): - detail.append(f'Stack {index}:') - detail.extend(stack) - return detail - - def get_error_info(self) -> tuple[str, list[str]]: - return self._get_error_desc(), self._get_error_detail() - - @property - def has_err(self) -> bool: - return self.type not in (self.NOT_APPLICABLE, self.NO_SEMANTIC_ERR) - - def to_dict(self): - return { - 'has_err': self.has_err, - 'err_type': self.type, - 'crash_symptom': self.crash_symptom, - 'crash_stacks': self.crash_stacks, - 'crash_func': self.crash_func, - } + """Fuzz target semantic check results.""" + + NOT_APPLICABLE = "-" + NO_SEMANTIC_ERR = "NO_SEMANTIC_ERR" + LOG_MESS_UP = "LOG_MESS_UP" + FP_NEAR_INIT_CRASH = "FP_NEAR_INIT_CRASH" + FP_TARGET_CRASH = "FP_TARGET_CRASH" + FP_MEMLEAK = "FP_MEMLEAK" + FP_OOM = "FP_OOM" + FP_TIMEOUT = "FP_TIMEOUT" + NO_COV_INCREASE = "NO_COV_INCREASE" + NULL_DEREF = "NULL_DEREF" + SIGNAL = "SIGNAL" + EXIT = "EXIT" + OVERWRITE_CONST = "OVERWRITE_CONST" + + # Regex for extract crash symptoms. + # Matches over 18 types of ASAN errors symptoms + # e.g. ERROR: AddressSanitizer: attempting use-after-free on xxx + # e.g. ERROR: AddressSanitizer: attempting stack-overflow on xxx + # e.g. ERROR: AddressSanitizer: attempting negative-size-param on xxx + # Full list here: + # https://github.com/occia/fuzzdrivergpt/blob/35b0e957a61be8bd506017cda621a50e75f5acdb/validation/libVR.py#L466-L485. + SYMPTOM_ASAN = re.compile(r"ERROR: AddressSanitizer: (.*)\n") + # Matches 'ERROR: libFuzzer: timeout after xxx' + SYMPTOM_LIBFUZZER = re.compile(r"ERROR: libFuzzer: (.*)\n") + # E.g., matches 'SCARINESS: 10 (null-deref)' + SYMPTOM_SCARINESS = re.compile(r"SCARINESS:\s*\d+\s*\((.*)\)\n") + + # Regex for extract crash information. + INFO_CRASH = re.compile(r"ERROR: (.*?)(?=SUMMARY)", re.DOTALL) + # Regex for extract artifact file name. + ARTIFACT_NAME = re.compile(r"(?<=written to ./)crash-[\w]+") + + NO_COV_INCREASE_MSG_PREFIX = "No code coverage increasement" + + @classmethod + def extract_symptom(cls, fuzzlog: str) -> str: + """Extracts crash symptom from fuzzing logs.""" + # Need to catch this before ASAN. + match = cls.SYMPTOM_SCARINESS.search(fuzzlog) + if match: + return match.group(1).strip() + + match = cls.SYMPTOM_ASAN.search(fuzzlog) + if match: + return f"ASAN-{match.group(0).strip()}" + + match = cls.SYMPTOM_LIBFUZZER.search(fuzzlog) + if match: + return f"libFuzzer-{match.group(0).strip()}" + + return "" + + @classmethod + def is_no_cov_increase_err(cls, error_desc: Optional[str]) -> bool: + return (error_desc is not None) and error_desc.startswith( + cls.NO_COV_INCREASE_MSG_PREFIX + ) + + @classmethod + def extract_crash_info(cls, fuzzlog: str) -> str: + """Extracts crash information from fuzzing logs.""" + match = cls.INFO_CRASH.search(fuzzlog) + if match: + return match.group(1) + + logging.warning("Failed to match crash information.") + return "" + + @classmethod + def extract_artifact_name(cls, fuzzlog: str) -> str: + """Extracts artifact file name from fuzzing logs.""" + match = cls.ARTIFACT_NAME.search(fuzzlog) + if match: + return match.group(0).strip() + + logging.warning("Failed to match artifact file name.") + return "testcase" + + def __init__( + self, + err_type: str, + crash_symptom: str = "", + crash_stacks: Optional[list[list[str]]] = None, + crash_func: Optional[dict] = None, + ): + self.type = err_type + self.crash_symptom = crash_symptom + self.crash_stacks = crash_stacks if crash_stacks else [] + self.crash_func = crash_func if crash_func else {} + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}" + f'({", ".join(f"{k}={v!r}" for k, v in vars(self).items())})' + ) + + def _get_error_desc(self) -> str: + """Returns one sentence error description used in fix prompt.""" + if self.type == self.LOG_MESS_UP: + # TODO(happy-qop): Add detailed description for this error type. + return "Overlong fuzzing log." + if self.type == self.FP_NEAR_INIT_CRASH: + return ( + f"Fuzzing crashed immediately at runtime ({self.crash_symptom})" + ", indicating fuzz target code for invoking the function under" + " test is incorrect or unrobust." + ) + if self.type == self.FP_TARGET_CRASH: + return ( + f"Fuzzing has crashes ({self.crash_symptom}) caused by fuzz " + "target code, indicating its usage for the function under " + "test is incorrect or unrobust." + ) + if self.type == self.FP_MEMLEAK: + return ( + "Memory leak detected, indicating some memory was not freed " + "by the fuzz target." + ) + if self.type == self.FP_OOM: + return ( + "Out-of-memory error detected, suggesting the fuzz target " + "incorrectly allocates too much memory or has a memory leak." + ) + if self.type == self.FP_TIMEOUT: + return ( + "Fuzz target timed out at runtime, indicating its usage for " + "the function under test is incorrect or unrobust." + ) + if self.type == self.NO_COV_INCREASE: + # TODO(dongge): Append the implementation of the function under test. + return ( + self.NO_COV_INCREASE_MSG_PREFIX + ", indicating the fuzz target" + " ineffectively invokes the function under test." + ) + if self.type == self.NULL_DEREF: + return ( + "Accessing a null pointer, indicating improper parameter " + "initialization or incorrect function usages in the fuzz target." + ) + if self.type == self.SIGNAL: + return ( + "Abort with signal, indicating the fuzz target has violated some " + "assertion in the project, likely due to improper parameter " + "initialization or incorrect function usages." + ) + if self.type == self.EXIT: + return ( + "Fuzz target exited in a controlled manner without showing any " + "sign of memory corruption, likely due to the fuzz target is not " + "well designed to effectively find memory corruption " + "vulnerability in the function-under-test." + ) + if self.type == self.OVERWRITE_CONST: + return ( + "Fuzz target modified a const data. To fix this, ensure that all " + "input data passed to the fuzz target is treated as read-only " + "and not modified. Copy the input data to a separate buffer if " + "any modifications are necessary." + ) + + return "" + + def _get_error_detail(self) -> list[str]: + """Returns detailed error description used in fix prompt.""" + if self.type not in [ + self.FP_NEAR_INIT_CRASH, + self.FP_TARGET_CRASH, + self.FP_TIMEOUT, + ]: + return [] + + detail = ["Crash stacks:"] + for index, stack in enumerate(self.crash_stacks): + detail.append(f"Stack {index}:") + detail.extend(stack) + return detail + + def get_error_info(self) -> tuple[str, list[str]]: + return self._get_error_desc(), self._get_error_detail() + + @property + def has_err(self) -> bool: + return self.type not in (self.NOT_APPLICABLE, self.NO_SEMANTIC_ERR) + + def to_dict(self): + return { + "has_err": self.has_err, + "err_type": self.type, + "crash_symptom": self.crash_symptom, + "crash_stacks": self.crash_stacks, + "crash_func": self.crash_func, + } diff --git a/experiment/oss_fuzz_checkout.py b/experiment/oss_fuzz_checkout.py index 6c6bb19343..011ebd7ab9 100644 --- a/experiment/oss_fuzz_checkout.py +++ b/experiment/oss_fuzz_checkout.py @@ -30,472 +30,493 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -BUILD_DIR: str = 'build' -GLOBAL_TEMP_DIR: str = '' -ENABLE_CACHING = bool(int(os.getenv('OFG_USE_CACHING', '1'))) +BUILD_DIR: str = "build" +GLOBAL_TEMP_DIR: str = "" +ENABLE_CACHING = bool(int(os.getenv("OFG_USE_CACHING", "1"))) # Assume OSS-Fuzz is at repo root dir by default. # This will change if temp_dir is used. OSS_FUZZ_DIR: str = os.path.join( - os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'oss-fuzz') -CLEAN_UP_OSS_FUZZ = bool(int(os.getenv('OFG_CLEAN_UP_OSS_FUZZ', '1'))) + os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "oss-fuzz" +) +CLEAN_UP_OSS_FUZZ = bool(int(os.getenv("OFG_CLEAN_UP_OSS_FUZZ", "1"))) -VENV_DIR: str = 'venv' +VENV_DIR: str = "venv" def _remove_temp_oss_fuzz_repo(): - """Deletes the temporary OSS-Fuzz directory.""" - # Ensure we aren't deleting a real repo someone cares about. - assert not OSS_FUZZ_DIR.endswith('oss-fuzz') - try: - shutil.rmtree(OSS_FUZZ_DIR) - except PermissionError as e: - logger.warning('No permission to remove %s: %s', OSS_FUZZ_DIR, e) - except FileNotFoundError as e: - logger.warning('No OSS-Fuzz directory %s: %s', OSS_FUZZ_DIR, e) + """Deletes the temporary OSS-Fuzz directory.""" + # Ensure we aren't deleting a real repo someone cares about. + assert not OSS_FUZZ_DIR.endswith("oss-fuzz") + try: + shutil.rmtree(OSS_FUZZ_DIR) + except PermissionError as e: + logger.warning("No permission to remove %s: %s", OSS_FUZZ_DIR, e) + except FileNotFoundError as e: + logger.warning("No OSS-Fuzz directory %s: %s", OSS_FUZZ_DIR, e) def _set_temp_oss_fuzz_repo(): - """Creates a temporary directory for OSS-Fuzz repo and update |OSS_FUZZ_DIR|. - """ - # Holding the temp directory in a global object to ensure it won't be deleted - # before program ends. - global GLOBAL_TEMP_DIR - GLOBAL_TEMP_DIR = tempfile.mkdtemp() - global OSS_FUZZ_DIR - OSS_FUZZ_DIR = GLOBAL_TEMP_DIR - atexit.register(_remove_temp_oss_fuzz_repo) - _clone_oss_fuzz_repo() - - -def _clone_oss_fuzz_repo(): - """Clones OSS-Fuzz to |OSS_FUZZ_DIR|.""" - clone_command = [ - 'git', 'clone', 'https://github.com/google/oss-fuzz', '--depth', '1', - OSS_FUZZ_DIR - ] - proc = sp.Popen(clone_command, - stdout=sp.PIPE, - stderr=sp.PIPE, - stdin=sp.DEVNULL) - stdout, stderr = proc.communicate() - if proc.returncode != 0: - logger.info(stdout) - logger.info(stderr) - - -def clone_oss_fuzz(oss_fuzz_dir: str = ''): - """Clones the OSS-Fuzz repository.""" - if oss_fuzz_dir: + """Creates a temporary directory for OSS-Fuzz repo and update |OSS_FUZZ_DIR|.""" + # Holding the temp directory in a global object to ensure it won't be deleted + # before program ends. + global GLOBAL_TEMP_DIR + GLOBAL_TEMP_DIR = tempfile.mkdtemp() global OSS_FUZZ_DIR - OSS_FUZZ_DIR = oss_fuzz_dir - else: - _set_temp_oss_fuzz_repo() - - if not os.path.exists(OSS_FUZZ_DIR): + OSS_FUZZ_DIR = GLOBAL_TEMP_DIR + atexit.register(_remove_temp_oss_fuzz_repo) _clone_oss_fuzz_repo() - if CLEAN_UP_OSS_FUZZ: - clean_command = ['git', 'clean', '-fxd', '-e', VENV_DIR, '-e', BUILD_DIR] - sp.run(clean_command, - capture_output=True, - stdin=sp.DEVNULL, - check=True, - cwd=OSS_FUZZ_DIR) - - # Sync oss-fuzz data if needed. - if os.environ.get('OSS_FUZZ_DATA_DIR', ''): - src_projects = os.path.join(os.environ['OSS_FUZZ_DATA_DIR'], 'projects') - logger.info('OSS_FUZZ_DATA_DIR: %s', os.environ['OSS_FUZZ_DATA_DIR']) - logger.info('src_projects: %s', src_projects) - for proj in os.listdir(src_projects): - src_project = os.path.join(src_projects, proj) - dst_project = os.path.join(OSS_FUZZ_DIR, 'projects', proj) - logger.info('Copying: %s to %s', src_project, dst_project) - shutil.copytree(src_project, dst_project) + +def _clone_oss_fuzz_repo(): + """Clones OSS-Fuzz to |OSS_FUZZ_DIR|.""" + clone_command = [ + "git", + "clone", + "https://github.com/google/oss-fuzz", + "--depth", + "1", + OSS_FUZZ_DIR, + ] + proc = sp.Popen(clone_command, stdout=sp.PIPE, stderr=sp.PIPE, stdin=sp.DEVNULL) + stdout, stderr = proc.communicate() + if proc.returncode != 0: + logger.info(stdout) + logger.info(stderr) + + +def clone_oss_fuzz(oss_fuzz_dir: str = ""): + """Clones the OSS-Fuzz repository.""" + if oss_fuzz_dir: + global OSS_FUZZ_DIR + OSS_FUZZ_DIR = oss_fuzz_dir + else: + _set_temp_oss_fuzz_repo() + + if not os.path.exists(OSS_FUZZ_DIR): + _clone_oss_fuzz_repo() + + if CLEAN_UP_OSS_FUZZ: + clean_command = ["git", "clean", "-fxd", "-e", VENV_DIR, "-e", BUILD_DIR] + sp.run( + clean_command, + capture_output=True, + stdin=sp.DEVNULL, + check=True, + cwd=OSS_FUZZ_DIR, + ) + + # Sync oss-fuzz data if needed. + if os.environ.get("OSS_FUZZ_DATA_DIR", ""): + src_projects = os.path.join(os.environ["OSS_FUZZ_DATA_DIR"], "projects") + logger.info("OSS_FUZZ_DATA_DIR: %s", os.environ["OSS_FUZZ_DATA_DIR"]) + logger.info("src_projects: %s", src_projects) + for proj in os.listdir(src_projects): + src_project = os.path.join(src_projects, proj) + dst_project = os.path.join(OSS_FUZZ_DIR, "projects", proj) + logger.info("Copying: %s to %s", src_project, dst_project) + shutil.copytree(src_project, dst_project) def postprocess_oss_fuzz() -> None: - """Prepares the oss-fuzz directory for experiments.""" - # Write .gcloudignore to make submitting to GCB faster. - with open(os.path.join(OSS_FUZZ_DIR, '.gcloudignore'), 'w') as f: - f.write('__pycache__\n') - f.write('build\n') - f.write('.git\n') - f.write('.pytest_cache\n') - f.write('venv\n') - - # Set up dependencies to run OSS-Fuzz build scripts - if os.path.exists(os.path.join(OSS_FUZZ_DIR, VENV_DIR)): - return - - # If already in a virtualenv environment assume all is set up - venv_path = os.path.split(os.environ.get('VIRTUAL_ENV', '')) - if venv_path and venv_path[0].endswith(os.path.split(OSS_FUZZ_DIR)[-1]): - return - - result = sp.run(['python3', '-m', 'venv', VENV_DIR], - check=True, - capture_output=True, - stdin=sp.DEVNULL, - cwd=OSS_FUZZ_DIR) - result = sp.run([ - f'./{VENV_DIR}/bin/pip', 'install', '-r', - 'infra/build/functions/requirements.txt' - ], - check=True, - cwd=OSS_FUZZ_DIR, - stdin=sp.DEVNULL, - capture_output=True) - if result.returncode: - logger.info('Failed to postprocess OSS-Fuzz (%s)', OSS_FUZZ_DIR) - logger.info('stdout: %s', result.stdout) - logger.info('stderr: %s', result.stderr) + """Prepares the oss-fuzz directory for experiments.""" + # Write .gcloudignore to make submitting to GCB faster. + with open(os.path.join(OSS_FUZZ_DIR, ".gcloudignore"), "w") as f: + f.write("__pycache__\n") + f.write("build\n") + f.write(".git\n") + f.write(".pytest_cache\n") + f.write("venv\n") + + # Set up dependencies to run OSS-Fuzz build scripts + if os.path.exists(os.path.join(OSS_FUZZ_DIR, VENV_DIR)): + return + + # If already in a virtualenv environment assume all is set up + venv_path = os.path.split(os.environ.get("VIRTUAL_ENV", "")) + if venv_path and venv_path[0].endswith(os.path.split(OSS_FUZZ_DIR)[-1]): + return + + result = sp.run( + ["python3", "-m", "venv", VENV_DIR], + check=True, + capture_output=True, + stdin=sp.DEVNULL, + cwd=OSS_FUZZ_DIR, + ) + result = sp.run( + [ + f"./{VENV_DIR}/bin/pip", + "install", + "-r", + "infra/build/functions/requirements.txt", + ], + check=True, + cwd=OSS_FUZZ_DIR, + stdin=sp.DEVNULL, + capture_output=True, + ) + if result.returncode: + logger.info("Failed to postprocess OSS-Fuzz (%s)", OSS_FUZZ_DIR) + logger.info("stdout: %s", result.stdout) + logger.info("stderr: %s", result.stderr) def list_c_cpp_projects() -> list[str]: - """Returns a list of all c/c++ projects from oss-fuzz.""" - projects = [] - clone_oss_fuzz() - projects_dir = os.path.join(OSS_FUZZ_DIR, 'projects') - for project in os.listdir(projects_dir): - project_yaml_path = os.path.join(projects_dir, project, 'project.yaml') - with open(project_yaml_path) as yaml_file: - config = yaml_file.read() - if 'language: c' in config: - projects.append(project) - return sorted(projects) + """Returns a list of all c/c++ projects from oss-fuzz.""" + projects = [] + clone_oss_fuzz() + projects_dir = os.path.join(OSS_FUZZ_DIR, "projects") + for project in os.listdir(projects_dir): + project_yaml_path = os.path.join(projects_dir, project, "project.yaml") + with open(project_yaml_path) as yaml_file: + config = yaml_file.read() + if "language: c" in config: + projects.append(project) + return sorted(projects) def get_project_language(project: str) -> str: - """Returns the |project| language read from its project.yaml.""" - project_yaml_path = os.path.join(OSS_FUZZ_DIR, 'projects', project, - 'project.yaml') - if not os.path.isfile(project_yaml_path): - logger.warning('Failed to find the project yaml of %s, assuming it is C++', - project) - return 'C++' + """Returns the |project| language read from its project.yaml.""" + project_yaml_path = os.path.join(OSS_FUZZ_DIR, "projects", project, "project.yaml") + if not os.path.isfile(project_yaml_path): + logger.warning( + "Failed to find the project yaml of %s, assuming it is C++", project + ) + return "C++" - with open(project_yaml_path, 'r') as benchmark_file: - data = yaml.safe_load(benchmark_file) - return data.get('language', 'C++') + with open(project_yaml_path, "r") as benchmark_file: + data = yaml.safe_load(benchmark_file) + return data.get("language", "C++") def get_project_repository(project: str) -> str: - """Returns the |project| repository read from its project.yaml.""" - project_yaml_path = os.path.join(OSS_FUZZ_DIR, 'projects', project, - 'project.yaml') - if not os.path.isfile(project_yaml_path): - logger.warning( - 'Failed to find the project yaml of %s, return empty repository', - project) - return '' + """Returns the |project| repository read from its project.yaml.""" + project_yaml_path = os.path.join(OSS_FUZZ_DIR, "projects", project, "project.yaml") + if not os.path.isfile(project_yaml_path): + logger.warning( + "Failed to find the project yaml of %s, return empty repository", project + ) + return "" - with open(project_yaml_path, 'r') as benchmark_file: - data = yaml.safe_load(benchmark_file) - return data.get('main_repo', '') + with open(project_yaml_path, "r") as benchmark_file: + data = yaml.safe_load(benchmark_file) + return data.get("main_repo", "") def _get_project_cache_name(project: str) -> str: - """Gets name of cached container for a project.""" - return f'gcr.io.oss-fuzz.{project}_cache' + """Gets name of cached container for a project.""" + return f"gcr.io.oss-fuzz.{project}_cache" def _get_project_cache_image_name(project: str, sanitizer: str) -> str: - """Gets name of cached Docker image for a project and a respective - sanitizer.""" - return ('us-central1-docker.pkg.dev/oss-fuzz/oss-fuzz-gen/' - f'{project}-ofg-cached-{sanitizer}') + """Gets name of cached Docker image for a project and a respective + sanitizer.""" + return ( + "us-central1-docker.pkg.dev/oss-fuzz/oss-fuzz-gen/" + f"{project}-ofg-cached-{sanitizer}" + ) def _has_cache_build_script(project: str) -> bool: - """Checks if a project has cached fuzzer build script.""" - cached_build_script = os.path.join('fuzzer_build_script', project) - return os.path.isfile(cached_build_script) + """Checks if a project has cached fuzzer build script.""" + cached_build_script = os.path.join("fuzzer_build_script", project) + return os.path.isfile(cached_build_script) def _prepare_image_cache(project: str) -> bool: - """Prepares cached images of fuzzer build containers.""" - # Only create a cached image if we have a post-build build script - if not _has_cache_build_script(project): - logger.info('No cached script for %s', project) - return False - logger.info('%s has a cached build script', project) - - cached_container_name = _get_project_cache_name(project) - adjusted_env = os.environ | { - 'OSS_FUZZ_SAVE_CONTAINERS_NAME': cached_container_name - } - - logger.info('Creating a cached images') - for sanitizer in ['address', 'coverage']: - if is_image_cached(project, sanitizer): - logger.info('%s::%s is already cached, reusing existing cache.', project, - sanitizer) - continue - - # Pull the cache first - pull_cmd = [ - 'docker', 'pull', - _get_project_cache_image_name(project, sanitizer) - ] - try: - sp.run(pull_cmd, check=True) - logger.info('Successfully pulled cache image for %s', project) - except sp.CalledProcessError: - logger.info('Failed pulling image for %s', project) + """Prepares cached images of fuzzer build containers.""" + # Only create a cached image if we have a post-build build script + if not _has_cache_build_script(project): + logger.info("No cached script for %s", project) + return False + logger.info("%s has a cached build script", project) + + cached_container_name = _get_project_cache_name(project) + adjusted_env = os.environ | {"OSS_FUZZ_SAVE_CONTAINERS_NAME": cached_container_name} + + logger.info("Creating a cached images") + for sanitizer in ["address", "coverage"]: + if is_image_cached(project, sanitizer): + logger.info( + "%s::%s is already cached, reusing existing cache.", project, sanitizer + ) + continue + + # Pull the cache first + pull_cmd = ["docker", "pull", _get_project_cache_image_name(project, sanitizer)] + try: + sp.run(pull_cmd, check=True) + logger.info("Successfully pulled cache image for %s", project) + except sp.CalledProcessError: + logger.info("Failed pulling image for %s", project) + + if is_image_cached(project, sanitizer): + logger.info("pulled image for %s::%s", project, sanitizer) + continue + + # If pull did not work, create cached image by building using OSS-Fuzz + # with set variable. Fail if this does not work. + command = [ + "python3", + "infra/helper.py", + "build_fuzzers", + project, + "--sanitizer", + sanitizer, + ] + try: + sp.run(command, cwd=OSS_FUZZ_DIR, env=adjusted_env, check=True) + except sp.CalledProcessError: + logger.info("Failed to build fuzzer for %s.", project) + return False + + # Commit the container to an image + cached_image_name = _get_project_cache_image_name(project, sanitizer) + + command = ["docker", "commit", cached_container_name, cached_image_name] + try: + sp.run(command, check=True) + except sp.CalledProcessError: + logger.info("Could not rename image.") + return False + logger.info("Created cached image %s", cached_image_name) + + # Delete the container we created + command = ["docker", "container", "rm", cached_container_name] + try: + sp.run(command, check=True) + except sp.CalledProcessError: + logger.info("Could not rename image.") + return True - if is_image_cached(project, sanitizer): - logger.info('pulled image for %s::%s', project, sanitizer) - continue - # If pull did not work, create cached image by building using OSS-Fuzz - # with set variable. Fail if this does not work. - command = [ - 'python3', 'infra/helper.py', 'build_fuzzers', project, '--sanitizer', - sanitizer - ] - try: - sp.run(command, cwd=OSS_FUZZ_DIR, env=adjusted_env, check=True) - except sp.CalledProcessError: - logger.info('Failed to build fuzzer for %s.', project) - return False +def prepare_cached_images(experiment_targets: list[benchmarklib.Benchmark]) -> None: + """Builds cached Docker images for a set of targets.""" + all_projects = set() + for benchmark in experiment_targets: + all_projects.add(benchmark.project) - # Commit the container to an image - cached_image_name = _get_project_cache_image_name(project, sanitizer) + logger.info("Preparing cache for %d projects", len(all_projects)) + + for project in all_projects: + _prepare_image_cache(project) - command = ['docker', 'commit', cached_container_name, cached_image_name] - try: - sp.run(command, check=True) - except sp.CalledProcessError: - logger.info('Could not rename image.') - return False - logger.info('Created cached image %s', cached_image_name) - # Delete the container we created - command = ['docker', 'container', 'rm', cached_container_name] +def is_image_cached(project_name: str, sanitizer: str) -> bool: + """Checks whether a project has a cached Docker image post fuzzer + building.""" + cached_image_name = _get_project_cache_image_name(project_name, sanitizer) try: - sp.run(command, check=True) + sp.run( + ["docker", "manifest", "inspect", cached_image_name], + check=True, + stdin=sp.DEVNULL, + stdout=sp.DEVNULL, + stderr=sp.STDOUT, + ) + return True except sp.CalledProcessError: - logger.info('Could not rename image.') - return True - - -def prepare_cached_images( - experiment_targets: list[benchmarklib.Benchmark]) -> None: - """Builds cached Docker images for a set of targets.""" - all_projects = set() - for benchmark in experiment_targets: - all_projects.add(benchmark.project) - - logger.info('Preparing cache for %d projects', len(all_projects)) + return False - for project in all_projects: - _prepare_image_cache(project) +def rewrite_project_to_cached_project( + project_name: str, generated_project: str, sanitizer: str +) -> None: + """Rewrites Dockerfile of a project to enable cached build scripts.""" + cached_image_name = _get_project_cache_image_name(project_name, sanitizer) + generated_project_folder = os.path.join(OSS_FUZZ_DIR, "projects", generated_project) -def is_image_cached(project_name: str, sanitizer: str) -> bool: - """Checks whether a project has a cached Docker image post fuzzer - building.""" - cached_image_name = _get_project_cache_image_name(project_name, sanitizer) - try: - sp.run( - ['docker', 'manifest', 'inspect', cached_image_name], - check=True, - stdin=sp.DEVNULL, - stdout=sp.DEVNULL, - stderr=sp.STDOUT, + cached_dockerfile = os.path.join( + generated_project_folder, f"Dockerfile_{sanitizer}_cached" + ) + if os.path.isfile(cached_dockerfile): + logger.info("Already converted") + return + + # Check if there is an original Dockerfile, because we should use that in + # case,as otherwise the "Dockerfile" may be a copy of another sanitizer. + original_dockerfile = os.path.join(generated_project_folder, "Dockerfile_original") + if not os.path.isfile(original_dockerfile): + dockerfile = os.path.join(generated_project_folder, "Dockerfile") + shutil.copy(dockerfile, original_dockerfile) + + with open(original_dockerfile, "r") as f: + docker_content = f.read() + + arg_line = "ARG CACHE_IMAGE=" + cached_image_name + docker_content = arg_line + "\n" + docker_content + docker_content = re.sub( + r"FROM gcr.io/oss-fuzz-base/base-builder.*", "FROM $CACHE_IMAGE", docker_content ) - return True - except sp.CalledProcessError: - return False - - -def rewrite_project_to_cached_project(project_name: str, generated_project: str, - sanitizer: str) -> None: - """Rewrites Dockerfile of a project to enable cached build scripts.""" - cached_image_name = _get_project_cache_image_name(project_name, sanitizer) - generated_project_folder = os.path.join(OSS_FUZZ_DIR, 'projects', - generated_project) - - cached_dockerfile = os.path.join(generated_project_folder, - f'Dockerfile_{sanitizer}_cached') - if os.path.isfile(cached_dockerfile): - logger.info('Already converted') - return - - # Check if there is an original Dockerfile, because we should use that in - # case,as otherwise the "Dockerfile" may be a copy of another sanitizer. - original_dockerfile = os.path.join(generated_project_folder, - 'Dockerfile_original') - if not os.path.isfile(original_dockerfile): - dockerfile = os.path.join(generated_project_folder, 'Dockerfile') - shutil.copy(dockerfile, original_dockerfile) - - with open(original_dockerfile, 'r') as f: - docker_content = f.read() - - arg_line = 'ARG CACHE_IMAGE=' + cached_image_name - docker_content = arg_line + '\n' + docker_content - docker_content = re.sub(r'FROM gcr.io/oss-fuzz-base/base-builder.*', - 'FROM $CACHE_IMAGE', docker_content) - - # Now comment out everything except: - # - The first FROM. - # - The ARG we just added. - # - The last 2 COPY commands (for the build script and the target we added). - arg_line = -1 - from_line = -1 - copy_fuzzer_line = -1 - copy_build_line = -1 - - for line_idx, line in enumerate(docker_content.split('\n')): - if line.startswith('ARG') and arg_line == -1: - arg_line = line_idx - if line.startswith('FROM') and from_line == -1: - from_line = line_idx - if line.startswith('COPY'): - copy_fuzzer_line = copy_build_line - copy_build_line = line_idx - - lines_to_keep = {arg_line, from_line, copy_fuzzer_line, copy_build_line} - new_content = '' - for line_idx, line in enumerate(docker_content.split('\n')): - if line_idx not in lines_to_keep: - new_content += f'# {line}\n' - else: - new_content += f'{line}\n' - # Overwrite the existing one - with open(cached_dockerfile, 'w') as f: - f.write(new_content) + # Now comment out everything except: + # - The first FROM. + # - The ARG we just added. + # - The last 2 COPY commands (for the build script and the target we added). + arg_line = -1 + from_line = -1 + copy_fuzzer_line = -1 + copy_build_line = -1 + + for line_idx, line in enumerate(docker_content.split("\n")): + if line.startswith("ARG") and arg_line == -1: + arg_line = line_idx + if line.startswith("FROM") and from_line == -1: + from_line = line_idx + if line.startswith("COPY"): + copy_fuzzer_line = copy_build_line + copy_build_line = line_idx + + lines_to_keep = {arg_line, from_line, copy_fuzzer_line, copy_build_line} + new_content = "" + for line_idx, line in enumerate(docker_content.split("\n")): + if line_idx not in lines_to_keep: + new_content += f"# {line}\n" + else: + new_content += f"{line}\n" + + # Overwrite the existing one + with open(cached_dockerfile, "w") as f: + f.write(new_content) def prepare_build(project_name, sanitizer, generated_project): - """Prepares the correct Dockerfile to be used for cached builds.""" - generated_project_folder = os.path.join(OSS_FUZZ_DIR, 'projects', - generated_project) - if not ENABLE_CACHING: - return - dockerfile_to_use = os.path.join(generated_project_folder, 'Dockerfile') - original_dockerfile = os.path.join(generated_project_folder, - 'Dockerfile_original') - if is_image_cached(project_name, sanitizer): - logger.info('Using cached dockerfile') - cached_dockerfile = os.path.join(generated_project_folder, - f'Dockerfile_{sanitizer}_cached') - shutil.copy(cached_dockerfile, dockerfile_to_use) - else: - logger.info('Using original dockerfile') - shutil.copy(original_dockerfile, dockerfile_to_use) + """Prepares the correct Dockerfile to be used for cached builds.""" + generated_project_folder = os.path.join(OSS_FUZZ_DIR, "projects", generated_project) + if not ENABLE_CACHING: + return + dockerfile_to_use = os.path.join(generated_project_folder, "Dockerfile") + original_dockerfile = os.path.join(generated_project_folder, "Dockerfile_original") + if is_image_cached(project_name, sanitizer): + logger.info("Using cached dockerfile") + cached_dockerfile = os.path.join( + generated_project_folder, f"Dockerfile_{sanitizer}_cached" + ) + shutil.copy(cached_dockerfile, dockerfile_to_use) + else: + logger.info("Using original dockerfile") + shutil.copy(original_dockerfile, dockerfile_to_use) def _build_image(project_name: str) -> str: - """Builds project image in OSS-Fuzz""" - adjusted_env = os.environ | { - 'FUZZING_LANGUAGE': get_project_language(project_name) - } - command = [ - 'python3', 'infra/helper.py', 'build_image', '--pull', project_name - ] - try: - sp.run(command, - cwd=OSS_FUZZ_DIR, - env=adjusted_env, - stdout=sp.PIPE, - stderr=sp.PIPE, - check=True) - logger.info('Successfully build project image for %s', project_name) - return f'gcr.io/oss-fuzz/{project_name}' - except sp.CalledProcessError as e: - logger.error('Failed to build project image for %s: %s', project_name, - e.stderr.decode('utf-8')) - return '' + """Builds project image in OSS-Fuzz""" + adjusted_env = os.environ | {"FUZZING_LANGUAGE": get_project_language(project_name)} + command = ["python3", "infra/helper.py", "build_image", "--pull", project_name] + try: + sp.run( + command, + cwd=OSS_FUZZ_DIR, + env=adjusted_env, + stdout=sp.PIPE, + stderr=sp.PIPE, + check=True, + ) + logger.info("Successfully build project image for %s", project_name) + return f"gcr.io/oss-fuzz/{project_name}" + except sp.CalledProcessError as e: + logger.error( + "Failed to build project image for %s: %s", + project_name, + e.stderr.decode("utf-8"), + ) + return "" def rectify_docker_tag(docker_tag: str) -> str: - # Replace "::" and any character not \w, _, or . with "-". - valid_docker_tag = re.sub(r'::', '-', docker_tag) - valid_docker_tag = re.sub(r'[^\w_.]', '-', valid_docker_tag) - # Docker fails with tags containing -_ or _-. - valid_docker_tag = re.sub(r'[-_]{2,}', '-', valid_docker_tag) - return valid_docker_tag - - -def create_ossfuzz_project(benchmark: benchmarklib.Benchmark, - generated_oss_fuzz_project: str) -> str: - """Creates an OSS-Fuzz project by replicating an existing project.""" - generated_project_path = os.path.join(OSS_FUZZ_DIR, 'projects', - generated_oss_fuzz_project) - if os.path.exists(generated_project_path): - logger.info('Project %s already exists.', generated_project_path) - return generated_project_path + # Replace "::" and any character not \w, _, or . with "-". + valid_docker_tag = re.sub(r"::", "-", docker_tag) + valid_docker_tag = re.sub(r"[^\w_.]", "-", valid_docker_tag) + # Docker fails with tags containing -_ or _-. + valid_docker_tag = re.sub(r"[-_]{2,}", "-", valid_docker_tag) + return valid_docker_tag + + +def create_ossfuzz_project( + benchmark: benchmarklib.Benchmark, generated_oss_fuzz_project: str +) -> str: + """Creates an OSS-Fuzz project by replicating an existing project.""" + generated_project_path = os.path.join( + OSS_FUZZ_DIR, "projects", generated_oss_fuzz_project + ) + if os.path.exists(generated_project_path): + logger.info("Project %s already exists.", generated_project_path) + return generated_project_path - oss_fuzz_project_path = os.path.join(OSS_FUZZ_DIR, 'projects', - benchmark.project) - shutil.copytree(oss_fuzz_project_path, generated_project_path) - return generated_project_path + oss_fuzz_project_path = os.path.join(OSS_FUZZ_DIR, "projects", benchmark.project) + shutil.copytree(oss_fuzz_project_path, generated_project_path) + return generated_project_path def prepare_project_image(benchmark: benchmarklib.Benchmark) -> str: - """Prepares original image of the |project|'s fuzz target build container.""" - project = benchmark.project - image_name = f'gcr.io/oss-fuzz/{project}' - generated_oss_fuzz_project = f'{benchmark.id}-{uuid.uuid4().hex}' - generated_oss_fuzz_project = rectify_docker_tag(generated_oss_fuzz_project) - create_ossfuzz_project(benchmark, generated_oss_fuzz_project) - - if not ENABLE_CACHING: - logger.warning('Disabled caching when building image for %s', project) - elif is_image_cached(project, 'address'): - logger.info('Will use cached instance.') - # Rewrite for caching. - rewrite_project_to_cached_project(project, generated_oss_fuzz_project, - 'address') - # Prepare build - prepare_build(project, 'address', generated_oss_fuzz_project) - # Build the image - logger.info('Using cached project image for %s: %s', - generated_oss_fuzz_project, image_name) - else: - logger.warning('Unable to find cached project image for %s', project) - return _build_image(generated_oss_fuzz_project) - - -def create_ossfuzz_project_by_name(original_name: str, - generated_oss_fuzz_project: str) -> str: - """Creates an OSS-Fuzz project by replicating an existing project.""" - generated_project_path = os.path.join(OSS_FUZZ_DIR, 'projects', - generated_oss_fuzz_project) - if os.path.exists(generated_project_path): - logger.info('Project %s already exists.', generated_project_path) - return generated_project_path + """Prepares original image of the |project|'s fuzz target build container.""" + project = benchmark.project + image_name = f"gcr.io/oss-fuzz/{project}" + generated_oss_fuzz_project = f"{benchmark.id}-{uuid.uuid4().hex}" + generated_oss_fuzz_project = rectify_docker_tag(generated_oss_fuzz_project) + create_ossfuzz_project(benchmark, generated_oss_fuzz_project) + + if not ENABLE_CACHING: + logger.warning("Disabled caching when building image for %s", project) + elif is_image_cached(project, "address"): + logger.info("Will use cached instance.") + # Rewrite for caching. + rewrite_project_to_cached_project( + project, generated_oss_fuzz_project, "address" + ) + # Prepare build + prepare_build(project, "address", generated_oss_fuzz_project) + # Build the image + logger.info( + "Using cached project image for %s: %s", + generated_oss_fuzz_project, + image_name, + ) + else: + logger.warning("Unable to find cached project image for %s", project) + return _build_image(generated_oss_fuzz_project) + + +def create_ossfuzz_project_by_name( + original_name: str, generated_oss_fuzz_project: str +) -> str: + """Creates an OSS-Fuzz project by replicating an existing project.""" + generated_project_path = os.path.join( + OSS_FUZZ_DIR, "projects", generated_oss_fuzz_project + ) + if os.path.exists(generated_project_path): + logger.info("Project %s already exists.", generated_project_path) + return generated_project_path - oss_fuzz_project_path = os.path.join(OSS_FUZZ_DIR, 'projects', original_name) - shutil.copytree(oss_fuzz_project_path, generated_project_path) - return generated_project_path + oss_fuzz_project_path = os.path.join(OSS_FUZZ_DIR, "projects", original_name) + shutil.copytree(oss_fuzz_project_path, generated_project_path) + return generated_project_path def prepare_project_image_by_name(project_name: str) -> str: - """Prepares original image of the |project_name|'s fuzz target build - container.""" - project = project_name - image_name = f'gcr.io/oss-fuzz/{project}' - generated_oss_fuzz_project = f'{project_name}-{uuid.uuid4().hex}' - generated_oss_fuzz_project = rectify_docker_tag(generated_oss_fuzz_project) - create_ossfuzz_project_by_name(project, generated_oss_fuzz_project) - - if not ENABLE_CACHING: - logger.warning('Disabled caching when building image for %s', project) - elif is_image_cached(project, 'address'): - logger.info('Will use cached instance.') - # Rewrite for caching. - rewrite_project_to_cached_project(project, generated_oss_fuzz_project, - 'address') - # Prepare build - prepare_build(project, 'address', generated_oss_fuzz_project) - # Build the image - logger.info('Using cached project image for %s: %s', - generated_oss_fuzz_project, image_name) - else: - logger.warning('Unable to find cached project image for %s', project) - return _build_image(generated_oss_fuzz_project) + """Prepares original image of the |project_name|'s fuzz target build + container.""" + project = project_name + image_name = f"gcr.io/oss-fuzz/{project}" + generated_oss_fuzz_project = f"{project_name}-{uuid.uuid4().hex}" + generated_oss_fuzz_project = rectify_docker_tag(generated_oss_fuzz_project) + create_ossfuzz_project_by_name(project, generated_oss_fuzz_project) + + if not ENABLE_CACHING: + logger.warning("Disabled caching when building image for %s", project) + elif is_image_cached(project, "address"): + logger.info("Will use cached instance.") + # Rewrite for caching. + rewrite_project_to_cached_project( + project, generated_oss_fuzz_project, "address" + ) + # Prepare build + prepare_build(project, "address", generated_oss_fuzz_project) + # Build the image + logger.info( + "Using cached project image for %s: %s", + generated_oss_fuzz_project, + image_name, + ) + else: + logger.warning("Unable to find cached project image for %s", project) + return _build_image(generated_oss_fuzz_project) diff --git a/experiment/textcov.py b/experiment/textcov.py index 0fc2f531ab..2fcc81ef64 100644 --- a/experiment/textcov.py +++ b/experiment/textcov.py @@ -29,547 +29,562 @@ logger = logging.getLogger(__name__) # No spaces at the beginning, and ends with a ":". -FUNCTION_PATTERN = re.compile(r'^([^\s].*):$') -LINE_PATTERN = re.compile(r'^\s*\d+\|\s*([\d\.a-zA-Z]+)\|(.*)') +FUNCTION_PATTERN = re.compile(r"^([^\s].*):$") +LINE_PATTERN = re.compile(r"^\s*\d+\|\s*([\d\.a-zA-Z]+)\|(.*)") JVM_CLASS_MAPPING = { - 'Z': 'boolean', - 'B': 'byte', - 'C': 'char', - 'D': 'double', - 'F': 'float', - 'I': 'int', - 'J': 'long', - 'S': 'short' + "Z": "boolean", + "B": "byte", + "C": "char", + "D": "double", + "F": "float", + "I": "int", + "J": "long", + "S": "short", } -JVM_SKIPPED_METHOD = [ - 'fuzzerTestOneInput', 'fuzzerInitialize', 'fuzzerTearDown' -] +JVM_SKIPPED_METHOD = ["fuzzerTestOneInput", "fuzzerInitialize", "fuzzerTearDown"] def demangle(data: str) -> str: - """Demangles a string containing mangled C++ symbols.""" - return subprocess.check_output(['c++filt'], input=data, encoding='utf-8') + """Demangles a string containing mangled C++ symbols.""" + return subprocess.check_output(["c++filt"], input=data, encoding="utf-8") def _discard_fuzz_target_lines(covreport_content: str) -> str: - """Removes fuzz target lines from the coverage report.""" - # When comparing project code coverage contributed by fuzz targets, it's - # fairer to only consider lines in the project and not the code of targets. - # Assumption 1: llvm-cov separates lines from different files with an empty - # line by default in the coverage report. - # Assumption 2: All and only fuzz targets contain - # 'LLVMFuzzerTestOneInput'(C/C++) or 'fuzz_target' (Rust). - project_file_contents = [ - sec for sec in covreport_content.split('\n\n') - if 'LLVMFuzzerTestOneInput' not in sec or 'fuzz_target' not in sec - ] - return '\n\n'.join(project_file_contents) + """Removes fuzz target lines from the coverage report.""" + # When comparing project code coverage contributed by fuzz targets, it's + # fairer to only consider lines in the project and not the code of targets. + # Assumption 1: llvm-cov separates lines from different files with an empty + # line by default in the coverage report. + # Assumption 2: All and only fuzz targets contain + # 'LLVMFuzzerTestOneInput'(C/C++) or 'fuzz_target' (Rust). + project_file_contents = [ + sec + for sec in covreport_content.split("\n\n") + if "LLVMFuzzerTestOneInput" not in sec or "fuzz_target" not in sec + ] + return "\n\n".join(project_file_contents) def normalize_template_args(name: str) -> str: - """Normalizes template arguments.""" - return re.sub(r'<.*>', '<>', name) + """Normalizes template arguments.""" + return re.sub(r"<.*>", "<>", name) def _parse_hitcount(data: str) -> float: - """Parse a hitcount.""" - # From https://github.com/llvm/llvm-project/blob/3f3620e5c9ee0f7b64afc39e5a26c6f4cc5e7b37/llvm/tools/llvm-cov/SourceCoverageView.cpp#L102 - multipliers = { - 'k': 1000, - 'M': 1000000, - 'G': 1000000000, - 'T': 1000000000000, - 'P': 1000000000000000, - 'E': 1000000000000000000, - 'Z': 1000000000000000000000, - 'Y': 1000000000000000000000000, - } - - if data[-1].isdigit(): - # Simple number < 1000. - return int(data) - - if data[-1] in multipliers: - # E.g. "11.4k" - return float(data[:-1]) * multipliers[data[-1]] - - raise ValueError(f'Suffix {data[-1]} is not supported') + """Parse a hitcount.""" + # From https://github.com/llvm/llvm-project/blob/3f3620e5c9ee0f7b64afc39e5a26c6f4cc5e7b37/llvm/tools/llvm-cov/SourceCoverageView.cpp#L102 + multipliers = { + "k": 1000, + "M": 1000000, + "G": 1000000000, + "T": 1000000000000, + "P": 1000000000000000, + "E": 1000000000000000000, + "Z": 1000000000000000000000, + "Y": 1000000000000000000000000, + } + + if data[-1].isdigit(): + # Simple number < 1000. + return int(data) + + if data[-1] in multipliers: + # E.g. "11.4k" + return float(data[:-1]) * multipliers[data[-1]] + + raise ValueError(f"Suffix {data[-1]} is not supported") @dataclasses.dataclass class Line: - """Represents a line.""" - contents: str = '' - hit_count: float = 0 + """Represents a line.""" + + contents: str = "" + hit_count: float = 0 @dataclasses.dataclass class Function: - """Represents a function in a textcov.""" - name: str = '' - # Line contents -> Line object. We key on line contents to account for - # potential line number movements. - lines: dict[str, Line] = dataclasses.field(default_factory=dict) - - def merge(self, other: Function): - for line in other.lines.values(): - if line.contents in self.lines: - self.lines[line.contents].hit_count += line.hit_count - else: - self.lines[line.contents] = Line(contents=line.contents, - hit_count=line.hit_count) - - @property - def covered_lines(self): - return sum(1 for l in self.lines.values() if l.hit_count > 0) - - def subtract_covered_lines(self, other: Function, language: str = 'c++'): - """Subtract covered lines.""" - - if language == 'jvm': - for line_no, line in self.lines.items(): - other_line = other.lines.get(line_no) - if other_line and other_line.hit_count > 0: - line.hit_count = 0 - else: - # For our analysis purposes, we completely delete any lines that are - # hit by the other, rather than subtracting hitcounts. - for line in other.lines.values(): - if line.hit_count and line.contents in self.lines: - del self.lines[line.contents] + """Represents a function in a textcov.""" + + name: str = "" + # Line contents -> Line object. We key on line contents to account for + # potential line number movements. + lines: dict[str, Line] = dataclasses.field(default_factory=dict) + + def merge(self, other: Function): + for line in other.lines.values(): + if line.contents in self.lines: + self.lines[line.contents].hit_count += line.hit_count + else: + self.lines[line.contents] = Line( + contents=line.contents, hit_count=line.hit_count + ) + + @property + def covered_lines(self): + return sum(1 for l in self.lines.values() if l.hit_count > 0) + + def subtract_covered_lines(self, other: Function, language: str = "c++"): + """Subtract covered lines.""" + + if language == "jvm": + for line_no, line in self.lines.items(): + other_line = other.lines.get(line_no) + if other_line and other_line.hit_count > 0: + line.hit_count = 0 + else: + # For our analysis purposes, we completely delete any lines that are + # hit by the other, rather than subtracting hitcounts. + for line in other.lines.values(): + if line.hit_count and line.contents in self.lines: + del self.lines[line.contents] @dataclasses.dataclass class File: - """Represents a file in a textcov, only for Python.""" - name: str = '' - # Line contents -> Line object. We key on line contents to account for - # potential line number movements. - lines: dict[str, Line] = dataclasses.field(default_factory=dict) + """Represents a file in a textcov, only for Python.""" - def merge(self, other: File): - for line in other.lines.values(): - if line.contents in self.lines: - self.lines[line.contents].hit_count += line.hit_count - else: - self.lines[line.contents] = Line(contents=line.contents, - hit_count=line.hit_count) + name: str = "" + # Line contents -> Line object. We key on line contents to account for + # potential line number movements. + lines: dict[str, Line] = dataclasses.field(default_factory=dict) - @property - def covered_lines(self): - return sum(1 for l in self.lines.values() if l.hit_count > 0) + def merge(self, other: File): + for line in other.lines.values(): + if line.contents in self.lines: + self.lines[line.contents].hit_count += line.hit_count + else: + self.lines[line.contents] = Line( + contents=line.contents, hit_count=line.hit_count + ) - def subtract_covered_lines(self, other: File): - """Subtract covered lines.""" + @property + def covered_lines(self): + return sum(1 for l in self.lines.values() if l.hit_count > 0) - for line_no, line in self.lines.items(): - other_line = other.lines.get(line_no) - if other_line and other_line.hit_count > 0: - line.hit_count = 0 + def subtract_covered_lines(self, other: File): + """Subtract covered lines.""" + + for line_no, line in self.lines.items(): + other_line = other.lines.get(line_no) + if other_line and other_line.hit_count > 0: + line.hit_count = 0 @dataclasses.dataclass class Textcov: - """Textcov.""" - # Function name -> Function object. - # For JVM / C / C++ / Rust - functions: dict[str, Function] = dataclasses.field(default_factory=dict) - # File name -> File object. - # For Python - files: dict[str, File] = dataclasses.field(default_factory=dict) - language: str = 'c++' - - @classmethod - def _read_file_with_fallback(cls, - file_handle: BinaryIO, - sample_size: int = 1000) -> str: - """Reads file_handle assuming its encoding is utf-8, detects the encoding - if otherwise.""" - file_content = file_handle.read() - - try: - # Try decoding the file content with UTF-8 encoding - return file_content.decode('utf-8') - except UnicodeDecodeError: - # If UTF-8 decoding fails, detect the file's encoding - raw_data = file_content[:sample_size] - result = chardet.detect(raw_data) - encoding = result['encoding'] - if encoding is None: - logger.warning('Failed to decode.') - raise UnicodeDecodeError("chardet", raw_data, 0, len(raw_data), - "Cannot detect encoding") - - # Decode the file content with the detected encoding - return file_content.decode(encoding) - - @classmethod - def from_file( - cls, - file_handle, - ignore_function_patterns: Optional[List[re.Pattern]] = None) -> Textcov: - """Read a textcov from a file handle.""" - if ignore_function_patterns is None: - ignore_function_patterns = [] - - textcov = cls() - textcov.language = 'c++' - - current_function_name: str = '' - current_function: Function = Function() - try: - demangled = demangle(cls._read_file_with_fallback(file_handle)) - except Exception as e: - logger.warning('Decoding failure: %s', e) - demangled = '' - demangled = _discard_fuzz_target_lines(demangled) - - for line in demangled.split('\n'): - match = FUNCTION_PATTERN.match(line) - if match: - # Normalize templates. - current_function_name = normalize_template_args(match.group(1)) - if any( - p.match(current_function_name) for p in ignore_function_patterns): - # Ignore this function. - current_function_name = '' - continue - - if current_function_name in textcov.functions: - current_function = textcov.functions[current_function_name] + """Textcov.""" + + # Function name -> Function object. + # For JVM / C / C++ / Rust + functions: dict[str, Function] = dataclasses.field(default_factory=dict) + # File name -> File object. + # For Python + files: dict[str, File] = dataclasses.field(default_factory=dict) + language: str = "c++" + + @classmethod + def _read_file_with_fallback( + cls, file_handle: BinaryIO, sample_size: int = 1000 + ) -> str: + """Reads file_handle assuming its encoding is utf-8, detects the encoding + if otherwise.""" + file_content = file_handle.read() + + try: + # Try decoding the file content with UTF-8 encoding + return file_content.decode("utf-8") + except UnicodeDecodeError: + # If UTF-8 decoding fails, detect the file's encoding + raw_data = file_content[:sample_size] + result = chardet.detect(raw_data) + encoding = result["encoding"] + if encoding is None: + logger.warning("Failed to decode.") + raise UnicodeDecodeError( + "chardet", raw_data, 0, len(raw_data), "Cannot detect encoding" + ) + + # Decode the file content with the detected encoding + return file_content.decode(encoding) + + @classmethod + def from_file( + cls, file_handle, ignore_function_patterns: Optional[List[re.Pattern]] = None + ) -> Textcov: + """Read a textcov from a file handle.""" + if ignore_function_patterns is None: + ignore_function_patterns = [] + + textcov = cls() + textcov.language = "c++" + + current_function_name: str = "" + current_function: Function = Function() + try: + demangled = demangle(cls._read_file_with_fallback(file_handle)) + except Exception as e: + logger.warning("Decoding failure: %s", e) + demangled = "" + demangled = _discard_fuzz_target_lines(demangled) + + for line in demangled.split("\n"): + match = FUNCTION_PATTERN.match(line) + if match: + # Normalize templates. + current_function_name = normalize_template_args(match.group(1)) + if any( + p.match(current_function_name) for p in ignore_function_patterns + ): + # Ignore this function. + current_function_name = "" + continue + + if current_function_name in textcov.functions: + current_function = textcov.functions[current_function_name] + else: + current_function = Function(name=current_function_name) + textcov.functions[current_function_name] = current_function + + continue + + if not current_function_name: + # No current functions. This can happen if we're currently in an + # ignored function. + continue + + match = LINE_PATTERN.match(line) + if match: + hit_count = _parse_hitcount(match.group(1)) + # Ignore whitespace differences + line_contents = match.group(2).strip() + + if line_contents in current_function.lines: + current_function.lines[line_contents].hit_count += hit_count + else: + current_function.lines[line_contents] = Line( + contents=line_contents, hit_count=hit_count + ) + + continue + return textcov + + @classmethod + def from_python_file(cls, file_handle) -> Textcov: + """Read a textcov from a all_cov.json file for python.""" + textcov = cls() + textcov.language = "python" + coverage_report = json.load(file_handle) + + # Process coverage report file by file + for file, data in coverage_report.get("files", {}).items(): + # Retrieve pure file and directory name + filename = file.replace("/pythoncovmergedfiles/medio/medio/", "") + filename = filename.split("site-packages/", 1)[-1] + current_file = File(name=filename) + + # Process line coverage information + covered_lines = data.get("executed_lines", []) + missed_lines = data.get("missing_lines", []) + for line_no in covered_lines: + line = f"Line{line_no}" + current_file.lines[line] = Line(contents=line, hit_count=1) + for line_no in missed_lines: + line = f"Line{line_no}" + current_file.lines[line] = Line(contents=line, hit_count=0) + + textcov.files[filename] = current_file + + return textcov + + @classmethod + def from_jvm_file(cls, file_handle) -> Textcov: + """Read a textcov from a jacoco.xml file.""" + textcov = cls() + textcov.language = "jvm" + jacoco_report = ET.parse(file_handle) + + # Process source file information + line_coverage_dict = {} + for item in jacoco_report.iter(): + if item.tag == "sourcefile": + line_coverage = [] + for line_item in item: + if line_item.tag == "line": + line_no = int(line_item.attrib["nr"]) + if line_item.attrib["mi"] == "0": + line_coverage.append((line_no, True)) + else: + line_coverage.append((line_no, False)) + line_coverage_dict[item.attrib["name"]] = line_coverage + + # Process methods + class_method_items = [] + for item in jacoco_report.iter(): + if item.tag == "class": + # Skip fuzzer classes + if textcov.is_fuzzer_class(item): + continue + + # Get line coverage information for this class + sourcefilename = item.attrib.get("sourcefilename") + if not sourcefilename: + # Fail safe for invalid jacoco.xml with no sourcefilename + continue + coverage = line_coverage_dict.get(sourcefilename, []) + + # Get class name and skip fuzzing and testing classes + class_name = item.attrib.get("name", "").replace("/", ".") + if ( + not class_name + or "test" in class_name.lower() + or "fuzzer" in class_name.lower() + ): + continue + + for method_item in item: + if method_item.tag == "method": + if method_item.attrib["name"] not in JVM_SKIPPED_METHOD: + class_method_items.append( + (class_name, method_item, coverage) + ) + + for class_name, method_item, coverage in class_method_items: + method_dict = method_item.attrib + method_name = method_dict["name"] + + # Determine start index in coverage list + start_line = int(method_dict.get("line", "-1")) + start_index = -1 + for count, item in enumerate(coverage): + if item[0] == start_line: + start_index = count + break + + # Failed to retrieve coverage information, skipping this method + if start_index == -1: + continue + + # Process all arguments type from shortern Java Class naming + args = textcov.determine_jvm_arguments_type(method_dict["desc"]) + + # Save method + full_method_name = f'[{class_name}].{method_name}({",".join(args)})' + current_method = Function(name=full_method_name) + + # Retrieve line coverage information + total_line = 0 + for cov_data in method_item: + if cov_data.attrib["type"] == "LINE": + total_line = int(cov_data.attrib["covered"]) + int( + cov_data.attrib["missed"] + ) + + for count in range(start_index, start_index + total_line): + if count >= len(coverage): + # Fail safe + break + line_no, is_reached = coverage[count] + line = f"Line{line_no}" + if is_reached: + current_method.lines[line] = Line(contents=line, hit_count=1) + else: + current_method.lines[line] = Line(contents=line, hit_count=0) + + textcov.functions[full_method_name] = current_method + + return textcov + + @classmethod + def from_rust_file( + cls, file_handle, ignore_function_patterns: Optional[List[re.Pattern]] = None + ) -> Textcov: + """Read a textcov from a file handle for rust project.""" + if ignore_function_patterns is None: + ignore_function_patterns = [] + + textcov = cls() + textcov.language = "rust" + + current_function_name: str = "" + current_function: Function = Function() + + try: + file_content = cls._read_file_with_fallback(file_handle) + demangled = rust_demangler.demangle(file_content) + except Exception as e: + logger.warning("Decoding failure: %s", e) + demangled = "" + demangled = _discard_fuzz_target_lines(demangled) + + for line in demangled.split("\n"): + match = FUNCTION_PATTERN.match(line) + if match: + # Normalize templates. + current_function_name = normalize_template_args(match.group(1)) + if any( + p.match(current_function_name) for p in ignore_function_patterns + ): + # Ignore this function. + current_function_name = "" + continue + + if current_function_name in textcov.functions: + current_function = textcov.functions[current_function_name] + else: + current_function = Function(name=current_function_name) + textcov.functions[current_function_name] = current_function + + continue + + if not current_function_name: + # No current functions. This can happen if we're currently in an + # ignored function. + continue + + match = LINE_PATTERN.match(line) + if match: + hit_count = _parse_hitcount(match.group(1)) + # Ignore whitespace differences + line_contents = match.group(2).strip() + + if line_contents in current_function.lines: + current_function.lines[line_contents].hit_count += hit_count + else: + current_function.lines[line_contents] = Line( + contents=line_contents, hit_count=hit_count + ) + + continue + return textcov + + def to_file(self, filename: str) -> None: + """Writes covered functions/files and lines to |filename|.""" + file_content = "" + + if self.language == "python": + target = self.files else: - current_function = Function(name=current_function_name) - textcov.functions[current_function_name] = current_function - - continue - - if not current_function_name: - # No current functions. This can happen if we're currently in an - # ignored function. - continue - - match = LINE_PATTERN.match(line) - if match: - hit_count = _parse_hitcount(match.group(1)) - # Ignore whitespace differences - line_contents = match.group(2).strip() - - if line_contents in current_function.lines: - current_function.lines[line_contents].hit_count += hit_count + target = self.functions + + for func_obj in target.values(): + for line_content, line_obj in func_obj.lines.items(): + file_content += f"{line_content}\n" if line_obj.hit_count else "" + + with open(filename, "w") as file: + file.write(file_content) + + def merge(self, other: Textcov): + """Merge another textcov""" + # The default language for Textcov is set to c++ + # This logic fixes the language of Textcov object when + # merging an existing Textcov with different language + if self.language != other.language and self.language == "c++": + self.language = other.language + + if self.language == "python": + for file in other.files.values(): + if file.name not in self.files: + self.files[file.name] = File(name=file.name) + self.files[file.name].merge(file) else: - current_function.lines[line_contents] = Line(contents=line_contents, - hit_count=hit_count) - - continue - return textcov - - @classmethod - def from_python_file(cls, file_handle) -> Textcov: - """Read a textcov from a all_cov.json file for python.""" - textcov = cls() - textcov.language = 'python' - coverage_report = json.load(file_handle) - - # Process coverage report file by file - for file, data in coverage_report.get('files', {}).items(): - # Retrieve pure file and directory name - filename = file.replace('/pythoncovmergedfiles/medio/medio/', '') - filename = filename.split('site-packages/', 1)[-1] - current_file = File(name=filename) - - # Process line coverage information - covered_lines = data.get('executed_lines', []) - missed_lines = data.get('missing_lines', []) - for line_no in covered_lines: - line = f'Line{line_no}' - current_file.lines[line] = Line(contents=line, hit_count=1) - for line_no in missed_lines: - line = f'Line{line_no}' - current_file.lines[line] = Line(contents=line, hit_count=0) - - textcov.files[filename] = current_file - - return textcov - - @classmethod - def from_jvm_file(cls, file_handle) -> Textcov: - """Read a textcov from a jacoco.xml file.""" - textcov = cls() - textcov.language = 'jvm' - jacoco_report = ET.parse(file_handle) - - # Process source file information - line_coverage_dict = {} - for item in jacoco_report.iter(): - if item.tag == 'sourcefile': - line_coverage = [] - for line_item in item: - if line_item.tag == 'line': - line_no = int(line_item.attrib['nr']) - if line_item.attrib['mi'] == '0': - line_coverage.append((line_no, True)) - else: - line_coverage.append((line_no, False)) - line_coverage_dict[item.attrib['name']] = line_coverage - - # Process methods - class_method_items = [] - for item in jacoco_report.iter(): - if item.tag == 'class': - # Skip fuzzer classes - if textcov.is_fuzzer_class(item): - continue - - # Get line coverage information for this class - sourcefilename = item.attrib.get('sourcefilename') - if not sourcefilename: - # Fail safe for invalid jacoco.xml with no sourcefilename - continue - coverage = line_coverage_dict.get(sourcefilename, []) - - # Get class name and skip fuzzing and testing classes - class_name = item.attrib.get('name', '').replace('/', '.') - if not class_name or 'test' in class_name.lower( - ) or 'fuzzer' in class_name.lower(): - continue - - for method_item in item: - if method_item.tag == 'method': - if method_item.attrib['name'] not in JVM_SKIPPED_METHOD: - class_method_items.append((class_name, method_item, coverage)) - - for class_name, method_item, coverage in class_method_items: - method_dict = method_item.attrib - method_name = method_dict['name'] - - # Determine start index in coverage list - start_line = int(method_dict.get('line', '-1')) - start_index = -1 - for count, item in enumerate(coverage): - if item[0] == start_line: - start_index = count - break - - # Failed to retrieve coverage information, skipping this method - if start_index == -1: - continue - - # Process all arguments type from shortern Java Class naming - args = textcov.determine_jvm_arguments_type(method_dict['desc']) - - # Save method - full_method_name = f'[{class_name}].{method_name}({",".join(args)})' - current_method = Function(name=full_method_name) - - # Retrieve line coverage information - total_line = 0 - for cov_data in method_item: - if cov_data.attrib['type'] == 'LINE': - total_line = int(cov_data.attrib['covered']) + int( - cov_data.attrib['missed']) - - for count in range(start_index, start_index + total_line): - if count >= len(coverage): - # Fail safe - break - line_no, is_reached = coverage[count] - line = f'Line{line_no}' - if is_reached: - current_method.lines[line] = Line(contents=line, hit_count=1) - else: - current_method.lines[line] = Line(contents=line, hit_count=0) - - textcov.functions[full_method_name] = current_method - - return textcov - - @classmethod - def from_rust_file( - cls, - file_handle, - ignore_function_patterns: Optional[List[re.Pattern]] = None) -> Textcov: - """Read a textcov from a file handle for rust project.""" - if ignore_function_patterns is None: - ignore_function_patterns = [] - - textcov = cls() - textcov.language = 'rust' - - current_function_name: str = '' - current_function: Function = Function() - - try: - file_content = cls._read_file_with_fallback(file_handle) - demangled = rust_demangler.demangle(file_content) - except Exception as e: - logger.warning('Decoding failure: %s', e) - demangled = '' - demangled = _discard_fuzz_target_lines(demangled) - - for line in demangled.split('\n'): - match = FUNCTION_PATTERN.match(line) - if match: - # Normalize templates. - current_function_name = normalize_template_args(match.group(1)) - if any( - p.match(current_function_name) for p in ignore_function_patterns): - # Ignore this function. - current_function_name = '' - continue - - if current_function_name in textcov.functions: - current_function = textcov.functions[current_function_name] - else: - current_function = Function(name=current_function_name) - textcov.functions[current_function_name] = current_function - - continue - - if not current_function_name: - # No current functions. This can happen if we're currently in an - # ignored function. - continue - - match = LINE_PATTERN.match(line) - if match: - hit_count = _parse_hitcount(match.group(1)) - # Ignore whitespace differences - line_contents = match.group(2).strip() - - if line_contents in current_function.lines: - current_function.lines[line_contents].hit_count += hit_count - else: - current_function.lines[line_contents] = Line(contents=line_contents, - hit_count=hit_count) - - continue - return textcov - - def to_file(self, filename: str) -> None: - """Writes covered functions/files and lines to |filename|.""" - file_content = '' - - if self.language == 'python': - target = self.files - else: - target = self.functions - - for func_obj in target.values(): - for line_content, line_obj in func_obj.lines.items(): - file_content += f'{line_content}\n' if line_obj.hit_count else '' - - with open(filename, 'w') as file: - file.write(file_content) - - def merge(self, other: Textcov): - """Merge another textcov""" - # The default language for Textcov is set to c++ - # This logic fixes the language of Textcov object when - # merging an existing Textcov with different language - if self.language != other.language and self.language == 'c++': - self.language = other.language - - if self.language == 'python': - for file in other.files.values(): - if file.name not in self.files: - self.files[file.name] = File(name=file.name) - self.files[file.name].merge(file) - else: - for function in other.functions.values(): - if function.name not in self.functions: - self.functions[function.name] = Function(name=function.name) - self.functions[function.name].merge(function) - - def subtract_covered_lines(self, other: Textcov): - """Diff another textcov""" - if self.language == 'python': - for file in other.files.values(): - if file.name in self.files: - self.files[file.name].subtract_covered_lines(file) - else: - for function in other.functions.values(): - if function.name in self.functions: - self.functions[function.name].subtract_covered_lines( - function, self.language) - - @property - def covered_lines(self): - if self.language == 'python': - return sum(f.covered_lines for f in self.files.values()) - - return sum(f.covered_lines for f in self.functions.values()) - - @property - def total_lines(self): - if self.language == 'python': - return sum(len(f.lines) for f in self.files.values()) - - return sum(len(f.lines) for f in self.functions.values()) - - def is_fuzzer_class(self, class_item) -> bool: - """Determine if the class_item is a fuzzer class.""" - return bool(class_item.find('./method[@name=\"fuzzerTestOneInput\"]')) - - def determine_jvm_arguments_type(self, desc: str) -> List[str]: - """ - Determine list of jvm arguments type for each method. - - The desc tag for each jvm method in the jacoco.xml coverage - report is in basic Java class name specification following - the format of "({Arguments}){ReturnType}". The basic java - class name specification use single upper case letter for - primitive types (and void type) and L{full_class_name}; for - object arguments. The JVM_CLASS_MAPPING give the mapping of - the single upper case letter of each primitive types. - - For example, for a method - "public void test(String,int,String[],boolean,int...)" - - The desc value of the above method will be - "(Ljava.lang.String;ILjava.lang.String;[]ZI[])V". - - This method is necessary to match the full method name with - the one given in the jacoco.xml report with full argument list. - """ - args = [] - arg = '' - start = False - next_arg = '' - array_count = 0 - for c in desc: - if c == '(': - continue - if c == ')': - break - - if start: - if c == ';': - start = False - next_arg = arg.replace('/', '.') + for function in other.functions.values(): + if function.name not in self.functions: + self.functions[function.name] = Function(name=function.name) + self.functions[function.name].merge(function) + + def subtract_covered_lines(self, other: Textcov): + """Diff another textcov""" + if self.language == "python": + for file in other.files.values(): + if file.name in self.files: + self.files[file.name].subtract_covered_lines(file) else: - arg = arg + c - else: - if c == 'L': - start = True - if next_arg: - next_arg += '[]' * array_count - array_count = 0 + for function in other.functions.values(): + if function.name in self.functions: + self.functions[function.name].subtract_covered_lines( + function, self.language + ) + + @property + def covered_lines(self): + if self.language == "python": + return sum(f.covered_lines for f in self.files.values()) + + return sum(f.covered_lines for f in self.functions.values()) + + @property + def total_lines(self): + if self.language == "python": + return sum(len(f.lines) for f in self.files.values()) + + return sum(len(f.lines) for f in self.functions.values()) + + def is_fuzzer_class(self, class_item) -> bool: + """Determine if the class_item is a fuzzer class.""" + return bool(class_item.find('./method[@name="fuzzerTestOneInput"]')) + + def determine_jvm_arguments_type(self, desc: str) -> List[str]: + """ + Determine list of jvm arguments type for each method. + + The desc tag for each jvm method in the jacoco.xml coverage + report is in basic Java class name specification following + the format of "({Arguments}){ReturnType}". The basic java + class name specification use single upper case letter for + primitive types (and void type) and L{full_class_name}; for + object arguments. The JVM_CLASS_MAPPING give the mapping of + the single upper case letter of each primitive types. + + For example, for a method + "public void test(String,int,String[],boolean,int...)" + + The desc value of the above method will be + "(Ljava.lang.String;ILjava.lang.String;[]ZI[])V". + + This method is necessary to match the full method name with + the one given in the jacoco.xml report with full argument list. + """ + args = [] + arg = "" + start = False + next_arg = "" + array_count = 0 + for c in desc: + if c == "(": + continue + if c == ")": + break + + if start: + if c == ";": + start = False + next_arg = arg.replace("/", ".") + else: + arg = arg + c + else: + if c == "L": + start = True + if next_arg: + next_arg += "[]" * array_count + array_count = 0 + args.append(next_arg) + arg = "" + next_arg = "" + elif c == "[": + array_count += 1 + else: + if c in JVM_CLASS_MAPPING: + if next_arg: + next_arg += "[]" * array_count + array_count = 0 + args.append(next_arg) + next_arg = JVM_CLASS_MAPPING[c] + + if next_arg: + next_arg += "[]" * array_count args.append(next_arg) - arg = '' - next_arg = '' - elif c == '[': - array_count += 1 - else: - if c in JVM_CLASS_MAPPING: - if next_arg: - next_arg += '[]' * array_count - array_count = 0 - args.append(next_arg) - next_arg = JVM_CLASS_MAPPING[c] - - if next_arg: - next_arg += '[]' * array_count - args.append(next_arg) - return args + return args diff --git a/experiment/workdir.py b/experiment/workdir.py index b218956409..66e4dd658a 100644 --- a/experiment/workdir.py +++ b/experiment/workdir.py @@ -22,110 +22,112 @@ class WorkDirs: - """Working directories.""" - - RUN_LOG_NAME_PATTERN = re.compile(r'.*-F(\d+).log') - - def __init__(self, base_dir, keep: bool = False): - self._base_dir = os.path.realpath(base_dir) - if os.path.exists(self._base_dir) and not keep: - # Clear existing directory. - rmtree(self._base_dir, ignore_errors=True) - - os.makedirs(self._base_dir, exist_ok=True) - os.makedirs(self.status, exist_ok=True) - os.makedirs(self.raw_targets, exist_ok=True) - os.makedirs(self.fixed_targets, exist_ok=True) - os.makedirs(self.build_logs, exist_ok=True) - os.makedirs(self.run_logs, exist_ok=True) - os.makedirs(self._corpus_base, exist_ok=True) - os.makedirs(self.dills, exist_ok=True) - os.makedirs(self.fuzz_targets, exist_ok=True) - os.makedirs(self._artifact_base, exist_ok=True) - - def __repr__(self) -> str: - return self._base_dir - - @property - def base(self) -> str: - return self._base_dir - - @property - def _corpus_base(self) -> str: - return os.path.join(self._base_dir, 'corpora') - - @property - def _artifact_base(self) -> str: - return os.path.join(self._base_dir, 'artifacts') - - def corpus(self, sample_id) -> str: - corpus_dir = os.path.join(self._corpus_base, str(sample_id)) - os.makedirs(corpus_dir, exist_ok=True) - return corpus_dir - - def artifact(self, generated_target_name: str, iteration: int, - trial: int) -> str: - artifact_dir = os.path.join( - self._artifact_base, - f'{generated_target_name}-F{iteration}-{trial:02d}') - os.makedirs(artifact_dir, exist_ok=True) - return artifact_dir - - def code_coverage_report(self, benchmark) -> str: - coverage_dir = os.path.join(self._base_dir, 'code-coverage-reports') - os.makedirs(coverage_dir, exist_ok=True) - - benchmark_coverage = os.path.join(coverage_dir, benchmark) - return benchmark_coverage - - @property - def status(self) -> str: - return os.path.join(self._base_dir, 'status') - - @property - def prompt(self) -> str: - return os.path.join(self._base_dir, 'prompt.txt') - - @property - def fuzz_targets(self) -> str: - return os.path.join(self._base_dir, 'fuzz_targets') - - # TODO(dongge): Deprecate this. - @property - def raw_targets(self) -> str: - return os.path.join(self._base_dir, 'raw_targets') - - # TODO(dongge): Deprecate this. - @property - def fixed_targets(self) -> str: - return os.path.join(self._base_dir, 'fixed_targets') - - @property - def build_logs(self) -> str: - return os.path.join(self._base_dir, 'logs', 'build') - - @property - def dills(self) -> str: - return os.path.join(self._base_dir, 'dills') - - @property - def run_logs(self) -> str: - return os.path.join(self._base_dir, 'logs', 'run') - - def build_logs_target(self, generated_target_name: str, iteration: int, - trial: int) -> str: - return os.path.join( - self.build_logs, - f'{generated_target_name}-F{iteration}-{trial:02d}.log') - - def run_logs_target(self, generated_target_name: str, iteration: int, - trial: int) -> str: - return os.path.join( - self.run_logs, f'{generated_target_name}-F{iteration}-{trial:02d}.log') - - @classmethod - def get_run_log_iteration(cls, filename: str) -> Optional[int]: - match = cls.RUN_LOG_NAME_PATTERN.match(filename) - if match: - return int(match.group(1)) - return None + """Working directories.""" + + RUN_LOG_NAME_PATTERN = re.compile(r".*-F(\d+).log") + + def __init__(self, base_dir, keep: bool = False): + self._base_dir = os.path.realpath(base_dir) + if os.path.exists(self._base_dir) and not keep: + # Clear existing directory. + rmtree(self._base_dir, ignore_errors=True) + + os.makedirs(self._base_dir, exist_ok=True) + os.makedirs(self.status, exist_ok=True) + os.makedirs(self.raw_targets, exist_ok=True) + os.makedirs(self.fixed_targets, exist_ok=True) + os.makedirs(self.build_logs, exist_ok=True) + os.makedirs(self.run_logs, exist_ok=True) + os.makedirs(self._corpus_base, exist_ok=True) + os.makedirs(self.dills, exist_ok=True) + os.makedirs(self.fuzz_targets, exist_ok=True) + os.makedirs(self._artifact_base, exist_ok=True) + + def __repr__(self) -> str: + return self._base_dir + + @property + def base(self) -> str: + return self._base_dir + + @property + def _corpus_base(self) -> str: + return os.path.join(self._base_dir, "corpora") + + @property + def _artifact_base(self) -> str: + return os.path.join(self._base_dir, "artifacts") + + def corpus(self, sample_id) -> str: + corpus_dir = os.path.join(self._corpus_base, str(sample_id)) + os.makedirs(corpus_dir, exist_ok=True) + return corpus_dir + + def artifact(self, generated_target_name: str, iteration: int, trial: int) -> str: + artifact_dir = os.path.join( + self._artifact_base, f"{generated_target_name}-F{iteration}-{trial:02d}" + ) + os.makedirs(artifact_dir, exist_ok=True) + return artifact_dir + + def code_coverage_report(self, benchmark) -> str: + coverage_dir = os.path.join(self._base_dir, "code-coverage-reports") + os.makedirs(coverage_dir, exist_ok=True) + + benchmark_coverage = os.path.join(coverage_dir, benchmark) + return benchmark_coverage + + @property + def status(self) -> str: + return os.path.join(self._base_dir, "status") + + @property + def prompt(self) -> str: + return os.path.join(self._base_dir, "prompt.txt") + + @property + def fuzz_targets(self) -> str: + return os.path.join(self._base_dir, "fuzz_targets") + + # TODO(dongge): Deprecate this. + @property + def raw_targets(self) -> str: + return os.path.join(self._base_dir, "raw_targets") + + # TODO(dongge): Deprecate this. + @property + def fixed_targets(self) -> str: + return os.path.join(self._base_dir, "fixed_targets") + + @property + def build_logs(self) -> str: + return os.path.join(self._base_dir, "logs", "build") + + @property + def dills(self) -> str: + return os.path.join(self._base_dir, "dills") + + @property + def run_logs(self) -> str: + return os.path.join(self._base_dir, "logs", "run") + + def build_logs_target( + self, generated_target_name: str, iteration: int, trial: int + ) -> str: + return os.path.join( + self.build_logs, f"{generated_target_name}-F{iteration}-{trial:02d}.log" + ) + + def run_logs_target( + self, generated_target_name: str, iteration: int, trial: int + ) -> str: + return os.path.join( + self.run_logs, f"{generated_target_name}-F{iteration}-{trial:02d}.log" + ) + + @classmethod + def get_run_log_iteration(cls, filename: str) -> Optional[int]: + match = cls.RUN_LOG_NAME_PATTERN.match(filename) + if match: + return int(match.group(1)) + return None diff --git a/experimental/build_fixer/build_fix.py b/experimental/build_fixer/build_fix.py index 5f2a9adbfb..d4bf92836b 100644 --- a/experimental/build_fixer/build_fix.py +++ b/experimental/build_fixer/build_fix.py @@ -36,690 +36,739 @@ from tool.base_tool import BaseTool from tool.container_tool import ProjectContainerTool -FIXER_TOOLS = [{ - 'type': - 'function', - 'name': - 'test_build_script', - 'description': - 'Tests a build script against target project. Use this for tesing build scripts that you suspect might work.', - 'parameters': { - 'type': 'object', - 'properties': { - 'build_script': { - 'type': 'string', - 'description': 'Bash script that builds the project.' - } +FIXER_TOOLS = [ + { + "type": "function", + "name": "test_build_script", + "description": "Tests a build script against target project. Use this for tesing build scripts that you suspect might work.", + "parameters": { + "type": "object", + "properties": { + "build_script": { + "type": "string", + "description": "Bash script that builds the project.", + } + }, + "required": ["build_script"], + "additionalProperties": False, }, - 'required': ['build_script'], - 'additionalProperties': False - } -}, { - 'type': - 'function', - 'name': - 'test_build_script_and_dockerfile', - 'description': - 'Tests a build script and Dockerfile against target project.', - 'parameters': { - 'type': 'object', - 'properties': { - 'build_script': { - 'type': 'string', - 'description': 'Bash script that builds the project.' + }, + { + "type": "function", + "name": "test_build_script_and_dockerfile", + "description": "Tests a build script and Dockerfile against target project.", + "parameters": { + "type": "object", + "properties": { + "build_script": { + "type": "string", + "description": "Bash script that builds the project.", + }, + "dockerfile": { + "type": "string", + "description": "Dockerfile that builds the project.", + }, }, - 'dockerfile': { - 'type': 'string', - 'description': 'Dockerfile that builds the project.' - } + "required": ["build_script", "dockerfile"], + "additionalProperties": False, }, - 'required': ['build_script', 'dockerfile'], - 'additionalProperties': False - } -}, { - 'type': - 'function', - 'name': - 'run_commands_in_container', - 'description': - 'Runs a command string in the project container. Use this for exploring the target project, such as running commands to inspect the project or its dependencies.', - 'parameters': { - 'type': 'object', - 'properties': { - 'command': { - 'type': - 'string', - 'description': - 'Bash commands separated by \';\' to run in the container.' - } + }, + { + "type": "function", + "name": "run_commands_in_container", + "description": "Runs a command string in the project container. Use this for exploring the target project, such as running commands to inspect the project or its dependencies.", + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Bash commands separated by ';' to run in the container.", + } + }, + "required": ["command"], + "additionalProperties": False, }, - 'required': ['command'], - 'additionalProperties': False - } -}] + }, +] class BuildFixAgent(BaseAgent): - """Agent for fixing OSS-Fuzz project builds.""" - - def __init__(self, - llm: LLM, - project_name, - work_dirs, - args, - use_tools: bool = True): - super().__init__(trial=1, llm=llm, args=args) - self.project_name = project_name - self.original_project_name = project_name - self.work_dirs = work_dirs - self.last_status = False - self.last_result = '' - self.compiles = False - self.check_all_passed = False - self.initial_error_result = '' - self.trial = 0 - - self.use_tools = use_tools - - self.success_build_script = '' - - self.project_language = oss_fuzz_checkout.get_project_language( - self.project_name) - - def _strip_license_from_file(self, file_content: str) -> str: - """Strips the license header from a file content.""" - # Strip first comments in a file. - new_content = '' - past_license = False - for line in file_content.splitlines(): - if past_license: - new_content += line + '\n' - continue - - if '#################' in line: - past_license = True - continue - - if line.startswith('#') and 'bash' not in line or 'python' not in line: - continue - new_content += line + '\n' - return new_content - - def _initial_prompt(self, results: list[Result], is_tools: bool = True): # pylint: disable=unused-argument - """Creates the initial prompt for the build fixer agent.""" - with open( - os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, 'projects', - self.project_name, 'build.sh'), 'r') as f: - build_script = self._strip_license_from_file(f.read()) - - with open( - os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, 'projects', - self.project_name, 'Dockerfile'), 'r') as f: - dockerfile = self._strip_license_from_file(f.read()) - - prompt = self.llm.prompt_type()(None) - - if is_tools: - template_prompt = templates.BUILD_FIX_PROBLEM_TOOLS - else: - template_prompt = templates.BUILD_FIX_PROBLEM - template_prompt = template_prompt.replace('{DOCKERFILE}', dockerfile) - template_prompt = template_prompt.replace('{BUILD_SCRIPT}', build_script) - template_prompt = template_prompt.replace('{LOGS}', - self.initial_error_result[-300:]) - template_prompt = template_prompt.replace('{MAX_DISCOVERY_ROUND}', - str(self.args.max_round)) - - if self.project_language.lower() == 'python': - template_prompt = template_prompt.replace('{LANGUAGE_SPECIFICS}', - templates.PYTHON_SPECIFICS) - elif self.project_language.lower() in ['c', 'c++']: - template_prompt = template_prompt.replace('{LANGUAGE_SPECIFICS}', - templates.C_CPP_SPECIFICS) - else: - template_prompt = template_prompt.replace('{LANGUAGE_SPECIFICS}', '') - #prompt.add_priming(template_prompt) - - prompt.add_priming(templates.BUILD_FIXER_LLM_PRIMING) - prompt.add_problem(template_prompt) - return prompt - - def execute(self, result_history: list[Result]) -> BuildResult: - """Executes the build fixer agent. - Creates a container tool and performs an initial build attempt. - The output of the build is then used to generate a prompt, - and the agent then goes into the iterative process. - """ - - # Prepare an initial image build. - result_name = oss_fuzz_checkout.prepare_project_image_by_name( - self.project_name) - - if not result_name: - logger.info(f'Failed to prepare project image for {self.project_name}.', - trial=self.trial) - sys.exit(1) - - self.project_name = result_name.split('/')[-1] - benchmark = Benchmark(self.project_name, self.project_name, '', '', '', '', - [], '') - - # Initial run of compile. - self.inspect_tool = ProjectContainerTool(benchmark, name='inspect') - result = self.inspect_tool.compile( - extra_commands=' && rm -rf /out/* > /dev/null') - - # If the build succeeded, we can exit - if result.returncode == 0: - logger.info(f'Build succeeded for {self.project_name}.', trial=self.trial) - logger.info('Nothing to fix.', trial=self.trial) - self.inspect_tool.terminate() - sys.exit(0) - - self.initial_error_result = result.stderr - - # Prepare initial prompt. - prompt = self._initial_prompt(result_history, self.use_tools) - build_result = BuildResult(benchmark=benchmark, - trial=0, - work_dirs=self.work_dirs, - author=self, - chat_history={self.name: ''}) - if self.use_tools: - self._agent_run_function_based_loop(prompt, build_result) - else: - self._agent_raw_loop(prompt, build_result) - return build_result - - def _test_buildscript_and_dockerfile(self, tool_call, build_script, - dockerfile): - """Tests a build script and Dockerfile against the target project.""" - build_fuzzers_result, target_dst = self._test_build_fuzzers( - build_script, dockerfile) - if build_fuzzers_result.returncode != 0: - logger.info('Build failed.', trial=self.trial) - parsed_stdout = build_fuzzers_result.stdout - parsed_stdout = self._simple_truncate_build_output(parsed_stdout) - - logger.info('Parsed stdout: %s', parsed_stdout, trial=self.trial) - - # Prepare for next iteration by adding messages to the chat. - self.llm.messages.append(tool_call) - self.llm.messages.append({ - 'type': 'function_call_output', - 'call_id': tool_call.call_id, - 'output': str(parsed_stdout) - }) - self.working_prompt = None - - else: - logger.info('Build succeeded.', trial=self.trial) - # Testing fuzzers run. - test_run_result = self._test_check_fuzzers(target_dst) - if test_run_result.returncode == 0: - logger.info('Fuzzers run successfully.', trial=self.trial) - self.success_build_script = build_script - self.success_dockerfile = dockerfile + """Agent for fixing OSS-Fuzz project builds.""" + + def __init__(self, llm: LLM, project_name, work_dirs, args, use_tools: bool = True): + super().__init__(trial=1, llm=llm, args=args) + self.project_name = project_name + self.original_project_name = project_name + self.work_dirs = work_dirs + self.last_status = False + self.last_result = "" + self.compiles = False + self.check_all_passed = False + self.initial_error_result = "" + self.trial = 0 + + self.use_tools = use_tools + + self.success_build_script = "" + + self.project_language = oss_fuzz_checkout.get_project_language( + self.project_name + ) + + def _strip_license_from_file(self, file_content: str) -> str: + """Strips the license header from a file content.""" + # Strip first comments in a file. + new_content = "" + past_license = False + for line in file_content.splitlines(): + if past_license: + new_content += line + "\n" + continue + + if "#################" in line: + past_license = True + continue + + if line.startswith("#") and "bash" not in line or "python" not in line: + continue + new_content += line + "\n" + return new_content + + def _initial_prompt( + self, results: list[Result], is_tools: bool = True + ): # pylint: disable=unused-argument + """Creates the initial prompt for the build fixer agent.""" + with open( + os.path.join( + oss_fuzz_checkout.OSS_FUZZ_DIR, + "projects", + self.project_name, + "build.sh", + ), + "r", + ) as f: + build_script = self._strip_license_from_file(f.read()) + + with open( + os.path.join( + oss_fuzz_checkout.OSS_FUZZ_DIR, + "projects", + self.project_name, + "Dockerfile", + ), + "r", + ) as f: + dockerfile = self._strip_license_from_file(f.read()) + + prompt = self.llm.prompt_type()(None) + + if is_tools: + template_prompt = templates.BUILD_FIX_PROBLEM_TOOLS + else: + template_prompt = templates.BUILD_FIX_PROBLEM + template_prompt = template_prompt.replace("{DOCKERFILE}", dockerfile) + template_prompt = template_prompt.replace("{BUILD_SCRIPT}", build_script) + template_prompt = template_prompt.replace( + "{LOGS}", self.initial_error_result[-300:] + ) + template_prompt = template_prompt.replace( + "{MAX_DISCOVERY_ROUND}", str(self.args.max_round) + ) + + if self.project_language.lower() == "python": + template_prompt = template_prompt.replace( + "{LANGUAGE_SPECIFICS}", templates.PYTHON_SPECIFICS + ) + elif self.project_language.lower() in ["c", "c++"]: + template_prompt = template_prompt.replace( + "{LANGUAGE_SPECIFICS}", templates.C_CPP_SPECIFICS + ) + else: + template_prompt = template_prompt.replace("{LANGUAGE_SPECIFICS}", "") + # prompt.add_priming(template_prompt) + + prompt.add_priming(templates.BUILD_FIXER_LLM_PRIMING) + prompt.add_problem(template_prompt) + return prompt + + def execute(self, result_history: list[Result]) -> BuildResult: + """Executes the build fixer agent. + Creates a container tool and performs an initial build attempt. + The output of the build is then used to generate a prompt, + and the agent then goes into the iterative process. + """ + + # Prepare an initial image build. + result_name = oss_fuzz_checkout.prepare_project_image_by_name(self.project_name) + + if not result_name: + logger.info( + f"Failed to prepare project image for {self.project_name}.", + trial=self.trial, + ) + sys.exit(1) + + self.project_name = result_name.split("/")[-1] + benchmark = Benchmark( + self.project_name, self.project_name, "", "", "", "", [], "" + ) + + # Initial run of compile. + self.inspect_tool = ProjectContainerTool(benchmark, name="inspect") + result = self.inspect_tool.compile( + extra_commands=" && rm -rf /out/* > /dev/null" + ) + + # If the build succeeded, we can exit + if result.returncode == 0: + logger.info(f"Build succeeded for {self.project_name}.", trial=self.trial) + logger.info("Nothing to fix.", trial=self.trial) + self.inspect_tool.terminate() + sys.exit(0) + + self.initial_error_result = result.stderr + + # Prepare initial prompt. + prompt = self._initial_prompt(result_history, self.use_tools) + build_result = BuildResult( + benchmark=benchmark, + trial=0, + work_dirs=self.work_dirs, + author=self, + chat_history={self.name: ""}, + ) + if self.use_tools: + self._agent_run_function_based_loop(prompt, build_result) + else: + self._agent_raw_loop(prompt, build_result) + return build_result + + def _test_buildscript_and_dockerfile(self, tool_call, build_script, dockerfile): + """Tests a build script and Dockerfile against the target project.""" + build_fuzzers_result, target_dst = self._test_build_fuzzers( + build_script, dockerfile + ) + if build_fuzzers_result.returncode != 0: + logger.info("Build failed.", trial=self.trial) + parsed_stdout = build_fuzzers_result.stdout + parsed_stdout = self._simple_truncate_build_output(parsed_stdout) + + logger.info("Parsed stdout: %s", parsed_stdout, trial=self.trial) + + # Prepare for next iteration by adding messages to the chat. + self.llm.messages.append(tool_call) + self.llm.messages.append( + { + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(parsed_stdout), + } + ) + self.working_prompt = None - self.exit_condition_met = True - else: - logger.info('Fuzzers run failed.', trial=self.trial) - prompt_text = test_run_result.stdout - # Prepare for next iteration by adding messages to the chat. + else: + logger.info("Build succeeded.", trial=self.trial) + # Testing fuzzers run. + test_run_result = self._test_check_fuzzers(target_dst) + if test_run_result.returncode == 0: + logger.info("Fuzzers run successfully.", trial=self.trial) + self.success_build_script = build_script + self.success_dockerfile = dockerfile + + self.exit_condition_met = True + else: + logger.info("Fuzzers run failed.", trial=self.trial) + prompt_text = test_run_result.stdout + # Prepare for next iteration by adding messages to the chat. + self.llm.messages.append(tool_call) + self.llm.messages.append( + { + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(prompt_text), + } + ) + + self.working_prompt = None + + def _func_handle_run_commands_in_container(self, tool_call, command_string): + """Runs a command string in the project container.""" + + # Execute the command directly, then return the formatted result + commands = command_string + logger.info("LLM Requested commands: %s", commands, trial=self.trial) + result = self.inspect_tool.execute(commands) + prompt_text = self._format_bash_execution_result( + result, previous_prompt=self.working_prompt + ) + + prompt_text = self._simple_truncate_build_output(prompt_text) + + # Extend messages to prepare for next iteration. self.llm.messages.append(tool_call) - self.llm.messages.append({ - 'type': 'function_call_output', - 'call_id': tool_call.call_id, - 'output': str(prompt_text) - }) - + self.llm.messages.append( + { + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(prompt_text), + } + ) self.working_prompt = None - def _func_handle_run_commands_in_container(self, tool_call, command_string): - """Runs a command string in the project container.""" - - # Execute the command directly, then return the formatted result - commands = command_string - logger.info('LLM Requested commands: %s', commands, trial=self.trial) - result = self.inspect_tool.execute(commands) - prompt_text = self._format_bash_execution_result( - result, previous_prompt=self.working_prompt) - - prompt_text = self._simple_truncate_build_output(prompt_text) - - # Extend messages to prepare for next iteration. - self.llm.messages.append(tool_call) - self.llm.messages.append({ - 'type': 'function_call_output', - 'call_id': tool_call.call_id, - 'output': str(prompt_text) - }) - self.working_prompt = None - - def _log_success(self): - """Utility funciton to log success of fixing.""" - logger.info('Succeeded fixing build script', trial=self.trial) - logger.info('-' * 25 + ' Build script: ' + '-' * 25, trial=self.trial) - logger.info(self.success_build_script, trial=self.trial) - logger.info('-' * 60, trial=self.trial) - - def _load_tool_arguments(self, tool_call: Any) -> Optional[dict]: - """Loads the arguments for a tool call.""" - try: - return json.loads(tool_call.arguments) - except json.JSONDecodeError as e: - logger.error('Failed to decode tool call arguments: %s', - e, - trial=self.trial) - - # Getting here means the arguments were not valid JSON. - # This happens sometimes, and to overcome this we extract - # the arguments using some simple manual parsing. - args = {} - - # 1: find the relevant function - # 2: For each argument of the function extract that - # keyword from the response. - for function_tool in FIXER_TOOLS: - if function_tool['name'] == tool_call.name: - for arg in function_tool['parameters']['properties']: - # Extract the argument value from the response. - val = self._extract_argument_from_broken_json(tool_call.arguments, - arg) - args[arg] = val - - if len(args) != len(function_tool['parameters']['properties']): - return None - return args - - def _extract_argument_from_broken_json(self, raw_response, key): - """Extracts a single argument from a broken JSON response.""" - # Find the first key - search_word = f'"{key}":' - location_idx = raw_response.find(search_word) - start_idx = location_idx + len(search_word) - - # Find the next two quotes, and take everything within them. - quote_locations = [] - for idx in range(len(raw_response[start_idx:])): - if raw_response[idx + start_idx] == '"': - # If this is escaped, discount - if raw_response[idx + start_idx - 1] == '\\': - continue - # We have a quote - quote_locations.append(idx + start_idx) - if len(quote_locations) == 2: - return raw_response[quote_locations[0] + 1:quote_locations[1]] - return None - - def _dispatch_tool_call(self, tool_call: Any) -> int: - """Dispatches a function call to the appropriate handler.""" - arguments = self._load_tool_arguments(tool_call) - if arguments is None: - return 0 - if tool_call.name == 'test_build_script_and_dockerfile': - self._test_buildscript_and_dockerfile(tool_call, - arguments['build_script'], - arguments['dockerfile']) - return 1 - if tool_call.name == 'test_build_script': - self._test_buildscript_and_dockerfile(tool_call, - arguments['build_script'], '') - return 1 - if tool_call.name == 'run_commands_in_container': - self._func_handle_run_commands_in_container(tool_call, - arguments['command']) - return 1 - - logger.info('Unsupported tool call: %s', tool_call.name, trial=self.trial) - return 0 - - def _agent_run_function_based_loop( - self, prompt: Optional[Prompt], build_result: BuildResult) -> None: # pylint: disable=unused-argument - """Runs the agent loop using a function-based approach.""" - self.working_prompt = prompt - # Agent loop - try: - client = self.llm.get_chat_client(model=self.llm.get_model()) - - cur_round = 0 - self.exit_condition_met = False - # Function execution and LLM communication loop. - while self.exit_condition_met is False: - logger.info(f'Agent Round {cur_round}', trial=self.trial) - - # Increment the round counter, but trigger exit condition if max - # rounds reached. - if cur_round > self.args.max_round: - logger.info('Max discovery rounds reached (%s).', - self.args.max_round, - trial=self.trial) - break - cur_round += 1 - - # Send prompt to LLM and get response. - logger.info('Sending prompt to LLM', trial=self.trial) - response = self.chat_llm_with_tools(client, self.working_prompt, - FIXER_TOOLS, self.trial) - - if not response: - logger.info('LLM did not return a response, skipping this round.', - trial=self.trial) - continue - - # Handle LLM tool calls. - tools_analysed = 0 - logger.info('Iterating response output', trial=self.trial) - for tool_call in response.output: - logger.info('- Response out:' + str(tool_call), trial=self.trial) - if tool_call.type != 'function_call': - continue - - logger.info('Handling tool call %s', tool_call.name, trial=self.trial) - logger.info('Tool call arguments: %s', - tool_call.arguments, - trial=self.trial) - tools_analysed += self._dispatch_tool_call(tool_call) - - # If no tool calls were made prepare LLM response saying we do not - # understand the message received. - if tools_analysed == 0 and not self.exit_condition_met: - logger.info( - 'Did not execute any tool calls and there is no exit condition.', - trial=self.trial) - self.working_prompt = self.llm.prompt_type()(None) - self.working_prompt.add_problem( - 'I was unable to interpret your last message. Use tool ' - 'calls to direct this process instead of messages.') - - # Post LLM communication and function execution loop. - # Log details on success. - if self.exit_condition_met: - self._log_success() - - # TODO (David): Add handling for "why did we not succeed" case. - finally: - self.inspect_tool.terminate() - - def _agent_raw_loop(self, prompt: Optional[Prompt], - build_result: BuildResult) -> None: - """Runs the agent loop, sending prompts to the LLM and handling - responses.""" - # Agent loop - self.trial = 0 - try: - client = self.llm.get_chat_client(model=self.llm.get_model()) - while prompt: - logger.info(f'Agent Round {self.trial}', trial=self.trial) - # Pass prompt history to LLM and get response. - logger.info('Sending prompt to LLM', trial=self.trial) - response = self.chat_llm(self.trial, - client=client, - prompt=prompt, - trial=self.trial) - - # Handle LLM response. - logger.info('Handling LLM response', trial=self.trial) - prompt = self._handle_llm_reponse(response, build_result) - if not prompt: - break - if self.trial >= self.args.max_round: - logger.info(f'Max discovery rounds reached ({self.args.max_round}).', - trial=self.trial) - break - self.trial += 1 - finally: - self.inspect_tool.terminate() - - def _parse_tag(self, response: str, tag: str) -> str: - """Parses the tag from LLM response.""" - patterns = [rf'<{tag}>(.*?)', rf'```{tag}(.*?)```'] - - # Matches both xml and code style tags - for pattern in patterns: - match = re.search(pattern, response, re.DOTALL) - if match: - return match.group(1).strip() - - return '' - - def _parse_tags(self, response: str, tag: str) -> list[str]: - """Parses the tags from LLM response.""" - patterns = [rf'<{tag}>(.*?)', rf'```{tag}(.*?)```'] - found_matches = [] - - # Matches both xml and code style tags - for pattern in patterns: - matches = re.findall(pattern, response, re.DOTALL) - found_matches.extend([content.strip() for content in matches]) - - return found_matches - - def _test_build_fuzzers( - self, - build_script: str, - dockerfile: str = '') -> tuple[subprocess.CompletedProcess, str]: - """Runs OSS-Fuzz's build_fuzzers command with the provided build script.""" - target_dst = self.original_project_name + '-copy-' + str( - uuid.uuid4().hex)[:8] - shutil.copytree( - os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, 'projects', - self.original_project_name), - os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, 'projects', target_dst)) - - self.success_build_script = build_script - # Overwrite the build script with the new one - with open( - os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, 'projects', target_dst, - 'build.sh'), 'w') as f: - f.write(build_script) - - if dockerfile: - # Overwrite the Dockerfile with the new one - with open( - os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, 'projects', target_dst, - 'Dockerfile'), 'w') as f: - f.write(dockerfile) - - # Build project - cmd = ['python3', 'infra/helper.py', 'build_fuzzers', target_dst] - result = subprocess.run(cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=False, - text=True, - encoding='utf-8', - errors='ignore', - cwd=oss_fuzz_checkout.OSS_FUZZ_DIR) - return result, target_dst - - def _test_check_fuzzers(self, target_dst) -> subprocess.CompletedProcess: - """Runs OSS-Fuzz's check_build command to evaluate build fuzzers.""" - - cmd = ['python3', 'infra/helper.py', 'check_build', target_dst] - result = subprocess.run(cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=False, - text=True, - encoding='utf-8', - errors='ignore', - cwd=oss_fuzz_checkout.OSS_FUZZ_DIR) - return result - - def _simple_truncate_build_output(self, output: str) -> str: - """Truncates the build output to a manageable size.""" - if len(output) > 8000: - return output[:1500] + '\n... (truncated)' + output[-6500:] - return output - - def _parse_llm_reponse_and_operate(self, response: str, tool: BaseTool, - prompt: Prompt) -> Prompt: - """Parses and LLM response and takes appropriate action. This includes - parsing bash commands to be executed in the container tool or extracting - the build script and testing it for compilation.""" - # Initialise variables - prompt_text = '' - success = False - self.invalid = False - self.missing_binary = False - - logger.info('=' * 80, trial=self.trial) - logger.info(response, trial=self.trial) - logger.info('=' * 80, trial=self.trial) - - # Retrieve data from response - build_script = self._parse_tag(response, 'bash') - commands = '; '.join(self._parse_tags(response, 'command')) - - if commands: - logger.info('LLM Requested commands: %s', commands, trial=self.trial) - self.discovery_stage = True - - # Execute the command directly, then return the formatted result - result = tool.execute(commands) - prompt_text = self._format_bash_execution_result(result, - previous_prompt=prompt) - if result.returncode == 0: - success = True - elif build_script: - logger.info('LLM Provided build script.', trial=self.trial) - self.discovery_stage = False - - # Fix shebang to ensure docker image failing is reflected. - lines = build_script.split('\n') - if lines[0].startswith("#!"): - lines[0] = "#!/bin/bash -eu" - else: - lines = ["#!/bin/bash -eu"] + lines - build_script = '\n'.join(lines) - - build_result, target_dst = self._test_build_fuzzers(build_script) - if build_result.returncode != 0: - logger.info('Build failed.', trial=self.trial) - parsed_stdout = build_result.stdout - tag = '---------------------------------------------------------------' - - parsed_stdout = tag.join(parsed_stdout.split(tag)[3:]) - prompt_text = 'Build failed, this is the output:\n' - parsed_stdout = self._simple_truncate_build_output(parsed_stdout) - prompt_text += f'{parsed_stdout}' - self.compiles = False - self.check_all_passed = False + def _log_success(self): + """Utility funciton to log success of fixing.""" + logger.info("Succeeded fixing build script", trial=self.trial) + logger.info("-" * 25 + " Build script: " + "-" * 25, trial=self.trial) + logger.info(self.success_build_script, trial=self.trial) + logger.info("-" * 60, trial=self.trial) + + def _load_tool_arguments(self, tool_call: Any) -> Optional[dict]: + """Loads the arguments for a tool call.""" + try: + return json.loads(tool_call.arguments) + except json.JSONDecodeError as e: + logger.error( + "Failed to decode tool call arguments: %s", e, trial=self.trial + ) + + # Getting here means the arguments were not valid JSON. + # This happens sometimes, and to overcome this we extract + # the arguments using some simple manual parsing. + args = {} + + # 1: find the relevant function + # 2: For each argument of the function extract that + # keyword from the response. + for function_tool in FIXER_TOOLS: + if function_tool["name"] == tool_call.name: + for arg in function_tool["parameters"]["properties"]: + # Extract the argument value from the response. + val = self._extract_argument_from_broken_json( + tool_call.arguments, arg + ) + args[arg] = val + + if len(args) != len(function_tool["parameters"]["properties"]): + return None + return args + + def _extract_argument_from_broken_json(self, raw_response, key): + """Extracts a single argument from a broken JSON response.""" + # Find the first key + search_word = f'"{key}":' + location_idx = raw_response.find(search_word) + start_idx = location_idx + len(search_word) + + # Find the next two quotes, and take everything within them. + quote_locations = [] + for idx in range(len(raw_response[start_idx:])): + if raw_response[idx + start_idx] == '"': + # If this is escaped, discount + if raw_response[idx + start_idx - 1] == "\\": + continue + # We have a quote + quote_locations.append(idx + start_idx) + if len(quote_locations) == 2: + return raw_response[quote_locations[0] + 1 : quote_locations[1]] + return None + + def _dispatch_tool_call(self, tool_call: Any) -> int: + """Dispatches a function call to the appropriate handler.""" + arguments = self._load_tool_arguments(tool_call) + if arguments is None: + return 0 + if tool_call.name == "test_build_script_and_dockerfile": + self._test_buildscript_and_dockerfile( + tool_call, arguments["build_script"], arguments["dockerfile"] + ) + return 1 + if tool_call.name == "test_build_script": + self._test_buildscript_and_dockerfile( + tool_call, arguments["build_script"], "" + ) + return 1 + if tool_call.name == "run_commands_in_container": + self._func_handle_run_commands_in_container(tool_call, arguments["command"]) + return 1 + + logger.info("Unsupported tool call: %s", tool_call.name, trial=self.trial) + return 0 + + def _agent_run_function_based_loop( + self, prompt: Optional[Prompt], build_result: BuildResult + ) -> None: # pylint: disable=unused-argument + """Runs the agent loop using a function-based approach.""" + self.working_prompt = prompt + # Agent loop + try: + client = self.llm.get_chat_client(model=self.llm.get_model()) + + cur_round = 0 + self.exit_condition_met = False + # Function execution and LLM communication loop. + while self.exit_condition_met is False: + logger.info(f"Agent Round {cur_round}", trial=self.trial) + + # Increment the round counter, but trigger exit condition if max + # rounds reached. + if cur_round > self.args.max_round: + logger.info( + "Max discovery rounds reached (%s).", + self.args.max_round, + trial=self.trial, + ) + break + cur_round += 1 + + # Send prompt to LLM and get response. + logger.info("Sending prompt to LLM", trial=self.trial) + response = self.chat_llm_with_tools( + client, self.working_prompt, FIXER_TOOLS, self.trial + ) + + if not response: + logger.info( + "LLM did not return a response, skipping this round.", + trial=self.trial, + ) + continue + + # Handle LLM tool calls. + tools_analysed = 0 + logger.info("Iterating response output", trial=self.trial) + for tool_call in response.output: + logger.info("- Response out:" + str(tool_call), trial=self.trial) + if tool_call.type != "function_call": + continue + + logger.info( + "Handling tool call %s", tool_call.name, trial=self.trial + ) + logger.info( + "Tool call arguments: %s", tool_call.arguments, trial=self.trial + ) + tools_analysed += self._dispatch_tool_call(tool_call) + + # If no tool calls were made prepare LLM response saying we do not + # understand the message received. + if tools_analysed == 0 and not self.exit_condition_met: + logger.info( + "Did not execute any tool calls and there is no exit condition.", + trial=self.trial, + ) + self.working_prompt = self.llm.prompt_type()(None) + self.working_prompt.add_problem( + "I was unable to interpret your last message. Use tool " + "calls to direct this process instead of messages." + ) + + # Post LLM communication and function execution loop. + # Log details on success. + if self.exit_condition_met: + self._log_success() + + # TODO (David): Add handling for "why did we not succeed" case. + finally: + self.inspect_tool.terminate() + + def _agent_raw_loop( + self, prompt: Optional[Prompt], build_result: BuildResult + ) -> None: + """Runs the agent loop, sending prompts to the LLM and handling + responses.""" + # Agent loop + self.trial = 0 + try: + client = self.llm.get_chat_client(model=self.llm.get_model()) + while prompt: + logger.info(f"Agent Round {self.trial}", trial=self.trial) + # Pass prompt history to LLM and get response. + logger.info("Sending prompt to LLM", trial=self.trial) + response = self.chat_llm( + self.trial, client=client, prompt=prompt, trial=self.trial + ) + + # Handle LLM response. + logger.info("Handling LLM response", trial=self.trial) + prompt = self._handle_llm_reponse(response, build_result) + if not prompt: + break + if self.trial >= self.args.max_round: + logger.info( + f"Max discovery rounds reached ({self.args.max_round}).", + trial=self.trial, + ) + break + self.trial += 1 + finally: + self.inspect_tool.terminate() + + def _parse_tag(self, response: str, tag: str) -> str: + """Parses the tag from LLM response.""" + patterns = [rf"<{tag}>(.*?)", rf"```{tag}(.*?)```"] + + # Matches both xml and code style tags + for pattern in patterns: + match = re.search(pattern, response, re.DOTALL) + if match: + return match.group(1).strip() + + return "" + + def _parse_tags(self, response: str, tag: str) -> list[str]: + """Parses the tags from LLM response.""" + patterns = [rf"<{tag}>(.*?)", rf"```{tag}(.*?)```"] + found_matches = [] + + # Matches both xml and code style tags + for pattern in patterns: + matches = re.findall(pattern, response, re.DOTALL) + found_matches.extend([content.strip() for content in matches]) + + return found_matches + + def _test_build_fuzzers( + self, build_script: str, dockerfile: str = "" + ) -> tuple[subprocess.CompletedProcess, str]: + """Runs OSS-Fuzz's build_fuzzers command with the provided build script.""" + target_dst = self.original_project_name + "-copy-" + str(uuid.uuid4().hex)[:8] + shutil.copytree( + os.path.join( + oss_fuzz_checkout.OSS_FUZZ_DIR, "projects", self.original_project_name + ), + os.path.join(oss_fuzz_checkout.OSS_FUZZ_DIR, "projects", target_dst), + ) + + self.success_build_script = build_script + # Overwrite the build script with the new one + with open( + os.path.join( + oss_fuzz_checkout.OSS_FUZZ_DIR, "projects", target_dst, "build.sh" + ), + "w", + ) as f: + f.write(build_script) + + if dockerfile: + # Overwrite the Dockerfile with the new one + with open( + os.path.join( + oss_fuzz_checkout.OSS_FUZZ_DIR, "projects", target_dst, "Dockerfile" + ), + "w", + ) as f: + f.write(dockerfile) + + # Build project + cmd = ["python3", "infra/helper.py", "build_fuzzers", target_dst] + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + text=True, + encoding="utf-8", + errors="ignore", + cwd=oss_fuzz_checkout.OSS_FUZZ_DIR, + ) + return result, target_dst + + def _test_check_fuzzers(self, target_dst) -> subprocess.CompletedProcess: + """Runs OSS-Fuzz's check_build command to evaluate build fuzzers.""" + + cmd = ["python3", "infra/helper.py", "check_build", target_dst] + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + text=True, + encoding="utf-8", + errors="ignore", + cwd=oss_fuzz_checkout.OSS_FUZZ_DIR, + ) + return result + + def _simple_truncate_build_output(self, output: str) -> str: + """Truncates the build output to a manageable size.""" + if len(output) > 8000: + return output[:1500] + "\n... (truncated)" + output[-6500:] + return output + + def _parse_llm_reponse_and_operate( + self, response: str, tool: BaseTool, prompt: Prompt + ) -> Prompt: + """Parses and LLM response and takes appropriate action. This includes + parsing bash commands to be executed in the container tool or extracting + the build script and testing it for compilation.""" + # Initialise variables + prompt_text = "" success = False - else: - # Success build - logger.info('Build succeeded.', trial=self.trial) - logger.info('Testing fuzzers run.', trial=self.trial) - test_run_result = self._test_check_fuzzers(target_dst) - if test_run_result.returncode == 0: - logger.info('Fuzzers run successfully.', trial=self.trial) - self.check_all_passed = True - success = True - self.compiles = True + self.invalid = False + self.missing_binary = False + + logger.info("=" * 80, trial=self.trial) + logger.info(response, trial=self.trial) + logger.info("=" * 80, trial=self.trial) + + # Retrieve data from response + build_script = self._parse_tag(response, "bash") + commands = "; ".join(self._parse_tags(response, "command")) + + if commands: + logger.info("LLM Requested commands: %s", commands, trial=self.trial) + self.discovery_stage = True + + # Execute the command directly, then return the formatted result + result = tool.execute(commands) + prompt_text = self._format_bash_execution_result( + result, previous_prompt=prompt + ) + if result.returncode == 0: + success = True + elif build_script: + logger.info("LLM Provided build script.", trial=self.trial) + self.discovery_stage = False + + # Fix shebang to ensure docker image failing is reflected. + lines = build_script.split("\n") + if lines[0].startswith("#!"): + lines[0] = "#!/bin/bash -eu" + else: + lines = ["#!/bin/bash -eu"] + lines + build_script = "\n".join(lines) + + build_result, target_dst = self._test_build_fuzzers(build_script) + if build_result.returncode != 0: + logger.info("Build failed.", trial=self.trial) + parsed_stdout = build_result.stdout + tag = "---------------------------------------------------------------" + + parsed_stdout = tag.join(parsed_stdout.split(tag)[3:]) + prompt_text = "Build failed, this is the output:\n" + parsed_stdout = self._simple_truncate_build_output(parsed_stdout) + prompt_text += f"{parsed_stdout}" + self.compiles = False + self.check_all_passed = False + success = False + else: + # Success build + logger.info("Build succeeded.", trial=self.trial) + logger.info("Testing fuzzers run.", trial=self.trial) + test_run_result = self._test_check_fuzzers(target_dst) + if test_run_result.returncode == 0: + logger.info("Fuzzers run successfully.", trial=self.trial) + self.check_all_passed = True + success = True + self.compiles = True + else: + logger.info("Fuzzers run failed.", trial=self.trial) + prompt_text = test_run_result.stdout + self.compiles = True + self.check_all_passed = False + success = False else: - logger.info('Fuzzers run failed.', trial=self.trial) - prompt_text = test_run_result.stdout - self.compiles = True - self.check_all_passed = False - success = False - else: - self.invalid = True - - self.last_status = success - self.last_result = prompt_text - - return prompt - - def _validate_operation_and_prepare_next_prompt( - self, build_result: BuildResult, prompt: Prompt) -> Optional[Prompt]: - """Interprets the results from operating on the LLM response and prepares - a new prompt for the next round of interaction.""" - - # Don't need to check for invalid result - if self.invalid: - return prompt - - # Execution fail - if self.discovery_stage: - logger.info('Validating BASH command response', trial=self.trial) - # Still in bash mode. - prompt.add_problem(self.last_result) - - # Store build result - build_result.compiles = False - build_result.compile_error = self.last_result - - return prompt - if not self.compiles: - logger.info('Validation build failure response', trial=self.trial) - retry = templates.LLM_RETRY.replace('{BASH_RESULT}', self.last_result) - prompt.add_problem(retry) - - # Store build result - build_result.compiles = False - build_result.compile_error = self.last_result - - return prompt - if not self.check_all_passed: - logger.info('Validating check_build failure', trial=self.trial) - retry = templates.LLM_RETRY_CHECK_ALL.replace('{BASH_RESULT}', - self.last_result) - prompt.add_problem(retry) - - # Store build result - build_result.compiles = False - build_result.compile_error = self.last_result - - return prompt - # Build script succeeded - return None - - def _handle_llm_reponse(self, response: str, - build_result: BuildResult) -> Optional[Prompt]: - """Validates LLM conclusion or executes its command.""" - prompt = self.llm.prompt_type()(None) - - if response: - prompt = self._parse_llm_reponse_and_operate(response, self.inspect_tool, - prompt) - logger.info('Handling conclusions', trial=self.trial) - prompt = self._validate_operation_and_prepare_next_prompt( - build_result, prompt) - if prompt is None: - logger.info('Succeeded fixing build script', trial=self.trial) - logger.info('-' * 25 + ' Build script: ' + '-' * 25, trial=self.trial) - logger.info(self.success_build_script, trial=self.trial) - logger.info('-' * 60, trial=self.trial) + self.invalid = True + + self.last_status = success + self.last_result = prompt_text + + return prompt + + def _validate_operation_and_prepare_next_prompt( + self, build_result: BuildResult, prompt: Prompt + ) -> Optional[Prompt]: + """Interprets the results from operating on the LLM response and prepares + a new prompt for the next round of interaction.""" + + # Don't need to check for invalid result + if self.invalid: + return prompt + + # Execution fail + if self.discovery_stage: + logger.info("Validating BASH command response", trial=self.trial) + # Still in bash mode. + prompt.add_problem(self.last_result) + + # Store build result + build_result.compiles = False + build_result.compile_error = self.last_result + + return prompt + if not self.compiles: + logger.info("Validation build failure response", trial=self.trial) + retry = templates.LLM_RETRY.replace("{BASH_RESULT}", self.last_result) + prompt.add_problem(retry) + + # Store build result + build_result.compiles = False + build_result.compile_error = self.last_result + + return prompt + if not self.check_all_passed: + logger.info("Validating check_build failure", trial=self.trial) + retry = templates.LLM_RETRY_CHECK_ALL.replace( + "{BASH_RESULT}", self.last_result + ) + prompt.add_problem(retry) + + # Store build result + build_result.compiles = False + build_result.compile_error = self.last_result + + return prompt + # Build script succeeded return None - return prompt + def _handle_llm_reponse( + self, response: str, build_result: BuildResult + ) -> Optional[Prompt]: + """Validates LLM conclusion or executes its command.""" + prompt = self.llm.prompt_type()(None) + + if response: + prompt = self._parse_llm_reponse_and_operate( + response, self.inspect_tool, prompt + ) + logger.info("Handling conclusions", trial=self.trial) + prompt = self._validate_operation_and_prepare_next_prompt( + build_result, prompt + ) + if prompt is None: + logger.info("Succeeded fixing build script", trial=self.trial) + logger.info("-" * 25 + " Build script: " + "-" * 25, trial=self.trial) + logger.info(self.success_build_script, trial=self.trial) + logger.info("-" * 60, trial=self.trial) + return None + + return prompt def fix_build(args, oss_fuzz_base, use_tools: bool = True): - """Fixes the build of a given project.""" + """Fixes the build of a given project.""" - project_name = args.project - oss_fuzz_checkout.OSS_FUZZ_DIR = oss_fuzz_base + project_name = args.project + oss_fuzz_checkout.OSS_FUZZ_DIR = oss_fuzz_base - # Disabling caching - oss_fuzz_checkout.ENABLE_CACHING = False + # Disabling caching + oss_fuzz_checkout.ENABLE_CACHING = False - work_dirs = WorkDirs(args.work_dirs, keep=True) + work_dirs = WorkDirs(args.work_dirs, keep=True) - # Prepare LLM model - llm = models.LLM.setup( - ai_binary=os.getenv('AI_BINARY', ''), - name=args.model, - max_tokens=4096, - num_samples=1, - temperature=0.4, - temperature_list=[], - ) - llm.MAX_INPUT_TOKEN = 25000 + # Prepare LLM model + llm = models.LLM.setup( + ai_binary=os.getenv("AI_BINARY", ""), + name=args.model, + max_tokens=4096, + num_samples=1, + temperature=0.4, + temperature_list=[], + ) + llm.MAX_INPUT_TOKEN = 25000 - # Set up Build fixer agent - agent = BuildFixAgent(llm, project_name, work_dirs, args, use_tools=use_tools) + # Set up Build fixer agent + agent = BuildFixAgent(llm, project_name, work_dirs, args, use_tools=use_tools) - # Execute the agent - agent.execute([]) + # Execute the agent + agent.execute([]) diff --git a/experimental/build_fixer/templates.py b/experimental/build_fixer/templates.py index 0ac74cd112..51bacbc41d 100644 --- a/experimental/build_fixer/templates.py +++ b/experimental/build_fixer/templates.py @@ -14,7 +14,7 @@ # limitations under the License. """Templates for the build fixer tool.""" -BUILD_FIXER_LLM_PRIMING = ''' +BUILD_FIXER_LLM_PRIMING = """ You are an expert software developer that specializes in creating shell scripts that compile and build codebases. You must support other developers when their codebases no longer build. You have a technical tone that focus on clear and concise messaging. @@ -56,7 +56,7 @@ CFLAGS=-O1 -fno-omit-frame-pointer -gline-tables-only -Wno-error=enum-constexpr-conversion -Wno-error=incompatible-function-pointer-types -Wno-error=int-conversion -Wno-error=deprecated-declarations -Wno-error=implicit-function-declaration -Wno-error=implicit-int -Wno-error=vla-cxx-extension -DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION -fsanitize=address -fsanitize-address-use-after-scope -fsanitize=fuzzer-no-link CXXFLAGS=-O1 -fno-omit-frame-pointer -gline-tables-only -Wno-error=enum-constexpr-conversion -Wno-error=incompatible-function-pointer-types -Wno-error=int-conversion -Wno-error=deprecated-declarations -Wno-error=implicit-function-declaration -Wno-error=implicit-int -Wno-error=vla-cxx-extension -DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION -fsanitize=address -fsanitize-address-use-after-scope -fsanitize=fuzzer-no-link -stdlib=libc++ LIB_FUZZING_ENGINE=-fsanitize=fuzzer -''' +""" BUILD_FIX_PROBLEM_TOOLS = """ Your task is to fix the build.sh script so that the project can be built successfully. @@ -118,15 +118,15 @@ If the build script fails or produces errors, you are encouraged to **return to interaction mode** by providing new `` tags. Use them to inspect logs, echo error messages, or run diagnostic commands (e.g., view files in `/tmp`, rerun failing commands with `-v`, etc.). This allows you to iteratively understand and fix the issues. """ -C_CPP_SPECIFICS = '''### OSS-Fuzz C/C++ projects +C_CPP_SPECIFICS = """### OSS-Fuzz C/C++ projects The project you are working on is a C/C++ project. You must use the relevant environment variables to compile the project: CC, CXX, CFLAGS, CXXFLAGS, LIB_FUZZING_ENGINE. The build script should be as C/C++ idiomatic as possible. -''' +""" -PYTHON_SPECIFICS = '''### OSS-Fuzz python projects +PYTHON_SPECIFICS = """### OSS-Fuzz python projects The project you are working on is a Python project. The build script should be as Pythonic as possible. @@ -137,9 +137,9 @@ If the build script does not unconditionally install the target codebase then the build script is not correct. Make sure to install the target codebase and avoid using packages already in installed in the Docker image. Avoid using `pip install .` and always use `python3 -m pip install .` instead. -''' +""" -LLM_RETRY = ''' +LLM_RETRY = """ I failed to build the project with the above provided build script. Please analyse the result and generate a new build script with the same assumption above. You must only returns the content of the build script and nothing else more as always. @@ -148,15 +148,15 @@ Here is a dump of the bash execution result. {BASH_RESULT} -''' +""" -LLM_RETRY_BASH = '''The output of the bash commands: +LLM_RETRY_BASH = """The output of the bash commands: {BASH_RESULT} -''' +""" -LLM_RETRY_CHECK_ALL = '''The build script worked, but when checking if the +LLM_RETRY_CHECK_ALL = """The build script worked, but when checking if the fuzzers run then the check failed. It is likely the changes you made caused no fuzzing harnesses to be built or the fuzzing harnesses are not runnable outside the container. @@ -166,4 +166,4 @@ – wraps the complete build script for both the target project and the fuzzing harness. Here is a dump of the bash execution result. -{BASH_RESULT}''' +{BASH_RESULT}""" diff --git a/experimental/build_generator/build_script_generator.py b/experimental/build_generator/build_script_generator.py index 609c01ef2e..dd811f922d 100644 --- a/experimental/build_generator/build_script_generator.py +++ b/experimental/build_generator/build_script_generator.py @@ -32,950 +32,971 @@ #### Logic for auto building a given source code folder #### ############################################################ class AutoBuildContainer: - """Auto build data container.""" + """Auto build data container.""" - def __init__(self, old: Optional["AutoBuildContainer"] = None): - if old: - self.list_of_commands = old.list_of_commands - self.list_of_required_packages = old.list_of_required_packages - self.heuristic_id = old.heuristic_id - else: - self.list_of_commands = [] - self.list_of_required_packages = [] - self.heuristic_id = '' + def __init__(self, old: Optional["AutoBuildContainer"] = None): + if old: + self.list_of_commands = old.list_of_commands + self.list_of_required_packages = old.list_of_required_packages + self.heuristic_id = old.heuristic_id + else: + self.list_of_commands = [] + self.list_of_required_packages = [] + self.heuristic_id = "" class BuildWorker: - """Keeper of data on auto generated builds.""" - - def __init__(self, build_suggestion: AutoBuildContainer, build_script: str, - build_directory: str, executable_files_build: Dict[str, - List[str]]): - self.build_suggestion: AutoBuildContainer = build_suggestion - self.build_script: str = build_script - self.build_directory: str = build_directory - self.executable_files_build: Dict[str, List[str]] = executable_files_build - self.base_fuzz_build: bool = False + """Keeper of data on auto generated builds.""" + + def __init__( + self, + build_suggestion: AutoBuildContainer, + build_script: str, + build_directory: str, + executable_files_build: Dict[str, List[str]], + ): + self.build_suggestion: AutoBuildContainer = build_suggestion + self.build_script: str = build_script + self.build_directory: str = build_directory + self.executable_files_build: Dict[str, List[str]] = executable_files_build + self.base_fuzz_build: bool = False class AutoBuildBase: - """Base class for auto builders.""" - - def __init__(self): - self.matches_found = {} - - @abstractmethod - def steps_to_build(self) -> Iterator[AutoBuildContainer]: - """Yields AutoBuildContainer objects.""" - - def match_files(self, file_list): - """Matches files needed for the build heuristic.""" - for fi in file_list: - base_file = os.path.basename(fi) - for key, val in self.matches_found.items(): - if base_file == key: - val.append(fi) - - def is_matched(self): - """Returns True if the build heuristic found matching files.""" - for found_matches in self.matches_found.values(): - if len(found_matches) > 0: - return True - return False + """Base class for auto builders.""" + + def __init__(self): + self.matches_found = {} + + @abstractmethod + def steps_to_build(self) -> Iterator[AutoBuildContainer]: + """Yields AutoBuildContainer objects.""" + + def match_files(self, file_list): + """Matches files needed for the build heuristic.""" + for fi in file_list: + base_file = os.path.basename(fi) + for key, val in self.matches_found.items(): + if base_file == key: + val.append(fi) + + def is_matched(self): + """Returns True if the build heuristic found matching files.""" + for found_matches in self.matches_found.values(): + if len(found_matches) > 0: + return True + return False - def determine_required_packages( - self, config_file_str: str) -> List[Tuple[str, str]]: - """Determine additional required package for installation in Dockerfile.""" + def determine_required_packages( + self, config_file_str: str + ) -> List[Tuple[str, str]]: + """Determine additional required package for installation in Dockerfile.""" - # Find all -l flags in makefile or other configurations - libs = re.findall(r"-l(\w+)", config_file_str) + # Find all -l flags in makefile or other configurations + libs = re.findall(r"-l(\w+)", config_file_str) - # Map to packages, skipping built-in or unmapped ones - required_packages = [(f'-l{lib}', constants.LIBRARY_PACKAGE_MAP[lib]) - for lib in libs - if lib in constants.LIBRARY_PACKAGE_MAP] + # Map to packages, skipping built-in or unmapped ones + required_packages = [ + (f"-l{lib}", constants.LIBRARY_PACKAGE_MAP[lib]) + for lib in libs + if lib in constants.LIBRARY_PACKAGE_MAP + ] - return list(set(required_packages)) + return list(set(required_packages)) class HeaderOnlyCBuilder(AutoBuildBase): - """Wrapper for building header-only targets""" - - def __init__(self): - super().__init__() - self.matches_found = {'.h': []} - - def match_files(self, file_list): - """Matches files needed for the build heuristic.""" - file_dicts = { - '.c': [], - '.h': [], - } - for fi in file_list: - for key, val in file_dicts.items(): - if fi.endswith(key) and 'test' not in fi and 'example' not in fi: - # Remove the first folder as that is "this" dir. - path_to_add = '/'.join(fi.split('/')[1:]) - val.append(path_to_add) - if not file_dicts['.c'] and file_dicts['.h']: - self.matches_found['.h'] = file_dicts['.h'] - - def steps_to_build(self) -> Iterator[AutoBuildContainer]: - build_container = AutoBuildContainer() - - header_writers = '' - for header_file in self.matches_found['.h']: - header_writers += f'echo "#include \\"{header_file}\\""' - header_writers += ' >> empty_wrapper.c\n' - - build_container.list_of_commands = [ - f'''touch empty_wrapper.c + """Wrapper for building header-only targets""" + + def __init__(self): + super().__init__() + self.matches_found = {".h": []} + + def match_files(self, file_list): + """Matches files needed for the build heuristic.""" + file_dicts = { + ".c": [], + ".h": [], + } + for fi in file_list: + for key, val in file_dicts.items(): + if fi.endswith(key) and "test" not in fi and "example" not in fi: + # Remove the first folder as that is "this" dir. + path_to_add = "/".join(fi.split("/")[1:]) + val.append(path_to_add) + if not file_dicts[".c"] and file_dicts[".h"]: + self.matches_found[".h"] = file_dicts[".h"] + + def steps_to_build(self) -> Iterator[AutoBuildContainer]: + build_container = AutoBuildContainer() + + header_writers = "" + for header_file in self.matches_found[".h"]: + header_writers += f'echo "#include \\"{header_file}\\""' + header_writers += " >> empty_wrapper.c\n" + + build_container.list_of_commands = [ + f"""touch empty_wrapper.c # Write includes for each of the header files {header_writers} rm -rf *.o $CC $CFLAGS -c empty_wrapper.c -o empty_wrapper.o llvm-ar rcs libfuzz.a *.o -''' - ] - build_container.heuristic_id = self.name + '1' - logger.info(build_container.list_of_commands[0]) - yield build_container +""" + ] + build_container.heuristic_id = self.name + "1" + logger.info(build_container.list_of_commands[0]) + yield build_container - @property - def name(self): - return 'HeaderOnlyCBuilder' + @property + def name(self): + return "HeaderOnlyCBuilder" class PureCFileCompiler(AutoBuildBase): - """Builder for compiling .c files direcetly in root repo dir.""" - - def __init__(self): - super().__init__() - self.matches_found = { - '.c': [], - } - - def match_files(self, file_list): - """Matches files needed for the build heuristic.""" - for fi in file_list: - for key, val in self.matches_found.items(): - if fi.endswith(key) and 'test' not in fi and 'example' not in fi: - logger.info('Adding %s', fi) - # Remove the first folder as that is "this" dir. - path_to_add = '/'.join(fi.split('/')[1:]) - val.append(path_to_add) - - def steps_to_build(self) -> Iterator[AutoBuildContainer]: - build_container = AutoBuildContainer() - build_container.list_of_commands = [ - '''for file in "%s"; do + """Builder for compiling .c files direcetly in root repo dir.""" + + def __init__(self): + super().__init__() + self.matches_found = { + ".c": [], + } + + def match_files(self, file_list): + """Matches files needed for the build heuristic.""" + for fi in file_list: + for key, val in self.matches_found.items(): + if fi.endswith(key) and "test" not in fi and "example" not in fi: + logger.info("Adding %s", fi) + # Remove the first folder as that is "this" dir. + path_to_add = "/".join(fi.split("/")[1:]) + val.append(path_to_add) + + def steps_to_build(self) -> Iterator[AutoBuildContainer]: + build_container = AutoBuildContainer() + build_container.list_of_commands = [ + """for file in "%s"; do $CC $CFLAGS -c ${file} done rm -f ./test*.o llvm-ar rcs libfuzz.a *.o -''' % (' '.join(self.matches_found['.c'])) - ] - build_container.heuristic_id = self.name + '1' - logger.info(build_container.list_of_commands[0]) - yield build_container +""" + % (" ".join(self.matches_found[".c"])) + ] + build_container.heuristic_id = self.name + "1" + logger.info(build_container.list_of_commands[0]) + yield build_container - @property - def name(self): - return 'pureCFileCompiler' + @property + def name(self): + return "pureCFileCompiler" class PureCFileCompilerFind(AutoBuildBase): - """Builder for compiling .c files direcetly in root repo dir, using find.""" - - def __init__(self): - super().__init__() - self.matches_found = { - '.c': [], - } - - def match_files(self, file_list): - """Matches files needed for the build heuristic.""" - for fi in file_list: - for key, val in self.matches_found.items(): - if fi.endswith(key): - val.append(fi) - - def steps_to_build(self) -> Iterator[AutoBuildContainer]: - build_container = AutoBuildContainer() - build_container.list_of_commands = [ - '''find . -name "*.c" -exec $CC $CFLAGS -I./src -c {} \\; + """Builder for compiling .c files direcetly in root repo dir, using find.""" + + def __init__(self): + super().__init__() + self.matches_found = { + ".c": [], + } + + def match_files(self, file_list): + """Matches files needed for the build heuristic.""" + for fi in file_list: + for key, val in self.matches_found.items(): + if fi.endswith(key): + val.append(fi) + + def steps_to_build(self) -> Iterator[AutoBuildContainer]: + build_container = AutoBuildContainer() + build_container.list_of_commands = [ + """find . -name "*.c" -exec $CC $CFLAGS -I./src -c {} \\; find . -name "*.o" -exec cp {} . \\; rm -f ./test*.o llvm-ar rcs libfuzz.a *.o -''' - ] - build_container.heuristic_id = self.name + '1' - yield build_container +""" + ] + build_container.heuristic_id = self.name + "1" + yield build_container - @property - def name(self): - return 'pureCFileCompilerFind' + @property + def name(self): + return "pureCFileCompilerFind" class PureCPPFileCompilerFind(AutoBuildBase): - """Builder for compiling .cpp files direcetly in root repo dir, using find.""" - - def __init__(self): - super().__init__() - self.matches_found = { - '.cpp': [], - '.c': [], - } - - def match_files(self, file_list): - """Matches files needed for the build heuristic.""" - for fi in file_list: - for key, val in self.matches_found.items(): - if fi.endswith(key): - val.append(fi) - - def steps_to_build(self) -> Iterator[AutoBuildContainer]: - build_container = AutoBuildContainer() - build_container.list_of_commands = [ - '''find . -name "*.cpp" -exec $CXX $CXXFLAGS -I./src -c {} \\; + """Builder for compiling .cpp files direcetly in root repo dir, using find.""" + + def __init__(self): + super().__init__() + self.matches_found = { + ".cpp": [], + ".c": [], + } + + def match_files(self, file_list): + """Matches files needed for the build heuristic.""" + for fi in file_list: + for key, val in self.matches_found.items(): + if fi.endswith(key): + val.append(fi) + + def steps_to_build(self) -> Iterator[AutoBuildContainer]: + build_container = AutoBuildContainer() + build_container.list_of_commands = [ + """find . -name "*.cpp" -exec $CXX $CXXFLAGS -I./src -c {} \\; find . -name "*.o" -exec cp {} . \\; rm -f ./test*.o llvm-ar rcs libfuzz.a *.o -''' - ] - build_container.heuristic_id = self.name + '1' - yield build_container +""" + ] + build_container.heuristic_id = self.name + "1" + yield build_container - @property - def name(self): - return 'PureCPPFileCompilerFind' + @property + def name(self): + return "PureCPPFileCompilerFind" class PureMakefileScanner(AutoBuildBase): - """Auto builder for pure Makefile projects, only relying on "make".""" + """Auto builder for pure Makefile projects, only relying on "make".""" - def __init__(self): - super().__init__() - self.matches_found = { - 'Makefile': [], - } + def __init__(self): + super().__init__() + self.matches_found = { + "Makefile": [], + } - def steps_to_build(self) -> Iterator[AutoBuildContainer]: - build_container = AutoBuildContainer() - build_container.list_of_commands = ['make'] - build_container.heuristic_id = self.name + '1' - yield build_container + def steps_to_build(self) -> Iterator[AutoBuildContainer]: + build_container = AutoBuildContainer() + build_container.list_of_commands = ["make"] + build_container.heuristic_id = self.name + "1" + yield build_container - @property - def name(self): - return 'make' + @property + def name(self): + return "make" class PureMakefileScannerWithPThread(AutoBuildBase): - """Auto builder for pure Makefile projects, only relying on "make".""" - - def __init__(self): - super().__init__() - self.matches_found = { - 'Makefile': [], - } - - def steps_to_build(self) -> Iterator[AutoBuildContainer]: - build_container = AutoBuildContainer() - build_container.list_of_commands = [ - 'export CXXFLAGS="${CXXFLAGS} -lpthread"', 'make' - ] - build_container.heuristic_id = self.name + '1' - yield build_container + """Auto builder for pure Makefile projects, only relying on "make".""" - @property - def name(self): - return 'make' + def __init__(self): + super().__init__() + self.matches_found = { + "Makefile": [], + } + def steps_to_build(self) -> Iterator[AutoBuildContainer]: + build_container = AutoBuildContainer() + build_container.list_of_commands = [ + 'export CXXFLAGS="${CXXFLAGS} -lpthread"', + "make", + ] + build_container.heuristic_id = self.name + "1" + yield build_container -class PureMakefileScannerWithSubstitutions(AutoBuildBase): - """Auto builder for pure Makefile projects with substitions.""" - - def __init__(self): - super().__init__() - self.matches_found = { - 'Makefile': [], - } - - def steps_to_build(self) -> Iterator[AutoBuildContainer]: - build_container = AutoBuildContainer() - # The following substitutes varioues patterns of overwriting of compilers - # which happens in some build files. Patterns of Werror are also suppressed - # by converting them to Wno-error. - build_container.list_of_commands = [ - 'sed -i \'s/-Werror/-Wno-error/g\' ./Makefile', - 'sed -i \'s/CC=/#CC=/g\' ./Makefile', - 'sed -i \'s/CXX=/#CXX=/g\' ./Makefile', - 'sed -i \'s/CC =/#CC=/g\' ./Makefile', - 'sed -i \'s/gcc/clang/g\' Make.Rules || true', - 'sed -i \'s/CXX =/#CXX=/g\' ./Makefile', 'make V=1 || true' - ] - build_container.heuristic_id = self.name + '1' - yield build_container + @property + def name(self): + return "make" - @property - def name(self): - return 'makeWithSubstitutions' +class PureMakefileScannerWithSubstitutions(AutoBuildBase): + """Auto builder for pure Makefile projects with substitions.""" + + def __init__(self): + super().__init__() + self.matches_found = { + "Makefile": [], + } + + def steps_to_build(self) -> Iterator[AutoBuildContainer]: + build_container = AutoBuildContainer() + # The following substitutes varioues patterns of overwriting of compilers + # which happens in some build files. Patterns of Werror are also suppressed + # by converting them to Wno-error. + build_container.list_of_commands = [ + "sed -i 's/-Werror/-Wno-error/g' ./Makefile", + "sed -i 's/CC=/#CC=/g' ./Makefile", + "sed -i 's/CXX=/#CXX=/g' ./Makefile", + "sed -i 's/CC =/#CC=/g' ./Makefile", + "sed -i 's/gcc/clang/g' Make.Rules || true", + "sed -i 's/CXX =/#CXX=/g' ./Makefile", + "make V=1 || true", + ] + build_container.heuristic_id = self.name + "1" + yield build_container + + @property + def name(self): + return "makeWithSubstitutions" -class PureMakefileScannerWithLibFlag(AutoBuildBase): - """Auto builder for pure Makefile projects, relying on "make" with - additional -l flags during linker process.""" - - def __init__(self): - super().__init__() - self.matches_found = { - 'Makefile': [], - } - - def steps_to_build(self) -> Iterator[AutoBuildContainer]: - # Determine possible lib flags from Makefile - lib_flags: List[Tuple[str, str]] = [] - for file in self.matches_found['Makefile']: - with open(file, 'r') as f: - content = f.read() - lib_flags.extend(self.determine_required_packages(content)) - - # Process required packages - build_container = AutoBuildContainer() - flags = [] - for flag, package in lib_flags: - flags.append(flag) - if package: - build_container.list_of_required_packages.append(package) - - # Route 1: Basic build with lib flags - build_container.list_of_commands = [ - f'make clean && make LDLIBS="{" ".join(flags)}"', - ('find . -type f -name "*.o" -print0 | ' - 'xargs -0 llvm-ar rcs libfuzz.a') - ] - build_container.heuristic_id = self.name + '1' - yield build_container - - # Route 2: Overriding CXXFLAGS - build_container_2 = AutoBuildContainer(build_container) - build_container_2.list_of_commands = [ - ('make clean && make CXXFLAGS="$CXXFLAGS"' - f' LDLIBS="{" ".join(flags)}"'), - ('find . -type f -name "*.o" -print0 | ' - 'xargs -0 llvm-ar rcs libfuzz.a') - ] - build_container_2.heuristic_id = self.name + '2' - yield build_container_2 - - # Route 2: Overriding CXXFLAGS and add PIC flag - build_container_3 = AutoBuildContainer(build_container) - build_container_3.list_of_commands = [ - ('make clean && make CXXFLAGS="$CXXFLAGS -fPIC"' - f' LDLIBS="{" ".join(flags)}"'), - ('find . -type f -name "*.o" -print0 | ' - 'xargs -0 llvm-ar rcs libfuzz.a') - ] - build_container_3.heuristic_id = self.name + '3' - yield build_container_3 - @property - def name(self): - return 'makewithlibflag' +class PureMakefileScannerWithLibFlag(AutoBuildBase): + """Auto builder for pure Makefile projects, relying on "make" with + additional -l flags during linker process.""" + + def __init__(self): + super().__init__() + self.matches_found = { + "Makefile": [], + } + + def steps_to_build(self) -> Iterator[AutoBuildContainer]: + # Determine possible lib flags from Makefile + lib_flags: List[Tuple[str, str]] = [] + for file in self.matches_found["Makefile"]: + with open(file, "r") as f: + content = f.read() + lib_flags.extend(self.determine_required_packages(content)) + + # Process required packages + build_container = AutoBuildContainer() + flags = [] + for flag, package in lib_flags: + flags.append(flag) + if package: + build_container.list_of_required_packages.append(package) + + # Route 1: Basic build with lib flags + build_container.list_of_commands = [ + f'make clean && make LDLIBS="{" ".join(flags)}"', + ('find . -type f -name "*.o" -print0 | ' "xargs -0 llvm-ar rcs libfuzz.a"), + ] + build_container.heuristic_id = self.name + "1" + yield build_container + + # Route 2: Overriding CXXFLAGS + build_container_2 = AutoBuildContainer(build_container) + build_container_2.list_of_commands = [ + ('make clean && make CXXFLAGS="$CXXFLAGS"' f' LDLIBS="{" ".join(flags)}"'), + ('find . -type f -name "*.o" -print0 | ' "xargs -0 llvm-ar rcs libfuzz.a"), + ] + build_container_2.heuristic_id = self.name + "2" + yield build_container_2 + + # Route 2: Overriding CXXFLAGS and add PIC flag + build_container_3 = AutoBuildContainer(build_container) + build_container_3.list_of_commands = [ + ( + 'make clean && make CXXFLAGS="$CXXFLAGS -fPIC"' + f' LDLIBS="{" ".join(flags)}"' + ), + ('find . -type f -name "*.o" -print0 | ' "xargs -0 llvm-ar rcs libfuzz.a"), + ] + build_container_3.heuristic_id = self.name + "3" + yield build_container_3 + + @property + def name(self): + return "makewithlibflag" class AutoRefConfScanner(AutoBuildBase): - """Auto-builder for patterns of "autoreconf fi; ./configure' make""" + """Auto-builder for patterns of "autoreconf fi; ./configure' make""" - def __init__(self): - super().__init__() - self.matches_found = { - 'configure.ac': [], - 'Makefile.am': [], - } + def __init__(self): + super().__init__() + self.matches_found = { + "configure.ac": [], + "Makefile.am": [], + } - def steps_to_build(self): - cmds_to_exec_from_root = ['autoreconf -fi', './configure', 'make'] - build_container = AutoBuildContainer() - build_container.list_of_commands = cmds_to_exec_from_root - build_container.heuristic_id = self.name + '1' - yield build_container + def steps_to_build(self): + cmds_to_exec_from_root = ["autoreconf -fi", "./configure", "make"] + build_container = AutoBuildContainer() + build_container.list_of_commands = cmds_to_exec_from_root + build_container.heuristic_id = self.name + "1" + yield build_container - @property - def name(self): - return 'autogen' + @property + def name(self): + return "autogen" class RawMake(AutoBuildBase): - """Similar to PureMake but also adds option for "make test". This is useful - to trigger more Fuzz Introspector analysis in the project.""" + """Similar to PureMake but also adds option for "make test". This is useful + to trigger more Fuzz Introspector analysis in the project.""" - def __init__(self): - super().__init__() - self.matches_found = { - 'Makefile': [], - } + def __init__(self): + super().__init__() + self.matches_found = { + "Makefile": [], + } - def steps_to_build(self): - cmds_to_exec_from_root = ['make'] - #yield cmds_to_exec_from_root - build_container = AutoBuildContainer() - build_container.list_of_commands = cmds_to_exec_from_root - build_container.heuristic_id = self.name + '1' - yield build_container + def steps_to_build(self): + cmds_to_exec_from_root = ["make"] + # yield cmds_to_exec_from_root + build_container = AutoBuildContainer() + build_container.list_of_commands = cmds_to_exec_from_root + build_container.heuristic_id = self.name + "1" + yield build_container - build_container2 = AutoBuildContainer() - build_container2.list_of_commands = cmds_to_exec_from_root + ['make test'] - build_container2.heuristic_id = self.name + '1' - yield build_container2 + build_container2 = AutoBuildContainer() + build_container2.list_of_commands = cmds_to_exec_from_root + ["make test"] + build_container2.heuristic_id = self.name + "1" + yield build_container2 - @property - def name(self): - return 'RawMake' + @property + def name(self): + return "RawMake" class AutogenScanner(AutoBuildBase): - """Auto builder for projects relying on "autoconf; autoheader.""" + """Auto builder for projects relying on "autoconf; autoheader.""" - def __init__(self): - super().__init__() - self.matches_found = { - 'configure.ac': [], - 'Makefile': [], - } + def __init__(self): + super().__init__() + self.matches_found = { + "configure.ac": [], + "Makefile": [], + } - def steps_to_build(self): - cmds_to_exec_from_root = ['autoconf', 'autoheader', './configure', 'make'] - #yield cmds_to_exec_from_root - build_container = AutoBuildContainer() - build_container.list_of_commands = cmds_to_exec_from_root - build_container.heuristic_id = self.name + '1' - yield build_container + def steps_to_build(self): + cmds_to_exec_from_root = ["autoconf", "autoheader", "./configure", "make"] + # yield cmds_to_exec_from_root + build_container = AutoBuildContainer() + build_container.list_of_commands = cmds_to_exec_from_root + build_container.heuristic_id = self.name + "1" + yield build_container - @property - def name(self): - return 'autogen' + @property + def name(self): + return "autogen" class AutogenScannerSH(AutoBuildBase): - """Auto builder for projects relying on "autogen.sh; autoconf; autoheader.""" + """Auto builder for projects relying on "autogen.sh; autoconf; autoheader.""" - def __init__(self): - super().__init__() - self.matches_found = {'configure.ac': [], 'autogen.sh': []} + def __init__(self): + super().__init__() + self.matches_found = {"configure.ac": [], "autogen.sh": []} - def steps_to_build(self): - cmds_to_exec_from_root = ['./autogen.sh', './configure', 'make'] - build_container = AutoBuildContainer() - build_container.list_of_commands = cmds_to_exec_from_root - build_container.heuristic_id = self.name + '1' - yield build_container + def steps_to_build(self): + cmds_to_exec_from_root = ["./autogen.sh", "./configure", "make"] + build_container = AutoBuildContainer() + build_container.list_of_commands = cmds_to_exec_from_root + build_container.heuristic_id = self.name + "1" + yield build_container - @property - def name(self): - return 'autogen20' + @property + def name(self): + return "autogen20" class BootstrapScanner(AutoBuildBase): - """Auto builder for projects that rely on bootstrap.sh; configure; make.""" + """Auto builder for projects that rely on bootstrap.sh; configure; make.""" - def __init__(self): - super().__init__() - self.matches_found = { - 'bootstrap.sh': [], - 'Makefile.am': [], - } + def __init__(self): + super().__init__() + self.matches_found = { + "bootstrap.sh": [], + "Makefile.am": [], + } - def steps_to_build(self): - cmds_to_exec_from_root = ['./bootstrap.sh', './configure', 'make'] - build_container = AutoBuildContainer() - build_container.list_of_commands = cmds_to_exec_from_root - build_container.heuristic_id = self.name + '1' - yield build_container + def steps_to_build(self): + cmds_to_exec_from_root = ["./bootstrap.sh", "./configure", "make"] + build_container = AutoBuildContainer() + build_container.list_of_commands = cmds_to_exec_from_root + build_container.heuristic_id = self.name + "1" + yield build_container - @property - def name(self): - return 'bootstrap-make' + @property + def name(self): + return "bootstrap-make" class AutogenConfScanner(AutoBuildBase): - """Auto builder for projects relying on "autoconf; autoheader.""" + """Auto builder for projects relying on "autoconf; autoheader.""" - def __init__(self): - super().__init__() - self.matches_found = { - 'configure.ac': [], - 'Makefile': [], - } + def __init__(self): + super().__init__() + self.matches_found = { + "configure.ac": [], + "Makefile": [], + } - def steps_to_build(self): - cmds_to_exec_from_root = ['./configure', 'make'] - #yield cmds_to_exec_from_root - build_container = AutoBuildContainer() - build_container.list_of_commands = cmds_to_exec_from_root - build_container.heuristic_id = self.name + '1' - yield build_container + def steps_to_build(self): + cmds_to_exec_from_root = ["./configure", "make"] + # yield cmds_to_exec_from_root + build_container = AutoBuildContainer() + build_container.list_of_commands = cmds_to_exec_from_root + build_container.heuristic_id = self.name + "1" + yield build_container - @property - def name(self): - return 'autogen-ConfMake' + @property + def name(self): + return "autogen-ConfMake" class CMakeScannerOptsParser(AutoBuildBase): - """Calls cmake to extract options from the CMakeLists.txt file of a project - and creates a build string where all BOOL values are set to OFF except those - with 'STATIC' in the name.""" - - def __init__(self): - super().__init__() - self.matches_found = { - 'CMakeLists.txt': [], - } - - def steps_to_build(self): - cmds_to_exec_from_root = [ - 'mkdir fuzz-build', - 'cd fuzz-build', - (f'cmake -DCMAKE_VERBOSE_MAKEFILE=ON {self.cmake_string} ' - '-DCMAKE_CXX_COMPILER=$CXX -DCMAKE_C_COMPILER=$CC ../'), - 'make V=1 || true', - ] - build_container = AutoBuildContainer() - build_container.list_of_commands = cmds_to_exec_from_root - build_container.heuristic_id = self.name + '1' - yield build_container - - @property - def name(self): - return 'autogen-ConfMakeOpt' - - def match_files(self, file_list: List[str]) -> None: - """Find CMakeLists.txt files and extract a string of the CMake options - that have all BOOL options set to OFF except for those with "STATIC" in the - name.""" - for fi in file_list: - # Focus on top dir - if fi.count('/') > 1: - continue - base_file = os.path.basename(fi) - for key, matches in self.matches_found.items(): - if base_file == key: - # Move directory - current_dir = os.getcwd() - cmake_base_dir = '/'.join(fi.split('/')[:-1]) - tmp_idx = 0 - tmp_dir = os.path.join(cmake_base_dir, f'temp-build-{tmp_idx}') - while os.path.isdir(tmp_dir): - tmp_idx += 1 - tmp_dir = os.path.join(cmake_base_dir, f'temp-build-{tmp_idx}') - - os.mkdir(tmp_dir) - os.chdir(tmp_dir) - extracted_string = self.extract_defensive_options() - if extracted_string: - matches.append(fi) - self.cmake_string = extracted_string - os.chdir(current_dir) - shutil.rmtree(tmp_dir) - - def extract_cmake_build_options(self) -> List[Dict[str, str]]: - """Extract options from CMakeLists.txt file one diretory up. Return as - list of dictionary items with the name, type and default value of the - CMake options.""" - option_elements = [] - - try: - output = subprocess.check_output('cmake -LAH ../ || true', - shell=True).decode() - except subprocess.CalledProcessError: - return option_elements - - # Parse the CMake options output to extract name, type and default value. - raw_options = [] - for line in output.split('\n'): - if ':' in line and '=' in line: - raw_options.append(line) - - for raw_option in raw_options: - option_default = raw_option.split('=')[-1] - option_type = raw_option.split('=')[0].split(':')[1] - option_name = raw_option.split('=')[0].split(':')[0] - - option_elements.append({ - 'name': option_name, - 'type': option_type, - 'default': option_default - }) - - return option_elements - - def extract_options_in_file(self) -> List[Dict[str, str]]: - """Extract CMake options from the CMakeLists.txt file one directory up.""" - with open('../CMakeLists.txt', 'r') as f: - cmake_content = f.read() - cmake_options = self.extract_cmake_build_options() - - # For each option in the cmake entire list of options identify which are - # defined inside of the CMakeLists.txt file of interest. - options_in_cmake_file = [] - for option in cmake_options: - if option['name'] in cmake_content: - options_in_cmake_file.append(option) - return options_in_cmake_file - - def extract_defensive_options(self) -> str: - """Extract options from CMakeLists.txt file as a string where all BOOL - options are set to False except for those with 'STATIC' in them.""" - options_in_cmake = self.extract_options_in_file() - cmake_string = '' - for option in options_in_cmake: - if option['type'] != 'BOOL': - continue - if 'STATIC' in option['name'] and option['default'] != 'ON': - cmake_string += f'-D{option["name"]}=ON ' - elif option['default'] != 'OFF': - cmake_string += f'-D{option["name"]}=OFF ' - return cmake_string + """Calls cmake to extract options from the CMakeLists.txt file of a project + and creates a build string where all BOOL values are set to OFF except those + with 'STATIC' in the name.""" + + def __init__(self): + super().__init__() + self.matches_found = { + "CMakeLists.txt": [], + } + + def steps_to_build(self): + cmds_to_exec_from_root = [ + "mkdir fuzz-build", + "cd fuzz-build", + ( + f"cmake -DCMAKE_VERBOSE_MAKEFILE=ON {self.cmake_string} " + "-DCMAKE_CXX_COMPILER=$CXX -DCMAKE_C_COMPILER=$CC ../" + ), + "make V=1 || true", + ] + build_container = AutoBuildContainer() + build_container.list_of_commands = cmds_to_exec_from_root + build_container.heuristic_id = self.name + "1" + yield build_container + + @property + def name(self): + return "autogen-ConfMakeOpt" + + def match_files(self, file_list: List[str]) -> None: + """Find CMakeLists.txt files and extract a string of the CMake options + that have all BOOL options set to OFF except for those with "STATIC" in the + name.""" + for fi in file_list: + # Focus on top dir + if fi.count("/") > 1: + continue + base_file = os.path.basename(fi) + for key, matches in self.matches_found.items(): + if base_file == key: + # Move directory + current_dir = os.getcwd() + cmake_base_dir = "/".join(fi.split("/")[:-1]) + tmp_idx = 0 + tmp_dir = os.path.join(cmake_base_dir, f"temp-build-{tmp_idx}") + while os.path.isdir(tmp_dir): + tmp_idx += 1 + tmp_dir = os.path.join(cmake_base_dir, f"temp-build-{tmp_idx}") + + os.mkdir(tmp_dir) + os.chdir(tmp_dir) + extracted_string = self.extract_defensive_options() + if extracted_string: + matches.append(fi) + self.cmake_string = extracted_string + os.chdir(current_dir) + shutil.rmtree(tmp_dir) + + def extract_cmake_build_options(self) -> List[Dict[str, str]]: + """Extract options from CMakeLists.txt file one diretory up. Return as + list of dictionary items with the name, type and default value of the + CMake options.""" + option_elements = [] + + try: + output = subprocess.check_output( + "cmake -LAH ../ || true", shell=True + ).decode() + except subprocess.CalledProcessError: + return option_elements + + # Parse the CMake options output to extract name, type and default value. + raw_options = [] + for line in output.split("\n"): + if ":" in line and "=" in line: + raw_options.append(line) + + for raw_option in raw_options: + option_default = raw_option.split("=")[-1] + option_type = raw_option.split("=")[0].split(":")[1] + option_name = raw_option.split("=")[0].split(":")[0] + + option_elements.append( + {"name": option_name, "type": option_type, "default": option_default} + ) + + return option_elements + + def extract_options_in_file(self) -> List[Dict[str, str]]: + """Extract CMake options from the CMakeLists.txt file one directory up.""" + with open("../CMakeLists.txt", "r") as f: + cmake_content = f.read() + cmake_options = self.extract_cmake_build_options() + + # For each option in the cmake entire list of options identify which are + # defined inside of the CMakeLists.txt file of interest. + options_in_cmake_file = [] + for option in cmake_options: + if option["name"] in cmake_content: + options_in_cmake_file.append(option) + return options_in_cmake_file + + def extract_defensive_options(self) -> str: + """Extract options from CMakeLists.txt file as a string where all BOOL + options are set to False except for those with 'STATIC' in them.""" + options_in_cmake = self.extract_options_in_file() + cmake_string = "" + for option in options_in_cmake: + if option["type"] != "BOOL": + continue + if "STATIC" in option["name"] and option["default"] != "ON": + cmake_string += f'-D{option["name"]}=ON ' + elif option["default"] != "OFF": + cmake_string += f'-D{option["name"]}=OFF ' + return cmake_string class CMakeScanner(AutoBuildBase): - """Auto builder for CMake projects.""" - - def __init__(self): - super().__init__() - self.matches_found = { - 'CMakeLists.txt': [], - } - - self.cmake_options = set() - - def match_files(self, file_list: List[str]) -> None: - for fi in file_list: - base_file = os.path.basename(fi) - for key, matches in self.matches_found.items(): - if base_file == key: - matches.append(fi) - - with open(fi, 'r') as f: - content = f.read() - for line in content.split('\n'): - if 'option(' in line: - option = line.split('option(')[1].split(' ')[0] - self.cmake_options.add(option) - - if len(self.cmake_options) > 0: - logger.info('Options:') - for option in self.cmake_options: - logger.info('%s', option) - - def steps_to_build(self): - # When we are running this, we are confident there are - # some heuristics that match what is needed for cmake builds. - # At this point, we will also scan for potential options - # in the cmake files, such as: - # - options related to shared libraries. - # - options related to which packags need installing. - cmds_to_exec_from_root = [ - 'mkdir fuzz-build', - 'cd fuzz-build', - 'cmake -DCMAKE_VERBOSE_MAKEFILE=ON ../', - 'make V=1 || true', - ] - build_container = AutoBuildContainer() - build_container.list_of_commands = cmds_to_exec_from_root - build_container.heuristic_id = self.name + '1' - yield build_container - - cmake_opts = [ - '-DCMAKE_VERBOSE_MAKEFILE=ON', '-DCMAKE_CXX_COMPILER=$CXX', - '-DCMAKE_CXX_FLAGS=\"$CXXFLAGS\"' - ] - - opt1 = [ - 'mkdir fuzz-build', - 'cd fuzz-build', - f'cmake {" ".join(cmake_opts)} ../', - 'make V=1 || true', - ] - build_container2 = AutoBuildContainer() - build_container2.list_of_commands = opt1 - build_container2.heuristic_id = self.name + '2' - yield build_container2 - - # Force static libraryes - opt_static = [ - 'mkdir fuzz-build', - 'cd fuzz-build', - f'cmake {" ".join(cmake_opts)} ../', - 'sed -i \'s/SHARED/STATIC/g\' ../CMakeLists.txt', - 'make V=1 || true', - ] - build_container_static = AutoBuildContainer() - build_container_static.list_of_commands = opt_static - build_container_static.heuristic_id = self.name + 'static' - yield build_container_static - - # Look for options often used for disabling dynamic shared libraries. - option_values = [] - for option in self.cmake_options: - if 'BUILD_SHARED_LIBS' == option: - option_values.append(f'-D{option}=OFF') - elif 'BUILD_STATIC' == option: - option_values.append(f'-D{option}=ON') - elif 'BUILD_SHARED' == option: - option_values.append(f'-D{option}=OFF') - elif 'ENABLE_STATIC' == option: - option_values.append(f'-D{option}=ON') - - if len(option_values) > 0: - option_string = ' '.join(option_values) - cmake_default_options = [ - '-DCMAKE_VERBOSE_MAKEFILE=ON', '-DCMAKE_CXX_COMPILER=$CXX', - '-DCMAKE_CXX_FLAGS=\"$CXXFLAGS\"' - ] - bopt = [ - 'mkdir fuzz-build', - 'cd fuzz-build', - f'cmake {" ".join(cmake_default_options)} {option_string} ../', - 'make V=1', - ] - build_container3 = AutoBuildContainer() - build_container3.list_of_commands = bopt - build_container3.heuristic_id = self.name + '3' - yield build_container3 - - # Build tests in-case - # Look for options often used for disabling dynamic shared libraries. - option_values = [] - for option in self.cmake_options: - if 'BUILD_SHARED_LIBS' == option: - option_values.append(f'-D{option}=OFF') - elif 'BUILD_STATIC' == option: - option_values.append(f'-D{option}=ON') - elif 'BUILD_SHARED' == option: - option_values.append(f'-D{option}=OFF') - elif 'ENABLE_STATIC' == option: - option_values.append(f'-D{option}=ON') - elif 'BUILD_TESTS' in option: - option_values.append(f'-D{option}=ON') - - if len(option_values) > 0: - option_string = ' '.join(option_values) - cmake_default_options = [ - '-DCMAKE_VERBOSE_MAKEFILE=ON', '-DCMAKE_CXX_COMPILER=$CXX', - '-DCMAKE_CXX_FLAGS=\"$CXXFLAGS\"' - ] - bopt = [ - 'mkdir fuzz-build', - 'cd fuzz-build', - f'cmake {" ".join(cmake_default_options)} {option_string} ../', - 'make V=1', - ] - build_container4 = AutoBuildContainer() - build_container4.list_of_commands = bopt - build_container4.heuristic_id = self.name + '3' - yield build_container4 - - @property - def name(self): - return 'cmake' + """Auto builder for CMake projects.""" + + def __init__(self): + super().__init__() + self.matches_found = { + "CMakeLists.txt": [], + } + + self.cmake_options = set() + + def match_files(self, file_list: List[str]) -> None: + for fi in file_list: + base_file = os.path.basename(fi) + for key, matches in self.matches_found.items(): + if base_file == key: + matches.append(fi) + + with open(fi, "r") as f: + content = f.read() + for line in content.split("\n"): + if "option(" in line: + option = line.split("option(")[1].split(" ")[0] + self.cmake_options.add(option) + + if len(self.cmake_options) > 0: + logger.info("Options:") + for option in self.cmake_options: + logger.info("%s", option) + + def steps_to_build(self): + # When we are running this, we are confident there are + # some heuristics that match what is needed for cmake builds. + # At this point, we will also scan for potential options + # in the cmake files, such as: + # - options related to shared libraries. + # - options related to which packags need installing. + cmds_to_exec_from_root = [ + "mkdir fuzz-build", + "cd fuzz-build", + "cmake -DCMAKE_VERBOSE_MAKEFILE=ON ../", + "make V=1 || true", + ] + build_container = AutoBuildContainer() + build_container.list_of_commands = cmds_to_exec_from_root + build_container.heuristic_id = self.name + "1" + yield build_container + + cmake_opts = [ + "-DCMAKE_VERBOSE_MAKEFILE=ON", + "-DCMAKE_CXX_COMPILER=$CXX", + '-DCMAKE_CXX_FLAGS="$CXXFLAGS"', + ] + + opt1 = [ + "mkdir fuzz-build", + "cd fuzz-build", + f'cmake {" ".join(cmake_opts)} ../', + "make V=1 || true", + ] + build_container2 = AutoBuildContainer() + build_container2.list_of_commands = opt1 + build_container2.heuristic_id = self.name + "2" + yield build_container2 + + # Force static libraryes + opt_static = [ + "mkdir fuzz-build", + "cd fuzz-build", + f'cmake {" ".join(cmake_opts)} ../', + "sed -i 's/SHARED/STATIC/g' ../CMakeLists.txt", + "make V=1 || true", + ] + build_container_static = AutoBuildContainer() + build_container_static.list_of_commands = opt_static + build_container_static.heuristic_id = self.name + "static" + yield build_container_static + + # Look for options often used for disabling dynamic shared libraries. + option_values = [] + for option in self.cmake_options: + if "BUILD_SHARED_LIBS" == option: + option_values.append(f"-D{option}=OFF") + elif "BUILD_STATIC" == option: + option_values.append(f"-D{option}=ON") + elif "BUILD_SHARED" == option: + option_values.append(f"-D{option}=OFF") + elif "ENABLE_STATIC" == option: + option_values.append(f"-D{option}=ON") + + if len(option_values) > 0: + option_string = " ".join(option_values) + cmake_default_options = [ + "-DCMAKE_VERBOSE_MAKEFILE=ON", + "-DCMAKE_CXX_COMPILER=$CXX", + '-DCMAKE_CXX_FLAGS="$CXXFLAGS"', + ] + bopt = [ + "mkdir fuzz-build", + "cd fuzz-build", + f'cmake {" ".join(cmake_default_options)} {option_string} ../', + "make V=1", + ] + build_container3 = AutoBuildContainer() + build_container3.list_of_commands = bopt + build_container3.heuristic_id = self.name + "3" + yield build_container3 + + # Build tests in-case + # Look for options often used for disabling dynamic shared libraries. + option_values = [] + for option in self.cmake_options: + if "BUILD_SHARED_LIBS" == option: + option_values.append(f"-D{option}=OFF") + elif "BUILD_STATIC" == option: + option_values.append(f"-D{option}=ON") + elif "BUILD_SHARED" == option: + option_values.append(f"-D{option}=OFF") + elif "ENABLE_STATIC" == option: + option_values.append(f"-D{option}=ON") + elif "BUILD_TESTS" in option: + option_values.append(f"-D{option}=ON") + + if len(option_values) > 0: + option_string = " ".join(option_values) + cmake_default_options = [ + "-DCMAKE_VERBOSE_MAKEFILE=ON", + "-DCMAKE_CXX_COMPILER=$CXX", + '-DCMAKE_CXX_FLAGS="$CXXFLAGS"', + ] + bopt = [ + "mkdir fuzz-build", + "cd fuzz-build", + f'cmake {" ".join(cmake_default_options)} {option_string} ../', + "make V=1", + ] + build_container4 = AutoBuildContainer() + build_container4.list_of_commands = bopt + build_container4.heuristic_id = self.name + "3" + yield build_container4 + + @property + def name(self): + return "cmake" class KConfigBuildScanner(AutoBuildBase): - """Auto builder for KConfig-based projects.""" - - def __init__(self): - super().__init__() - self.matches_found = { - 'Config.in': [], - 'Makefile': [], - } - - def is_matched(self): - """Returns True if the build heuristic found matching files.""" - # Ensure both Config.in and Makefile exists - for found_matches in self.matches_found.values(): - if len(found_matches) == 0: - return False - return True + """Auto builder for KConfig-based projects.""" + + def __init__(self): + super().__init__() + self.matches_found = { + "Config.in": [], + "Makefile": [], + } + + def is_matched(self): + """Returns True if the build heuristic found matching files.""" + # Ensure both Config.in and Makefile exists + for found_matches in self.matches_found.values(): + if len(found_matches) == 0: + return False + return True - def steps_to_build(self) -> Iterator[AutoBuildContainer]: - base_command = [ - ''' + def steps_to_build(self) -> Iterator[AutoBuildContainer]: + base_command = [ + """ make defconfig make find . -type f -name "*.o" > objfiles llvm-ar rcs libfuzz.a $(cat objfiles) -''' - ] - build_container = AutoBuildContainer() - build_container.list_of_commands = base_command - build_container.heuristic_id = self.name + '1' - yield build_container - - # Alternative to avoid Gold lld - build_container_2 = AutoBuildContainer() - base_command.append('export CFLAGS="${CFLAGS} -fuse-ld=lld"') - base_command.append('export CFXXLAGS="${CXXFLAGS} -fuse-ld=lld"') - build_container_2.list_of_commands = base_command - build_container.heuristic_id = self.name + '2' - yield build_container_2 - - # Alternative to avoid Gold lld and add thread/crypt libraries - build_container_3 = AutoBuildContainer() - base_command.append('export CFLAGS="${CFLAGS} -lpthread -lcrypt"') - base_command.append('export CFXXLAGS="${CXXFLAGS} -lpthread -lcrypt"') - build_container_3.list_of_commands = base_command - build_container.heuristic_id = self.name + '3' - yield build_container_3 - - @property - def name(self): - return 'kconfig' +""" + ] + build_container = AutoBuildContainer() + build_container.list_of_commands = base_command + build_container.heuristic_id = self.name + "1" + yield build_container + + # Alternative to avoid Gold lld + build_container_2 = AutoBuildContainer() + base_command.append('export CFLAGS="${CFLAGS} -fuse-ld=lld"') + base_command.append('export CFXXLAGS="${CXXFLAGS} -fuse-ld=lld"') + build_container_2.list_of_commands = base_command + build_container.heuristic_id = self.name + "2" + yield build_container_2 + + # Alternative to avoid Gold lld and add thread/crypt libraries + build_container_3 = AutoBuildContainer() + base_command.append('export CFLAGS="${CFLAGS} -lpthread -lcrypt"') + base_command.append('export CFXXLAGS="${CXXFLAGS} -lpthread -lcrypt"') + build_container_3.list_of_commands = base_command + build_container.heuristic_id = self.name + "3" + yield build_container_3 + + @property + def name(self): + return "kconfig" def match_build_heuristics_on_folder(abspath_of_target: str): - """Yields AutoBuildContainer objects. - - Traverses the files in the target folder. Uses the file list as input to - auto build heuristics, and for each heuristic will yield any of the - build steps that are deemed matching.""" - all_files = utils.get_all_files_in_path(abspath_of_target) - all_checks = [ - AutogenConfScanner(), - PureCFileCompiler(), - PureCFileCompilerFind(), - PureCPPFileCompilerFind(), - PureMakefileScanner(), - PureMakefileScannerWithPThread(), - PureMakefileScannerWithSubstitutions(), - PureMakefileScannerWithLibFlag(), - AutogenScanner(), - AutoRefConfScanner(), - CMakeScanner(), - CMakeScannerOptsParser(), - RawMake(), - BootstrapScanner(), - AutogenScannerSH(), - HeaderOnlyCBuilder(), - KConfigBuildScanner(), - ] - - checks_to_test = [] - - logger.info('Filtering out build scripts') - build_heuristics_to_analyse = os.getenv('BUILD_HEURISTICS', 'all') - if build_heuristics_to_analyse == 'all': - checks_to_test = all_checks - else: - all_build_heuristics = build_heuristics_to_analyse.split(',') - for name in all_build_heuristics: - for check in all_checks: - if check.name == name: - checks_to_test.append(check) - - logger.info('Using %d checks.', len(checks_to_test)) - for scanner in checks_to_test: - scanner.match_files(all_files) - if scanner.is_matched(): - logger.info('Matched: %s', scanner.name) - yield from scanner.steps_to_build() + """Yields AutoBuildContainer objects. + + Traverses the files in the target folder. Uses the file list as input to + auto build heuristics, and for each heuristic will yield any of the + build steps that are deemed matching.""" + all_files = utils.get_all_files_in_path(abspath_of_target) + all_checks = [ + AutogenConfScanner(), + PureCFileCompiler(), + PureCFileCompilerFind(), + PureCPPFileCompilerFind(), + PureMakefileScanner(), + PureMakefileScannerWithPThread(), + PureMakefileScannerWithSubstitutions(), + PureMakefileScannerWithLibFlag(), + AutogenScanner(), + AutoRefConfScanner(), + CMakeScanner(), + CMakeScannerOptsParser(), + RawMake(), + BootstrapScanner(), + AutogenScannerSH(), + HeaderOnlyCBuilder(), + KConfigBuildScanner(), + ] + + checks_to_test = [] + + logger.info("Filtering out build scripts") + build_heuristics_to_analyse = os.getenv("BUILD_HEURISTICS", "all") + if build_heuristics_to_analyse == "all": + checks_to_test = all_checks + else: + all_build_heuristics = build_heuristics_to_analyse.split(",") + for name in all_build_heuristics: + for check in all_checks: + if check.name == name: + checks_to_test.append(check) + + logger.info("Using %d checks.", len(checks_to_test)) + for scanner in checks_to_test: + scanner.match_files(all_files) + if scanner.is_matched(): + logger.info("Matched: %s", scanner.name) + yield from scanner.steps_to_build() def get_all_binary_files_from_folder(path: str) -> Dict[str, List[str]]: - """Extracts binary artifacts from a list of files, based on file suffix.""" - all_files = utils.get_all_files_in_path(path, path) + """Extracts binary artifacts from a list of files, based on file suffix.""" + all_files = utils.get_all_files_in_path(path, path) - executable_files = {'static-libs': [], 'dynamic-libs': [], 'object-files': []} - for fil in all_files: - if fil.endswith('.o'): - executable_files['object-files'].append(fil) - if fil.endswith('.a'): - executable_files['static-libs'].append(fil) - if fil.endswith('.so'): - executable_files['dynamic-libs'].append(fil) - return executable_files + executable_files = {"static-libs": [], "dynamic-libs": [], "object-files": []} + for fil in all_files: + if fil.endswith(".o"): + executable_files["object-files"].append(fil) + if fil.endswith(".a"): + executable_files["static-libs"].append(fil) + if fil.endswith(".so"): + executable_files["dynamic-libs"].append(fil) + return executable_files -def wrap_build_script(test_dir: str, build_container: AutoBuildContainer, - abspath_of_target: str) -> str: - build_script = '#!/bin/bash\n' - build_script += f'rm -rf /{test_dir}\n' - build_script += f'cp -rf {abspath_of_target} {test_dir}\n' - build_script += f'cd {test_dir}\n' - for cmd in build_container.list_of_commands: - build_script += cmd + '\n' +def wrap_build_script( + test_dir: str, build_container: AutoBuildContainer, abspath_of_target: str +) -> str: + build_script = "#!/bin/bash\n" + build_script += f"rm -rf /{test_dir}\n" + build_script += f"cp -rf {abspath_of_target} {test_dir}\n" + build_script += f"cd {test_dir}\n" + for cmd in build_container.list_of_commands: + build_script += cmd + "\n" - return build_script + return build_script def convert_build_heuristics_to_scripts( - all_build_suggestions: List[AutoBuildContainer], testing_base_dir: str, - abspath_of_target: str) -> List[Tuple[str, str, AutoBuildContainer]]: - """Convert Auto build containers into bash scripts.""" - all_build_scripts = [] - for idx, build_suggestion in enumerate(all_build_suggestions): - test_dir = os.path.abspath( - os.path.join(os.getcwd(), testing_base_dir + str(idx))) - build_script = wrap_build_script(test_dir, build_suggestion, - abspath_of_target) - all_build_scripts.append((build_script, test_dir, build_suggestion)) - return all_build_scripts + all_build_suggestions: List[AutoBuildContainer], + testing_base_dir: str, + abspath_of_target: str, +) -> List[Tuple[str, str, AutoBuildContainer]]: + """Convert Auto build containers into bash scripts.""" + all_build_scripts = [] + for idx, build_suggestion in enumerate(all_build_suggestions): + test_dir = os.path.abspath( + os.path.join(os.getcwd(), testing_base_dir + str(idx)) + ) + build_script = wrap_build_script(test_dir, build_suggestion, abspath_of_target) + all_build_scripts.append((build_script, test_dir, build_suggestion)) + return all_build_scripts def extract_build_suggestions( - target_dir, testing_base_dir) -> List[Tuple[str, str, AutoBuildContainer]]: - """Statically create suggested build scripts for a project.""" - # Get all of the build heuristics - all_build_suggestions: List[AutoBuildContainer] = list( - match_build_heuristics_on_folder(target_dir)) - logger.info('Found %d possible build suggestions', len(all_build_suggestions)) - #all_build_suggestions = all_build_suggestions[:2] - for build_suggestion in all_build_suggestions: - logger.info('- %s', build_suggestion.heuristic_id) - - # Convert the build heuristics into build scripts - all_build_scripts = convert_build_heuristics_to_scripts( - all_build_suggestions, testing_base_dir, target_dir) - return all_build_scripts + target_dir, testing_base_dir +) -> List[Tuple[str, str, AutoBuildContainer]]: + """Statically create suggested build scripts for a project.""" + # Get all of the build heuristics + all_build_suggestions: List[AutoBuildContainer] = list( + match_build_heuristics_on_folder(target_dir) + ) + logger.info("Found %d possible build suggestions", len(all_build_suggestions)) + # all_build_suggestions = all_build_suggestions[:2] + for build_suggestion in all_build_suggestions: + logger.info("- %s", build_suggestion.heuristic_id) + + # Convert the build heuristics into build scripts + all_build_scripts = convert_build_heuristics_to_scripts( + all_build_suggestions, testing_base_dir, target_dir + ) + return all_build_scripts def raw_build_evaluation( all_build_scripts: List[Tuple[str, str, AutoBuildContainer]] ) -> Dict[str, BuildWorker]: - """Run each of the build scripts and extract any artifacts build by them.""" - build_results = {} - for build_script, test_dir, build_suggestion in all_build_scripts: - logger.info('Evaluating build heuristic %s', build_suggestion.heuristic_id) - with open('/src/build.sh', 'w') as bf: - bf.write(build_script) - - pkgs = build_suggestion.list_of_required_packages - if pkgs: - command = f'apt install -y {" ".join(pkgs)} && compile' - else: - command = 'compile' - try: - subprocess.check_call(command, - shell=True, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL) - except subprocess.CalledProcessError: - pass - logger.info('Finished evaluation.') - # Identify any binary artifacts built that weren't there prior - # to running the build. - logger.info('Finding executables') - binary_files_build = get_all_binary_files_from_folder(test_dir) - logger.info('Finished looking for executables.') - - build_worker = BuildWorker(build_suggestion, build_script, test_dir, - binary_files_build) - - build_results[test_dir] = build_worker - - return build_results + """Run each of the build scripts and extract any artifacts build by them.""" + build_results = {} + for build_script, test_dir, build_suggestion in all_build_scripts: + logger.info("Evaluating build heuristic %s", build_suggestion.heuristic_id) + with open("/src/build.sh", "w") as bf: + bf.write(build_script) + + pkgs = build_suggestion.list_of_required_packages + if pkgs: + command = f'apt install -y {" ".join(pkgs)} && compile' + else: + command = "compile" + try: + subprocess.check_call( + command, + shell=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + except subprocess.CalledProcessError: + pass + logger.info("Finished evaluation.") + # Identify any binary artifacts built that weren't there prior + # to running the build. + logger.info("Finding executables") + binary_files_build = get_all_binary_files_from_folder(test_dir) + logger.info("Finished looking for executables.") + + build_worker = BuildWorker( + build_suggestion, build_script, test_dir, binary_files_build + ) + + build_results[test_dir] = build_worker + + return build_results diff --git a/experimental/build_generator/constants.py b/experimental/build_generator/constants.py index c1d21e55a8..64448a13b7 100644 --- a/experimental/build_generator/constants.py +++ b/experimental/build_generator/constants.py @@ -14,18 +14,18 @@ # limitations under the License. """Holds constants used for from-scratch generation.""" -SHARED_MEMORY_RESULTS_DIR = 'autogen-results' -PROJECT_BASE = 'temp-project-' +SHARED_MEMORY_RESULTS_DIR = "autogen-results" +PROJECT_BASE = "temp-project-" -MODEL_GPT_35_TURBO = 'gpt-3.5-turbo' -MODEL_GPT_4 = 'gpt-4' -MODEL_VERTEX = 'vertex' +MODEL_GPT_35_TURBO = "gpt-3.5-turbo" +MODEL_GPT_4 = "gpt-4" +MODEL_VERTEX = "vertex" MODELS = [MODEL_GPT_35_TURBO, MODEL_VERTEX] MAX_PROMPT_LENGTH = 25000 -INTROSPECTOR_OSS_FUZZ_DIR = '/src/inspector' -INTROSPECTOR_ALL_FUNCTIONS_FILE = 'all-fuzz-introspector-functions.json' +INTROSPECTOR_OSS_FUZZ_DIR = "/src/inspector" +INTROSPECTOR_ALL_FUNCTIONS_FILE = "all-fuzz-introspector-functions.json" # Common -l to required package mapping for Dockerfile installation LIBRARY_PACKAGE_MAP = { diff --git a/experimental/build_generator/file_utils.py b/experimental/build_generator/file_utils.py index 6bea864f18..e19d88fb63 100644 --- a/experimental/build_generator/file_utils.py +++ b/experimental/build_generator/file_utils.py @@ -17,56 +17,61 @@ from typing import List, Optional try: - # For execution outside of a docker container - from experimental.build_generator import templates + # For execution outside of a docker container + from experimental.build_generator import templates except (ImportError, SystemError): - # For execution inside of a docker container - import templates + # For execution inside of a docker container + import templates def determine_project_language(path: str) -> str: - """Returns the likely language of a project by looking at file suffixes.""" - all_files = get_all_files_in_path(path, path) + """Returns the likely language of a project by looking at file suffixes.""" + all_files = get_all_files_in_path(path, path) - language_dict = {'c': 0, 'c++': 0} - for source_file in all_files: - if source_file.endswith('.c'): - language_dict['c'] = language_dict['c'] + 1 - elif source_file.endswith('.cpp'): - language_dict['c++'] = language_dict['c++'] + 1 - elif source_file.endswith('.cc'): - language_dict['c++'] = language_dict['c++'] + 1 + language_dict = {"c": 0, "c++": 0} + for source_file in all_files: + if source_file.endswith(".c"): + language_dict["c"] = language_dict["c"] + 1 + elif source_file.endswith(".cpp"): + language_dict["c++"] = language_dict["c++"] + 1 + elif source_file.endswith(".cc"): + language_dict["c++"] = language_dict["c++"] + 1 - target_language = 'c++' - max_count = 0 - for language, count in language_dict.items(): - if count > max_count: - target_language = language - max_count = count - return target_language + target_language = "c++" + max_count = 0 + for language, count in language_dict.items(): + if count > max_count: + target_language = language + max_count = count + return target_language def get_language_defaults(language: str): - compilers_and_flags = { - 'c': ('$CC', '$CFLAGS', '/src/empty-fuzzer.c', templates.C_BASE_TEMPLATE), - 'c++': ('$CXX', '$CXXFLAGS', '/src/empty-fuzzer.cpp', - templates.CPP_BASE_TEMPLATE), - } - return compilers_and_flags[language] + compilers_and_flags = { + "c": ("$CC", "$CFLAGS", "/src/empty-fuzzer.c", templates.C_BASE_TEMPLATE), + "c++": ( + "$CXX", + "$CXXFLAGS", + "/src/empty-fuzzer.cpp", + templates.CPP_BASE_TEMPLATE, + ), + } + return compilers_and_flags[language] -def get_all_files_in_path(base_path: str, - path_to_subtract: Optional[str] = None) -> List[str]: - """Gets all files in a tree and returns as a list of strings.""" - all_files = [] - if path_to_subtract is None: - path_to_subtract = os.getcwd() - for root, _, files in os.walk(base_path): - for fi in files: - path = os.path.join(root, fi) - if path.startswith(path_to_subtract): - path = path[len(path_to_subtract):] - if len(path) > 0 and path[0] == '/': - path = path[1:] - all_files.append(path) - return all_files +def get_all_files_in_path( + base_path: str, path_to_subtract: Optional[str] = None +) -> List[str]: + """Gets all files in a tree and returns as a list of strings.""" + all_files = [] + if path_to_subtract is None: + path_to_subtract = os.getcwd() + for root, _, files in os.walk(base_path): + for fi in files: + path = os.path.join(root, fi) + if path.startswith(path_to_subtract): + path = path[len(path_to_subtract) :] + if len(path) > 0 and path[0] == "/": + path = path[1:] + all_files.append(path) + return all_files diff --git a/experimental/build_generator/llm_agent.py b/experimental/build_generator/llm_agent.py index 351d394a90..20f24f9928 100644 --- a/experimental/build_generator/llm_agent.py +++ b/experimental/build_generator/llm_agent.py @@ -34,472 +34,505 @@ class BuildScriptAgent(BaseAgent): - """Base class for build script agent.""" - - def __init__(self, - trial: int, - llm: LLM, - args: argparse.Namespace, - github_url: str, - language: str, - tools: Optional[list[BaseTool]] = None, - name: str = ''): - super().__init__(trial, llm, args, tools, name) - self.github_url = github_url - self.language = language - self.build_files = {} - self.last_status = False - self.last_result = '' - self.invalid = False - self.target_files = {} - self.discovery_stage = False - - # Get sample fuzzing harness - _, _, self.harness_path, self.harness_code = ( - file_utils.get_language_defaults(self.language)) - self.harness_name = self.harness_path.split('/')[-1].split('.')[0] - - def _parse_tag(self, response: str, tag: str) -> str: - """Parses the tag from LLM response.""" - patterns = [rf'<{tag}>(.*?)', rf'```{tag}(.*?)```'] - - # Matches both xml and code style tags - for pattern in patterns: - match = re.search(pattern, response, re.DOTALL) - if match: - return match.group(1).strip() - - return '' - - def _parse_tags(self, response: str, tag: str) -> list[str]: - """Parses the tags from LLM response.""" - patterns = [rf'<{tag}>(.*?)', rf'```{tag}(.*?)```'] - found_matches = [] - - # Matches both xml and code style tags - for pattern in patterns: - matches = re.findall(pattern, response, re.DOTALL) - found_matches.extend([content.strip() for content in matches]) - - return found_matches - - def _test_introspector_build(self, tool: BaseTool) -> bool: - """Helper to test the generated build script for introspector build.""" - # Handles environment variables for introspector build - envs = { - 'SANITIZER': 'introspector', - 'FUZZ_INTROSPECTOR_AUTO_FUZZ': '1', - 'PROJECT_NAME': 'auto-fuzz-proj', - 'FI_DISABLE_LIGHT': '1', - 'FUZZING_LANGUAGE': self.language, - 'FUZZINTRO_OUTDIR': self.args.work_dirs, - } - env_str = ' '.join(f"{key}='{value}'" for key, value in envs.items()) - - # Test introspector build - result = tool.execute(f'{env_str} compile') - - return result.returncode == 0 - - def _container_handle_bash_commands(self, response: str, tool: BaseTool, - prompt: Prompt) -> Prompt: - """Handles the command from LLM with container |tool|.""" - # Initialise variables - prompt_text = '' - success = False - self.invalid = False - self.missing_binary = False - - # Retrieve data from response - harness = self._parse_tag(response, 'fuzzer') - build_script = self._parse_tag(response, 'bash') - commands = '; '.join(self._parse_tags(response, 'command')) - - if commands: - self.discovery_stage = True - - # Execute the command directly, then return the formatted result - result = tool.execute(commands) - prompt_text = self._format_bash_execution_result(result, - previous_prompt=prompt) - if result.returncode == 0: - success = True - elif build_script: - self.discovery_stage = False - - # Restart the container to ensure a fresh session for test - if isinstance(tool, ProjectContainerTool): - tool.terminate() - tool = ProjectContainerTool(benchmark=tool.benchmark, name='test') - self.inspect_tool = tool - - # Update fuzzing harness - if harness: - self.harness_code = harness - if isinstance(tool, ProjectContainerTool): - tool.write_to_file(self.harness_code, self.harness_path) - - # Fix shebang to ensure docker image failing is reflected. - lines = build_script.split('\n') - if lines[0].startswith("#!"): - lines[0] = "#!/bin/bash -eu" - else: - lines = ["#!/bin/bash -eu"] + lines - build_script = '\n'.join(lines) - - # Update build script - if isinstance(tool, ProjectContainerTool): - tool.write_to_file(build_script, tool.build_script_path) - - # Test and parse result - result = tool.execute('compile') - format_result = self._format_bash_execution_result( - result, previous_prompt=prompt) - prompt_text = self._parse_tag(format_result, 'stderr') + '\n' - if result.returncode == 0: - # Execution success, validating if the fuzzer binary built correctly - command = f'test -f $OUT/{self.harness_name}' - result = tool.execute(command) - - if result.returncode == 0: - success = True - # Test introspector build - introspector_build_result = self._test_introspector_build(tool) - command = f'test -d {constants.INTROSPECTOR_OSS_FUZZ_DIR}' - introspector_dir_check_result = tool.execute(command) - if introspector_dir_check_result.returncode == 0: - logger.info('Introspector build success', trial=-1) + """Base class for build script agent.""" + + def __init__( + self, + trial: int, + llm: LLM, + args: argparse.Namespace, + github_url: str, + language: str, + tools: Optional[list[BaseTool]] = None, + name: str = "", + ): + super().__init__(trial, llm, args, tools, name) + self.github_url = github_url + self.language = language + self.build_files = {} + self.last_status = False + self.last_result = "" + self.invalid = False + self.target_files = {} + self.discovery_stage = False + + # Get sample fuzzing harness + _, _, self.harness_path, self.harness_code = file_utils.get_language_defaults( + self.language + ) + self.harness_name = self.harness_path.split("/")[-1].split(".")[0] + + def _parse_tag(self, response: str, tag: str) -> str: + """Parses the tag from LLM response.""" + patterns = [rf"<{tag}>(.*?)", rf"```{tag}(.*?)```"] + + # Matches both xml and code style tags + for pattern in patterns: + match = re.search(pattern, response, re.DOTALL) + if match: + return match.group(1).strip() + + return "" + + def _parse_tags(self, response: str, tag: str) -> list[str]: + """Parses the tags from LLM response.""" + patterns = [rf"<{tag}>(.*?)", rf"```{tag}(.*?)```"] + found_matches = [] + + # Matches both xml and code style tags + for pattern in patterns: + matches = re.findall(pattern, response, re.DOTALL) + found_matches.extend([content.strip() for content in matches]) + + return found_matches + + def _test_introspector_build(self, tool: BaseTool) -> bool: + """Helper to test the generated build script for introspector build.""" + # Handles environment variables for introspector build + envs = { + "SANITIZER": "introspector", + "FUZZ_INTROSPECTOR_AUTO_FUZZ": "1", + "PROJECT_NAME": "auto-fuzz-proj", + "FI_DISABLE_LIGHT": "1", + "FUZZING_LANGUAGE": self.language, + "FUZZINTRO_OUTDIR": self.args.work_dirs, + } + env_str = " ".join(f"{key}='{value}'" for key, value in envs.items()) + + # Test introspector build + result = tool.execute(f"{env_str} compile") + + return result.returncode == 0 + + def _container_handle_bash_commands( + self, response: str, tool: BaseTool, prompt: Prompt + ) -> Prompt: + """Handles the command from LLM with container |tool|.""" + # Initialise variables + prompt_text = "" + success = False + self.invalid = False + self.missing_binary = False + + # Retrieve data from response + harness = self._parse_tag(response, "fuzzer") + build_script = self._parse_tag(response, "bash") + commands = "; ".join(self._parse_tags(response, "command")) + + if commands: + self.discovery_stage = True + + # Execute the command directly, then return the formatted result + result = tool.execute(commands) + prompt_text = self._format_bash_execution_result( + result, previous_prompt=prompt + ) + if result.returncode == 0: + success = True + elif build_script: + self.discovery_stage = False + + # Restart the container to ensure a fresh session for test + if isinstance(tool, ProjectContainerTool): + tool.terminate() + tool = ProjectContainerTool(benchmark=tool.benchmark, name="test") + self.inspect_tool = tool + + # Update fuzzing harness + if harness: + self.harness_code = harness + if isinstance(tool, ProjectContainerTool): + tool.write_to_file(self.harness_code, self.harness_path) + + # Fix shebang to ensure docker image failing is reflected. + lines = build_script.split("\n") + if lines[0].startswith("#!"): + lines[0] = "#!/bin/bash -eu" else: - logger.info('Failed to get introspector results', trial=-1) - - if not introspector_build_result: - logger.info(('Introspector build returned error, ' - 'but light version worked.'), - trial=-1) - else: - # Fuzzer binary not compiled correctly - success = False - self.missing_binary = True - else: - self.invalid = True - - self.last_status = success - self.last_result = prompt_text - - return prompt - - def _container_handle_conclusion(self, cur_round: int, response: str, - build_result: BuildResult, - prompt: Prompt) -> Optional[Prompt]: - """Runs a compilation tool to validate the new build script from LLM.""" - logger.info('----- ROUND %02d Received conclusion -----', - cur_round, - trial=build_result.trial) - - # Don't need to check for invalid result - if self.invalid: - return prompt - - # Execution fail - if not self.last_status: - if self.missing_binary: - retry = templates.LLM_MISSING_BINARY.replace('{RESULT}', - self.last_result) - retry = retry.replace('{FUZZER_NAME}', self.harness_name) - else: - retry = templates.LLM_RETRY.replace('{BASH_RESULT}', self.last_result) - prompt.add_problem(retry) - - # Store build result - build_result.compiles = False - build_result.compile_error = self.last_result - - return prompt - - # Build success and compiled binary exist - build_result.compiles = True - build_result.fuzz_target_source = self.harness_code - build_script_source = self._parse_tag(response, 'bash') - if not build_script_source.startswith('#!'): - build_script_source = templates.EMPTY_OSS_FUZZ_BUILD + build_script_source - build_result.build_script_source = build_script_source - - return None - - def _container_tool_reaction(self, cur_round: int, response: str, - build_result: BuildResult) -> Optional[Prompt]: - """Validates LLM conclusion or executes its command.""" - prompt = self.llm.prompt_type()(None) - - if response: - prompt = self._container_handle_bash_commands(response, self.inspect_tool, - prompt) - - # Check result and try building with the new builds script - prompt = self._container_handle_conclusion(cur_round, response, - build_result, prompt) - - if prompt is None: - return None - - if not response or not prompt or not prompt.get(): - prompt = self._container_handle_invalid_tool_usage( - self.inspect_tool, cur_round, response, prompt) - - return prompt - - def _prepare_repository(self) -> str: - """Helper to prepare the repository for analysis.""" - target_path = os.path.join(self.args.work_dirs, - self.github_url.split('/')[-1]) - if not os.path.isdir(target_path): - subprocess.check_call( - f'git clone --recurse-submodules {self.github_url} {target_path}', - shell=True) - - return os.path.abspath(target_path) - - def _discover_headers(self) -> list[str]: - """Helper to discover some header files for inclusion.""" - # Prepare targert repository - target_path = self._prepare_repository() - - headers = set() - for root, _, files in os.walk(target_path): - for file in files: - if file.endswith((".h", ".hpp")): - header_path = os.path.join(root, file) - headers.add(header_path.replace(target_path, '')) - - return list(headers) - - def execute(self, result_history: list[Result]) -> BuildResult: - """Executes the agent based on previous result.""" - last_result = result_history[-1] - logger.info('Executing %s', self.name, trial=last_result.trial) - benchmark = last_result.benchmark - self.inspect_tool = ProjectContainerTool(benchmark, name='inspect') - self.inspect_tool.compile(extra_commands=' && rm -rf /out/* > /dev/null') - cur_round = 1 - dis_round = 1 - build_result = BuildResult(benchmark=benchmark, - trial=last_result.trial, - work_dirs=last_result.work_dirs, - author=self, - chat_history={self.name: ''}) - - prompt = self._initial_prompt(result_history) - try: - client = self.llm.get_chat_client(model=self.llm.get_model()) - while prompt: - # Sleep shortly to avoid RPM - time.sleep(6) - - response = self.chat_llm(cur_round, - client=client, - prompt=prompt, - trial=last_result.trial) - prompt = self._container_tool_reaction(cur_round, response, - build_result) - - if self.discovery_stage: - dis_round += 1 - if dis_round >= MAX_DISCOVERY_ROUND: - break + lines = ["#!/bin/bash -eu"] + lines + build_script = "\n".join(lines) + + # Update build script + if isinstance(tool, ProjectContainerTool): + tool.write_to_file(build_script, tool.build_script_path) + + # Test and parse result + result = tool.execute("compile") + format_result = self._format_bash_execution_result( + result, previous_prompt=prompt + ) + prompt_text = self._parse_tag(format_result, "stderr") + "\n" + if result.returncode == 0: + # Execution success, validating if the fuzzer binary built correctly + command = f"test -f $OUT/{self.harness_name}" + result = tool.execute(command) + + if result.returncode == 0: + success = True + # Test introspector build + introspector_build_result = self._test_introspector_build(tool) + command = f"test -d {constants.INTROSPECTOR_OSS_FUZZ_DIR}" + introspector_dir_check_result = tool.execute(command) + if introspector_dir_check_result.returncode == 0: + logger.info("Introspector build success", trial=-1) + else: + logger.info("Failed to get introspector results", trial=-1) + + if not introspector_build_result: + logger.info( + ( + "Introspector build returned error, " + "but light version worked." + ), + trial=-1, + ) + else: + # Fuzzer binary not compiled correctly + success = False + self.missing_binary = True else: - cur_round += 1 - if cur_round >= self.max_round: - break - finally: - logger.info('Stopping and removing the inspect container %s', - self.inspect_tool.container_id, - trial=last_result.trial) - self.inspect_tool.terminate() + self.invalid = True + + self.last_status = success + self.last_result = prompt_text + + return prompt + + def _container_handle_conclusion( + self, cur_round: int, response: str, build_result: BuildResult, prompt: Prompt + ) -> Optional[Prompt]: + """Runs a compilation tool to validate the new build script from LLM.""" + logger.info( + "----- ROUND %02d Received conclusion -----", + cur_round, + trial=build_result.trial, + ) + + # Don't need to check for invalid result + if self.invalid: + return prompt + + # Execution fail + if not self.last_status: + if self.missing_binary: + retry = templates.LLM_MISSING_BINARY.replace( + "{RESULT}", self.last_result + ) + retry = retry.replace("{FUZZER_NAME}", self.harness_name) + else: + retry = templates.LLM_RETRY.replace("{BASH_RESULT}", self.last_result) + prompt.add_problem(retry) + + # Store build result + build_result.compiles = False + build_result.compile_error = self.last_result - return build_result + return prompt + + # Build success and compiled binary exist + build_result.compiles = True + build_result.fuzz_target_source = self.harness_code + build_script_source = self._parse_tag(response, "bash") + if not build_script_source.startswith("#!"): + build_script_source = templates.EMPTY_OSS_FUZZ_BUILD + build_script_source + build_result.build_script_source = build_script_source + + return None + + def _container_tool_reaction( + self, cur_round: int, response: str, build_result: BuildResult + ) -> Optional[Prompt]: + """Validates LLM conclusion or executes its command.""" + prompt = self.llm.prompt_type()(None) + + if response: + prompt = self._container_handle_bash_commands( + response, self.inspect_tool, prompt + ) + + # Check result and try building with the new builds script + prompt = self._container_handle_conclusion( + cur_round, response, build_result, prompt + ) + + if prompt is None: + return None + + if not response or not prompt or not prompt.get(): + prompt = self._container_handle_invalid_tool_usage( + self.inspect_tool, cur_round, response, prompt + ) + + return prompt + + def _prepare_repository(self) -> str: + """Helper to prepare the repository for analysis.""" + target_path = os.path.join(self.args.work_dirs, self.github_url.split("/")[-1]) + if not os.path.isdir(target_path): + subprocess.check_call( + f"git clone --recurse-submodules {self.github_url} {target_path}", + shell=True, + ) + + return os.path.abspath(target_path) + + def _discover_headers(self) -> list[str]: + """Helper to discover some header files for inclusion.""" + # Prepare targert repository + target_path = self._prepare_repository() + + headers = set() + for root, _, files in os.walk(target_path): + for file in files: + if file.endswith((".h", ".hpp")): + header_path = os.path.join(root, file) + headers.add(header_path.replace(target_path, "")) + + return list(headers) + + def execute(self, result_history: list[Result]) -> BuildResult: + """Executes the agent based on previous result.""" + last_result = result_history[-1] + logger.info("Executing %s", self.name, trial=last_result.trial) + benchmark = last_result.benchmark + self.inspect_tool = ProjectContainerTool(benchmark, name="inspect") + self.inspect_tool.compile(extra_commands=" && rm -rf /out/* > /dev/null") + cur_round = 1 + dis_round = 1 + build_result = BuildResult( + benchmark=benchmark, + trial=last_result.trial, + work_dirs=last_result.work_dirs, + author=self, + chat_history={self.name: ""}, + ) + + prompt = self._initial_prompt(result_history) + try: + client = self.llm.get_chat_client(model=self.llm.get_model()) + while prompt: + # Sleep shortly to avoid RPM + time.sleep(6) + + response = self.chat_llm( + cur_round, client=client, prompt=prompt, trial=last_result.trial + ) + prompt = self._container_tool_reaction( + cur_round, response, build_result + ) + + if self.discovery_stage: + dis_round += 1 + if dis_round >= MAX_DISCOVERY_ROUND: + break + else: + cur_round += 1 + if cur_round >= self.max_round: + break + finally: + logger.info( + "Stopping and removing the inspect container %s", + self.inspect_tool.container_id, + trial=last_result.trial, + ) + self.inspect_tool.terminate() + + return build_result class BuildSystemBuildScriptAgent(BuildScriptAgent): - """Generate a working Dockerfile and build script from scratch - with build system.""" - - def __init__(self, - trial: int, - llm: LLM, - args: argparse.Namespace, - github_url: str, - language: str, - tools: Optional[list[BaseTool]] = None, - name: str = ''): - super().__init__(trial, llm, args, github_url, language, tools, name) - self.target_files = { - 'Makefile': [], - 'configure.ac': [], - 'Makefile.am': [], - 'autogen.sh': [], - 'bootstrap.sh': [], - 'CMakeLists.txt': [], - 'Config.in': [], - } - - def _discover_build_configurations(self) -> bool: - """Helper to discover the build configuartions of a repository.""" - # Prepare targert repository - target_path = self._prepare_repository() - - # Locate common build configuration files - for root_dir, _, files in os.walk(target_path): - for file in files: - if file in self.target_files: - full_path = os.path.join(root_dir, file) - self.target_files[file].append(full_path) - - # Extract content of build files - for files in self.target_files.values(): - for file in files: - with open(file, 'r') as f: - self.build_files[file.replace(target_path, '')] = f.read() - - return len(self.build_files) > 0 - - def _initial_prompt(self, results: list[Result]) -> Prompt: - """Constructs initial prompt of the agent.""" - # pylint: disable=unused-argument - - prompt = self.llm.prompt_type()(None) - - # Extract build configuration files content - build_files_str = [] - for file, content in self.build_files.items(): - target_str = templates.LLM_BUILD_FILE_TEMPLATE.replace('{PATH}', file) - target_str = target_str.replace('{CONTENT}', content) - build_files_str.append(target_str) - - # Extract template Dockerfile content - dockerfile_str = templates.CLEAN_OSS_FUZZ_DOCKER - dockerfile_str = dockerfile_str.replace('{additional_packages}', '') - dockerfile_str = dockerfile_str.replace('{fuzzer_dir}', '$SRC/') - dockerfile_str = dockerfile_str.replace('{repo_url}', self.github_url) - dockerfile_str = dockerfile_str.replace('{project_repo_dir}', - self.github_url.split('/')[-1]) - - # Prepare prompt problem string - problem = templates.LLM_PROBLEM.replace('{BUILD_FILES}', - '\n'.join(build_files_str)) - problem = problem.replace('{DOCKERFILE}', dockerfile_str) - problem = problem.replace('{FUZZER}', self.harness_code) - problem = problem.replace('{FUZZER_NAME}', self.harness_name) - problem = problem.replace('{FUZZING_FILE}', - self.harness_path.split('/')[-1]) - - headers = self._discover_headers() - problem = problem.replace('{HEADERS}', - ','.join(headers[:SAMPLE_HEADERS_COUNT])) - - prompt.add_priming(templates.LLM_PRIMING) - prompt.add_problem(problem) - - return prompt - - def execute(self, result_history: list[Result]) -> BuildResult: - """Executes the agent based on previous result.""" - if not self._discover_build_configurations(): - logger.info('No known build configuration.', trial=self.trial) - return BuildResult(benchmark=result_history[-1].benchmark, - trial=result_history[-1].trial, - work_dirs=result_history[-1].work_dirs, - author=self, - chat_history={self.name: ''}) - - return super().execute(result_history) + """Generate a working Dockerfile and build script from scratch + with build system.""" + + def __init__( + self, + trial: int, + llm: LLM, + args: argparse.Namespace, + github_url: str, + language: str, + tools: Optional[list[BaseTool]] = None, + name: str = "", + ): + super().__init__(trial, llm, args, github_url, language, tools, name) + self.target_files = { + "Makefile": [], + "configure.ac": [], + "Makefile.am": [], + "autogen.sh": [], + "bootstrap.sh": [], + "CMakeLists.txt": [], + "Config.in": [], + } + + def _discover_build_configurations(self) -> bool: + """Helper to discover the build configuartions of a repository.""" + # Prepare targert repository + target_path = self._prepare_repository() + + # Locate common build configuration files + for root_dir, _, files in os.walk(target_path): + for file in files: + if file in self.target_files: + full_path = os.path.join(root_dir, file) + self.target_files[file].append(full_path) + + # Extract content of build files + for files in self.target_files.values(): + for file in files: + with open(file, "r") as f: + self.build_files[file.replace(target_path, "")] = f.read() + + return len(self.build_files) > 0 + + def _initial_prompt(self, results: list[Result]) -> Prompt: + """Constructs initial prompt of the agent.""" + # pylint: disable=unused-argument + + prompt = self.llm.prompt_type()(None) + + # Extract build configuration files content + build_files_str = [] + for file, content in self.build_files.items(): + target_str = templates.LLM_BUILD_FILE_TEMPLATE.replace("{PATH}", file) + target_str = target_str.replace("{CONTENT}", content) + build_files_str.append(target_str) + + # Extract template Dockerfile content + dockerfile_str = templates.CLEAN_OSS_FUZZ_DOCKER + dockerfile_str = dockerfile_str.replace("{additional_packages}", "") + dockerfile_str = dockerfile_str.replace("{fuzzer_dir}", "$SRC/") + dockerfile_str = dockerfile_str.replace("{repo_url}", self.github_url) + dockerfile_str = dockerfile_str.replace( + "{project_repo_dir}", self.github_url.split("/")[-1] + ) + + # Prepare prompt problem string + problem = templates.LLM_PROBLEM.replace( + "{BUILD_FILES}", "\n".join(build_files_str) + ) + problem = problem.replace("{DOCKERFILE}", dockerfile_str) + problem = problem.replace("{FUZZER}", self.harness_code) + problem = problem.replace("{FUZZER_NAME}", self.harness_name) + problem = problem.replace("{FUZZING_FILE}", self.harness_path.split("/")[-1]) + + headers = self._discover_headers() + problem = problem.replace("{HEADERS}", ",".join(headers[:SAMPLE_HEADERS_COUNT])) + + prompt.add_priming(templates.LLM_PRIMING) + prompt.add_problem(problem) + + return prompt + + def execute(self, result_history: list[Result]) -> BuildResult: + """Executes the agent based on previous result.""" + if not self._discover_build_configurations(): + logger.info("No known build configuration.", trial=self.trial) + return BuildResult( + benchmark=result_history[-1].benchmark, + trial=result_history[-1].trial, + work_dirs=result_history[-1].work_dirs, + author=self, + chat_history={self.name: ""}, + ) + + return super().execute(result_history) class AutoDiscoveryBuildScriptAgent(BuildScriptAgent): - """Generate a working Dockerfile and build script from scratch - with LLM auto discovery""" - - def __init__(self, - trial: int, - llm: LLM, - args: argparse.Namespace, - github_url: str, - language: str, - tools: Optional[list[BaseTool]] = None, - name: str = ''): - super().__init__(trial, llm, args, github_url, language, tools, name) - self.discovery_stage = True - - def _initial_prompt(self, results: list[Result]) -> Prompt: - """Constructs initial prompt of the agent.""" - # pylint: disable=unused-argument - - prompt = self.llm.prompt_type()(None) - - # Extract template Dockerfile content - dockerfile_str = templates.CLEAN_OSS_FUZZ_DOCKER - dockerfile_str = dockerfile_str.replace('{additional_packages}', '') - dockerfile_str = dockerfile_str.replace('{repo_url}', self.github_url) - dockerfile_str = dockerfile_str.replace('{project_repo_dir}', - self.github_url.split('/')[-1]) - - # Prepare prompt problem string - problem = templates.LLM_AUTO_DISCOVERY - problem = problem.replace('{PROJECT_NAME}', self.github_url.split('/')[-1]) - problem = problem.replace('{DOCKERFILE}', dockerfile_str) - problem = problem.replace('{MAX_DISCOVERY_ROUND}', str(MAX_DISCOVERY_ROUND)) - problem = problem.replace('{FUZZER}', self.harness_code) - problem = problem.replace('{FUZZER_NAME}', self.harness_name) - problem = problem.replace('{FUZZING_FILE}', - self.harness_path.split('/')[-1]) - - prompt.add_priming(templates.LLM_PRIMING) - prompt.add_problem(problem) - - return prompt - - def _container_handle_invalid_tool_usage(self, tool: BaseTool, cur_round: int, - response: str, - prompt: Prompt) -> Prompt: - """Formats a prompt to re-teach LLM how to use the |tool|.""" - # pylint: disable=unused-argument - - logger.warning('ROUND %02d Invalid response from LLM: %s', - cur_round, - response, - trial=self.trial) - prompt.add_problem(templates.LLM_NO_VALID_TAG) - - return prompt - - def _container_tool_reaction(self, cur_round: int, response: str, - build_result: BuildResult) -> Optional[Prompt]: - """Validates LLM conclusion or executes its command.""" - prompt = self.llm.prompt_type()(None) - - if response: - prompt = self._container_handle_bash_commands(response, self.inspect_tool, - prompt) - - if self.discovery_stage: - # Relay the command output back to LLM - feedback = templates.LLM_DOCKER_FEEDBACK - feedback = feedback.replace('{RESULT}', self.last_result) - prompt.add_problem(feedback) - else: - # Check result and try building with the new builds script - prompt = self._container_handle_conclusion(cur_round, response, - build_result, prompt) - - if prompt is None: - return None - - if not response or not prompt.get() or self.invalid: - prompt = self._container_handle_invalid_tool_usage( - self.inspect_tool, cur_round, response, prompt) - - return prompt - - def execute(self, result_history: list[Result]) -> BuildResult: - """Executes the agent based on previous result.""" - self._prepare_repository() - return super().execute(result_history) + """Generate a working Dockerfile and build script from scratch + with LLM auto discovery""" + + def __init__( + self, + trial: int, + llm: LLM, + args: argparse.Namespace, + github_url: str, + language: str, + tools: Optional[list[BaseTool]] = None, + name: str = "", + ): + super().__init__(trial, llm, args, github_url, language, tools, name) + self.discovery_stage = True + + def _initial_prompt(self, results: list[Result]) -> Prompt: + """Constructs initial prompt of the agent.""" + # pylint: disable=unused-argument + + prompt = self.llm.prompt_type()(None) + + # Extract template Dockerfile content + dockerfile_str = templates.CLEAN_OSS_FUZZ_DOCKER + dockerfile_str = dockerfile_str.replace("{additional_packages}", "") + dockerfile_str = dockerfile_str.replace("{repo_url}", self.github_url) + dockerfile_str = dockerfile_str.replace( + "{project_repo_dir}", self.github_url.split("/")[-1] + ) + + # Prepare prompt problem string + problem = templates.LLM_AUTO_DISCOVERY + problem = problem.replace("{PROJECT_NAME}", self.github_url.split("/")[-1]) + problem = problem.replace("{DOCKERFILE}", dockerfile_str) + problem = problem.replace("{MAX_DISCOVERY_ROUND}", str(MAX_DISCOVERY_ROUND)) + problem = problem.replace("{FUZZER}", self.harness_code) + problem = problem.replace("{FUZZER_NAME}", self.harness_name) + problem = problem.replace("{FUZZING_FILE}", self.harness_path.split("/")[-1]) + + prompt.add_priming(templates.LLM_PRIMING) + prompt.add_problem(problem) + + return prompt + + def _container_handle_invalid_tool_usage( + self, tool: BaseTool, cur_round: int, response: str, prompt: Prompt + ) -> Prompt: + """Formats a prompt to re-teach LLM how to use the |tool|.""" + # pylint: disable=unused-argument + + logger.warning( + "ROUND %02d Invalid response from LLM: %s", + cur_round, + response, + trial=self.trial, + ) + prompt.add_problem(templates.LLM_NO_VALID_TAG) + + return prompt + + def _container_tool_reaction( + self, cur_round: int, response: str, build_result: BuildResult + ) -> Optional[Prompt]: + """Validates LLM conclusion or executes its command.""" + prompt = self.llm.prompt_type()(None) + + if response: + prompt = self._container_handle_bash_commands( + response, self.inspect_tool, prompt + ) + + if self.discovery_stage: + # Relay the command output back to LLM + feedback = templates.LLM_DOCKER_FEEDBACK + feedback = feedback.replace("{RESULT}", self.last_result) + prompt.add_problem(feedback) + else: + # Check result and try building with the new builds script + prompt = self._container_handle_conclusion( + cur_round, response, build_result, prompt + ) + + if prompt is None: + return None + + if not response or not prompt.get() or self.invalid: + prompt = self._container_handle_invalid_tool_usage( + self.inspect_tool, cur_round, response, prompt + ) + + return prompt + + def execute(self, result_history: list[Result]) -> BuildResult: + """Executes the agent based on previous result.""" + self._prepare_repository() + return super().execute(result_history) diff --git a/experimental/build_generator/manager.py b/experimental/build_generator/manager.py index 39007f73d9..91e9355d6c 100644 --- a/experimental/build_generator/manager.py +++ b/experimental/build_generator/manager.py @@ -31,43 +31,44 @@ INTROSPECTOR_OSS_FUZZ_DIR = constants.INTROSPECTOR_OSS_FUZZ_DIR INTROSPECTOR_ALL_FUNCTIONS_FILE = constants.INTROSPECTOR_OSS_FUZZ_DIR -LLM_MODEL = '' +LLM_MODEL = "" -FUZZER_PRE_HEADERS = '''#include +FUZZER_PRE_HEADERS = """#include #include #include #include -''' +""" SECONDS_TO_RUN_HARNESS = 20 logger = logging.getLogger(name=__name__) -LOG_FMT = ('%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ' - ': %(funcName)s: %(message)s') +LOG_FMT = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " ": %(funcName)s: %(message)s" +) def setup_model(model: str): - global LLM_MODEL - LLM_MODEL = model + global LLM_MODEL + LLM_MODEL = model class Test: - """Holder of data about tests used by a repository.""" + """Holder of data about tests used by a repository.""" - def __init__(self, test_path, test_content): - self.test_path = test_path - self.test_content = test_content + def __init__(self, test_path, test_content): + self.test_path = test_path + self.test_content = test_content def get_all_functions_in_project(introspection_files_found): - all_functions_in_project = [] - for fi_yaml_file in introspection_files_found: - with open(fi_yaml_file, 'r') as file: - yaml_content = yaml.safe_load(file) - for elem in yaml_content['All functions']['Elements']: - all_functions_in_project.append(elem) + all_functions_in_project = [] + for fi_yaml_file in introspection_files_found: + with open(fi_yaml_file, "r") as file: + yaml_content = yaml.safe_load(file) + for elem in yaml_content["All functions"]["Elements"]: + all_functions_in_project.append(elem) - return all_functions_in_project + return all_functions_in_project ################################################## @@ -76,587 +77,628 @@ def get_all_functions_in_project(introspection_files_found): def get_all_header_files(all_files: List[str]) -> List[str]: - all_header_files = [] - to_avoids = ['stdlib.h', 'stdio.h', 'unistd.h'] - for yaml_file in all_files: - if yaml_file.endswith('.h'): - header_basename = os.path.basename(yaml_file) - if header_basename in to_avoids: - continue - all_header_files.append(yaml_file) - return all_header_files + all_header_files = [] + to_avoids = ["stdlib.h", "stdio.h", "unistd.h"] + for yaml_file in all_files: + if yaml_file.endswith(".h"): + header_basename = os.path.basename(yaml_file) + if header_basename in to_avoids: + continue + all_header_files.append(yaml_file) + return all_header_files def get_all_introspector_files(target_dir): - all_files = utils.get_all_files_in_path(target_dir) - introspection_files_found = [] - for yaml_file in all_files: - if 'allFunctionsWithMain' in yaml_file: - #print(yaml_file) - introspection_files_found.append(yaml_file) - elif 'fuzzerLogFile-' in yaml_file and yaml_file.endswith('.yaml'): - introspection_files_found.append(yaml_file) - return introspection_files_found + all_files = utils.get_all_files_in_path(target_dir) + introspection_files_found = [] + for yaml_file in all_files: + if "allFunctionsWithMain" in yaml_file: + # print(yaml_file) + introspection_files_found.append(yaml_file) + elif "fuzzerLogFile-" in yaml_file and yaml_file.endswith(".yaml"): + introspection_files_found.append(yaml_file) + return introspection_files_found def build_empty_fuzzers(build_workers, language) -> None: - """Run build scripts against an empty fuzzer harness.""" - # Stage 2: perform program analysis to extract data to be used for - # harness generation. + """Run build scripts against an empty fuzzer harness.""" + # Stage 2: perform program analysis to extract data to be used for + # harness generation. + + # For each of the auto generated build scripts try to link + # the resulting static libraries against an empty fuzzer. + fuzz_compiler, _, empty_fuzzer_file, fuzz_template = utils.get_language_defaults( + language + ) + for test_dir, build_worker in build_workers: + logger.info( + "Test dir: %s :: %s", + test_dir, + str(build_worker.executable_files_build["refined-static-libs"]), + ) + + if not build_worker.executable_files_build["refined-static-libs"]: + continue + + logger.info("Trying to link in an empty fuzzer") + + # empty_fuzzer_file = '/src/empty-fuzzer.cpp' + with open(empty_fuzzer_file, "w") as f: + f.write(fuzz_template) + + # Try to link the fuzzer to the static libs + cmd = [ + fuzz_compiler, + "-fsanitize=fuzzer", + "-fsanitize=address", + empty_fuzzer_file, + ] + for refined_static_lib in build_worker.executable_files_build[ + "refined-static-libs" + ]: + cmd.append(os.path.join(test_dir, refined_static_lib)) + + logger.info("Command [%s]", " ".join(cmd)) + try: + subprocess.check_call(" ".join(cmd), shell=True) + base_fuzz_build = True + except subprocess.CalledProcessError: + base_fuzz_build = False + + logger.info("Base fuzz build: %s", str(base_fuzz_build)) + build_worker.base_fuzz_build = base_fuzz_build - # For each of the auto generated build scripts try to link - # the resulting static libraries against an empty fuzzer. - fuzz_compiler, _, empty_fuzzer_file, fuzz_template = ( - utils.get_language_defaults(language)) - for test_dir, build_worker in build_workers: - logger.info('Test dir: %s :: %s', test_dir, - str(build_worker.executable_files_build['refined-static-libs'])) - if not build_worker.executable_files_build['refined-static-libs']: - continue +def refine_static_libs(build_results) -> None: + """Returns a list of static libraries with substitution of common gtest + libraries, which should not be linked in the fuzzer builds.""" + for test_dir in build_results: + refined_static_list = [] + libs_to_avoid = { + "libgtest.a", + "libgmock.a", + "libgmock_main.a", + "libgtest_main.a", + } + build_worker = build_results[test_dir] + static_libs = build_worker.executable_files_build["static-libs"] + for static_lib in static_libs: + if any( + os.path.basename(static_lib) in lib_to_avoid + for lib_to_avoid in libs_to_avoid + ): + continue + refined_static_list.append(static_lib) + build_worker.executable_files_build["refined-static-libs"] = refined_static_list + + +def run_introspector_on_dir(build_worker, test_dir, language) -> Tuple[bool, List[str]]: + """Runs Fuzz Introspector on a target directory with the ability + to analyse code without having fuzzers (FUZZ_INTROSPECTOR_AUTO_FUZZ=1). + + This is done by running the bbuild script that succeeded using introspector + sanitizer from OSS-Fuzz, where introspector will collect data form any + executable linked during the vanilla build. - logger.info('Trying to link in an empty fuzzer') + This is done by way of the OSS-Fuzz `compile` command and by setting + the environment appropriately before running this command. + """ + introspector_vanilla_build_script = build_worker.build_script + (fuzz_compiler, fuzz_flags, empty_fuzzer_file, fuzz_template) = ( + utils.get_language_defaults(language) + ) - #empty_fuzzer_file = '/src/empty-fuzzer.cpp' - with open(empty_fuzzer_file, 'w') as f: - f.write(fuzz_template) + with open(empty_fuzzer_file, "w") as f: + f.write(fuzz_template) # Try to link the fuzzer to the static libs - cmd = [ - fuzz_compiler, '-fsanitize=fuzzer', '-fsanitize=address', - empty_fuzzer_file + fuzzer_build_cmd = [ + fuzz_compiler, + fuzz_flags, + "$LIB_FUZZING_ENGINE", + empty_fuzzer_file, ] + fuzzer_build_cmd.append("-Wl,--allow-multiple-definition") for refined_static_lib in build_worker.executable_files_build[ - 'refined-static-libs']: - cmd.append(os.path.join(test_dir, refined_static_lib)) + "refined-static-libs" + ]: + fuzzer_build_cmd.append("-Wl,--whole-archive") + fuzzer_build_cmd.append(os.path.join(test_dir, refined_static_lib)) + fuzzer_build_cmd.append("-Wl,--no-whole-archive") - logger.info('Command [%s]', ' '.join(cmd)) - try: - subprocess.check_call(' '.join(cmd), shell=True) - base_fuzz_build = True - except subprocess.CalledProcessError: - base_fuzz_build = False + fuzzer_build_cmd.append("-o /src/compiled_binary") - logger.info('Base fuzz build: %s', str(base_fuzz_build)) - build_worker.base_fuzz_build = base_fuzz_build + introspector_vanilla_build_script += "\n" + introspector_vanilla_build_script += " ".join(fuzzer_build_cmd) + with open("/src/build.sh", "w") as bs: + bs.write(introspector_vanilla_build_script) -def refine_static_libs(build_results) -> None: - """Returns a list of static libraries with substitution of common gtest - libraries, which should not be linked in the fuzzer builds.""" - for test_dir in build_results: - refined_static_list = [] - libs_to_avoid = { - 'libgtest.a', 'libgmock.a', 'libgmock_main.a', 'libgtest_main.a' - } - build_worker = build_results[test_dir] - static_libs = build_worker.executable_files_build['static-libs'] - for static_lib in static_libs: - if any( - os.path.basename(static_lib) in lib_to_avoid - for lib_to_avoid in libs_to_avoid): - continue - refined_static_list.append(static_lib) - build_worker.executable_files_build[ - 'refined-static-libs'] = refined_static_list - - -def run_introspector_on_dir(build_worker, test_dir, - language) -> Tuple[bool, List[str]]: - """Runs Fuzz Introspector on a target directory with the ability - to analyse code without having fuzzers (FUZZ_INTROSPECTOR_AUTO_FUZZ=1). + if os.path.isfile("/src/compiled_binary"): + os.remove("/src/compiled_binary") - This is done by running the bbuild script that succeeded using introspector - sanitizer from OSS-Fuzz, where introspector will collect data form any - executable linked during the vanilla build. + modified_env = os.environ + modified_env["SANITIZER"] = "introspector" + modified_env["FUZZ_INTROSPECTOR_AUTO_FUZZ"] = "1" + modified_env["PROJECT_NAME"] = "auto-fuzz-proj" + modified_env["FUZZINTRO_OUTDIR"] = test_dir - This is done by way of the OSS-Fuzz `compile` command and by setting - the environment appropriately before running this command. - """ - introspector_vanilla_build_script = build_worker.build_script - (fuzz_compiler, fuzz_flags, empty_fuzzer_file, - fuzz_template) = utils.get_language_defaults(language) - - with open(empty_fuzzer_file, 'w') as f: - f.write(fuzz_template) - - # Try to link the fuzzer to the static libs - fuzzer_build_cmd = [ - fuzz_compiler, fuzz_flags, '$LIB_FUZZING_ENGINE', empty_fuzzer_file - ] - fuzzer_build_cmd.append('-Wl,--allow-multiple-definition') - for refined_static_lib in build_worker.executable_files_build[ - 'refined-static-libs']: - fuzzer_build_cmd.append('-Wl,--whole-archive') - fuzzer_build_cmd.append(os.path.join(test_dir, refined_static_lib)) - fuzzer_build_cmd.append('-Wl,--no-whole-archive') - - fuzzer_build_cmd.append('-o /src/compiled_binary') - - introspector_vanilla_build_script += '\n' - introspector_vanilla_build_script += ' '.join(fuzzer_build_cmd) - - with open('/src/build.sh', 'w') as bs: - bs.write(introspector_vanilla_build_script) - - if os.path.isfile('/src/compiled_binary'): - os.remove('/src/compiled_binary') - - modified_env = os.environ - modified_env['SANITIZER'] = 'introspector' - modified_env['FUZZ_INTROSPECTOR_AUTO_FUZZ'] = '1' - modified_env['PROJECT_NAME'] = 'auto-fuzz-proj' - modified_env['FUZZINTRO_OUTDIR'] = test_dir - - # Disable FI light because we want to make sure we can compile as well. - modified_env['FI_DISABLE_LIGHT'] = "1" - - try: - subprocess.check_call('compile', - shell=True, - env=modified_env, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL) - build_returned_error = False - except subprocess.CalledProcessError: - build_returned_error = True - - if not os.path.isfile('/src/compiled_binary'): - build_returned_error = True - - logger.info('Introspector build: %s', str(build_returned_error)) - return build_returned_error, fuzzer_build_cmd - - -def create_clean_oss_fuzz_from_empty(github_repo: str, build_worker, - language: str, test_dir) -> None: - """Converts a successful empty fuzzer build into an OSS-Fuzz project.""" - - # Save the results - nidx = 0 - oss_fuzz_folder = f'/out/empty-build-{nidx}' - while os.path.isdir(oss_fuzz_folder): - nidx += 1 - oss_fuzz_folder = f'/out/empty-build-{nidx}' - - os.makedirs(oss_fuzz_folder) - - introspector_vanilla_build_script = build_worker.build_script - (fuzz_compiler, fuzz_flags, empty_fuzzer_file, - fuzz_template) = utils.get_language_defaults(language) - - # Write empty fuzzer - with open(os.path.join(oss_fuzz_folder, os.path.basename(empty_fuzzer_file)), - 'w') as f: - f.write(fuzz_template) - - # Try to link the fuzzer to the static libs - fuzzer_build_cmd = [ - fuzz_compiler, fuzz_flags, '$LIB_FUZZING_ENGINE', empty_fuzzer_file - ] - fuzzer_build_cmd.append('-Wl,--allow-multiple-definition') - for refined_static_lib in build_worker.executable_files_build[ - 'refined-static-libs']: - fuzzer_build_cmd.append('-Wl,--whole-archive') - fuzzer_build_cmd.append(os.path.join(test_dir, refined_static_lib)) - - fuzzer_build_cmd.append('-Wl,--no-whole-archive') - - # Add inclusion of header file paths. This is anticipating any harnesses - # will make an effort to include relevant header files. - all_header_files = get_all_header_files(utils.get_all_files_in_path(test_dir)) - paths_to_include = set() - for header_file in all_header_files: - if not header_file.startswith('/src/'): - header_file = '/src/' + header_file - if '/test/' in header_file: - continue - if 'googletest' in header_file: - continue - - path_to_include = '/'.join(header_file.split('/')[:-1]) - paths_to_include.add(path_to_include) - for path_to_include in paths_to_include: - logger.info('Path to include: %s', path_to_include) - fuzzer_build_cmd.append(f'-I{path_to_include}') - - introspector_vanilla_build_script += '\n' - introspector_vanilla_build_script += ' '.join(fuzzer_build_cmd) - - #with open(os.path.join(oss_fuzz_folder, 'build.sh'), 'w') as bs: - # bs.write(introspector_vanilla_build_script) - - # Project yaml - project_yaml = { - 'homepage': github_repo, - 'language': language, - 'primary_contact': 'add_your_email@here.com', - 'main_repo': github_repo - } - with open(os.path.join(oss_fuzz_folder, 'project.yaml'), 'w') as project_out: - yaml.dump(project_yaml, project_out) - - # Create Dockerfile - project_repo_dir = github_repo.split('/')[-1] - additional_packages = build_worker.build_suggestion.list_of_required_packages - dockerfile = templates.CLEAN_OSS_FUZZ_DOCKER.format( - repo_url=github_repo, - project_repo_dir=project_repo_dir, - additional_packages=' '.join(additional_packages), - fuzzer_dir='$SRC/fuzzers/') - with open(os.path.join(oss_fuzz_folder, 'Dockerfile'), 'w') as docker_out: - docker_out.write(dockerfile) - - logger.info('Build script:') - logger.info(introspector_vanilla_build_script) - logger.info('-' * 45) - - # Build file - clean_build_content = convert_test_build_to_clean_build( - introspector_vanilla_build_script, project_repo_dir) - - with open(os.path.join(oss_fuzz_folder, 'build.sh'), 'w') as f: - f.write(clean_build_content) - - -def create_clean_oss_fuzz_from_success(github_repo: str, out_dir: str, - pkgs: List[str], language: str) -> None: - """Converts a successful out dir into a working OSS-Fuzz project.""" - oss_fuzz_folder = os.path.join(out_dir, 'oss-fuzz-project') - os.makedirs(oss_fuzz_folder) - - # Project yaml - project_yaml = { - 'homepage': github_repo, - 'language': language, - 'primary_contact': 'add_your_email@here.com', - 'main_repo': github_repo - } - with open(os.path.join(oss_fuzz_folder, 'project.yaml'), 'w') as project_out: - yaml.dump(project_yaml, project_out) - - # Copy fuzzer - _, _, fuzzer_target_file, _ = utils.get_language_defaults(language) - shutil.copy( - os.path.join(out_dir, os.path.basename(fuzzer_target_file)), - os.path.join(oss_fuzz_folder, - os.path.basename(fuzzer_target_file).replace('empty-', ''))) - - # Create Dockerfile - project_repo_dir = github_repo.split('/')[-1] - dockerfile = templates.CLEAN_OSS_FUZZ_DOCKER.format( - repo_url=github_repo, - project_repo_dir=project_repo_dir, - additional_packages=' '.join(pkgs), - fuzzer_dir='$SRC/fuzzers/') - with open(os.path.join(oss_fuzz_folder, 'Dockerfile'), 'w') as docker_out: - docker_out.write(dockerfile) - - # Build file - with open(os.path.join(out_dir, 'build.sh'), 'r') as f: - build_content = f.read() - - clean_build_content = convert_test_build_to_clean_build( - build_content, project_repo_dir) - - with open(os.path.join(oss_fuzz_folder, 'build.sh'), 'w') as f: - f.write(clean_build_content) - - -def create_clean_clusterfuzz_lite_from_success(github_repo: str, out_dir: str, - pkgs: List[str], - language: str) -> None: - """Converts a successful out dir into a working ClusterFuzzLite project.""" - cflite_folder = os.path.join(out_dir, 'clusterfuzz-lite-project') - os.makedirs(cflite_folder) - - # Project yaml - project_yaml = { - 'language': language, - } - with open(os.path.join(cflite_folder, 'project.yaml'), 'w') as project_out: - yaml.dump(project_yaml, project_out) - - # Copy fuzzer - _, _, fuzzer_target_file, _ = utils.get_language_defaults(language) - shutil.copy( - os.path.join(out_dir, os.path.basename(fuzzer_target_file)), - os.path.join(cflite_folder, - os.path.basename(fuzzer_target_file).replace('empty-', ''))) - - # Create Dockerfile - project_repo_dir = github_repo.split('/')[-1] - dockerfile = templates.CLEAN_DOCKER_CFLITE.format( - project_repo_dir=project_repo_dir, additional_packages=' '.join(pkgs)) - with open(os.path.join(cflite_folder, 'Dockerfile'), 'w') as docker_out: - docker_out.write(dockerfile) - - # Build file - with open(os.path.join(out_dir, 'build.sh'), 'r') as f: - build_content = f.read() - - clean_build_content = convert_test_build_to_clean_build( - build_content, project_repo_dir) - - with open(os.path.join(cflite_folder, 'build.sh'), 'w') as f: - f.write(clean_build_content) - - with open(os.path.join(cflite_folder, 'cflite_pr.yml'), 'w') as f: - f.write(templates.CFLITE_TEMPLATE) - - -def convert_fuzz_build_line_to_loop(clean_build_content: str, - original_build_folder: str, - project_repo_dir: str) -> str: - """Adjust fuzz building script so that harnesses are build in a loop - iterating $SRC/fuzzers/*. The goal of this is to make it easier to add - additional harnesses that will also get build. - """ - split_lines = clean_build_content.split('\n') - target_line_idx = -1 - for idx, tmp_line in enumerate(split_lines): - if '/src/generated-fuzzer' in tmp_line or '/src/empty-fuzzer' in tmp_line: - target_line_idx = idx - break - if target_line_idx == -1: - raise RuntimeError('Did not find harness build command.') - - wrapper_script = '''for fuzzer in $SRC/fuzzers/*; do - fuzzer_target=$(basename $fuzzer) - fuzzer_target="${fuzzer_target%.*}" - LINE_TO_SUBSTITUTE -done''' - target_line = split_lines[target_line_idx] + # Disable FI light because we want to make sure we can compile as well. + modified_env["FI_DISABLE_LIGHT"] = "1" - # Make adjustments to the harness build command: - # 1) Output fuzzers to $OUT/ instead of /src/generated-fuzzer - # 2) Name fuzzer baesd on bash variable instead of 'empty-fuzzer' - # 3) Use '$SRC/' instead of '/src/' - # 4) Rewrite file paths from test build directory to cloned directory, to - # adjust e.g. library and include paths. - target_line = target_line.replace( - '/src/generated-fuzzer', '$OUT/${fuzzer_target}').replace( - '/src/empty-fuzzer.cpp', - '${fuzzer}').replace('/src/empty-fuzzer.c', '${fuzzer}').replace( - '/src/', '$SRC/').replace(original_build_folder, project_repo_dir) + try: + subprocess.check_call( + "compile", + shell=True, + env=modified_env, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + build_returned_error = False + except subprocess.CalledProcessError: + build_returned_error = True + + if not os.path.isfile("/src/compiled_binary"): + build_returned_error = True - if '$OUT/${fuzzer_target}' not in target_line: - target_line += ' -o $OUT/${fuzzer_target}' + logger.info("Introspector build: %s", str(build_returned_error)) + return build_returned_error, fuzzer_build_cmd - wrapper_script = wrapper_script.replace('LINE_TO_SUBSTITUTE', target_line) - split_lines[target_line_idx] = wrapper_script - return '\n'.join(split_lines) +def create_clean_oss_fuzz_from_empty( + github_repo: str, build_worker, language: str, test_dir +) -> None: + """Converts a successful empty fuzzer build into an OSS-Fuzz project.""" -def convert_test_build_to_clean_build(test_build_script: str, - project_repo_dir: str) -> str: - """Rewrites a build.sh used during testing to a proper OSS-Fuzz build.sh.""" - split_build_content = test_build_script.split('\n') + # Save the results + nidx = 0 + oss_fuzz_folder = f"/out/empty-build-{nidx}" + while os.path.isdir(oss_fuzz_folder): + nidx += 1 + oss_fuzz_folder = f"/out/empty-build-{nidx}" - # Extract the test folder name - original_build_folder = split_build_content[1].split('/')[-1] + os.makedirs(oss_fuzz_folder) - # Remove the lines used in the testing build script to navigate test folders. - clean_build_content_lines = '\n'.join(split_build_content[:1] + - split_build_content[4:]) + introspector_vanilla_build_script = build_worker.build_script + (fuzz_compiler, fuzz_flags, empty_fuzzer_file, fuzz_template) = ( + utils.get_language_defaults(language) + ) - clean_build_content = convert_fuzz_build_line_to_loop( - clean_build_content_lines, original_build_folder, project_repo_dir) - return clean_build_content + # Write empty fuzzer + with open( + os.path.join(oss_fuzz_folder, os.path.basename(empty_fuzzer_file)), "w" + ) as f: + f.write(fuzz_template) + + # Try to link the fuzzer to the static libs + fuzzer_build_cmd = [ + fuzz_compiler, + fuzz_flags, + "$LIB_FUZZING_ENGINE", + empty_fuzzer_file, + ] + fuzzer_build_cmd.append("-Wl,--allow-multiple-definition") + for refined_static_lib in build_worker.executable_files_build[ + "refined-static-libs" + ]: + fuzzer_build_cmd.append("-Wl,--whole-archive") + fuzzer_build_cmd.append(os.path.join(test_dir, refined_static_lib)) + + fuzzer_build_cmd.append("-Wl,--no-whole-archive") + + # Add inclusion of header file paths. This is anticipating any harnesses + # will make an effort to include relevant header files. + all_header_files = get_all_header_files(utils.get_all_files_in_path(test_dir)) + paths_to_include = set() + for header_file in all_header_files: + if not header_file.startswith("/src/"): + header_file = "/src/" + header_file + if "/test/" in header_file: + continue + if "googletest" in header_file: + continue + + path_to_include = "/".join(header_file.split("/")[:-1]) + paths_to_include.add(path_to_include) + for path_to_include in paths_to_include: + logger.info("Path to include: %s", path_to_include) + fuzzer_build_cmd.append(f"-I{path_to_include}") + + introspector_vanilla_build_script += "\n" + introspector_vanilla_build_script += " ".join(fuzzer_build_cmd) + + # with open(os.path.join(oss_fuzz_folder, 'build.sh'), 'w') as bs: + # bs.write(introspector_vanilla_build_script) + + # Project yaml + project_yaml = { + "homepage": github_repo, + "language": language, + "primary_contact": "add_your_email@here.com", + "main_repo": github_repo, + } + with open(os.path.join(oss_fuzz_folder, "project.yaml"), "w") as project_out: + yaml.dump(project_yaml, project_out) + + # Create Dockerfile + project_repo_dir = github_repo.split("/")[-1] + additional_packages = build_worker.build_suggestion.list_of_required_packages + dockerfile = templates.CLEAN_OSS_FUZZ_DOCKER.format( + repo_url=github_repo, + project_repo_dir=project_repo_dir, + additional_packages=" ".join(additional_packages), + fuzzer_dir="$SRC/fuzzers/", + ) + with open(os.path.join(oss_fuzz_folder, "Dockerfile"), "w") as docker_out: + docker_out.write(dockerfile) + + logger.info("Build script:") + logger.info(introspector_vanilla_build_script) + logger.info("-" * 45) + + # Build file + clean_build_content = convert_test_build_to_clean_build( + introspector_vanilla_build_script, project_repo_dir + ) + + with open(os.path.join(oss_fuzz_folder, "build.sh"), "w") as f: + f.write(clean_build_content) + + +def create_clean_oss_fuzz_from_success( + github_repo: str, out_dir: str, pkgs: List[str], language: str +) -> None: + """Converts a successful out dir into a working OSS-Fuzz project.""" + oss_fuzz_folder = os.path.join(out_dir, "oss-fuzz-project") + os.makedirs(oss_fuzz_folder) + + # Project yaml + project_yaml = { + "homepage": github_repo, + "language": language, + "primary_contact": "add_your_email@here.com", + "main_repo": github_repo, + } + with open(os.path.join(oss_fuzz_folder, "project.yaml"), "w") as project_out: + yaml.dump(project_yaml, project_out) + + # Copy fuzzer + _, _, fuzzer_target_file, _ = utils.get_language_defaults(language) + shutil.copy( + os.path.join(out_dir, os.path.basename(fuzzer_target_file)), + os.path.join( + oss_fuzz_folder, os.path.basename(fuzzer_target_file).replace("empty-", "") + ), + ) + + # Create Dockerfile + project_repo_dir = github_repo.split("/")[-1] + dockerfile = templates.CLEAN_OSS_FUZZ_DOCKER.format( + repo_url=github_repo, + project_repo_dir=project_repo_dir, + additional_packages=" ".join(pkgs), + fuzzer_dir="$SRC/fuzzers/", + ) + with open(os.path.join(oss_fuzz_folder, "Dockerfile"), "w") as docker_out: + docker_out.write(dockerfile) + + # Build file + with open(os.path.join(out_dir, "build.sh"), "r") as f: + build_content = f.read() + + clean_build_content = convert_test_build_to_clean_build( + build_content, project_repo_dir + ) + + with open(os.path.join(oss_fuzz_folder, "build.sh"), "w") as f: + f.write(clean_build_content) + + +def create_clean_clusterfuzz_lite_from_success( + github_repo: str, out_dir: str, pkgs: List[str], language: str +) -> None: + """Converts a successful out dir into a working ClusterFuzzLite project.""" + cflite_folder = os.path.join(out_dir, "clusterfuzz-lite-project") + os.makedirs(cflite_folder) + + # Project yaml + project_yaml = { + "language": language, + } + with open(os.path.join(cflite_folder, "project.yaml"), "w") as project_out: + yaml.dump(project_yaml, project_out) + + # Copy fuzzer + _, _, fuzzer_target_file, _ = utils.get_language_defaults(language) + shutil.copy( + os.path.join(out_dir, os.path.basename(fuzzer_target_file)), + os.path.join( + cflite_folder, os.path.basename(fuzzer_target_file).replace("empty-", "") + ), + ) + + # Create Dockerfile + project_repo_dir = github_repo.split("/")[-1] + dockerfile = templates.CLEAN_DOCKER_CFLITE.format( + project_repo_dir=project_repo_dir, additional_packages=" ".join(pkgs) + ) + with open(os.path.join(cflite_folder, "Dockerfile"), "w") as docker_out: + docker_out.write(dockerfile) + + # Build file + with open(os.path.join(out_dir, "build.sh"), "r") as f: + build_content = f.read() + + clean_build_content = convert_test_build_to_clean_build( + build_content, project_repo_dir + ) + + with open(os.path.join(cflite_folder, "build.sh"), "w") as f: + f.write(clean_build_content) + + with open(os.path.join(cflite_folder, "cflite_pr.yml"), "w") as f: + f.write(templates.CFLITE_TEMPLATE) + + +def convert_fuzz_build_line_to_loop( + clean_build_content: str, original_build_folder: str, project_repo_dir: str +) -> str: + """Adjust fuzz building script so that harnesses are build in a loop + iterating $SRC/fuzzers/*. The goal of this is to make it easier to add + additional harnesses that will also get build. + """ + split_lines = clean_build_content.split("\n") + target_line_idx = -1 + for idx, tmp_line in enumerate(split_lines): + if "/src/generated-fuzzer" in tmp_line or "/src/empty-fuzzer" in tmp_line: + target_line_idx = idx + break + if target_line_idx == -1: + raise RuntimeError("Did not find harness build command.") + + wrapper_script = """for fuzzer in $SRC/fuzzers/*; do + fuzzer_target=$(basename $fuzzer) + fuzzer_target="${fuzzer_target%.*}" + LINE_TO_SUBSTITUTE +done""" + target_line = split_lines[target_line_idx] + + # Make adjustments to the harness build command: + # 1) Output fuzzers to $OUT/ instead of /src/generated-fuzzer + # 2) Name fuzzer baesd on bash variable instead of 'empty-fuzzer' + # 3) Use '$SRC/' instead of '/src/' + # 4) Rewrite file paths from test build directory to cloned directory, to + # adjust e.g. library and include paths. + target_line = ( + target_line.replace("/src/generated-fuzzer", "$OUT/${fuzzer_target}") + .replace("/src/empty-fuzzer.cpp", "${fuzzer}") + .replace("/src/empty-fuzzer.c", "${fuzzer}") + .replace("/src/", "$SRC/") + .replace(original_build_folder, project_repo_dir) + ) + + if "$OUT/${fuzzer_target}" not in target_line: + target_line += " -o $OUT/${fuzzer_target}" + + wrapper_script = wrapper_script.replace("LINE_TO_SUBSTITUTE", target_line) + split_lines[target_line_idx] = wrapper_script + return "\n".join(split_lines) + + +def convert_test_build_to_clean_build( + test_build_script: str, project_repo_dir: str +) -> str: + """Rewrites a build.sh used during testing to a proper OSS-Fuzz build.sh.""" + split_build_content = test_build_script.split("\n") + + # Extract the test folder name + original_build_folder = split_build_content[1].split("/")[-1] + + # Remove the lines used in the testing build script to navigate test folders. + clean_build_content_lines = "\n".join( + split_build_content[:1] + split_build_content[4:] + ) + + clean_build_content = convert_fuzz_build_line_to_loop( + clean_build_content_lines, original_build_folder, project_repo_dir + ) + return clean_build_content def append_to_report(outdir, msg): - if not os.path.isdir(outdir): - os.mkdir(outdir) - report_path = os.path.join(outdir, 'report.txt') - with open(report_path, 'a+') as f: - f.write(msg + '\n') + if not os.path.isdir(outdir): + os.mkdir(outdir) + report_path = os.path.join(outdir, "report.txt") + with open(report_path, "a+") as f: + f.write(msg + "\n") def load_introspector_report(): - """Extract introspector as python dictionary from local run.""" - if not os.path.isfile(os.path.join(INTROSPECTOR_OSS_FUZZ_DIR, - 'summary.json')): - return None - with open(os.path.join(INTROSPECTOR_OSS_FUZZ_DIR, 'summary.json'), 'r') as f: - summary_report = json.loads(f.read()) - - # Get all functions folder - if not os.path.isfile( - os.path.join(INTROSPECTOR_OSS_FUZZ_DIR, INTROSPECTOR_ALL_FUNCTIONS_FILE)): - return None - with open( - os.path.join(INTROSPECTOR_OSS_FUZZ_DIR, INTROSPECTOR_ALL_FUNCTIONS_FILE), - 'r') as f: - all_functions_list = json.loads(f.read()) - - summary_report['MergedProjectProfile']['all-functions'] = all_functions_list - return summary_report - - -def auto_generate(github_url, disable_testing_build_scripts=False, outdir=''): - """Generates build script and fuzzer harnesses for a GitHub repository.""" - target_source_path = os.path.join(os.getcwd(), github_url.split('/')[-1]) - dst_folder = github_url.split('/')[-1] - - # clone the base project into a dedicated folder - if not os.path.isdir(target_source_path): - subprocess.check_call( - f'git clone --recurse-submodules {github_url} {dst_folder}', shell=True) - - # Stage 1: Build script generation - language = utils.determine_project_language(target_source_path) - logger.info('Target language: %s', language) - append_to_report(outdir, f'Target language: {language}') - - # record the path - logger.info('[+] Extracting build scripts statically') - all_build_scripts: List[Tuple[ - str, str, build_script_generator. - AutoBuildContainer]] = build_script_generator.extract_build_suggestions( - target_source_path, 'test-fuzz-build-') - - # Check each of the build scripts. - logger.info('[+] Testing build suggestions') - build_results = build_script_generator.raw_build_evaluation(all_build_scripts) - logger.info('Checking results of %d build generators', len(build_results)) - - if disable_testing_build_scripts: - logger.info('disabling testing build scripts') - return - - for test_dir, build_worker in build_results.items(): - build_heuristic = build_worker.build_suggestion.heuristic_id - static_libs = build_worker.executable_files_build['static-libs'] - - append_to_report( - outdir, - f'build success: {build_heuristic} :: {test_dir} :: {static_libs}') - logger.info('%s : %s : %s', build_heuristic, test_dir, static_libs) - - # For each of the auto generated build scripts identify the - # static libraries resulting from the build. - refine_static_libs(build_results) - - refined_builds = [] - b_idx = 0 - for test_dir, build_worker in build_results.items(): - if len(build_worker.executable_files_build) > 1: - for ref_lib in build_worker.executable_files_build['refined-static-libs']: - b_idx += 1 - new_worker = build_script_generator.BuildWorker( - build_worker.build_suggestion, build_worker.build_script, - build_worker.build_directory, - build_worker.executable_files_build.copy()) - new_worker.build_suggestion.heuristic_id = ( - new_worker.build_suggestion.heuristic_id + f'-{b_idx}') - new_worker.executable_files_build['refined-static-libs'] = [ref_lib] - refined_builds.append((test_dir, new_worker)) - refined_builds.append((test_dir, build_worker)) - - build_results = refined_builds - - logger.info('logging builds') - for test_dir, build_worker in build_results: - logger.info('Sample:') - logger.info(json.dumps(build_worker.executable_files_build)) - logger.info('------------------------') - - # Stage 2: perform program analysis to extract data to be used for - # harness generation. - build_empty_fuzzers(build_results, language) - - # Stage 3: Harness generation and harness testing. - # We now know for which versions we can generate a base fuzzer. - # Continue by runnig an introspector build using the auto-generated - # build scripts but fuzz introspector as the sanitier. The introspector - # build will analyze all code build in the project, meaning we will - # extract build data for code linked in e.g. samples and more during - # the build. The consequence is we will have a lot more data than if - # we only were to build the base fuzzer using introspector builds. - # Then, proceed to use the generated program analysis data as arguments - # to heuristics which will generate fuzzers. - # We need to run introspector per build, because we're essentially not - # sure if the produced binary files are the same. We could maybe optimize - # this to check if there are differences in build output. - logger.info('Going through %d build results to generate fuzzers', - len(build_results)) - - for test_dir, build_worker in build_results: - logger.info('Checking build heuristic: %s', - build_worker.build_suggestion.heuristic_id) - - # Skip if build suggestion did not work with an empty fuzzer. - if not build_worker.base_fuzz_build: - logger.info('Build failed, skipping') - continue - - # Run Fuzz Introspector on the target - logger.info('Running introspector build') - if os.path.isdir(INTROSPECTOR_OSS_FUZZ_DIR): - shutil.rmtree(INTROSPECTOR_OSS_FUZZ_DIR) - - build_returned_error, _ = run_introspector_on_dir(build_worker, test_dir, - language) - - if os.path.isdir(INTROSPECTOR_OSS_FUZZ_DIR): - logger.info('Introspector build success') - else: - logger.info('Failed to get introspector results') - - if build_returned_error: - logger.info( - 'Introspector build returned error, but light version worked.') - continue - - # Identify the relevant functions - introspector_report = load_introspector_report() - if introspector_report is None: - continue - - func_count = len( - introspector_report['MergedProjectProfile']['all-functions']) - logger.info('Found a total of %d functions.', func_count) - append_to_report(outdir, 'Introspector analysis done') - - logger.info('Test dir: %s', str(test_dir)) - - append_to_report(outdir, f'Total functions in {test_dir} : {func_count}') - - create_clean_oss_fuzz_from_empty(github_url, build_worker, language, - test_dir) + """Extract introspector as python dictionary from local run.""" + if not os.path.isfile(os.path.join(INTROSPECTOR_OSS_FUZZ_DIR, "summary.json")): + return None + with open(os.path.join(INTROSPECTOR_OSS_FUZZ_DIR, "summary.json"), "r") as f: + summary_report = json.loads(f.read()) + + # Get all functions folder + if not os.path.isfile( + os.path.join(INTROSPECTOR_OSS_FUZZ_DIR, INTROSPECTOR_ALL_FUNCTIONS_FILE) + ): + return None + with open( + os.path.join(INTROSPECTOR_OSS_FUZZ_DIR, INTROSPECTOR_ALL_FUNCTIONS_FILE), "r" + ) as f: + all_functions_list = json.loads(f.read()) + + summary_report["MergedProjectProfile"]["all-functions"] = all_functions_list + return summary_report + + +def auto_generate(github_url, disable_testing_build_scripts=False, outdir=""): + """Generates build script and fuzzer harnesses for a GitHub repository.""" + target_source_path = os.path.join(os.getcwd(), github_url.split("/")[-1]) + dst_folder = github_url.split("/")[-1] + + # clone the base project into a dedicated folder + if not os.path.isdir(target_source_path): + subprocess.check_call( + f"git clone --recurse-submodules {github_url} {dst_folder}", shell=True + ) + + # Stage 1: Build script generation + language = utils.determine_project_language(target_source_path) + logger.info("Target language: %s", language) + append_to_report(outdir, f"Target language: {language}") + + # record the path + logger.info("[+] Extracting build scripts statically") + all_build_scripts: List[ + Tuple[str, str, build_script_generator.AutoBuildContainer] + ] = build_script_generator.extract_build_suggestions( + target_source_path, "test-fuzz-build-" + ) + + # Check each of the build scripts. + logger.info("[+] Testing build suggestions") + build_results = build_script_generator.raw_build_evaluation(all_build_scripts) + logger.info("Checking results of %d build generators", len(build_results)) + + if disable_testing_build_scripts: + logger.info("disabling testing build scripts") + return + + for test_dir, build_worker in build_results.items(): + build_heuristic = build_worker.build_suggestion.heuristic_id + static_libs = build_worker.executable_files_build["static-libs"] + + append_to_report( + outdir, f"build success: {build_heuristic} :: {test_dir} :: {static_libs}" + ) + logger.info("%s : %s : %s", build_heuristic, test_dir, static_libs) + + # For each of the auto generated build scripts identify the + # static libraries resulting from the build. + refine_static_libs(build_results) + + refined_builds = [] + b_idx = 0 + for test_dir, build_worker in build_results.items(): + if len(build_worker.executable_files_build) > 1: + for ref_lib in build_worker.executable_files_build["refined-static-libs"]: + b_idx += 1 + new_worker = build_script_generator.BuildWorker( + build_worker.build_suggestion, + build_worker.build_script, + build_worker.build_directory, + build_worker.executable_files_build.copy(), + ) + new_worker.build_suggestion.heuristic_id = ( + new_worker.build_suggestion.heuristic_id + f"-{b_idx}" + ) + new_worker.executable_files_build["refined-static-libs"] = [ref_lib] + refined_builds.append((test_dir, new_worker)) + refined_builds.append((test_dir, build_worker)) + + build_results = refined_builds + + logger.info("logging builds") + for test_dir, build_worker in build_results: + logger.info("Sample:") + logger.info(json.dumps(build_worker.executable_files_build)) + logger.info("------------------------") + + # Stage 2: perform program analysis to extract data to be used for + # harness generation. + build_empty_fuzzers(build_results, language) + + # Stage 3: Harness generation and harness testing. + # We now know for which versions we can generate a base fuzzer. + # Continue by runnig an introspector build using the auto-generated + # build scripts but fuzz introspector as the sanitier. The introspector + # build will analyze all code build in the project, meaning we will + # extract build data for code linked in e.g. samples and more during + # the build. The consequence is we will have a lot more data than if + # we only were to build the base fuzzer using introspector builds. + # Then, proceed to use the generated program analysis data as arguments + # to heuristics which will generate fuzzers. + # We need to run introspector per build, because we're essentially not + # sure if the produced binary files are the same. We could maybe optimize + # this to check if there are differences in build output. + logger.info( + "Going through %d build results to generate fuzzers", len(build_results) + ) + + for test_dir, build_worker in build_results: + logger.info( + "Checking build heuristic: %s", build_worker.build_suggestion.heuristic_id + ) + + # Skip if build suggestion did not work with an empty fuzzer. + if not build_worker.base_fuzz_build: + logger.info("Build failed, skipping") + continue + + # Run Fuzz Introspector on the target + logger.info("Running introspector build") + if os.path.isdir(INTROSPECTOR_OSS_FUZZ_DIR): + shutil.rmtree(INTROSPECTOR_OSS_FUZZ_DIR) + + build_returned_error, _ = run_introspector_on_dir( + build_worker, test_dir, language + ) + + if os.path.isdir(INTROSPECTOR_OSS_FUZZ_DIR): + logger.info("Introspector build success") + else: + logger.info("Failed to get introspector results") + + if build_returned_error: + logger.info("Introspector build returned error, but light version worked.") + continue + + # Identify the relevant functions + introspector_report = load_introspector_report() + if introspector_report is None: + continue + + func_count = len(introspector_report["MergedProjectProfile"]["all-functions"]) + logger.info("Found a total of %d functions.", func_count) + append_to_report(outdir, "Introspector analysis done") + + logger.info("Test dir: %s", str(test_dir)) + + append_to_report(outdir, f"Total functions in {test_dir} : {func_count}") + + create_clean_oss_fuzz_from_empty(github_url, build_worker, language, test_dir) def parse_commandline(): - """Commandline parser.""" - parser = argparse.ArgumentParser() - parser.add_argument('repo', help='Github url of target') - parser.add_argument('--disable-build-test', - action='store_true', - help='disables') - parser.add_argument('--out', '-o', help='Directory to store successful runs') - parser.add_argument('--model', - '-m', - help='Model to use for auto generation', - type=str) - return parser + """Commandline parser.""" + parser = argparse.ArgumentParser() + parser.add_argument("repo", help="Github url of target") + parser.add_argument("--disable-build-test", action="store_true", help="disables") + parser.add_argument("--out", "-o", help="Directory to store successful runs") + parser.add_argument( + "--model", "-m", help="Model to use for auto generation", type=str + ) + return parser def setup_logging(): - logging.basicConfig(level=logging.INFO, format=LOG_FMT) + logging.basicConfig(level=logging.INFO, format=LOG_FMT) def main(): - parser = parse_commandline() - args = parser.parse_args() - setup_logging() + parser = parse_commandline() + args = parser.parse_args() + setup_logging() - setup_model(args.model) + setup_model(args.model) - append_to_report(args.out, f'Analysing: {args.repo}') + append_to_report(args.out, f"Analysing: {args.repo}") - auto_generate(args.repo, args.disable_build_test, args.out) + auto_generate(args.repo, args.disable_build_test, args.out) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/experimental/build_generator/runner.py b/experimental/build_generator/runner.py index ea5e6cf3dc..fd0e5ca65e 100644 --- a/experimental/build_generator/runner.py +++ b/experimental/build_generator/runner.py @@ -29,545 +29,607 @@ from experiment import oss_fuzz_checkout from experiment.benchmark import Benchmark from experiment.workdir import WorkDirs -from experimental.build_generator import (constants, file_utils, llm_agent, - templates) +from experimental.build_generator import ( + constants, + file_utils, + llm_agent, + templates, +) from llm_toolkit import models from results import Result silent_global = False logger = logging.getLogger(name=__name__) -LOG_FMT = ('%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ' - ': %(funcName)s: %(message)s') - - -def setup_worker_project(oss_fuzz_base: str, - project_name: str, - llm_model: str, - github_url: str = '', - from_agent: bool = False, - workdir: str = '') -> str: - """Setup empty OSS-Fuzz project used for analysis.""" - language = '' - - temp_project_dir = os.path.join(oss_fuzz_base, "projects", project_name) - if os.path.isdir(temp_project_dir): - shutil.rmtree(temp_project_dir) - - os.makedirs(temp_project_dir) - with open(os.path.join(temp_project_dir, 'project.yaml'), 'w') as f: - f.write(templates.EMPTY_PROJECT_YAML) - with open(os.path.join(temp_project_dir, 'build.sh'), 'w') as f: - f.write(templates.EMPTY_OSS_FUZZ_BUILD) - with open(os.path.join(temp_project_dir, 'Dockerfile'), 'w') as f: +LOG_FMT = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " ": %(funcName)s: %(message)s" +) + + +def setup_worker_project( + oss_fuzz_base: str, + project_name: str, + llm_model: str, + github_url: str = "", + from_agent: bool = False, + workdir: str = "", +) -> str: + """Setup empty OSS-Fuzz project used for analysis.""" + language = "" + + temp_project_dir = os.path.join(oss_fuzz_base, "projects", project_name) + if os.path.isdir(temp_project_dir): + shutil.rmtree(temp_project_dir) + + os.makedirs(temp_project_dir) + with open(os.path.join(temp_project_dir, "project.yaml"), "w") as f: + f.write(templates.EMPTY_PROJECT_YAML) + with open(os.path.join(temp_project_dir, "build.sh"), "w") as f: + f.write(templates.EMPTY_OSS_FUZZ_BUILD) + with open(os.path.join(temp_project_dir, "Dockerfile"), "w") as f: + if from_agent: + file_content = templates.CLEAN_OSS_FUZZ_DOCKER + file_content = file_content.replace("{additional_packages}", "") + file_content = file_content.replace("{repo_url}", github_url) + file_content = file_content.replace("{fuzzer_dir}", "$SRC/") + file_content = file_content.replace( + "{project_repo_dir}", github_url.split("/")[-1] + ) + else: + file_content = templates.AUTOGEN_DOCKER_FILE + + f.write(file_content) + + # Prepare demo fuzzing harness source if from_agent: - file_content = templates.CLEAN_OSS_FUZZ_DOCKER - file_content = file_content.replace('{additional_packages}', '') - file_content = file_content.replace('{repo_url}', github_url) - file_content = file_content.replace('{fuzzer_dir}', '$SRC/') - file_content = file_content.replace('{project_repo_dir}', - github_url.split('/')[-1]) + repo_path = os.path.join(workdir, "temp_repo") + git.Repo.clone_from(github_url, repo_path) + try: + language = file_utils.determine_project_language(repo_path) + _, _, name, code = file_utils.get_language_defaults(language) + with open(os.path.join(temp_project_dir, name.split("/")[-1]), "w") as f: + f.write(code) + finally: + if os.path.exists(repo_path) and os.path.isdir(repo_path): + shutil.rmtree(repo_path) + + if llm_model == "vertex": + json_config = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", None) + if json_config is None: + logger.info("vertex model is set but could not find configuration file.") + logger.info("Plese set GOOGLE_APPLICATION_CREDENTIALS env variable.") + sys.exit(1) + shutil.copyfile(json_config, os.path.join(temp_project_dir, "creds.json")) + + # Copy over the generator (only for general approach + if not from_agent: + files_to_copy = { + "build_script_generator.py", + "manager.py", + "templates.py", + "constants.py", + "file_utils.py", + } + for target_file in files_to_copy: + shutil.copyfile( + os.path.join(os.path.dirname(os.path.abspath(__file__)), target_file), + os.path.join(temp_project_dir, target_file.split("/")[-1]), + ) + + # Build a version of the project + if silent_global: + subprocess.check_call( + f"python3 infra/helper.py build_fuzzers {project_name}", + shell=True, + cwd=oss_fuzz_base, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) else: - file_content = templates.AUTOGEN_DOCKER_FILE + subprocess.check_call( + f"python3 infra/helper.py build_fuzzers {project_name}", + shell=True, + cwd=oss_fuzz_base, + ) + + return language + + +def run_autogen( + github_url, + outdir, + oss_fuzz_base, + worker_project, + model, + openai_api_key=None, + build_heuristics="all", + max_successful_builds: int = -1, + max_timeout: int = 0, +): + """Launch auto-gen analysis within OSS-Fuzz container.""" + initiator_cmd = f"python3 /src/manager.py {github_url} -o {outdir}" + initiator_cmd += f" --model={model}" + if max_successful_builds > 0: + initiator_cmd += f" --max-successful={max_successful_builds}" + + extra_environment = [] + if model == constants.MODEL_VERTEX: + extra_environment.append("-e") + extra_environment.append("GOOGLE_APPLICATION_CREDENTIALS=/src/creds.json") + elif openai_api_key: + extra_environment.append("-e") + extra_environment.append(f"OPENAI_API_KEY={openai_api_key}") + + cmd = [ + "docker", + "run", + "--rm", + "-e", + "FUZZING_ENGINE=libfuzzer", + "-e", + "SANITIZER=address", + "-e", + "ARCHITECTURE=x86_64", + "-e", + "HELPER=True", + "-e", + "FUZZING_LANGUAGE=c++", + "-e", + f"BUILD_HEURISTICS={build_heuristics}", + ] + extra_environment + + if max_timeout: + cmd = ["timeout", str(max_timeout)] + cmd + + cmd += [ + "-v", + f"{oss_fuzz_base}/build/out/{worker_project}:/out", + "-v", + f"{oss_fuzz_base}/build/work/{worker_project}:/work", + "-t", + f"gcr.io/oss-fuzz/{worker_project}", + # Command to run inside the container + initiator_cmd, + ] + + cmd_to_run = " ".join(cmd) + try: + if silent_global: + subprocess.check_call( + cmd_to_run, + cwd=oss_fuzz_base, + shell=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.STDOUT, + ) + else: + subprocess.check_call(cmd_to_run, cwd=oss_fuzz_base, shell=True) + except subprocess.CalledProcessError: + pass - f.write(file_content) - # Prepare demo fuzzing harness source - if from_agent: - repo_path = os.path.join(workdir, 'temp_repo') - git.Repo.clone_from(github_url, repo_path) +def read_targets_file(filename: str) -> List[str]: + """Parse input file.""" + res_targets = [] + with open(filename, "r") as f: + targets = f.read().split("\n") + for e in targets: + if len(e) < 6: + continue + if e: + res_targets.append(e) + return res_targets + + +def run_on_targets( + target, + oss_fuzz_base, + worker_project_name, + idx, + llm_model, + semaphore=None, + build_heuristics="all", + output="", + max_timeout: int = 0, +): + """Thread entry point for single project auto-gen.""" + + if semaphore is not None: + semaphore.acquire() + + openai_api_key = os.getenv("OPENAI_API_KEY", None) + + outdir = os.path.join("/out/", constants.SHARED_MEMORY_RESULTS_DIR) + with open("status-log.txt", "a") as f: + f.write(f"Targeting: {target} :: {idx}\n") + run_autogen( + target, + outdir, + oss_fuzz_base, + worker_project_name, + llm_model, + openai_api_key=openai_api_key, + build_heuristics=build_heuristics, + max_timeout=max_timeout, + ) + + # Cleanup the OSS-Fuzz docker image + clean_up_cmd = ["docker", "image", "rm", f"gcr.io/oss-fuzz/{worker_project_name}"] try: - language = file_utils.determine_project_language(repo_path) - _, _, name, code = file_utils.get_language_defaults(language) - with open(os.path.join(temp_project_dir, name.split('/')[-1]), 'w') as f: - f.write(code) - finally: - if os.path.exists(repo_path) and os.path.isdir(repo_path): - shutil.rmtree(repo_path) - - if llm_model == 'vertex': - json_config = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS', None) - if json_config is None: - logger.info('vertex model is set but could not find configuration file.') - logger.info('Plese set GOOGLE_APPLICATION_CREDENTIALS env variable.') - sys.exit(1) - shutil.copyfile(json_config, os.path.join(temp_project_dir, 'creds.json')) - - # Copy over the generator (only for general approach - if not from_agent: - files_to_copy = { - 'build_script_generator.py', 'manager.py', 'templates.py', - 'constants.py', 'file_utils.py' - } - for target_file in files_to_copy: - shutil.copyfile( - os.path.join(os.path.dirname(os.path.abspath(__file__)), target_file), - os.path.join(temp_project_dir, - target_file.split('/')[-1])) - - # Build a version of the project - if silent_global: - subprocess.check_call( - f'python3 infra/helper.py build_fuzzers {project_name}', - shell=True, - cwd=oss_fuzz_base, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL) - else: - subprocess.check_call( - f'python3 infra/helper.py build_fuzzers {project_name}', - shell=True, - cwd=oss_fuzz_base) - - return language - - -def run_autogen(github_url, - outdir, - oss_fuzz_base, - worker_project, - model, - openai_api_key=None, - build_heuristics='all', - max_successful_builds: int = -1, - max_timeout: int = 0): - """Launch auto-gen analysis within OSS-Fuzz container.""" - initiator_cmd = f'python3 /src/manager.py {github_url} -o {outdir}' - initiator_cmd += f' --model={model}' - if max_successful_builds > 0: - initiator_cmd += f' --max-successful={max_successful_builds}' - - extra_environment = [] - if model == constants.MODEL_VERTEX: - extra_environment.append('-e') - extra_environment.append('GOOGLE_APPLICATION_CREDENTIALS=/src/creds.json') - elif openai_api_key: - extra_environment.append('-e') - extra_environment.append(f'OPENAI_API_KEY={openai_api_key}') - - cmd = [ - 'docker', - 'run', - '--rm', - '-e', - 'FUZZING_ENGINE=libfuzzer', - '-e', - 'SANITIZER=address', - '-e', - 'ARCHITECTURE=x86_64', - '-e', - 'HELPER=True', - '-e', - 'FUZZING_LANGUAGE=c++', - '-e', - f'BUILD_HEURISTICS={build_heuristics}', - ] + extra_environment - - if max_timeout: - cmd = ['timeout', str(max_timeout)] + cmd - - cmd += [ - '-v', - f'{oss_fuzz_base}/build/out/{worker_project}:/out', - '-v', - f'{oss_fuzz_base}/build/work/{worker_project}:/work', - '-t', - f'gcr.io/oss-fuzz/{worker_project}', - # Command to run inside the container - initiator_cmd - ] - - cmd_to_run = ' '.join(cmd) - try: - if silent_global: - subprocess.check_call(cmd_to_run, - cwd=oss_fuzz_base, - shell=True, - stdout=subprocess.DEVNULL, - stderr=subprocess.STDOUT) - else: - subprocess.check_call(cmd_to_run, cwd=oss_fuzz_base, shell=True) - except subprocess.CalledProcessError: - pass + subprocess.check_call(" ".join(clean_up_cmd), shell=True) + except subprocess.CalledProcessError: + pass + # Write to output directory + copy_result_to_out(worker_project_name, oss_fuzz_base, output) -def read_targets_file(filename: str) -> List[str]: - """Parse input file.""" - res_targets = [] - with open(filename, 'r') as f: - targets = f.read().split("\n") - for e in targets: - if len(e) < 6: - continue - if e: - res_targets.append(e) - return res_targets - - -def run_on_targets(target, - oss_fuzz_base, - worker_project_name, - idx, - llm_model, - semaphore=None, - build_heuristics='all', - output='', - max_timeout: int = 0): - """Thread entry point for single project auto-gen.""" - - if semaphore is not None: - semaphore.acquire() - - openai_api_key = os.getenv('OPENAI_API_KEY', None) - - outdir = os.path.join('/out/', constants.SHARED_MEMORY_RESULTS_DIR) - with open('status-log.txt', 'a') as f: - f.write(f'Targeting: {target} :: {idx}\n') - run_autogen(target, - outdir, - oss_fuzz_base, - worker_project_name, - llm_model, - openai_api_key=openai_api_key, - build_heuristics=build_heuristics, - max_timeout=max_timeout) - - # Cleanup the OSS-Fuzz docker image - clean_up_cmd = [ - 'docker', 'image', 'rm', f'gcr.io/oss-fuzz/{worker_project_name}' - ] - try: - subprocess.check_call(' '.join(clean_up_cmd), shell=True) - except subprocess.CalledProcessError: - pass - - # Write to output directory - copy_result_to_out(worker_project_name, oss_fuzz_base, output) - - if semaphore is not None: - semaphore.release() + if semaphore is not None: + semaphore.release() def get_next_worker_project(oss_fuzz_base: str) -> str: - """Gets next OSS-Fuzz worker projecet.""" - max_idx = -1 - for project_dir in os.listdir(os.path.join(oss_fuzz_base, 'projects')): - if not constants.PROJECT_BASE in project_dir: - continue - try: - tmp_idx = int(project_dir.replace(constants.PROJECT_BASE, '')) - max_idx = max(tmp_idx, max_idx) - except: - continue - return f'{constants.PROJECT_BASE}{max_idx+1}' - - -def copy_result_to_out(project_generated, - oss_fuzz_base, - output, - from_agent=False, - project_name='') -> None: - """Copy raw results into an output directory and in a refined format.""" - # Go through the output - os.makedirs(output, exist_ok=True) - raw_result_dir = os.path.join(output, 'raw-results') - os.makedirs(raw_result_dir, exist_ok=True) - - if from_agent: - project_directory = os.path.join(oss_fuzz_base, 'projects', - project_generated) - else: - project_directory = os.path.join(oss_fuzz_base, 'build', 'out', - project_generated) - - if not os.path.isdir(project_directory): - logger.info('Could not find project %s', project_directory) - return - shutil.copytree(project_directory, - os.path.join(raw_result_dir, project_generated), - dirs_exist_ok=True) - - oss_fuzz_projects = os.path.join(output, 'oss-fuzz-projects') - os.makedirs(oss_fuzz_projects, exist_ok=True) - dst_dir = '' - if from_agent: - build_dir = os.path.join(raw_result_dir, project_generated) - if not os.path.isdir(build_dir): - return - - dst_project = f'{project_name.lower()}-agent' - dst_dir = os.path.join(oss_fuzz_projects, dst_project) - shutil.copytree(build_dir, dst_dir, dirs_exist_ok=True) - else: - report_txt = os.path.join(raw_result_dir, project_generated, - 'autogen-results', 'report.txt') - if not os.path.isfile(report_txt): - return - - with open(report_txt, 'r') as f: - for line in f: - if 'Analysing' in line: - project_name = line.split('/')[-1].replace('\n', '').lower() - if not project_name: - return - - idx = 0 - while True: - base_build_dir = f'empty-build-{idx}' - idx += 1 - - build_dir = os.path.join(raw_result_dir, project_generated, - base_build_dir) - if not os.path.isdir(build_dir): - break - - dst_project = f'{project_name}-{base_build_dir}' - dst_dir = os.path.join(oss_fuzz_projects, dst_project) - if os.path.isdir(dst_dir): - logger.info('Destination dir alrady exists: %s. Skipping', dst_dir) - continue - shutil.copytree(build_dir, dst_dir) - - if not dst_dir: - return - # Make sure project.yaml has correct language - is_c = False - for elem in os.listdir(dst_dir): - if elem.endswith('.c'): - is_c = True - if is_c: - lang = 'c' - else: - lang = 'c++' - dst_yaml = os.path.join(dst_dir, 'project.yaml') - with open(dst_yaml, 'r') as f: - content = f.read() - lines = [] - for line in content.split('\n'): - if 'language' in line: - lines.append(f'language: {lang}') + """Gets next OSS-Fuzz worker projecet.""" + max_idx = -1 + for project_dir in os.listdir(os.path.join(oss_fuzz_base, "projects")): + if not constants.PROJECT_BASE in project_dir: + continue + try: + tmp_idx = int(project_dir.replace(constants.PROJECT_BASE, "")) + max_idx = max(tmp_idx, max_idx) + except: + continue + return f"{constants.PROJECT_BASE}{max_idx+1}" + + +def copy_result_to_out( + project_generated, oss_fuzz_base, output, from_agent=False, project_name="" +) -> None: + """Copy raw results into an output directory and in a refined format.""" + # Go through the output + os.makedirs(output, exist_ok=True) + raw_result_dir = os.path.join(output, "raw-results") + os.makedirs(raw_result_dir, exist_ok=True) + + if from_agent: + project_directory = os.path.join(oss_fuzz_base, "projects", project_generated) else: - lines.append(line) - with open(dst_yaml, 'w') as f: - f.write('\n'.join(lines)) - - -def run_parallels(oss_fuzz_base, - target_repositories, - llm_model, - build_heuristics, - output, - parallel_jobs=6, - max_timeout=0): - """Run auto-gen on a list of projects in parallel. - - Parallelisation is done by way of threads. Practically - all of the computation will happen inside an OSS-Fuzz - Docker container and not within this Python script as such.""" - semaphore = threading.Semaphore(parallel_jobs) - jobs = [] - projects_generated = [] - for idx, target in enumerate(target_repositories): - worker_project_name = get_next_worker_project(oss_fuzz_base) - logger.info('Worker project name: %s', worker_project_name) - projects_generated.append(worker_project_name) - try: - setup_worker_project(oss_fuzz_base, worker_project_name, llm_model) - except subprocess.CalledProcessError: - logger.info('Project setup issue for %s', worker_project_name) - continue - proc = threading.Thread(target=run_on_targets, - args=(target, oss_fuzz_base, worker_project_name, - idx, llm_model, semaphore, build_heuristics, - output, max_timeout)) - jobs.append(proc) - proc.start() + project_directory = os.path.join( + oss_fuzz_base, "build", "out", project_generated + ) + + if not os.path.isdir(project_directory): + logger.info("Could not find project %s", project_directory) + return + shutil.copytree( + project_directory, + os.path.join(raw_result_dir, project_generated), + dirs_exist_ok=True, + ) + + oss_fuzz_projects = os.path.join(output, "oss-fuzz-projects") + os.makedirs(oss_fuzz_projects, exist_ok=True) + dst_dir = "" + if from_agent: + build_dir = os.path.join(raw_result_dir, project_generated) + if not os.path.isdir(build_dir): + return - for proc in jobs: - proc.join() + dst_project = f"{project_name.lower()}-agent" + dst_dir = os.path.join(oss_fuzz_projects, dst_project) + shutil.copytree(build_dir, dst_dir, dirs_exist_ok=True) + else: + report_txt = os.path.join( + raw_result_dir, project_generated, "autogen-results", "report.txt" + ) + if not os.path.isfile(report_txt): + return + + with open(report_txt, "r") as f: + for line in f: + if "Analysing" in line: + project_name = line.split("/")[-1].replace("\n", "").lower() + if not project_name: + return + + idx = 0 + while True: + base_build_dir = f"empty-build-{idx}" + idx += 1 + + build_dir = os.path.join(raw_result_dir, project_generated, base_build_dir) + if not os.path.isdir(build_dir): + break + + dst_project = f"{project_name}-{base_build_dir}" + dst_dir = os.path.join(oss_fuzz_projects, dst_project) + if os.path.isdir(dst_dir): + logger.info("Destination dir alrady exists: %s. Skipping", dst_dir) + continue + shutil.copytree(build_dir, dst_dir) + + if not dst_dir: + return + # Make sure project.yaml has correct language + is_c = False + for elem in os.listdir(dst_dir): + if elem.endswith(".c"): + is_c = True + if is_c: + lang = "c" + else: + lang = "c++" + dst_yaml = os.path.join(dst_dir, "project.yaml") + with open(dst_yaml, "r") as f: + content = f.read() + lines = [] + for line in content.split("\n"): + if "language" in line: + lines.append(f"language: {lang}") + else: + lines.append(line) + with open(dst_yaml, "w") as f: + f.write("\n".join(lines)) + + +def run_parallels( + oss_fuzz_base, + target_repositories, + llm_model, + build_heuristics, + output, + parallel_jobs=6, + max_timeout=0, +): + """Run auto-gen on a list of projects in parallel. + + Parallelisation is done by way of threads. Practically + all of the computation will happen inside an OSS-Fuzz + Docker container and not within this Python script as such.""" + semaphore = threading.Semaphore(parallel_jobs) + jobs = [] + projects_generated = [] + for idx, target in enumerate(target_repositories): + worker_project_name = get_next_worker_project(oss_fuzz_base) + logger.info("Worker project name: %s", worker_project_name) + projects_generated.append(worker_project_name) + try: + setup_worker_project(oss_fuzz_base, worker_project_name, llm_model) + except subprocess.CalledProcessError: + logger.info("Project setup issue for %s", worker_project_name) + continue + proc = threading.Thread( + target=run_on_targets, + args=( + target, + oss_fuzz_base, + worker_project_name, + idx, + llm_model, + semaphore, + build_heuristics, + output, + max_timeout, + ), + ) + jobs.append(proc) + proc.start() + + for proc in jobs: + proc.join() def run_agent(target_repositories: List[str], args: argparse.Namespace): - """Generates build script and fuzzer harnesses for a GitHub repository using - llm agent approach.""" - # Process default arguments - oss_fuzz_base = os.path.abspath(args.oss_fuzz) - - # Set OSS_FUZZ_DIR in oss_fuzz_checkout as the agent will use this module - # for dealing with the generated project. - - oss_fuzz_checkout.OSS_FUZZ_DIR = oss_fuzz_base - work_dirs = WorkDirs(args.work_dirs, keep=True) - - # All agents - llm_agents = [ - llm_agent.AutoDiscoveryBuildScriptAgent, - llm_agent.BuildSystemBuildScriptAgent, - ] - - for target_repository in target_repositories: - logger.info('Target repository: %s', target_repository) - # Prepare environment - worker_project_name = get_next_worker_project(oss_fuzz_base) - try: - language = setup_worker_project(oss_fuzz_base, worker_project_name, - args.model, target_repository, True, - os.path.abspath(args.work_dirs)) - except subprocess.CalledProcessError: - logger.info('Issues setting up %s', target_repository) - continue - benchmark = Benchmark(worker_project_name, worker_project_name, '', '', '', - '', [], '') - - for llm_agent_ctr in llm_agents: - build_script = '' - harness = '' - - # Prepare new LLM model - llm = models.LLM.setup( - ai_binary=os.getenv('AI_BINARY', ''), - name=args.model, - max_tokens=4096, - num_samples=1, - temperature=0.4, - temperature_list=[], - ) - llm.MAX_INPUT_TOKEN = constants.MAX_PROMPT_LENGTH - - logger.info('Agent: %s.', llm_agent_ctr.__name__) - agent = llm_agent_ctr(trial=1, - llm=llm, - args=args, - github_url=target_repository, - language=language) - result_history = [ - Result(benchmark=benchmark, trial=1, work_dirs=work_dirs) - ] - - try: - build_result = agent.execute(result_history) - except OpenAIError: - logger.info(('Round 1 build script generation failed for project %s' - ' with openai errors'), target_repository) - break - except subprocess.CalledProcessError: - logger.info('Issue running agent, %s', target_repository) - break - - if build_result.compiles: - build_script = build_result.build_script_source - harness = build_result.fuzz_target_source - - logger.info('Build script generation success for project %s', - target_repository) - - # Update build script - build_script_path = os.path.join(oss_fuzz_base, 'projects', - worker_project_name, 'build.sh') - with open(build_script_path, 'w') as f: - f.write(build_script) - - # Update harness code - _, _, harness_name, default_code = file_utils.get_language_defaults( - language) - if not harness: - harness = default_code - - harness_path = os.path.join(oss_fuzz_base, 'projects', - worker_project_name, - harness_name.split('/')[-1]) - with open(harness_path, 'w') as f: - f.write(harness) - - # Copy result to out - copy_result_to_out(worker_project_name, oss_fuzz_base, args.out, True, - target_repository.split('/')[-1]) - break - - # Clean up workdir - if os.path.isdir(args.work_dirs): - shutil.rmtree(args.work_dirs) + """Generates build script and fuzzer harnesses for a GitHub repository using + llm agent approach.""" + # Process default arguments + oss_fuzz_base = os.path.abspath(args.oss_fuzz) + + # Set OSS_FUZZ_DIR in oss_fuzz_checkout as the agent will use this module + # for dealing with the generated project. + + oss_fuzz_checkout.OSS_FUZZ_DIR = oss_fuzz_base + work_dirs = WorkDirs(args.work_dirs, keep=True) + + # All agents + llm_agents = [ + llm_agent.AutoDiscoveryBuildScriptAgent, + llm_agent.BuildSystemBuildScriptAgent, + ] + + for target_repository in target_repositories: + logger.info("Target repository: %s", target_repository) + # Prepare environment + worker_project_name = get_next_worker_project(oss_fuzz_base) + try: + language = setup_worker_project( + oss_fuzz_base, + worker_project_name, + args.model, + target_repository, + True, + os.path.abspath(args.work_dirs), + ) + except subprocess.CalledProcessError: + logger.info("Issues setting up %s", target_repository) + continue + benchmark = Benchmark( + worker_project_name, worker_project_name, "", "", "", "", [], "" + ) + + for llm_agent_ctr in llm_agents: + build_script = "" + harness = "" + + # Prepare new LLM model + llm = models.LLM.setup( + ai_binary=os.getenv("AI_BINARY", ""), + name=args.model, + max_tokens=4096, + num_samples=1, + temperature=0.4, + temperature_list=[], + ) + llm.MAX_INPUT_TOKEN = constants.MAX_PROMPT_LENGTH + + logger.info("Agent: %s.", llm_agent_ctr.__name__) + agent = llm_agent_ctr( + trial=1, + llm=llm, + args=args, + github_url=target_repository, + language=language, + ) + result_history = [Result(benchmark=benchmark, trial=1, work_dirs=work_dirs)] + + try: + build_result = agent.execute(result_history) + except OpenAIError: + logger.info( + ( + "Round 1 build script generation failed for project %s" + " with openai errors" + ), + target_repository, + ) + break + except subprocess.CalledProcessError: + logger.info("Issue running agent, %s", target_repository) + break + + if build_result.compiles: + build_script = build_result.build_script_source + harness = build_result.fuzz_target_source + + logger.info( + "Build script generation success for project %s", target_repository + ) + + # Update build script + build_script_path = os.path.join( + oss_fuzz_base, "projects", worker_project_name, "build.sh" + ) + with open(build_script_path, "w") as f: + f.write(build_script) + + # Update harness code + _, _, harness_name, default_code = file_utils.get_language_defaults( + language + ) + if not harness: + harness = default_code + + harness_path = os.path.join( + oss_fuzz_base, + "projects", + worker_project_name, + harness_name.split("/")[-1], + ) + with open(harness_path, "w") as f: + f.write(harness) + + # Copy result to out + copy_result_to_out( + worker_project_name, + oss_fuzz_base, + args.out, + True, + target_repository.split("/")[-1], + ) + break + + # Clean up workdir + if os.path.isdir(args.work_dirs): + shutil.rmtree(args.work_dirs) def parse_commandline(): - """Parse the commandline.""" - parser = argparse.ArgumentParser() - parser.add_argument('--oss-fuzz', '-of', help='OSS-Fuzz base') - parser.add_argument('--input', '-i', help='Input to analyze') - parser.add_argument('--out', - '-o', - default='Generated builds', - help='Directory to store output.') - parser.add_argument('--silent', - '-s', - help='Disable logging in subprocess.', - action='store_true') - parser.add_argument('--build-heuristics', - '-b', - help='Comma-separated string of build heuristics to use', - default='all') - parser.add_argument( - '--model', - '-m', - help=f'LLM model to use. Available: {str(constants.MODELS)}', - type=str) - parser.add_argument('--agent', - '-a', - help='Use LLM Agent Builder or not.', - action='store_true') - parser.add_argument('--max-round', - '-mr', - help='Max round of trial for the llm build script agent.', - type=int, - default=10) - parser.add_argument('--work-dirs', - '-w', - help='Working directory path.', - type=str, - default='./work_dirs') - - return parser.parse_args() + """Parse the commandline.""" + parser = argparse.ArgumentParser() + parser.add_argument("--oss-fuzz", "-of", help="OSS-Fuzz base") + parser.add_argument("--input", "-i", help="Input to analyze") + parser.add_argument( + "--out", "-o", default="Generated builds", help="Directory to store output." + ) + parser.add_argument( + "--silent", "-s", help="Disable logging in subprocess.", action="store_true" + ) + parser.add_argument( + "--build-heuristics", + "-b", + help="Comma-separated string of build heuristics to use", + default="all", + ) + parser.add_argument( + "--model", + "-m", + help=f"LLM model to use. Available: {str(constants.MODELS)}", + type=str, + ) + parser.add_argument( + "--agent", "-a", help="Use LLM Agent Builder or not.", action="store_true" + ) + parser.add_argument( + "--max-round", + "-mr", + help="Max round of trial for the llm build script agent.", + type=int, + default=10, + ) + parser.add_argument( + "--work-dirs", + "-w", + help="Working directory path.", + type=str, + default="./work_dirs", + ) + + return parser.parse_args() def setup_logging(): - logging.basicConfig(level=logging.INFO, format=LOG_FMT) + logging.basicConfig(level=logging.INFO, format=LOG_FMT) def extract_target_repositories(target_input) -> list[str]: - if not target_input: - return [] + if not target_input: + return [] - if os.path.isfile(target_input): - target_repositories = read_targets_file(target_input) - else: - target_repositories = [target_input] + if os.path.isfile(target_input): + target_repositories = read_targets_file(target_input) + else: + target_repositories = [target_input] - refined_targets = [] - for repo in target_repositories: - # Remove trailing / - while repo.endswith('/'): - repo = repo[:-1] - refined_targets.append(repo) - logger.info(refined_targets) + refined_targets = [] + for repo in target_repositories: + # Remove trailing / + while repo.endswith("/"): + repo = repo[:-1] + refined_targets.append(repo) + logger.info(refined_targets) - return refined_targets + return refined_targets def main(): - global silent_global + global silent_global - args = parse_commandline() + args = parse_commandline() - setup_logging() - target_repositories = extract_target_repositories(args.input) - silent_global = args.silent + setup_logging() + target_repositories = extract_target_repositories(args.input) + silent_global = args.silent - if args.agent: - run_agent(target_repositories, args) - else: - run_parallels(os.path.abspath(args.oss_fuzz), target_repositories, - args.model, args.build_heuristics, args.out) + if args.agent: + run_agent(target_repositories, args) + else: + run_parallels( + os.path.abspath(args.oss_fuzz), + target_repositories, + args.model, + args.build_heuristics, + args.out, + ) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/experimental/build_generator/templates.py b/experimental/build_generator/templates.py index 1cd536b303..85b326d21b 100644 --- a/experimental/build_generator/templates.py +++ b/experimental/build_generator/templates.py @@ -15,7 +15,7 @@ """Holds templates used by the auto-generator both inside and outside the OSS-Fuzz base builder.""" -OSS_FUZZ_LICENSE = '''# Copyright 2025 Google LLC. +OSS_FUZZ_LICENSE = """# Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,19 +30,25 @@ # limitations under the License. # ################################################################################ -''' +""" -EMPTY_OSS_FUZZ_BUILD = '''#!/bin/bash -eu -''' + OSS_FUZZ_LICENSE +EMPTY_OSS_FUZZ_BUILD = ( + """#!/bin/bash -eu +""" + + OSS_FUZZ_LICENSE +) -BASE_DOCKER_HEAD = OSS_FUZZ_LICENSE + ''' +BASE_DOCKER_HEAD = ( + OSS_FUZZ_LICENSE + + """ FROM gcr.io/oss-fuzz-base/base-builder RUN apt-get update && apt-get install -y make autoconf automake autopoint \\ libtool cmake pkg-config curl check libcpputest-dev \\ flex bison re2c protobuf-compiler uuid uuid-dev -''' +""" +) -CFLITE_TEMPLATE = '''name: ClusterFuzzLite PR fuzzing +CFLITE_TEMPLATE = """name: ClusterFuzzLite PR fuzzing on: workflow_dispatch: pull_request: @@ -72,11 +78,11 @@ mode: 'code-change' report-unreproducible-crashes: false sanitizer: ${{ matrix.sanitizer }} -''' +""" # Empty CPP harness that is used to confirm compilation when generating # auto-build scripts. -CPP_BASE_TEMPLATE = '''#include +CPP_BASE_TEMPLATE = """#include #include extern "C" int @@ -90,11 +96,11 @@ // end of fuzzer contents return 0; -}''' +}""" # Empty C harness that is used to confirm compilation when generating # auto-build scripts. -C_BASE_TEMPLATE = '''#include +C_BASE_TEMPLATE = """#include #include #include @@ -107,11 +113,13 @@ // end of fuzzer contents return 0; -}''' +}""" # Docker file used for starting the auto-gen workflow within an OSS-Fuzz # base-builder image. -AUTOGEN_DOCKER_FILE = BASE_DOCKER_HEAD + ''' +AUTOGEN_DOCKER_FILE = ( + BASE_DOCKER_HEAD + + """ RUN rm /usr/local/bin/cargo && \\ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \\ apt-get install -y cargo @@ -130,7 +138,8 @@ RUN python3 -m pip install pyyaml WORKDIR $SRC COPY build.sh $SRC/ -''' +""" +) EMPTY_PROJECT_YAML = """homepage: "https://github.com/google/oss-fuzz" language: c++ @@ -141,7 +150,9 @@ """ # Docker file used for OSS-Fuzz integrations. -CLEAN_OSS_FUZZ_DOCKER = BASE_DOCKER_HEAD + ''' {additional_packages} +CLEAN_OSS_FUZZ_DOCKER = ( + BASE_DOCKER_HEAD + + """ {additional_packages} COPY *.sh $SRC/ RUN mkdir -p {fuzzer_dir} COPY *.cpp *.c {fuzzer_dir} @@ -156,22 +167,26 @@ ENV FI_DISABLE_LIGHT=1 RUN git clone --recurse-submodules {repo_url} {project_repo_dir} WORKDIR $SRC/{project_repo_dir} -''' +""" +) -CLEAN_DOCKER_CFLITE = BASE_DOCKER_HEAD + ''' {additional_packages} +CLEAN_DOCKER_CFLITE = ( + BASE_DOCKER_HEAD + + """ {additional_packages} COPY . $SRC/{project_repo_dir} COPY .clusterfuzzlite/build.sh $SRC/build.sh COPY .clusterfuzzlite/*.cpp $SRC/ COPY .clusterfuzzlite/*.c $SRC/ WORKDIR $SRC/{project_repo_dir} -''' +""" +) # Template file for building LLM prompt -LLM_PRIMING = ''' +LLM_PRIMING = """ You are a developer wanting to build a given C/C++ projects. -''' +""" -LLM_PROBLEM = ''' +LLM_PROBLEM = """ You are tasked with generating a fuzzing harness and build script to fuzz a target project. Use the provided build system files to compile the project and link it with the fuzzing harness. ### Output Format @@ -226,14 +241,14 @@ {HEADERS} -''' +""" -LLM_BUILD_FILE_TEMPLATE = ''' +LLM_BUILD_FILE_TEMPLATE = """ {PATH} {CONTENT} -''' +""" -LLM_RETRY = ''' +LLM_RETRY = """ I failed to build the project with the above provided build script. Please analyse the result and generate a new build script with the same assumption above. You must only returns the content of the build script and nothing else more as always. @@ -243,9 +258,9 @@ Here is a dump of the bash execution result. {BASH_RESULT} -''' +""" -LLM_AUTO_DISCOVERY = ''' +LLM_AUTO_DISCOVERY = """ You are tasked with generating a **build script** to compile and statically link a target project, and updating a **template fuzzing harness** by including relevant project headers. Do **not** modify the harness logic, only add `#include` statements. The source code is located at `$SRC/{PROJECT_NAME}` inside a Docker container running **Ubuntu 24.04**. The fuzzing harness template is at `$SRC/{FUZZING_FILE}` and is provided below. @@ -496,15 +511,15 @@ Your **first reply** must be a `` block to begin project exploration. Your **final reply** must include the `` block and, if the harness was modified, the `` block. -''' +""" -LLM_DOCKER_FEEDBACK = ''' +LLM_DOCKER_FEEDBACK = """ Here is the result of that command execution: {RESULT} -''' +""" -LLM_NO_VALID_TAG = ''' +LLM_NO_VALID_TAG = """ Your previous response is invalid. To be valid, the response must meet the following requirements regarding XML tags: @@ -516,12 +531,12 @@ - The tag is **required only if** the fuzzing harness has been modified. If included, it must contain the **entire source code** of the updated fuzzing harness, not just a diff or partial snippet. Do not include any content outside these XML tags. Revisit your output and regenerate it with these rules strictly followed. -''' +""" -LLM_MISSING_BINARY = ''' +LLM_MISSING_BINARY = """ The compiled binary was not found at `$OUT/{FUZZER_NAME}`. Please ensure that you use `-o $OUT/{FUZZER_NAME}` during the linking stage of the fuzzing harness. Below is the output from executing the previously generated build script for reference: {RESULT} -''' +""" diff --git a/experimental/end_to_end/cli.py b/experimental/end_to_end/cli.py index d20b9d241c..54887dca93 100644 --- a/experimental/end_to_end/cli.py +++ b/experimental/end_to_end/cli.py @@ -34,719 +34,777 @@ from llm_toolkit import models logger = logging.getLogger(name=__name__) -LOG_FMT = ('%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ' - ': %(funcName)s: %(message)s') +LOG_FMT = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " ": %(funcName)s: %(message)s" +) OFG_BASE_DIR = os.path.abspath( - os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..")) + os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..") +) def setup_workdirs(defined_dir): - """Sets up the working directory.""" - - if defined_dir: - workdir = defined_dir - else: - workdir = tempfile.mkdtemp() - logger.info('Using work directory: %s', workdir) - os.makedirs(workdir, exist_ok=True) - - # Clone two OSS-Fuzz projects - subprocess.check_call( - 'git clone https://github.com/google/oss-fuzz oss-fuzz-1', - shell=True, - cwd=workdir) - - # Clone another OSS-Fuzz, for OFG core - subprocess.check_call('git clone https://github.com/google/oss-fuzz oss-fuzz', - shell=True, - cwd=workdir) - os.mkdir(os.path.join(workdir, 'oss-fuzz', 'venv')) - - # Clone Fuzz Introspector - subprocess.check_call('git clone https://github.com/ossf/fuzz-introspector', - shell=True, - cwd=workdir) - - # Ensure fuzz introspector's requirements.txt is installed - subprocess.check_call('python3 -m pip install -r requirements.txt', - shell=True, - cwd=os.path.join(workdir, 'fuzz-introspector')) - subprocess.check_call('python3 -m pip install -r requirements.txt', - shell=True, - cwd=os.path.join(workdir, 'fuzz-introspector', 'tools', - 'web-fuzzing-introspection')) - return workdir + """Sets up the working directory.""" + + if defined_dir: + workdir = defined_dir + else: + workdir = tempfile.mkdtemp() + logger.info("Using work directory: %s", workdir) + os.makedirs(workdir, exist_ok=True) + + # Clone two OSS-Fuzz projects + subprocess.check_call( + "git clone https://github.com/google/oss-fuzz oss-fuzz-1", + shell=True, + cwd=workdir, + ) + + # Clone another OSS-Fuzz, for OFG core + subprocess.check_call( + "git clone https://github.com/google/oss-fuzz oss-fuzz", shell=True, cwd=workdir + ) + os.mkdir(os.path.join(workdir, "oss-fuzz", "venv")) + + # Clone Fuzz Introspector + subprocess.check_call( + "git clone https://github.com/ossf/fuzz-introspector", shell=True, cwd=workdir + ) + + # Ensure fuzz introspector's requirements.txt is installed + subprocess.check_call( + "python3 -m pip install -r requirements.txt", + shell=True, + cwd=os.path.join(workdir, "fuzz-introspector"), + ) + subprocess.check_call( + "python3 -m pip install -r requirements.txt", + shell=True, + cwd=os.path.join( + workdir, "fuzz-introspector", "tools", "web-fuzzing-introspection" + ), + ) + return workdir def _run_introspector_collection(runner_script, project, wd, semaphore): - """Run introspector on the given project.""" - semaphore.acquire() - - cmd = ['python3'] - cmd.append(runner_script) # introspector helper script - cmd.append('introspector') # force an introspector run - cmd.append(project) # target project - cmd.append('1') # run the harness for 1 second - cmd.append('--disable-webserver') # do not launch FI webapp - - try: - logger.info('Collecting introspector information on %s', project) - subprocess.check_call(' '.join(cmd), - shell=True, - cwd=wd, - stdout=subprocess.DEVNULL, - stderr=subprocess.STDOUT) - except subprocess.CalledProcessError: - pass - semaphore.release() - - -def extract_introspector_reports_for_benchmarks(projects_to_run, workdir, - parallel_build_jobs): - """Runs introspector through each report to collect program analysis data.""" - oss_fuzz_dir = os.path.join(workdir, 'oss-fuzz') - runner_script = os.path.join(workdir, 'fuzz-introspector', - 'oss_fuzz_integration', 'runner.py') - - semaphore = threading.Semaphore(parallel_build_jobs) - jobs = [] - - for project in projects_to_run: - proc = threading.Thread(target=_run_introspector_collection, - args=(runner_script, project, oss_fuzz_dir, - semaphore)) - jobs.append(proc) - proc.start() - - for proc in jobs: - proc.join() - - # Often the terminal will become corrupted after a lot of introspector runs. - # Call reset here to ensure we're in a safe state. - subprocess.check_call('reset', shell=True) + """Run introspector on the given project.""" + semaphore.acquire() + + cmd = ["python3"] + cmd.append(runner_script) # introspector helper script + cmd.append("introspector") # force an introspector run + cmd.append(project) # target project + cmd.append("1") # run the harness for 1 second + cmd.append("--disable-webserver") # do not launch FI webapp + + try: + logger.info("Collecting introspector information on %s", project) + subprocess.check_call( + " ".join(cmd), + shell=True, + cwd=wd, + stdout=subprocess.DEVNULL, + stderr=subprocess.STDOUT, + ) + except subprocess.CalledProcessError: + pass + semaphore.release() + + +def extract_introspector_reports_for_benchmarks( + projects_to_run, workdir, parallel_build_jobs +): + """Runs introspector through each report to collect program analysis data.""" + oss_fuzz_dir = os.path.join(workdir, "oss-fuzz") + runner_script = os.path.join( + workdir, "fuzz-introspector", "oss_fuzz_integration", "runner.py" + ) + + semaphore = threading.Semaphore(parallel_build_jobs) + jobs = [] + + for project in projects_to_run: + proc = threading.Thread( + target=_run_introspector_collection, + args=(runner_script, project, oss_fuzz_dir, semaphore), + ) + jobs.append(proc) + proc.start() + + for proc in jobs: + proc.join() + + # Often the terminal will become corrupted after a lot of introspector runs. + # Call reset here to ensure we're in a safe state. + subprocess.check_call("reset", shell=True) def shutdown_fi_webapp(): - """Shutsdown the FI webapp if it exists.""" - try: - subprocess.check_call('curl --silent http://localhost:8080/api/shutdown', - shell=True) - except subprocess.CalledProcessError: - pass + """Shutsdown the FI webapp if it exists.""" + try: + subprocess.check_call( + "curl --silent http://localhost:8080/api/shutdown", shell=True + ) + except subprocess.CalledProcessError: + pass def create_fi_db(workdir): - """Creates the FI webapp database""" - oss_fuzz_dir = os.path.join(workdir, 'oss-fuzz') - - fi_db_dir = os.path.join(workdir, 'fuzz-introspector', 'tools', - 'web-fuzzing-introspection', 'app', 'static', - 'assets', 'db') - cmd = ['python3'] - cmd.append('web_db_creator_from_summary.py') - cmd.append('--local-oss-fuzz') - cmd.append(oss_fuzz_dir) - try: - logger.info('Creating fuzz introspector database') - subprocess.check_call(' '.join(cmd), - shell=True, - cwd=fi_db_dir, - stdout=subprocess.DEVNULL, - stderr=subprocess.STDOUT) - logger.info('Created database successfully') - except subprocess.CalledProcessError: - logger.info('Failed creation of DB') + """Creates the FI webapp database""" + oss_fuzz_dir = os.path.join(workdir, "oss-fuzz") + + fi_db_dir = os.path.join( + workdir, + "fuzz-introspector", + "tools", + "web-fuzzing-introspection", + "app", + "static", + "assets", + "db", + ) + cmd = ["python3"] + cmd.append("web_db_creator_from_summary.py") + cmd.append("--local-oss-fuzz") + cmd.append(oss_fuzz_dir) + try: + logger.info("Creating fuzz introspector database") + subprocess.check_call( + " ".join(cmd), + shell=True, + cwd=fi_db_dir, + stdout=subprocess.DEVNULL, + stderr=subprocess.STDOUT, + ) + logger.info("Created database successfully") + except subprocess.CalledProcessError: + logger.info("Failed creation of DB") def launch_fi_webapp(workdir): - """Launches webapp so OFG can query projects.""" - logger.info('Launching webapp') - oss_fuzz_dir = os.path.join(workdir, 'oss-fuzz') - fi_webapp_dir = os.path.join(workdir, 'fuzz-introspector', 'tools', - 'web-fuzzing-introspection', 'app') - environ = os.environ.copy() - environ['FUZZ_INTROSPECTOR_LOCAL_OSS_FUZZ'] = oss_fuzz_dir - cmd = ['python3'] - cmd.append('main.py &') - - subprocess.check_call(' '.join(cmd), - shell=True, - cwd=fi_webapp_dir, - env=environ, - stdout=subprocess.DEVNULL, - stderr=subprocess.STDOUT) + """Launches webapp so OFG can query projects.""" + logger.info("Launching webapp") + oss_fuzz_dir = os.path.join(workdir, "oss-fuzz") + fi_webapp_dir = os.path.join( + workdir, "fuzz-introspector", "tools", "web-fuzzing-introspection", "app" + ) + environ = os.environ.copy() + environ["FUZZ_INTROSPECTOR_LOCAL_OSS_FUZZ"] = oss_fuzz_dir + cmd = ["python3"] + cmd.append("main.py &") + + subprocess.check_call( + " ".join(cmd), + shell=True, + cwd=fi_webapp_dir, + env=environ, + stdout=subprocess.DEVNULL, + stderr=subprocess.STDOUT, + ) def wait_until_fi_webapp_is_launched(): - """Return when the webapp has started""" - logger.info('Waiting for the webapp to start') - - sec_to_wait = 10 - for _ in range(10): - time.sleep(sec_to_wait) - - resp = requests.get('http://127.0.0.1:8080', timeout=10) - if 'Fuzzing' in resp.text: - return - # If this is reached then the webapp likely didn't start. - # Exit. - logger.info('Could not start FI webapp') - sys.exit(0) - - -def run_ofg_generation(projects_to_run, workdir, args, target_benchmarks=''): - """Runs harness generation""" - logger.info('Running OFG experiment: %s', os.getcwd()) - oss_fuzz_dir = os.path.join(workdir, 'oss-fuzz') - - cmd = ['python3', os.path.join(OFG_BASE_DIR, 'run_all_experiments.py')] - cmd.append('--model') - cmd.append(args.model) - - if not target_benchmarks: - cmd.append('-g') - cmd.append(args.benchmark_oracles) - cmd.append('-gp') - cmd.append(','.join(projects_to_run)) - cmd.append('-gm') - cmd.append(str(args.generate_benchmarks_max)) - else: - cmd.append('-b') - cmd.append(target_benchmarks) - cmd.append('--context') - cmd.append('-of') - cmd.append(oss_fuzz_dir) - cmd.append('-e') - cmd.append('http://127.0.0.1:8080/api') - cmd.append('-mr') - cmd.append(str(args.max_round)) - if args.hg_agent: - cmd.append('--agent') - - environ = os.environ.copy() - - environ['LLM_NUM_EVA'] = '4' - environ['LLM_NUM_EXP'] = '4' - environ['OFG_CLEAN_UP_OSS_FUZZ'] = '0' - environ['OFG_USE_CACHING'] = '0' - - subprocess.check_call(' '.join(cmd), shell=True, env=environ) + """Return when the webapp has started""" + logger.info("Waiting for the webapp to start") + + sec_to_wait = 10 + for _ in range(10): + time.sleep(sec_to_wait) + + resp = requests.get("http://127.0.0.1:8080", timeout=10) + if "Fuzzing" in resp.text: + return + # If this is reached then the webapp likely didn't start. + # Exit. + logger.info("Could not start FI webapp") + sys.exit(0) + + +def run_ofg_generation(projects_to_run, workdir, args, target_benchmarks=""): + """Runs harness generation""" + logger.info("Running OFG experiment: %s", os.getcwd()) + oss_fuzz_dir = os.path.join(workdir, "oss-fuzz") + + cmd = ["python3", os.path.join(OFG_BASE_DIR, "run_all_experiments.py")] + cmd.append("--model") + cmd.append(args.model) + + if not target_benchmarks: + cmd.append("-g") + cmd.append(args.benchmark_oracles) + cmd.append("-gp") + cmd.append(",".join(projects_to_run)) + cmd.append("-gm") + cmd.append(str(args.generate_benchmarks_max)) + else: + cmd.append("-b") + cmd.append(target_benchmarks) + cmd.append("--context") + cmd.append("-of") + cmd.append(oss_fuzz_dir) + cmd.append("-e") + cmd.append("http://127.0.0.1:8080/api") + cmd.append("-mr") + cmd.append(str(args.max_round)) + if args.hg_agent: + cmd.append("--agent") + + environ = os.environ.copy() + + environ["LLM_NUM_EVA"] = "4" + environ["LLM_NUM_EXP"] = "4" + environ["OFG_CLEAN_UP_OSS_FUZZ"] = "0" + environ["OFG_USE_CACHING"] = "0" + + subprocess.check_call(" ".join(cmd), shell=True, env=environ) def copy_generated_projects_to_harness_gen(out_gen, workdir): - """Copies projects from build generation ready for harness generation.""" - projects_dir = os.path.join(out_gen, 'oss-fuzz-projects') - if not os.path.isdir(projects_dir): - logger.info('Found no projects.') - return set() - - # Copy projects over - projects_to_run = [] - for project in os.listdir(projects_dir): - dst = os.path.join(workdir, 'oss-fuzz', 'projects', project) - if os.path.isdir(dst): - shutil.rmtree(dst) - logger.info('Copying: %s :: %s', os.path.join(projects_dir, project), - os.path.join(workdir, 'oss-fuzz', 'projects', project)) - shutil.copytree(os.path.join(projects_dir, project), - os.path.join(workdir, 'oss-fuzz', 'projects', project)) - projects_to_run.append(project) - return projects_to_run + """Copies projects from build generation ready for harness generation.""" + projects_dir = os.path.join(out_gen, "oss-fuzz-projects") + if not os.path.isdir(projects_dir): + logger.info("Found no projects.") + return set() + + # Copy projects over + projects_to_run = [] + for project in os.listdir(projects_dir): + dst = os.path.join(workdir, "oss-fuzz", "projects", project) + if os.path.isdir(dst): + shutil.rmtree(dst) + logger.info( + "Copying: %s :: %s", + os.path.join(projects_dir, project), + os.path.join(workdir, "oss-fuzz", "projects", project), + ) + shutil.copytree( + os.path.join(projects_dir, project), + os.path.join(workdir, "oss-fuzz", "projects", project), + ) + projects_to_run.append(project) + return projects_to_run def create_merged_oss_fuzz_projects( - projects_to_run, - workdir, - merged_project_out_dir='final-oss-fuzz-projects') -> None: - """Create OSS-Fuzz projects using successful harnesses.""" - - logger.info('Merging harnesses for the following projects: %s', - str(projects_to_run)) - logger.info('Writing results in %s', merged_project_out_dir) - - # Get list of projects created auto-building for. - generated_projects = [] - for project_name in projects_to_run: - project_yaml = os.path.join(workdir, 'oss-fuzz', 'projects', project_name, - 'project.yaml') - if not os.path.isfile(project_yaml): - continue - with open(project_yaml, 'r', encoding='utf-8') as f: - project_dict = yaml.safe_load(f) - - generated_projects.append({ - 'name': project_name, - 'language': project_dict['language'] - }) - - # Iterate results and copy fuzz harnesses into dedicated project folder. - results_dir = 'results' - if not os.path.isdir(results_dir): - logger.info('No results identified') - return - - for result in os.listdir(results_dir): - # Find project name - project = {} - for project_gen in generated_projects: - if result.startswith(f'output-{project_gen["name"]}'): - project = project_gen - if not project: - continue - - # Copy the harness over - #if not os.path.isdir('final-oss-fuzz-projects'): - # os.makedirs('final-oss-fuzz-projects') - project_dir = os.path.join(merged_project_out_dir, project['name']) - #if not os.path.isdir(project_dir): - # os.makedirs(project_dir) - os.makedirs(project_dir, exist_ok=True) - - # Check if it was successful - idx_to_copy = '' - status_base = os.path.join('results', result, 'status') - for idx in sorted(os.listdir(status_base)): - id_path = os.path.join(status_base, idx) - if not os.path.isdir(id_path): - continue - result_json = os.path.join(id_path, 'result.json') - if not os.path.isfile(result_json): - continue - with open(result_json, 'r') as f: - json_dict = json.loads(f.read()) - if json_dict['compiles']: - idx_to_copy = idx - break - - if not idx_to_copy: - logger.info('Did not find a harness to copy') - continue - logger.debug('Copying idx: %s', idx_to_copy) - - # Copy over the harness - fuzz_src = os.path.join('results', result, 'fuzz_targets', - f'{idx_to_copy}.fuzz_target') - with open(fuzz_src, 'r') as f: - fuzz_content = f.read() - idx = 0 - - while True: - if 'extern \'C\'' in fuzz_content or 'std::' in fuzz_content: - fuzz_dst = os.path.join(project_dir, f'empty-fuzzer.{idx}.cpp') - else: - fuzz_dst = os.path.join(project_dir, f'empty-fuzzer.{idx}.c') - if not os.path.isfile(fuzz_dst): - break - idx += 1 - - # Copy the harness - build_src = os.path.join(workdir, 'oss-fuzz', 'projects', project['name'], - 'build.sh') - build_dst = os.path.join(project_dir, 'build.sh') - shutil.copy(build_src, build_dst) - - docker_src = os.path.join(workdir, 'oss-fuzz', 'projects', project['name'], - 'Dockerfile') - docker_dst = os.path.join(project_dir, 'Dockerfile') - shutil.copy(docker_src, docker_dst) - - project_yaml_src = os.path.join(workdir, 'oss-fuzz', 'projects', - project['name'], 'project.yaml') - project_yaml_dst = os.path.join(project_dir, 'project.yaml') - shutil.copy(project_yaml_src, project_yaml_dst) - - shutil.copy(fuzz_src, fuzz_dst) + projects_to_run, workdir, merged_project_out_dir="final-oss-fuzz-projects" +) -> None: + """Create OSS-Fuzz projects using successful harnesses.""" + + logger.info( + "Merging harnesses for the following projects: %s", str(projects_to_run) + ) + logger.info("Writing results in %s", merged_project_out_dir) + + # Get list of projects created auto-building for. + generated_projects = [] + for project_name in projects_to_run: + project_yaml = os.path.join( + workdir, "oss-fuzz", "projects", project_name, "project.yaml" + ) + if not os.path.isfile(project_yaml): + continue + with open(project_yaml, "r", encoding="utf-8") as f: + project_dict = yaml.safe_load(f) + + generated_projects.append( + {"name": project_name, "language": project_dict["language"]} + ) + + # Iterate results and copy fuzz harnesses into dedicated project folder. + results_dir = "results" + if not os.path.isdir(results_dir): + logger.info("No results identified") + return + + for result in os.listdir(results_dir): + # Find project name + project = {} + for project_gen in generated_projects: + if result.startswith(f'output-{project_gen["name"]}'): + project = project_gen + if not project: + continue + + # Copy the harness over + # if not os.path.isdir('final-oss-fuzz-projects'): + # os.makedirs('final-oss-fuzz-projects') + project_dir = os.path.join(merged_project_out_dir, project["name"]) + # if not os.path.isdir(project_dir): + # os.makedirs(project_dir) + os.makedirs(project_dir, exist_ok=True) + + # Check if it was successful + idx_to_copy = "" + status_base = os.path.join("results", result, "status") + for idx in sorted(os.listdir(status_base)): + id_path = os.path.join(status_base, idx) + if not os.path.isdir(id_path): + continue + result_json = os.path.join(id_path, "result.json") + if not os.path.isfile(result_json): + continue + with open(result_json, "r") as f: + json_dict = json.loads(f.read()) + if json_dict["compiles"]: + idx_to_copy = idx + break + + if not idx_to_copy: + logger.info("Did not find a harness to copy") + continue + logger.debug("Copying idx: %s", idx_to_copy) + + # Copy over the harness + fuzz_src = os.path.join( + "results", result, "fuzz_targets", f"{idx_to_copy}.fuzz_target" + ) + with open(fuzz_src, "r") as f: + fuzz_content = f.read() + idx = 0 + + while True: + if "extern 'C'" in fuzz_content or "std::" in fuzz_content: + fuzz_dst = os.path.join(project_dir, f"empty-fuzzer.{idx}.cpp") + else: + fuzz_dst = os.path.join(project_dir, f"empty-fuzzer.{idx}.c") + if not os.path.isfile(fuzz_dst): + break + idx += 1 + + # Copy the harness + build_src = os.path.join( + workdir, "oss-fuzz", "projects", project["name"], "build.sh" + ) + build_dst = os.path.join(project_dir, "build.sh") + shutil.copy(build_src, build_dst) + + docker_src = os.path.join( + workdir, "oss-fuzz", "projects", project["name"], "Dockerfile" + ) + docker_dst = os.path.join(project_dir, "Dockerfile") + shutil.copy(docker_src, docker_dst) + + project_yaml_src = os.path.join( + workdir, "oss-fuzz", "projects", project["name"], "project.yaml" + ) + project_yaml_dst = os.path.join(project_dir, "project.yaml") + shutil.copy(project_yaml_src, project_yaml_dst) + + shutil.copy(fuzz_src, fuzz_dst) def _create_data_dir(workdir): - """Copy data from build generation to directory for cloud experimentation""" - dst_dir = _get_next_data_dst_dir() - oss_fuzz_build_out = os.path.join(workdir, 'oss-fuzz', 'build', 'out') - - # Copy OSS-Fuzz data - projects_to_copy = [] - out_folders = ['inspector', 'report', 'report_target', 'textcov_reports'] - for bp in os.listdir(oss_fuzz_build_out): - src_project = os.path.join(oss_fuzz_build_out, bp) - dst_project = os.path.join(dst_dir, 'oss-fuzz2', 'build', 'out', bp) - - # Make sure all directories are there - do_copy = True - for out_folder in out_folders: - if not os.path.isdir(os.path.join(src_project, out_folder)): - do_copy = False - if not do_copy: - continue - os.makedirs(dst_project, exist_ok=True) - - for out_folder in out_folders: - shutil.copytree(os.path.join(src_project, out_folder), - os.path.join(dst_project, out_folder)) - projects_to_copy.append(bp) - - os.makedirs(os.path.join(dst_dir, 'oss-fuzz2', 'projects'), exist_ok=True) - - for project in projects_to_copy: - p_src = os.path.join(workdir, 'oss-fuzz', 'projects', project) - p_dst = os.path.join(dst_dir, 'oss-fuzz2', 'projects', project) - shutil.copytree(p_src, p_dst) - - # Copy Fuzz Introspector data - fuzz_introspector_db_folder = os.path.join(workdir, 'fuzz-introspector', - 'tools', - 'web-fuzzing-introspection', 'app', - 'static', 'assets', 'db') - shutil.copytree(fuzz_introspector_db_folder, - os.path.join(dst_dir, 'fuzz_introspector_db')) - - # Delete .gitignore that may exist in the DB folder. We do this because the - # files are needed when uploaded to OFG. - gitignore_file = os.path.join(dst_dir, 'fuzz_introspector_db', '.gitignore') - if os.path.isfile(gitignore_file): - os.remove(gitignore_file) - - return dst_dir + """Copy data from build generation to directory for cloud experimentation""" + dst_dir = _get_next_data_dst_dir() + oss_fuzz_build_out = os.path.join(workdir, "oss-fuzz", "build", "out") + + # Copy OSS-Fuzz data + projects_to_copy = [] + out_folders = ["inspector", "report", "report_target", "textcov_reports"] + for bp in os.listdir(oss_fuzz_build_out): + src_project = os.path.join(oss_fuzz_build_out, bp) + dst_project = os.path.join(dst_dir, "oss-fuzz2", "build", "out", bp) + + # Make sure all directories are there + do_copy = True + for out_folder in out_folders: + if not os.path.isdir(os.path.join(src_project, out_folder)): + do_copy = False + if not do_copy: + continue + os.makedirs(dst_project, exist_ok=True) + + for out_folder in out_folders: + shutil.copytree( + os.path.join(src_project, out_folder), + os.path.join(dst_project, out_folder), + ) + projects_to_copy.append(bp) + + os.makedirs(os.path.join(dst_dir, "oss-fuzz2", "projects"), exist_ok=True) + + for project in projects_to_copy: + p_src = os.path.join(workdir, "oss-fuzz", "projects", project) + p_dst = os.path.join(dst_dir, "oss-fuzz2", "projects", project) + shutil.copytree(p_src, p_dst) + + # Copy Fuzz Introspector data + fuzz_introspector_db_folder = os.path.join( + workdir, + "fuzz-introspector", + "tools", + "web-fuzzing-introspection", + "app", + "static", + "assets", + "db", + ) + shutil.copytree( + fuzz_introspector_db_folder, os.path.join(dst_dir, "fuzz_introspector_db") + ) + + # Delete .gitignore that may exist in the DB folder. We do this because the + # files are needed when uploaded to OFG. + gitignore_file = os.path.join(dst_dir, "fuzz_introspector_db", ".gitignore") + if os.path.isfile(gitignore_file): + os.remove(gitignore_file) + + return dst_dir def prepare_fuzz_introspector_db(out_gen, workdir, parallel_introspector_jobs): - # Run introspector collection on the generated projects - projects_to_run = copy_generated_projects_to_harness_gen(out_gen, workdir) - extract_introspector_reports_for_benchmarks(projects_to_run, workdir, - parallel_introspector_jobs) - - # Create a fuzz introspector database based on the projects in - # the working directory's OSS-Fuzz. - shutdown_fi_webapp() - create_fi_db(workdir) - - -def run_harness_generation(workdir, - args, - target_project='', - target_function=''): - """Runs harness generation based on the projects in `out_gen`""" - - # Read the json file from FI to get all current projects. - fi_project_json = os.path.join(workdir, 'fuzz-introspector', 'tools', - 'web-fuzzing-introspection', 'app', 'static', - 'assets', 'db', 'all-project-current.json') - if not os.path.isfile(fi_project_json): - logger.info('Did not find FI DB file.') - set() - - projects_to_run = [] - if target_project: - projects_to_run = [target_project] - else: - with open(fi_project_json, 'r') as f: - json_content = json.load(f) - for elem in json_content: - projects_to_run.append(elem['project_name']) - - # Launch the fuzz introspector webapp so it's ready for OFG core - shutdown_fi_webapp() - launch_fi_webapp(workdir) - wait_until_fi_webapp_is_launched() - dst_data_dir = _create_data_dir(workdir) - logger.info('Wrote data directory for OFG experiments in %s', dst_data_dir) - - # Generate benchmarks if asked to - if target_project and target_function: - logger.info('Generating benchmark for specific function') - introspector.set_introspector_endpoints('http://127.0.0.1:8080/api') - benchmark_dir = introspector.generate_benchmark_for_targeted_function( - target_project, target_function) - if not benchmark_dir: - logger.info('Failed to generated benchmarks.') - sys.exit(1) - else: - logger.info('Generating a broad set of benchmarks') - benchmark_dir = '' - - # Run OFG core using local OSS-Fuzz and local Fuzz Introspector. - run_ofg_generation(projects_to_run, workdir, args, benchmark_dir) - - create_merged_oss_fuzz_projects(projects_to_run, workdir) - return projects_to_run + # Run introspector collection on the generated projects + projects_to_run = copy_generated_projects_to_harness_gen(out_gen, workdir) + extract_introspector_reports_for_benchmarks( + projects_to_run, workdir, parallel_introspector_jobs + ) + + # Create a fuzz introspector database based on the projects in + # the working directory's OSS-Fuzz. + shutdown_fi_webapp() + create_fi_db(workdir) + + +def run_harness_generation(workdir, args, target_project="", target_function=""): + """Runs harness generation based on the projects in `out_gen`""" + + # Read the json file from FI to get all current projects. + fi_project_json = os.path.join( + workdir, + "fuzz-introspector", + "tools", + "web-fuzzing-introspection", + "app", + "static", + "assets", + "db", + "all-project-current.json", + ) + if not os.path.isfile(fi_project_json): + logger.info("Did not find FI DB file.") + set() + + projects_to_run = [] + if target_project: + projects_to_run = [target_project] + else: + with open(fi_project_json, "r") as f: + json_content = json.load(f) + for elem in json_content: + projects_to_run.append(elem["project_name"]) + + # Launch the fuzz introspector webapp so it's ready for OFG core + shutdown_fi_webapp() + launch_fi_webapp(workdir) + wait_until_fi_webapp_is_launched() + dst_data_dir = _create_data_dir(workdir) + logger.info("Wrote data directory for OFG experiments in %s", dst_data_dir) + + # Generate benchmarks if asked to + if target_project and target_function: + logger.info("Generating benchmark for specific function") + introspector.set_introspector_endpoints("http://127.0.0.1:8080/api") + benchmark_dir = introspector.generate_benchmark_for_targeted_function( + target_project, target_function + ) + if not benchmark_dir: + logger.info("Failed to generated benchmarks.") + sys.exit(1) + else: + logger.info("Generating a broad set of benchmarks") + benchmark_dir = "" + + # Run OFG core using local OSS-Fuzz and local Fuzz Introspector. + run_ofg_generation(projects_to_run, workdir, args, benchmark_dir) + + create_merged_oss_fuzz_projects(projects_to_run, workdir) + return projects_to_run def setup_logging(): - """Initiate logging.""" - logging.basicConfig(level=logging.DEBUG, format=LOG_FMT) + """Initiate logging.""" + logging.basicConfig(level=logging.DEBUG, format=LOG_FMT) def _get_next_folder_in_idx(base_name): - """Get next pre-named work directory.""" - idx = 0 - while True: - if not os.path.isdir(f'{base_name}-{idx}'): - break - idx += 1 - return f'{base_name}-{idx}' + """Get next pre-named work directory.""" + idx = 0 + while True: + if not os.path.isdir(f"{base_name}-{idx}"): + break + idx += 1 + return f"{base_name}-{idx}" def get_next_out_folder(): - """Get next pre-named work directory.""" - return _get_next_folder_in_idx('generated-projects') + """Get next pre-named work directory.""" + return _get_next_folder_in_idx("generated-projects") def _get_next_data_dst_dir(): - """Gets next data dir""" - return _get_next_folder_in_idx('data-dir') + """Gets next data dir""" + return _get_next_folder_in_idx("data-dir") def _run_build_generation(workdir, out_folder, args): - """ Build script generation. """ - oss_fuzz_dir = os.path.join(workdir, 'oss-fuzz-1') - target_repositories = runner.extract_target_repositories(args.input) - if args.build_generation_mode == 'agent': - # Prepare arguments used deeper in OFG core. - # TODO(David) make this cleaner. - args.oss_fuzz = oss_fuzz_dir - args.work_dirs = 'work_dirs' - runner.run_agent(target_repositories, args) - elif args.build_generation_mode == 'template-based': - runner.run_parallels(os.path.abspath(oss_fuzz_dir), - target_repositories, - args.model, - 'all', - out_folder, - parallel_jobs=args.build_jobs, - max_timeout=args.build_timeout) - else: - logger.info('Unknown build generation mode: %s', args.build_generation_mode) - sys.exit(1) - - -def run_fuzz_introspector_db_creation(workdir, generated_builds, - parallel_build_jobs): - """Entrypoint for fuzz introspector database creation.""" - - workdir = os.path.abspath(workdir) - - # Create working directory if it doesn't exist. - if not os.path.isdir(workdir): - workdir = setup_workdirs(workdir) - prepare_fuzz_introspector_db(generated_builds, workdir, parallel_build_jobs) + """Build script generation.""" + oss_fuzz_dir = os.path.join(workdir, "oss-fuzz-1") + target_repositories = runner.extract_target_repositories(args.input) + if args.build_generation_mode == "agent": + # Prepare arguments used deeper in OFG core. + # TODO(David) make this cleaner. + args.oss_fuzz = oss_fuzz_dir + args.work_dirs = "work_dirs" + runner.run_agent(target_repositories, args) + elif args.build_generation_mode == "template-based": + runner.run_parallels( + os.path.abspath(oss_fuzz_dir), + target_repositories, + args.model, + "all", + out_folder, + parallel_jobs=args.build_jobs, + max_timeout=args.build_timeout, + ) + else: + logger.info("Unknown build generation mode: %s", args.build_generation_mode) + sys.exit(1) + + +def run_fuzz_introspector_db_creation(workdir, generated_builds, parallel_build_jobs): + """Entrypoint for fuzz introspector database creation.""" + + workdir = os.path.abspath(workdir) + + # Create working directory if it doesn't exist. + if not os.path.isdir(workdir): + workdir = setup_workdirs(workdir) + prepare_fuzz_introspector_db(generated_builds, workdir, parallel_build_jobs) def run_build_generation(args): - """Generates builds and harnesses for repositories in input.""" + """Generates builds and harnesses for repositories in input.""" - # Prepare working directory. - workdir = setup_workdirs(args.workdir) + # Prepare working directory. + workdir = setup_workdirs(args.workdir) - abs_workdir = os.path.abspath(workdir) - if not args.out: - out_folder = get_next_out_folder() - else: - out_folder = args.out + abs_workdir = os.path.abspath(workdir) + if not args.out: + out_folder = get_next_out_folder() + else: + out_folder = args.out - _run_build_generation(abs_workdir, out_folder, args) + _run_build_generation(abs_workdir, out_folder, args) def run_cmd_fix_build(args): - """Command entrypoint for fixing OSS-Fuzz build scripts.""" - workdir = setup_workdirs(None) - abs_workdir = os.path.abspath(workdir) - oss_fuzz_dir = os.path.join(abs_workdir, 'oss-fuzz') - args.work_dirs = 'work_dirs' - build_fix.fix_build(args, oss_fuzz_dir) + """Command entrypoint for fixing OSS-Fuzz build scripts.""" + workdir = setup_workdirs(None) + abs_workdir = os.path.abspath(workdir) + oss_fuzz_dir = os.path.join(abs_workdir, "oss-fuzz") + args.work_dirs = "work_dirs" + build_fix.fix_build(args, oss_fuzz_dir) def run_cmd_harness_generation(args): - """Entrypoint for command for harness generation.""" + """Entrypoint for command for harness generation.""" - # Prepare working directory. - abs_workdir = os.path.abspath(args.workdir) + # Prepare working directory. + abs_workdir = os.path.abspath(args.workdir) - # Run harness generation. - projects_run = run_harness_generation(abs_workdir, args, args.project, - args.function_name) + # Run harness generation. + projects_run = run_harness_generation( + abs_workdir, args, args.project, args.function_name + ) - # Log results. - logger.info('Finished analysis') - logger.info('Projects generated (%d): ', len(projects_run)) + # Log results. + logger.info("Finished analysis") + logger.info("Projects generated (%d): ", len(projects_run)) def run_full(args): - """Generates builds and harnesses for repositories in input.""" + """Generates builds and harnesses for repositories in input.""" - # Prepare working directory. - workdir = setup_workdirs(args.workdir) + # Prepare working directory. + workdir = setup_workdirs(args.workdir) - abs_workdir = os.path.abspath(workdir) - if not args.out: - out_folder = get_next_out_folder() - else: - out_folder = args.out + abs_workdir = os.path.abspath(workdir) + if not args.out: + out_folder = get_next_out_folder() + else: + out_folder = args.out - _run_build_generation(abs_workdir, out_folder, args) + _run_build_generation(abs_workdir, out_folder, args) - # Prepare fuzz introspector database. - prepare_fuzz_introspector_db(out_folder, abs_workdir, args.build_jobs) + # Prepare fuzz introspector database. + prepare_fuzz_introspector_db(out_folder, abs_workdir, args.build_jobs) - # Run harness generation. - projects_run = run_harness_generation(abs_workdir, args) + # Run harness generation. + projects_run = run_harness_generation(abs_workdir, args) - # Log results. - logger.info('Finished analysis') - logger.info('Projects generated (%d): ', len(projects_run)) + # Log results. + logger.info("Finished analysis") + logger.info("Projects generated (%d): ", len(projects_run)) def _add_base_build_gen_arguments(parser): - """Adds base arguments for build generation.""" - parser.add_argument('--build-generation-mode', - '-bgm', - default='agent', - help='Build generation mode. Defines how the build ' - 'generation is done. ' - 'Available modes: agent, template-based.') - parser.add_argument( - '--input', - '-i', - help=('Input to analyze. This can be either a URL to a git repository ' - 'or a file with each line being a URL to a git reopsitory.')) - parser.add_argument('--model', - '-m', - help=('Models available: ' - f'{", ".join(models.LLM.all_llm_names())}.'), - type=str) - parser.add_argument('--build-jobs', - help='Parallel build-generator jobs to run.', - default=2, - type=int) - parser.add_argument( - '--build-timeout', - help='Timeout for build generation per project, in seconds.', - default=0, - type=int) - parser.add_argument('-w', '--workdir', help='Work directory to use') + """Adds base arguments for build generation.""" + parser.add_argument( + "--build-generation-mode", + "-bgm", + default="agent", + help="Build generation mode. Defines how the build " + "generation is done. " + "Available modes: agent, template-based.", + ) + parser.add_argument( + "--input", + "-i", + help=( + "Input to analyze. This can be either a URL to a git repository " + "or a file with each line being a URL to a git reopsitory." + ), + ) + parser.add_argument( + "--model", + "-m", + help=("Models available: " f'{", ".join(models.LLM.all_llm_names())}.'), + type=str, + ) + parser.add_argument( + "--build-jobs", + help="Parallel build-generator jobs to run.", + default=2, + type=int, + ) + parser.add_argument( + "--build-timeout", + help="Timeout for build generation per project, in seconds.", + default=0, + type=int, + ) + parser.add_argument("-w", "--workdir", help="Work directory to use") def _add_base_harness_gen_arguments(parser): - """Adds base arguments for harness generation.""" - parser.add_argument('--hg-agent', - '-ha', - help='Enable agent harness generation', - action='store_true') - parser.add_argument('-gm', - '--generate-benchmarks-max', - help='Max targets to generate per benchmark heuristic.', - type=int, - default=5) - parser.add_argument('-mr', - '--max-round', - type=int, - default=5, - help='Max trial round for agents.') - parser.add_argument( - '--benchmark-oracles', - default=('far-reach-low-coverage,low-cov-with-fuzz-keyword,' - 'easy-params-far-reach,test-migration')) + """Adds base arguments for harness generation.""" + parser.add_argument( + "--hg-agent", "-ha", help="Enable agent harness generation", action="store_true" + ) + parser.add_argument( + "-gm", + "--generate-benchmarks-max", + help="Max targets to generate per benchmark heuristic.", + type=int, + default=5, + ) + parser.add_argument( + "-mr", "--max-round", type=int, default=5, help="Max trial round for agents." + ) + parser.add_argument( + "--benchmark-oracles", + default=( + "far-reach-low-coverage,low-cov-with-fuzz-keyword," + "easy-params-far-reach,test-migration" + ), + ) def parse_commandline(): - """Parse the commandline.""" - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(dest='command') - - # Parser for fixing OSS-Fuzz build - fix_build_parser = subparsers.add_parser('fix-build', - help='Fixes OSS-Fuzz build scripts') - - fix_build_parser.add_argument('--project', - type=str, - help='The project to fix') - fix_build_parser.add_argument('--model', - help='The model to use for build fixing.') - fix_build_parser.add_argument('-mr', - '--max-round', - type=int, - default=20, - help='Max trial round for agents.') - - # Run build generation. - run_build_gen = subparsers.add_parser( - 'generate-builds', - help='Generate OSS-Fuzz projects with build scripts but empty fuzzers.') - run_build_gen.add_argument('--out', - '-o', - help='Directory to store output.', - default='oss-fuzz-generated') - run_build_gen.add_argument('-mr', - '--max-round', - type=int, - default=5, - help='Max trial round for agents.') - _add_base_build_gen_arguments(run_build_gen) - - # Generate fuzz introspector database. - run_generate_fi_db_parser = subparsers.add_parser( - 'generate-fuzz-introspector-database', - help='Generates a fuzz introspector database from auto build projects.') - - run_generate_fi_db_parser.add_argument('--generated-builds', required=True) - run_generate_fi_db_parser.add_argument('--workdir', required=True) - run_generate_fi_db_parser.add_argument('--parallel-build-jobs', - type=int, - default=5) - - # Run harness generation - run_harness_generation_parser = subparsers.add_parser( - 'generate-harnesses', - help="Harness generation of OSS-Fuzz projects.", - ) - - run_harness_generation_parser.add_argument( - '--model', - '-m', - help=('Models available: ' - f'{", ".join(models.LLM.all_llm_names())}.'), - type=str) - run_harness_generation_parser.add_argument('-w', - '--workdir', - help='Work directory to use') - run_harness_generation_parser.add_argument( - '--project', default='', help='Limit analysis to specified project.') - run_harness_generation_parser.add_argument('--function-name', - default='', - help='Target function') - _add_base_harness_gen_arguments(run_harness_generation_parser) - - # Run a full end to end generation. - run_full_parser = subparsers.add_parser( - 'generate-full', - help="Generate OSS-Fuzz integration from git URLs.", - ) - run_full_parser.add_argument('--out', - '-o', - help='Directory to store output.', - default='oss-fuzz-generated') - - _add_base_build_gen_arguments(run_full_parser) - _add_base_harness_gen_arguments(run_full_parser) - - return parser.parse_args() + """Parse the commandline.""" + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="command") + + # Parser for fixing OSS-Fuzz build + fix_build_parser = subparsers.add_parser( + "fix-build", help="Fixes OSS-Fuzz build scripts" + ) + + fix_build_parser.add_argument("--project", type=str, help="The project to fix") + fix_build_parser.add_argument("--model", help="The model to use for build fixing.") + fix_build_parser.add_argument( + "-mr", "--max-round", type=int, default=20, help="Max trial round for agents." + ) + + # Run build generation. + run_build_gen = subparsers.add_parser( + "generate-builds", + help="Generate OSS-Fuzz projects with build scripts but empty fuzzers.", + ) + run_build_gen.add_argument( + "--out", "-o", help="Directory to store output.", default="oss-fuzz-generated" + ) + run_build_gen.add_argument( + "-mr", "--max-round", type=int, default=5, help="Max trial round for agents." + ) + _add_base_build_gen_arguments(run_build_gen) + + # Generate fuzz introspector database. + run_generate_fi_db_parser = subparsers.add_parser( + "generate-fuzz-introspector-database", + help="Generates a fuzz introspector database from auto build projects.", + ) + + run_generate_fi_db_parser.add_argument("--generated-builds", required=True) + run_generate_fi_db_parser.add_argument("--workdir", required=True) + run_generate_fi_db_parser.add_argument("--parallel-build-jobs", type=int, default=5) + + # Run harness generation + run_harness_generation_parser = subparsers.add_parser( + "generate-harnesses", + help="Harness generation of OSS-Fuzz projects.", + ) + + run_harness_generation_parser.add_argument( + "--model", + "-m", + help=("Models available: " f'{", ".join(models.LLM.all_llm_names())}.'), + type=str, + ) + run_harness_generation_parser.add_argument( + "-w", "--workdir", help="Work directory to use" + ) + run_harness_generation_parser.add_argument( + "--project", default="", help="Limit analysis to specified project." + ) + run_harness_generation_parser.add_argument( + "--function-name", default="", help="Target function" + ) + _add_base_harness_gen_arguments(run_harness_generation_parser) + + # Run a full end to end generation. + run_full_parser = subparsers.add_parser( + "generate-full", + help="Generate OSS-Fuzz integration from git URLs.", + ) + run_full_parser.add_argument( + "--out", "-o", help="Directory to store output.", default="oss-fuzz-generated" + ) + + _add_base_build_gen_arguments(run_full_parser) + _add_base_harness_gen_arguments(run_full_parser) + + return parser.parse_args() def main(): - args = parse_commandline() - setup_logging() - - if args.command == 'generate-full': - run_full(args) - if args.command == 'generate-fuzz-introspector-database': - run_fuzz_introspector_db_creation(args.workdir, args.generated_builds, - args.parallel_build_jobs) - if args.command == 'generate-builds': - run_build_generation(args) - if args.command == 'generate-harnesses': - run_cmd_harness_generation(args) - if args.command == 'fix-build': - run_cmd_fix_build(args) - - -if __name__ == '__main__': - main() + args = parse_commandline() + setup_logging() + + if args.command == "generate-full": + run_full(args) + if args.command == "generate-fuzz-introspector-database": + run_fuzz_introspector_db_creation( + args.workdir, args.generated_builds, args.parallel_build_jobs + ) + if args.command == "generate-builds": + run_build_generation(args) + if args.command == "generate-harnesses": + run_cmd_harness_generation(args) + if args.command == "fix-build": + run_cmd_fix_build(args) + + +if __name__ == "__main__": + main() diff --git a/experimental/from_scratch/generate.py b/experimental/from_scratch/generate.py index dfebc8928c..b5433eb911 100644 --- a/experimental/from_scratch/generate.py +++ b/experimental/from_scratch/generate.py @@ -22,19 +22,22 @@ # pyright: reportMissingImports = false from fuzz_introspector import commands as fi_commands -from fuzz_introspector.analyses import (far_reach_low_coverage_analyser, - test_analyser) +from fuzz_introspector.analyses import ( + far_reach_low_coverage_analyser, + test_analyser, +) from experiment import benchmark as benchmarklib from llm_toolkit import models, output_parser, prompt_builder, prompts -LOG_FMT = ('%(asctime)s.%(msecs)03d %(levelname)s ' - '%(module)s - %(funcName)s: %(message)s') +LOG_FMT = ( + "%(asctime)s.%(msecs)03d %(levelname)s " "%(module)s - %(funcName)s: %(message)s" +) logging.basicConfig( level=logging.INFO, format=LOG_FMT, - datefmt='%Y-%m-%d %H:%M:%S', + datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger(name=__name__) @@ -44,523 +47,558 @@ def _add_default_args(parser): - """Default arguments for subparsers""" - parser.add_argument('-m', - '--model', - default=models.DefaultModel.name, - help=('Models available: ' - f'{", ".join(models.LLM.all_llm_names())}')) - parser.add_argument('-o', - '--out-dir', - default='./results', - help='Directory where results will be stored.') - parser.add_argument('-t', - '--target-dir', - help='Directory with project source.', - required=True) - parser.add_argument('-l', - '--language', - help='Main language of the target project source.', - required=True) + """Default arguments for subparsers""" + parser.add_argument( + "-m", + "--model", + default=models.DefaultModel.name, + help=("Models available: " f'{", ".join(models.LLM.all_llm_names())}'), + ) + parser.add_argument( + "-o", + "--out-dir", + default="./results", + help="Directory where results will be stored.", + ) + parser.add_argument( + "-t", "--target-dir", help="Directory with project source.", required=True + ) + parser.add_argument( + "-l", + "--language", + help="Main language of the target project source.", + required=True, + ) def parse_args() -> argparse.Namespace: - """Parses command line arguments.""" - parser = argparse.ArgumentParser(description='Fuzz generation helpers.') - subparsers = parser.add_subparsers(dest='command') - - # Far reach generator - far_reach_low_cov_generator = subparsers.add_parser( - 'generate-far-reach-targets', - help='Generate fuzzing harnesses for far reaching functions') - _add_default_args(far_reach_low_cov_generator) - - # Function specific generation - function_target_generator = subparsers.add_parser( - 'generate-for-function', - help='Generate fuzzing harness for specific function.') - _add_default_args(function_target_generator) - function_target_generator.add_argument('-f', - '--function', - help=('Function to target.')) - function_target_generator.add_argument( - '--only-exact-match', - action='store_true', - help=('Flag to indicate if exact function name' - 'matching is needed.')) - - # Source code location generator - source_location_generator = subparsers.add_parser( - 'generate-for-source', - help='Generate fuzzing harness for source code location') - _add_default_args(source_location_generator) - source_location_generator.add_argument( - '-s', - '--source-file', - help='Source file name to locate target function.', - default='') - source_location_generator.add_argument( - '-sl', - '--source-line', - type=int, - help='Source line number to locate target function.', - default=0) - - # Test to harness generator - test_to_harness_all_generator = subparsers.add_parser( - 'generate-test-to-harness', - help='Convert all tests in target to fuzz harnesses.') - _add_default_args(test_to_harness_all_generator) - - return parser.parse_args() + """Parses command line arguments.""" + parser = argparse.ArgumentParser(description="Fuzz generation helpers.") + subparsers = parser.add_subparsers(dest="command") + + # Far reach generator + far_reach_low_cov_generator = subparsers.add_parser( + "generate-far-reach-targets", + help="Generate fuzzing harnesses for far reaching functions", + ) + _add_default_args(far_reach_low_cov_generator) + + # Function specific generation + function_target_generator = subparsers.add_parser( + "generate-for-function", help="Generate fuzzing harness for specific function." + ) + _add_default_args(function_target_generator) + function_target_generator.add_argument( + "-f", "--function", help=("Function to target.") + ) + function_target_generator.add_argument( + "--only-exact-match", + action="store_true", + help=("Flag to indicate if exact function name" "matching is needed."), + ) + + # Source code location generator + source_location_generator = subparsers.add_parser( + "generate-for-source", help="Generate fuzzing harness for source code location" + ) + _add_default_args(source_location_generator) + source_location_generator.add_argument( + "-s", + "--source-file", + help="Source file name to locate target function.", + default="", + ) + source_location_generator.add_argument( + "-sl", + "--source-line", + type=int, + help="Source line number to locate target function.", + default=0, + ) + + # Test to harness generator + test_to_harness_all_generator = subparsers.add_parser( + "generate-test-to-harness", + help="Convert all tests in target to fuzz harnesses.", + ) + _add_default_args(test_to_harness_all_generator) + + return parser.parse_args() def setup_model(args) -> models.LLM: - return models.LLM.setup(ai_binary='', - name=args.model, - max_tokens=MAX_TOKENS, - num_samples=NUM_SAMPLES, - temperature=TEMPERATURE) + return models.LLM.setup( + ai_binary="", + name=args.model, + max_tokens=MAX_TOKENS, + num_samples=NUM_SAMPLES, + temperature=TEMPERATURE, + ) def get_target_benchmark_for_function( language, target_dir, target_function_name, only_exact_match ) -> Tuple[Optional[benchmarklib.Benchmark], Optional[dict[str, Any]]]: - """Run introspector analysis on a target directory and extract benchmark""" - entrypoint = introspector_lang_to_entrypoint(language) - - _, report = fi_commands.analyse_end_to_end(arg_language=language, - target_dir=target_dir, - entrypoint=entrypoint, - out_dir='.', - coverage_url='', - report_name='report-name', - module_only=True, - dump_files=False) - project = report['light-project'] - introspector_project = report.get('introspector-project', None) - if introspector_project: - logger.info('Found introspector repoject') - for analysis in introspector_project.optional_analyses: - logger.info(analysis.name) - if analysis.name == 'FarReachLowCoverageAnalyser': - logger.info(analysis.get_json_string_result()) - else: - logger.info('Did not find any introspector project') - - # Get target function - function = project.find_function_by_name(target_function_name, - only_exact_match) - - if function: - param_list = [] - - for idx, arg_name in enumerate(function.arg_names): - param_list.append({'name': arg_name, 'type': function.arg_types[idx]}) - - # Build a context. - function_source = function.function_source_code_as_text() - xrefs = project.get_cross_references_by_name(function.name) - logger.info('Total xrefs found %d', len(xrefs)) - if len(xrefs) > 10: - xrefs = xrefs[:10] - xref_strings = [xref.function_source_code_as_text() for xref in xrefs] - - context = { - 'func_source': function_source, - 'files': [], - 'decl': '', - 'xrefs': xref_strings, - 'header': '', - } - - return benchmarklib.Benchmark( - benchmark_id='sample', - project='no-name', - language=language, - function_name=function.name, - function_signature=function.sig, - return_type=function.return_type, - params=param_list, - target_path=function.parent_source.source_file), context - - return None, None + """Run introspector analysis on a target directory and extract benchmark""" + entrypoint = introspector_lang_to_entrypoint(language) + + _, report = fi_commands.analyse_end_to_end( + arg_language=language, + target_dir=target_dir, + entrypoint=entrypoint, + out_dir=".", + coverage_url="", + report_name="report-name", + module_only=True, + dump_files=False, + ) + project = report["light-project"] + introspector_project = report.get("introspector-project", None) + if introspector_project: + logger.info("Found introspector repoject") + for analysis in introspector_project.optional_analyses: + logger.info(analysis.name) + if analysis.name == "FarReachLowCoverageAnalyser": + logger.info(analysis.get_json_string_result()) + else: + logger.info("Did not find any introspector project") + + # Get target function + function = project.find_function_by_name(target_function_name, only_exact_match) + + if function: + param_list = [] + + for idx, arg_name in enumerate(function.arg_names): + param_list.append({"name": arg_name, "type": function.arg_types[idx]}) + + # Build a context. + function_source = function.function_source_code_as_text() + xrefs = project.get_cross_references_by_name(function.name) + logger.info("Total xrefs found %d", len(xrefs)) + if len(xrefs) > 10: + xrefs = xrefs[:10] + xref_strings = [xref.function_source_code_as_text() for xref in xrefs] + + context = { + "func_source": function_source, + "files": [], + "decl": "", + "xrefs": xref_strings, + "header": "", + } + + return ( + benchmarklib.Benchmark( + benchmark_id="sample", + project="no-name", + language=language, + function_name=function.name, + function_signature=function.sig, + return_type=function.return_type, + params=param_list, + target_path=function.parent_source.source_file, + ), + context, + ) + + return None, None def get_target_benchmark_for_source( language, target_dir, target_source_file, target_source_line ) -> Tuple[Optional[benchmarklib.Benchmark], Optional[dict[str, Any]]]: - """Run introspector analysis on a target directory and extract benchmark""" - entrypoint = introspector_lang_to_entrypoint(language) - - _, report = fi_commands.analyse_end_to_end(arg_language=language, - target_dir=target_dir, - entrypoint=entrypoint, - out_dir='.', - coverage_url='', - report_name='report-name', - module_only=True, - dump_files=False) - project = report['light-project'] - introspector_project = report.get('introspector-project', None) - if introspector_project: - logger.info('Found introspector repoject') - for analysis in introspector_project.optional_analyses: - logger.info(analysis.name) - if analysis.name == 'FarReachLowCoverageAnalyser': - logger.info(analysis.get_json_string_result()) - else: - logger.info('Did not find any introspector project') - - # Get target function - function = project.get_function_by_source_suffix_line(target_source_file, - target_source_line) - - if function: - param_list = [] - - for idx, arg_name in enumerate(function.arg_names): - param_list.append({'name': arg_name, 'type': function.arg_types[idx]}) - - # Build a context. - function_source = function.function_source_code_as_text() - xrefs = project.get_cross_references_by_name(function.name) - logger.info('Total xrefs found %d', len(xrefs)) - if len(xrefs) > 10: - xrefs = xrefs[:10] - xref_strings = [xref.function_source_code_as_text() for xref in xrefs] - - context = { - 'func_source': function_source, - 'files': [], - 'decl': '', - 'xrefs': xref_strings, - 'header': '', - } - - return benchmarklib.Benchmark( - benchmark_id='sample', - project='no-name', - language=language, - function_name=function.name, - function_signature=function.sig, - return_type=function.return_type, - params=param_list, - target_path=function.parent_source.source_file), context - - return None, None - - -def construct_fuzz_prompt(model, benchmark, context, - language) -> prompts.Prompt: - """Local benchmarker""" - if language in ['c', 'c++']: - builder = prompt_builder.DefaultTemplateBuilder(model, benchmark=benchmark) - elif language == 'rust': - builder = prompt_builder.DefaultRustTemplateBuilder(model, - benchmark=benchmark) - else: - builder = prompt_builder.DefaultJvmTemplateBuilder(model, - benchmark=benchmark) - - fuzz_prompt = builder.build([], project_context_content=context) - return fuzz_prompt + """Run introspector analysis on a target directory and extract benchmark""" + entrypoint = introspector_lang_to_entrypoint(language) + + _, report = fi_commands.analyse_end_to_end( + arg_language=language, + target_dir=target_dir, + entrypoint=entrypoint, + out_dir=".", + coverage_url="", + report_name="report-name", + module_only=True, + dump_files=False, + ) + project = report["light-project"] + introspector_project = report.get("introspector-project", None) + if introspector_project: + logger.info("Found introspector repoject") + for analysis in introspector_project.optional_analyses: + logger.info(analysis.name) + if analysis.name == "FarReachLowCoverageAnalyser": + logger.info(analysis.get_json_string_result()) + else: + logger.info("Did not find any introspector project") + + # Get target function + function = project.get_function_by_source_suffix_line( + target_source_file, target_source_line + ) + + if function: + param_list = [] + + for idx, arg_name in enumerate(function.arg_names): + param_list.append({"name": arg_name, "type": function.arg_types[idx]}) + + # Build a context. + function_source = function.function_source_code_as_text() + xrefs = project.get_cross_references_by_name(function.name) + logger.info("Total xrefs found %d", len(xrefs)) + if len(xrefs) > 10: + xrefs = xrefs[:10] + xref_strings = [xref.function_source_code_as_text() for xref in xrefs] + + context = { + "func_source": function_source, + "files": [], + "decl": "", + "xrefs": xref_strings, + "header": "", + } + + return ( + benchmarklib.Benchmark( + benchmark_id="sample", + project="no-name", + language=language, + function_name=function.name, + function_signature=function.sig, + return_type=function.return_type, + params=param_list, + target_path=function.parent_source.source_file, + ), + context, + ) + + return None, None + + +def construct_fuzz_prompt(model, benchmark, context, language) -> prompts.Prompt: + """Local benchmarker""" + if language in ["c", "c++"]: + builder = prompt_builder.DefaultTemplateBuilder(model, benchmark=benchmark) + elif language == "rust": + builder = prompt_builder.DefaultRustTemplateBuilder(model, benchmark=benchmark) + else: + builder = prompt_builder.DefaultJvmTemplateBuilder(model, benchmark=benchmark) + + fuzz_prompt = builder.build([], project_context_content=context) + return fuzz_prompt def print_prompt(fuzz_prompt: prompts.Prompt) -> None: - """Prints prompt to stdout.""" - print('Querying with the prompt') - print('-' * 40) - raw_prompt = fuzz_prompt.get() - if isinstance(raw_prompt, list): - for elem in raw_prompt: - if isinstance(elem, dict) and 'content' in elem: - print(elem['content']) - else: - print(raw_prompt) - print('-' * 40) + """Prints prompt to stdout.""" + print("Querying with the prompt") + print("-" * 40) + raw_prompt = fuzz_prompt.get() + if isinstance(raw_prompt, list): + for elem in raw_prompt: + if isinstance(elem, dict) and "content" in elem: + print(elem["content"]) + else: + print(raw_prompt) + print("-" * 40) def get_fuzz_prompt_str(fuzz_prompt: prompts.Prompt) -> str: - """Prints prompt to stdout.""" - prompt_string = '' - raw_prompt = fuzz_prompt.get() - if isinstance(raw_prompt, list): - for elem in raw_prompt: - if isinstance(elem, dict) and 'content' in elem: - prompt_string += elem['content'] - return prompt_string + """Prints prompt to stdout.""" + prompt_string = "" + raw_prompt = fuzz_prompt.get() + if isinstance(raw_prompt, list): + for elem in raw_prompt: + if isinstance(elem, dict) and "content" in elem: + prompt_string += elem["content"] + return prompt_string def introspector_lang_to_entrypoint(language: str) -> str: - """Map an introspector language to entrypoint function.""" - if language in ['c', 'c++']: - return 'LLVMFuzzerTestOneInput' - if language == 'jvm': - return 'fuzzerTestOneInput' - if language == 'rust': - return 'fuzz_target' + """Map an introspector language to entrypoint function.""" + if language in ["c", "c++"]: + return "LLVMFuzzerTestOneInput" + if language == "jvm": + return "fuzzerTestOneInput" + if language == "rust": + return "fuzz_target" - # Not supporting other language yet - return '' + # Not supporting other language yet + return "" def get_far_reach_benchmarks( language, target_dir ) -> list[Tuple[Optional[benchmarklib.Benchmark], Optional[dict[str, Any]]]]: - """Run introspector analysis to extract fear-reaching targets and generate - harnesses for it.""" - entrypoint = introspector_lang_to_entrypoint(language) - - _, report = fi_commands.analyse_end_to_end(arg_language=language, - target_dir=target_dir, - entrypoint=entrypoint, - out_dir='.', - coverage_url='', - report_name='report-name', - module_only=True, - dump_files=False) - project = report['light-project'] - introspector_project = report.get('introspector-project', None) - - far_analysis = far_reach_low_coverage_analyser.FarReachLowCoverageAnalyser() - far_analysis.standalone_analysis(introspector_project.proj_profile, - introspector_project.profiles, '') - - target_benchmarks = [] - for target_function in far_analysis.json_results.get('functions', []): - # Get target function - target_function_name = target_function['function_name'] - if target_function_name: - function = project.find_function_by_name(target_function_name, True) - else: - function = None - - if function: - param_list = [] - - for idx, arg_name in enumerate(function.arg_names): - param_list.append({'name': arg_name, 'type': function.arg_types[idx]}) - - # Build a context. - # Shorten the source function text if necessary. - function_source = function.function_source_code_as_text() - if len(function_source) > 1000: - logger.info('Function source is %d bytes. Shortening to 1000', - len(function_source)) - function_source = function_source[:1000] + '\n ....' - - xrefs = project.get_cross_references(function) - if len(xrefs) > 10: - xrefs = xrefs[:10] - xref_strings = [] - for xref in xrefs: - source_str = xref.function_source_code_as_text() - # Only include xref if it's not too large. - if len(source_str) > 2000: - continue - xref_strings.append(source_str) - - context = { - 'func_source': function_source, - 'files': [], - 'decl': '', - 'xrefs': xref_strings, - 'header': '', - } - - target_benchmarks.append((benchmarklib.Benchmark( - benchmark_id='sample', - project='no-name', - language=language, - function_name=function.name, - function_signature=function.sig, - return_type=function.return_type, - params=param_list, - target_path=function.parent_source.source_file), context)) - - return target_benchmarks + """Run introspector analysis to extract fear-reaching targets and generate + harnesses for it.""" + entrypoint = introspector_lang_to_entrypoint(language) + + _, report = fi_commands.analyse_end_to_end( + arg_language=language, + target_dir=target_dir, + entrypoint=entrypoint, + out_dir=".", + coverage_url="", + report_name="report-name", + module_only=True, + dump_files=False, + ) + project = report["light-project"] + introspector_project = report.get("introspector-project", None) + + far_analysis = far_reach_low_coverage_analyser.FarReachLowCoverageAnalyser() + far_analysis.standalone_analysis( + introspector_project.proj_profile, introspector_project.profiles, "" + ) + + target_benchmarks = [] + for target_function in far_analysis.json_results.get("functions", []): + # Get target function + target_function_name = target_function["function_name"] + if target_function_name: + function = project.find_function_by_name(target_function_name, True) + else: + function = None + + if function: + param_list = [] + + for idx, arg_name in enumerate(function.arg_names): + param_list.append({"name": arg_name, "type": function.arg_types[idx]}) + + # Build a context. + # Shorten the source function text if necessary. + function_source = function.function_source_code_as_text() + if len(function_source) > 1000: + logger.info( + "Function source is %d bytes. Shortening to 1000", + len(function_source), + ) + function_source = function_source[:1000] + "\n ...." + + xrefs = project.get_cross_references(function) + if len(xrefs) > 10: + xrefs = xrefs[:10] + xref_strings = [] + for xref in xrefs: + source_str = xref.function_source_code_as_text() + # Only include xref if it's not too large. + if len(source_str) > 2000: + continue + xref_strings.append(source_str) + + context = { + "func_source": function_source, + "files": [], + "decl": "", + "xrefs": xref_strings, + "header": "", + } + + target_benchmarks.append( + ( + benchmarklib.Benchmark( + benchmark_id="sample", + project="no-name", + language=language, + function_name=function.name, + function_signature=function.sig, + return_type=function.return_type, + params=param_list, + target_path=function.parent_source.source_file, + ), + context, + ) + ) + + return target_benchmarks def get_next_out_dir(out_dir: str) -> str: - """Prepare next folder to put generate harness in.""" - idx = 0 - while True: - target_response = out_dir + f'-{idx}' # os.path.join(out_dir, str(idx)) - if not os.path.isdir(target_response): - return target_response - idx += 1 + """Prepare next folder to put generate harness in.""" + idx = 0 + while True: + target_response = out_dir + f"-{idx}" # os.path.join(out_dir, str(idx)) + if not os.path.isdir(target_response): + return target_response + idx += 1 def get_introspector_language(args) -> str: - """Gets the language in introspector style from the CLI args.""" - if args.language == 'c': - return 'c' - if args.language in ['c++', 'cpp']: - return 'c++' - if args.language in ['jvm', 'java']: - return 'jvm' - if args.language in ['rs', 'rust']: - return 'rust' - - print(f'Language {args.language} not support. Exiting.') - sys.exit(0) + """Gets the language in introspector style from the CLI args.""" + if args.language == "c": + return "c" + if args.language in ["c++", "cpp"]: + return "c++" + if args.language in ["jvm", "java"]: + return "jvm" + if args.language in ["rs", "rust"]: + return "rust" + + print(f"Language {args.language} not support. Exiting.") + sys.exit(0) def generate_far_reach_targets(args): - """Generates a set of harnesses based on far-reach analysis.""" - model = setup_model(args) - language = get_introspector_language(args) - # Get the benchmarks corresponding to far-reach analysis. - target_pairs = get_far_reach_benchmarks(language, args.target_dir) - fuzz_harnesses = [] - for target_benchmark, context in target_pairs: - fuzz_prompt = construct_fuzz_prompt(model, target_benchmark, context, - language) - str_prompt = get_fuzz_prompt_str(fuzz_prompt) - if len(str_prompt) > 15000: - logger.info('Skipping prompt because its too large') - print_prompt(fuzz_prompt) - continue - print_prompt(fuzz_prompt) + """Generates a set of harnesses based on far-reach analysis.""" + model = setup_model(args) + language = get_introspector_language(args) + # Get the benchmarks corresponding to far-reach analysis. + target_pairs = get_far_reach_benchmarks(language, args.target_dir) + fuzz_harnesses = [] + for target_benchmark, context in target_pairs: + fuzz_prompt = construct_fuzz_prompt(model, target_benchmark, context, language) + str_prompt = get_fuzz_prompt_str(fuzz_prompt) + if len(str_prompt) > 15000: + logger.info("Skipping prompt because its too large") + print_prompt(fuzz_prompt) + continue + print_prompt(fuzz_prompt) + + try: + fuzz_harness_response = model.ask_llm(fuzz_prompt) + fuzz_harness_source = output_parser.filter_code(fuzz_harness_response) + fuzz_harnesses.append(fuzz_harness_source) + except Exception: # pylint: disable=broad-exception-caught + pass - try: - fuzz_harness_response = model.ask_llm(fuzz_prompt) - fuzz_harness_source = output_parser.filter_code(fuzz_harness_response) - fuzz_harnesses.append(fuzz_harness_source) - except Exception: # pylint: disable=broad-exception-caught - pass - - response_dir = get_next_out_dir(args.out_dir) - os.makedirs(response_dir, exist_ok=True) - extension = '.raw' - if args.language == 'c++': - extension = '.cpp' - elif args.language == 'c': - extension = '.c' - for idx, fuzz_harness in enumerate(fuzz_harnesses): - # adjust extension if needed - if extension == '.cpp' and 'extern "C"' not in fuzz_harness: - act_ext = '.c' - else: - act_ext = extension - with open(os.path.join(response_dir, f'harness_{idx}{act_ext}'), - 'w', - encoding='utf-8') as f: - f.write(fuzz_harness) + response_dir = get_next_out_dir(args.out_dir) + os.makedirs(response_dir, exist_ok=True) + extension = ".raw" + if args.language == "c++": + extension = ".cpp" + elif args.language == "c": + extension = ".c" + for idx, fuzz_harness in enumerate(fuzz_harnesses): + # adjust extension if needed + if extension == ".cpp" and 'extern "C"' not in fuzz_harness: + act_ext = ".c" + else: + act_ext = extension + with open( + os.path.join(response_dir, f"harness_{idx}{act_ext}"), "w", encoding="utf-8" + ) as f: + f.write(fuzz_harness) def generate_test_to_harness_targets(args): - """Test to harness converter""" - model = setup_model(args) - language = get_introspector_language(args) - - entrypoint = introspector_lang_to_entrypoint(args.language) - - _, report = fi_commands.analyse_end_to_end(arg_language=language, - target_dir=args.target_dir, - entrypoint=entrypoint, - out_dir='.', - coverage_url='', - report_name='report-name', - module_only=True, - dump_files=False) - introspector_project = report.get('introspector-project', None) - - tth_analysis = test_analyser.TestAnalyser() - tth_analysis.standalone_analysis(introspector_project.proj_profile, - introspector_project.profiles, '') - tests = tth_analysis.test_file_paths - for test_file_path in tests: - benchmark = benchmarklib.Benchmark(benchmark_id='sample', - project='no-name', - language=language, - function_name='', - function_signature='', - return_type='', - params=[], - target_path='', - test_file_path=test_file_path) - - with open(test_file_path, 'r', encoding='utf-8') as f: - test_source = f.read() - - # If the test source code is above a certain limit we'll reduce - # the size of it to avoid having a too long token count. - if len(test_source) > 5000: - test_source = test_source[:2400] + '\n.....\n' + test_source[-2400:] - builder = prompt_builder.TestToHarnessConverter(model, benchmark=benchmark) - fuzz_prompt = builder.build([], test_source_code=test_source) - - try: - raw_result = model.ask_llm(fuzz_prompt) - except Exception: # pylint: disable=broad-exception-caught - continue - - logger.info('Filtering code') - generated_source = output_parser.filter_code(raw_result) - logger.info('Done filtering code') - - response_dir = get_next_out_dir(args.out_dir) - os.makedirs(response_dir, exist_ok=True) - with open(os.path.join(response_dir, 'fuzz.c'), 'w', encoding='utf-8') as f: - f.write(generated_source) + """Test to harness converter""" + model = setup_model(args) + language = get_introspector_language(args) + + entrypoint = introspector_lang_to_entrypoint(args.language) + + _, report = fi_commands.analyse_end_to_end( + arg_language=language, + target_dir=args.target_dir, + entrypoint=entrypoint, + out_dir=".", + coverage_url="", + report_name="report-name", + module_only=True, + dump_files=False, + ) + introspector_project = report.get("introspector-project", None) + + tth_analysis = test_analyser.TestAnalyser() + tth_analysis.standalone_analysis( + introspector_project.proj_profile, introspector_project.profiles, "" + ) + tests = tth_analysis.test_file_paths + for test_file_path in tests: + benchmark = benchmarklib.Benchmark( + benchmark_id="sample", + project="no-name", + language=language, + function_name="", + function_signature="", + return_type="", + params=[], + target_path="", + test_file_path=test_file_path, + ) + + with open(test_file_path, "r", encoding="utf-8") as f: + test_source = f.read() + + # If the test source code is above a certain limit we'll reduce + # the size of it to avoid having a too long token count. + if len(test_source) > 5000: + test_source = test_source[:2400] + "\n.....\n" + test_source[-2400:] + builder = prompt_builder.TestToHarnessConverter(model, benchmark=benchmark) + fuzz_prompt = builder.build([], test_source_code=test_source) + + try: + raw_result = model.ask_llm(fuzz_prompt) + except Exception: # pylint: disable=broad-exception-caught + continue + + logger.info("Filtering code") + generated_source = output_parser.filter_code(raw_result) + logger.info("Done filtering code") + + response_dir = get_next_out_dir(args.out_dir) + os.makedirs(response_dir, exist_ok=True) + with open(os.path.join(response_dir, "fuzz.c"), "w", encoding="utf-8") as f: + f.write(generated_source) def generate_for_target_function(args): - """Generate harness for single function/source location""" - model = setup_model(args) - language = get_introspector_language(args) - if args.command == 'generate-for-function': - - target_benchmark, context = get_target_benchmark_for_function( - language, args.target_dir, args.function, args.only_exact_match) - else: - target_benchmark, context = get_target_benchmark_for_source( - language, args.target_dir, args.source_file, args.source_line) - - if target_benchmark is None: - print('Could not find target function. Exiting.') - sys.exit(0) + """Generate harness for single function/source location""" + model = setup_model(args) + language = get_introspector_language(args) + if args.command == "generate-for-function": + + target_benchmark, context = get_target_benchmark_for_function( + language, args.target_dir, args.function, args.only_exact_match + ) + else: + target_benchmark, context = get_target_benchmark_for_source( + language, args.target_dir, args.source_file, args.source_line + ) - fuzz_prompt = construct_fuzz_prompt(model, target_benchmark, context, - language) - print_prompt(fuzz_prompt) - os.makedirs(args.out_dir, exist_ok=True) - print(f'Running query and writing results in {args.out_dir}') - raw_response = model.ask_llm(fuzz_prompt) - generated_fuzz_harness = output_parser.filter_code(raw_response) - - response_dir = get_next_out_dir(args.out_dir) - os.makedirs(response_dir, exist_ok=True) - extension = '.raw' - if args.language == 'c++': - extension = '.cpp' - elif args.language == 'c': - extension = '.c' - - if extension == '.cpp' and 'extern "C"' not in generated_fuzz_harness: - act_ext = '.c' - else: - act_ext = extension - - with open(os.path.join(response_dir, f'harness{act_ext}'), - 'w', - encoding='utf-8') as f: - f.write(generated_fuzz_harness) + if target_benchmark is None: + print("Could not find target function. Exiting.") + sys.exit(0) + + fuzz_prompt = construct_fuzz_prompt(model, target_benchmark, context, language) + print_prompt(fuzz_prompt) + os.makedirs(args.out_dir, exist_ok=True) + print(f"Running query and writing results in {args.out_dir}") + raw_response = model.ask_llm(fuzz_prompt) + generated_fuzz_harness = output_parser.filter_code(raw_response) + + response_dir = get_next_out_dir(args.out_dir) + os.makedirs(response_dir, exist_ok=True) + extension = ".raw" + if args.language == "c++": + extension = ".cpp" + elif args.language == "c": + extension = ".c" + + if extension == ".cpp" and 'extern "C"' not in generated_fuzz_harness: + act_ext = ".c" + else: + act_ext = extension + + with open( + os.path.join(response_dir, f"harness{act_ext}"), "w", encoding="utf-8" + ) as f: + f.write(generated_fuzz_harness) def main(): - """Entrypoint""" - args = parse_args() + """Entrypoint""" + args = parse_args() - if args.command == 'generate-far-reach-targets': - generate_far_reach_targets(args) - elif args.command == 'generate-test-to-harness': - generate_test_to_harness_targets(args) - else: - generate_for_target_function(args) + if args.command == "generate-far-reach-targets": + generate_far_reach_targets(args) + elif args.command == "generate-test-to-harness": + generate_test_to_harness_targets(args) + else: + generate_for_target_function(args) if __name__ == "__main__": - main() + main() diff --git a/experimental/jvm/constants.py b/experimental/jvm/constants.py index 5accfabf82..c2cffa5c14 100644 --- a/experimental/jvm/constants.py +++ b/experimental/jvm/constants.py @@ -16,18 +16,15 @@ """Provides a set of constant values for new Java projects integration""" MAVEN_URL = { - '3.1.1': - 'https://archive.apache.org/dist/maven/maven-3/3.1.1/binaries/apache-maven-3.1.1-bin.zip', - '3.2.5': - 'https://archive.apache.org/dist/maven/maven-3/3.2.5/binaries/apache-maven-3.2.5-bin.zip', - '3.9.2': - 'https://archive.apache.org/dist/maven/maven-3/3.9.2/binaries/apache-maven-3.9.2-bin.zip', + "3.1.1": "https://archive.apache.org/dist/maven/maven-3/3.1.1/binaries/apache-maven-3.1.1-bin.zip", + "3.2.5": "https://archive.apache.org/dist/maven/maven-3/3.2.5/binaries/apache-maven-3.2.5-bin.zip", + "3.9.2": "https://archive.apache.org/dist/maven/maven-3/3.9.2/binaries/apache-maven-3.9.2-bin.zip", } -GRADLE_URL = 'https://services.gradle.org/distributions/gradle-7.4.2-bin.zip' +GRADLE_URL = "https://services.gradle.org/distributions/gradle-7.4.2-bin.zip" -ANT_URL = 'https://dlcdn.apache.org//ant/binaries/apache-ant-1.10.14-bin.zip' +ANT_URL = "https://dlcdn.apache.org//ant/binaries/apache-ant-1.10.14-bin.zip" -PROTO_URL = 'https://github.com/protocolbuffers/protobuf/releases/download/v3.15.8/protoc-3.15.8-linux-x86_64.zip' +PROTO_URL = "https://github.com/protocolbuffers/protobuf/releases/download/v3.15.8/protoc-3.15.8-linux-x86_64.zip" -JDK15_URL = 'https://download.java.net/java/GA/jdk15.0.2/0d1cfde4252546c6931946de8db48ee2/7/GPL/openjdk-15.0.2_linux-x64_bin.tar.gz' +JDK15_URL = "https://download.java.net/java/GA/jdk15.0.2/0d1cfde4252546c6931946de8db48ee2/7/GPL/openjdk-15.0.2_linux-x64_bin.tar.gz" diff --git a/experimental/jvm/generate_projects.py b/experimental/jvm/generate_projects.py index 99df3a2268..7e11723517 100644 --- a/experimental/jvm/generate_projects.py +++ b/experimental/jvm/generate_projects.py @@ -14,88 +14,89 @@ # limitations under the License. """Manager for running auto-gen from scratch.""" -import sys - -sys.path.append('../../') - import argparse import logging import os import shutil +import sys from experimental.jvm import utils +sys.path.append("../../") + + silent_global = False logger = logging.getLogger(name=__name__) -LOG_FMT = ('%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ' - ': %(funcName)s: %(message)s') +LOG_FMT = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " ": %(funcName)s: %(message)s" +) def parse_commandline(): - """Parse the commandline.""" - parser = argparse.ArgumentParser() - parser.add_argument('--workdir', '-w', help='Working directory') - parser.add_argument('--oss-fuzz', '-o', help='OSS-Fuzz base') - parser.add_argument( - '--github-url', - '-u', - help='A comma separated string with all GitHub URLs of target projects') - parser.add_argument('--silent', - '-s', - help='Disable logging in subprocess.', - action='store_true') - return parser.parse_args() + """Parse the commandline.""" + parser = argparse.ArgumentParser() + parser.add_argument("--workdir", "-w", help="Working directory") + parser.add_argument("--oss-fuzz", "-o", help="OSS-Fuzz base") + parser.add_argument( + "--github-url", + "-u", + help="A comma separated string with all GitHub URLs of target projects", + ) + parser.add_argument( + "--silent", "-s", help="Disable logging in subprocess.", action="store_true" + ) + return parser.parse_args() def main(): - global silent_global - - args = parse_commandline() - oss_fuzz_dir = os.path.abspath(args.oss_fuzz) - work_dir = os.path.abspath(args.workdir) - silent_global = args.silent - logging.basicConfig(level=logging.INFO, format=LOG_FMT) - - generated_project_name_list = [] - for url in args.github_url.split(','): - # Retrieve project name - project_name = utils.get_project_name(url) - if not project_name: - # Malformed url - logger.warning('Skipping wrong github url: %s', url) - continue - - # Clone project for static analysis - base_dir = os.path.join(oss_fuzz_dir, 'projects', project_name) - if os.path.isdir(base_dir): - # Project already exists, reuse the existing project directly - generated_project_name_list.append(os.path.basename(base_dir)) - continue - - project_dir = os.path.join(base_dir, 'proj') - if not utils.git_clone_project(url, project_dir): - # Clone error or invalid url - logger.warning('Failed to clone from the github url: %s', url) - shutil.rmtree(base_dir) - continue - - # Prepare OSS-Fuzz base files - if not utils.prepare_base_files(base_dir, project_name, url): - # Invalid build type or non-Java project - logger.warning('Build type of project %s is not supported.', project_name) - shutil.rmtree(base_dir) - continue - - # Clean up project and store generated project name - generated_project_name_list.append(os.path.basename(base_dir)) - shutil.rmtree(project_dir) - - # Store generated project name - if generated_project_name_list: - with open(os.path.join(work_dir, 'project-name'), 'w') as file: - file.write(','.join(generated_project_name_list)) - - -if __name__ == '__main__': - main() + global silent_global + + args = parse_commandline() + oss_fuzz_dir = os.path.abspath(args.oss_fuzz) + work_dir = os.path.abspath(args.workdir) + silent_global = args.silent + logging.basicConfig(level=logging.INFO, format=LOG_FMT) + + generated_project_name_list = [] + for url in args.github_url.split(","): + # Retrieve project name + project_name = utils.get_project_name(url) + if not project_name: + # Malformed url + logger.warning("Skipping wrong github url: %s", url) + continue + + # Clone project for static analysis + base_dir = os.path.join(oss_fuzz_dir, "projects", project_name) + if os.path.isdir(base_dir): + # Project already exists, reuse the existing project directly + generated_project_name_list.append(os.path.basename(base_dir)) + continue + + project_dir = os.path.join(base_dir, "proj") + if not utils.git_clone_project(url, project_dir): + # Clone error or invalid url + logger.warning("Failed to clone from the github url: %s", url) + shutil.rmtree(base_dir) + continue + + # Prepare OSS-Fuzz base files + if not utils.prepare_base_files(base_dir, project_name, url): + # Invalid build type or non-Java project + logger.warning("Build type of project %s is not supported.", project_name) + shutil.rmtree(base_dir) + continue + + # Clean up project and store generated project name + generated_project_name_list.append(os.path.basename(base_dir)) + shutil.rmtree(project_dir) + + # Store generated project name + if generated_project_name_list: + with open(os.path.join(work_dir, "project-name"), "w") as file: + file.write(",".join(generated_project_name_list)) + + +if __name__ == "__main__": + main() diff --git a/experimental/jvm/result_merger.py b/experimental/jvm/result_merger.py index 278955069a..7e366d0b35 100644 --- a/experimental/jvm/result_merger.py +++ b/experimental/jvm/result_merger.py @@ -23,116 +23,121 @@ from typing import Optional logger = logging.getLogger(name=__name__) -LOG_FMT = ('%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ' - ': %(funcName)s: %(message)s') +LOG_FMT = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " ": %(funcName)s: %(message)s" +) def retrieve_top_harness_index(benchmark_dir: str) -> Optional[str]: - """Return the top harness index for a generated target""" - result = [] - for path in os.listdir(os.path.join(benchmark_dir, 'status')): - base_path = os.path.join(benchmark_dir, 'status', path) - result_path = os.path.join(base_path, 'result.json') - if os.path.isdir(base_path) and os.path.isfile(result_path): - with open(result_path, 'r') as f: - json_dict = json.load(f) - if json_dict.get('compiles', False): - result.append({ - 'index': path, - 'coverage': json_dict.get('coverage', 0.0) - }) - - if result: - return sorted(result, key=lambda item: (item.get('coverage')), - reverse=True)[0].get('index') - - return None + """Return the top harness index for a generated target""" + result = [] + for path in os.listdir(os.path.join(benchmark_dir, "status")): + base_path = os.path.join(benchmark_dir, "status", path) + result_path = os.path.join(base_path, "result.json") + if os.path.isdir(base_path) and os.path.isfile(result_path): + with open(result_path, "r") as f: + json_dict = json.load(f) + if json_dict.get("compiles", False): + result.append( + {"index": path, "coverage": json_dict.get("coverage", 0.0)} + ) + + if result: + return sorted(result, key=lambda item: (item.get("coverage")), reverse=True)[ + 0 + ].get("index") + + return None def retrieve_success_harness_name(project_dir_list: list[str]) -> list[str]: - """Returns name of success harnesses of the target project.""" - result_list = [] - for project_dir in project_dir_list: - top_harness_index = retrieve_top_harness_index(project_dir) - if top_harness_index: - result_list.append(f'{os.path.basename(project_dir)}-{top_harness_index}') + """Returns name of success harnesses of the target project.""" + result_list = [] + for project_dir in project_dir_list: + top_harness_index = retrieve_top_harness_index(project_dir) + if top_harness_index: + result_list.append(f"{os.path.basename(project_dir)}-{top_harness_index}") - return result_list + return result_list def group_result_dir_by_project_name( - result_dir: str, project_name_list: list[str]) -> dict[str, list]: - """This function group the generater directory by project name.""" - result_dir_map = {} - for project_basename in os.listdir(result_dir): - for project_name in project_name_list: - if project_basename.startswith(f'output-{project_name}-'): - project_list = result_dir_map.get(project_name, []) - project_list.append(os.path.join(result_dir, project_basename)) - result_dir_map[project_name] = project_list - break - - return result_dir_map - - -def copy_success_ofg_autogens(result_dir: str, destination: str, - project_name_list: list[str], - oss_fuzz_path: str) -> None: - """Copies the success harnesses for each projects to destination.""" - result_dir_map = group_result_dir_by_project_name(os.path.abspath(result_dir), - project_name_list) - - for project_name, project_dir_list in result_dir_map.items(): - logger.info('Handling project %s', project_name) - destination_dir = os.path.join(destination, project_name) - if not os.path.isdir(destination_dir): - os.mkdir(destination_dir) - - top_harnesses = retrieve_success_harness_name(project_dir_list) - logger.info('Found %d success harnesses for project %s', len(top_harnesses), - project_name) - for name in top_harnesses: - src_dir = os.path.join(oss_fuzz_path, 'projects', name) - dst_dir = os.path.join(destination_dir, name.rsplit('-', 1)[0]) - shutil.copytree(src_dir, dst_dir) - - if len(os.listdir(destination_dir)) == 0: - logger.warning('No success target for project %s', project_name) - shutil.rmtree(destination_dir) + result_dir: str, project_name_list: list[str] +) -> dict[str, list]: + """This function group the generater directory by project name.""" + result_dir_map = {} + for project_basename in os.listdir(result_dir): + for project_name in project_name_list: + if project_basename.startswith(f"output-{project_name}-"): + project_list = result_dir_map.get(project_name, []) + project_list.append(os.path.join(result_dir, project_basename)) + result_dir_map[project_name] = project_list + break + + return result_dir_map + + +def copy_success_ofg_autogens( + result_dir: str, destination: str, project_name_list: list[str], oss_fuzz_path: str +) -> None: + """Copies the success harnesses for each projects to destination.""" + result_dir_map = group_result_dir_by_project_name( + os.path.abspath(result_dir), project_name_list + ) + + for project_name, project_dir_list in result_dir_map.items(): + logger.info("Handling project %s", project_name) + destination_dir = os.path.join(destination, project_name) + if not os.path.isdir(destination_dir): + os.mkdir(destination_dir) + + top_harnesses = retrieve_success_harness_name(project_dir_list) + logger.info( + "Found %d success harnesses for project %s", + len(top_harnesses), + project_name, + ) + for name in top_harnesses: + src_dir = os.path.join(oss_fuzz_path, "projects", name) + dst_dir = os.path.join(destination_dir, name.rsplit("-", 1)[0]) + shutil.copytree(src_dir, dst_dir) + + if len(os.listdir(destination_dir)) == 0: + logger.warning("No success target for project %s", project_name) + shutil.rmtree(destination_dir) def parse_commandline(): - """Parse commandline.""" - parser = argparse.ArgumentParser() - parser.add_argument('--result-dir', - '-r', - help='Results created by OFG.', - type=str) - parser.add_argument( - '--destination-dir', - '-d', - help='Folder with projects generated by from-scratch OFG.', - type=str) - parser.add_argument( - '--project-name', - '-p', - help='A comma separated string of all target project name.', - type=str) - parser.add_argument('--oss-fuzz-path', - '-o', - help='Path of the OSS-Fuzz used by OFG.', - type=str) - return parser.parse_args() + """Parse commandline.""" + parser = argparse.ArgumentParser() + parser.add_argument("--result-dir", "-r", help="Results created by OFG.", type=str) + parser.add_argument( + "--destination-dir", + "-d", + help="Folder with projects generated by from-scratch OFG.", + type=str, + ) + parser.add_argument( + "--project-name", + "-p", + help="A comma separated string of all target project name.", + type=str, + ) + parser.add_argument( + "--oss-fuzz-path", "-o", help="Path of the OSS-Fuzz used by OFG.", type=str + ) + return parser.parse_args() def main(): - """CLI entrypoint.""" - logging.basicConfig(level=logging.INFO, format=LOG_FMT) + """CLI entrypoint.""" + logging.basicConfig(level=logging.INFO, format=LOG_FMT) - args = parse_commandline() - copy_success_ofg_autogens(args.result_dir, args.destination_dir, - args.project_name, args.oss_fuzz_path) + args = parse_commandline() + copy_success_ofg_autogens( + args.result_dir, args.destination_dir, args.project_name, args.oss_fuzz_path + ) if __name__ == "__main__": - main() + main() diff --git a/experimental/jvm/utils.py b/experimental/jvm/utils.py index f9207895b4..55496d493f 100644 --- a/experimental/jvm/utils.py +++ b/experimental/jvm/utils.py @@ -15,197 +15,198 @@ ############################################################################### """Provides a set of utils for oss-fuzz-gen on new Java projects integration""" -import sys - -sys.path.append('../../') - import logging import os import subprocess +import sys from typing import Optional from urllib3.util import parse_url from experimental.jvm import constants, oss_fuzz_templates +sys.path.append("../../") + + logger = logging.getLogger(__name__) # Project preparation utils ########################### def git_clone_project(github_url: str, destination: str) -> bool: - """Clone project from github url to destination""" - cmd = ['git clone', github_url, destination] - try: - subprocess.check_call(" ".join(cmd), - shell=True, - timeout=600, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL) - except subprocess.TimeoutExpired: - return False - except subprocess.CalledProcessError: - return False - return True + """Clone project from github url to destination""" + cmd = ["git clone", github_url, destination] + try: + subprocess.check_call( + " ".join(cmd), + shell=True, + timeout=600, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + except subprocess.TimeoutExpired: + return False + except subprocess.CalledProcessError: + return False + return True def get_project_name(github_url: str) -> Optional[str]: - """Get project name by simplify github url""" - # HTTPS Type - # https://github.com/{user}/{proj_name} or https://github.com/{user}/{proj_name}.git - # or - # SSH Type - # git@github.com:{user}/{proj_name} or git@github.com:{user}/{proj_name}.git - - # Remove the .git suffix - if github_url.endswith('.git'): - github_url = github_url[:-4] - - if github_url.startswith('https://'): - # Validate url for HTTPS type - parsed_url = parse_url(github_url) - host = parsed_url.host - path = parsed_url.path - if path and host == 'github.com' and len(path.split('/')) == 3: - return path.split('/')[2] - elif github_url.startswith('git@github.com:'): - # Validate url for SSH type - path = github_url.split('/') - if len(path) == 2: - return path[1] - - # Malformed or invalid github url - return None + """Get project name by simplify github url""" + # HTTPS Type + # https://github.com/{user}/{proj_name} or https://github.com/{user}/{proj_name}.git + # or + # SSH Type + # git@github.com:{user}/{proj_name} or git@github.com:{user}/{proj_name}.git + + # Remove the .git suffix + if github_url.endswith(".git"): + github_url = github_url[:-4] + + if github_url.startswith("https://"): + # Validate url for HTTPS type + parsed_url = parse_url(github_url) + host = parsed_url.host + path = parsed_url.path + if path and host == "github.com" and len(path.split("/")) == 3: + return path.split("/")[2] + elif github_url.startswith("git@github.com:"): + # Validate url for SSH type + path = github_url.split("/") + if len(path) == 2: + return path[1] + + # Malformed or invalid github url + return None def prepare_base_files(base_dir: str, project_name: str, url: str) -> bool: - """Prepare OSS-Fuzz base files for Java project fuzzing""" + """Prepare OSS-Fuzz base files for Java project fuzzing""" - # Determine build type and build directory for the project - build_type, version = _find_project_build_type(os.path.join(base_dir, "proj"), - project_name) + # Determine build type and build directory for the project + build_type, version = _find_project_build_type( + os.path.join(base_dir, "proj"), project_name + ) - # Preapre build.sh and Dockerfile - build_file = _get_build_file(build_type) - docker_file = _get_docker_file(build_type, version, url) - if not docker_file or not build_file: - return False + # Preapre build.sh and Dockerfile + build_file = _get_build_file(build_type) + docker_file = _get_docker_file(build_type, version, url) + if not docker_file or not build_file: + return False - try: - with open(os.path.join(base_dir, 'build.sh'), 'w') as f: - f.write(build_file) + try: + with open(os.path.join(base_dir, "build.sh"), "w") as f: + f.write(build_file) - with open(os.path.join(base_dir, 'Dockerfile'), 'w') as f: - f.write(docker_file) + with open(os.path.join(base_dir, "Dockerfile"), "w") as f: + f.write(docker_file) - with open(os.path.join(base_dir, 'project.yaml'), 'w') as f: - f.write(oss_fuzz_templates.YAML_JAVA.replace("{TARGET_REPO}", url)) + with open(os.path.join(base_dir, "project.yaml"), "w") as f: + f.write(oss_fuzz_templates.YAML_JAVA.replace("{TARGET_REPO}", url)) - with open(os.path.join(base_dir, 'Fuzz.java'), 'w') as f: - f.write(oss_fuzz_templates.FUZZER_JAVA) - except: - return False + with open(os.path.join(base_dir, "Fuzz.java"), "w") as f: + f.write(oss_fuzz_templates.FUZZER_JAVA) + except: + return False - return True + return True def _get_build_file(build_type: str) -> str: - """Prepare build.sh content for this project.""" + """Prepare build.sh content for this project.""" - if build_type == 'ant': - build_file = oss_fuzz_templates.BUILD_JAVA_ANT - elif build_type == 'gradle': - build_file = oss_fuzz_templates.BUILD_JAVA_GRADLE - elif build_type == 'maven': - build_file = oss_fuzz_templates.BUILD_JAVA_MAVEN - else: - return '' + if build_type == "ant": + build_file = oss_fuzz_templates.BUILD_JAVA_ANT + elif build_type == "gradle": + build_file = oss_fuzz_templates.BUILD_JAVA_GRADLE + elif build_type == "maven": + build_file = oss_fuzz_templates.BUILD_JAVA_MAVEN + else: + return "" - return build_file + oss_fuzz_templates.BUILD_JAVA_BASE + return build_file + oss_fuzz_templates.BUILD_JAVA_BASE def _get_docker_file(build_type: str, version: str, url: str) -> str: - """Prepare build.sh content for this project.""" + """Prepare build.sh content for this project.""" - if build_type == 'ant': - docker_file = oss_fuzz_templates.DOCKERFILE_JAVA_ANT - docker_file = docker_file.replace('{ANT_URL}', constants.ANT_URL) - elif build_type == 'gradle': - docker_file = oss_fuzz_templates.DOCKERFILE_JAVA_GRADLE - docker_file = docker_file.replace('{GRADLE_URL}', constants.GRADLE_URL) - elif build_type == 'maven': - # Check for invalid version - if version not in constants.MAVEN_URL: - return '' + if build_type == "ant": + docker_file = oss_fuzz_templates.DOCKERFILE_JAVA_ANT + docker_file = docker_file.replace("{ANT_URL}", constants.ANT_URL) + elif build_type == "gradle": + docker_file = oss_fuzz_templates.DOCKERFILE_JAVA_GRADLE + docker_file = docker_file.replace("{GRADLE_URL}", constants.GRADLE_URL) + elif build_type == "maven": + # Check for invalid version + if version not in constants.MAVEN_URL: + return "" - docker_file = oss_fuzz_templates.DOCKERFILE_JAVA_MAVEN - docker_file = docker_file.replace('{MAVEN_URL}', - constants.MAVEN_URL[version]) - docker_file = docker_file.replace('{MAVEN_VERSION}', version) - else: - return '' + docker_file = oss_fuzz_templates.DOCKERFILE_JAVA_MAVEN + docker_file = docker_file.replace("{MAVEN_URL}", constants.MAVEN_URL[version]) + docker_file = docker_file.replace("{MAVEN_VERSION}", version) + else: + return "" - docker_file = docker_file.replace('{PROTO_URL}', constants.PROTO_URL) - docker_file = docker_file.replace('{JDK15_URL}', constants.JDK15_URL) - docker_file = docker_file.replace('{TARGET_REPO}', url) + docker_file = docker_file.replace("{PROTO_URL}", constants.PROTO_URL) + docker_file = docker_file.replace("{JDK15_URL}", constants.JDK15_URL) + docker_file = docker_file.replace("{TARGET_REPO}", url) - return docker_file + return docker_file # Java Project discovery utils ############################## def _find_dir_build_type(project_dir: str) -> tuple[str, str]: - """Determine the java build project type of the directory""" + """Determine the java build project type of the directory""" - if os.path.exists(os.path.join(project_dir, 'pom.xml')): - return 'maven', _get_maven_version(project_dir) + if os.path.exists(os.path.join(project_dir, "pom.xml")): + return "maven", _get_maven_version(project_dir) - if os.path.exists(os.path.join( - project_dir, 'build.gradle')) or os.path.exists( - os.path.join(project_dir, 'build.gradle.kts')): - return 'gradle', '' + if os.path.exists(os.path.join(project_dir, "build.gradle")) or os.path.exists( + os.path.join(project_dir, "build.gradle.kts") + ): + return "gradle", "" - if os.path.exists(os.path.join(project_dir, 'build.xml')): - return 'ant', '' + if os.path.exists(os.path.join(project_dir, "build.xml")): + return "ant", "" - return '', '' + return "", "" def _get_maven_version(base_dir: str) -> str: - """Prepare Maven specific logic for build.sh.""" - with open(os.path.join(base_dir, 'pom.xml'), 'r') as f: - data = f.read() - - # Determine if the project requires older JVM - if '1.5' in data or '1.5' in data: - return '3.1.1' + """Prepare Maven specific logic for build.sh.""" + with open(os.path.join(base_dir, "pom.xml"), "r") as f: + data = f.read() - if '1.6' in data or '1.6' in data: - return '3.2.5' + # Determine if the project requires older JVM + if "1.5" in data or "1.5" in data: + return "3.1.1" - return '3.9.2' + if "1.6" in data or "1.6" in data: + return "3.2.5" + return "3.9.2" -def _find_project_build_type(project_dir: str, - proj_name: str) -> tuple[str, str]: - """Search for base project directory to detect project build type""" - # Search for current directory first - project_build_data = _find_dir_build_type(project_dir) - if project_build_data: - return project_build_data - # Search for sub directory with name same as project name - for subdir in os.listdir(project_dir): - if os.path.isdir(os.path.join(project_dir, subdir)) and subdir == proj_name: - target_dir = os.path.join(project_dir, subdir) - project_build_data = _find_dir_build_type(target_dir) +def _find_project_build_type(project_dir: str, proj_name: str) -> tuple[str, str]: + """Search for base project directory to detect project build type""" + # Search for current directory first + project_build_data = _find_dir_build_type(project_dir) if project_build_data: - return project_build_data - - # Recursively look for subdirectory that contains build property file - for root, _, _ in os.walk(project_dir): - project_build_data = _find_dir_build_type(root) - if project_build_data: - return project_build_data - - return '', '' + return project_build_data + + # Search for sub directory with name same as project name + for subdir in os.listdir(project_dir): + if os.path.isdir(os.path.join(project_dir, subdir)) and subdir == proj_name: + target_dir = os.path.join(project_dir, subdir) + project_build_data = _find_dir_build_type(target_dir) + if project_build_data: + return project_build_data + + # Recursively look for subdirectory that contains build property file + for root, _, _ in os.walk(project_dir): + project_build_data = _find_dir_build_type(root) + if project_build_data: + return project_build_data + + return "", "" diff --git a/experimental/manual/oss_fuzz_vuln_prompt.py b/experimental/manual/oss_fuzz_vuln_prompt.py index 8857df3f54..0cd74350d2 100644 --- a/experimental/manual/oss_fuzz_vuln_prompt.py +++ b/experimental/manual/oss_fuzz_vuln_prompt.py @@ -75,154 +75,163 @@ * Thoroughly analyze the patch to verify that it eliminates the crash and prevents exploitation. """ PROMPT_MAX_LENGTH = 1048576 -PROJECTS_DIR = os.path.join('oss-fuzz', 'projects') -STACK_FRAME_START_REGEX = re.compile(r'\s*#\d+\s+0x[0-9A-Fa-f]+\s+') +PROJECTS_DIR = os.path.join("oss-fuzz", "projects") +STACK_FRAME_START_REGEX = re.compile(r"\s*#\d+\s+0x[0-9A-Fa-f]+\s+") STACK_FRAME_PATH_LINE_REGEX = re.compile( - r'(?<=\[|\(|\s)([a-zA-Z/.][^\s]*?)\s*(:|@)\s*(\d+)(?=\]$|\)$|:\d+$|$)') -EXCLUDED_FILE_PATH_SUBSTRINGS = ('/compiler-rt/', '/glibc-', '/usr/') + r"(?<=\[|\(|\s)([a-zA-Z/.][^\s]*?)\s*(:|@)\s*(\d+)(?=\]$|\)$|:\d+$|$)" +) +EXCLUDED_FILE_PATH_SUBSTRINGS = ("/compiler-rt/", "/glibc-", "/usr/") def get_local_repo_path(repo_url): - """Returns the local path of the repository.""" - local_repo_name = repo_url.split('/')[-1] - return os.path.join(PROJECTS_DIR, local_repo_name) + """Returns the local path of the repository.""" + local_repo_name = repo_url.split("/")[-1] + return os.path.join(PROJECTS_DIR, local_repo_name) def get_git_commit_range(regression_range): - """Converts a regression range to a git commit range.""" - # If the range is a single commit, return the range as previous commit:commit. - if not ':' in regression_range and not '..' in regression_range: - return f"{regression_range}~..{regression_range}" + """Converts a regression range to a git commit range.""" + # If the range is a single commit, return the range as previous commit:commit. + if not ":" in regression_range and not ".." in regression_range: + return f"{regression_range}~..{regression_range}" - return regression_range.replace(':', '..') + return regression_range.replace(":", "..") def get_changeset_diff(repo_url, regression_range): - """Fetches the code diff for a given commit range in a Git repository.""" - local_repo_path = get_local_repo_path(repo_url) - - try: - if not os.path.exists(local_repo_path): - subprocess.run(["git", "clone", repo_url, local_repo_path], - stdout=subprocess.DEVNULL, - check=True) - else: - subprocess.run(["git", "pull"], - cwd=local_repo_path, - stdout=subprocess.DEVNULL, - check=True) - except Exception as e: - raise RuntimeError(f"Error cloning/pulling repository {repo_url}: {e}") - - try: - repo = Repo(local_repo_path) - except InvalidGitRepositoryError: - raise ValueError(f"Invalid Git repository path: {local_repo_path}") - - try: - diff = repo.git.diff(get_git_commit_range(regression_range)) - return diff.encode('utf-8', 'replace').decode('utf-8') - except Exception as e: - raise RuntimeError(f"Error retrieving changeset diff: {e}") + """Fetches the code diff for a given commit range in a Git repository.""" + local_repo_path = get_local_repo_path(repo_url) + + try: + if not os.path.exists(local_repo_path): + subprocess.run( + ["git", "clone", repo_url, local_repo_path], + stdout=subprocess.DEVNULL, + check=True, + ) + else: + subprocess.run( + ["git", "pull"], + cwd=local_repo_path, + stdout=subprocess.DEVNULL, + check=True, + ) + except Exception as e: + raise RuntimeError(f"Error cloning/pulling repository {repo_url}: {e}") + + try: + repo = Repo(local_repo_path) + except InvalidGitRepositoryError: + raise ValueError(f"Invalid Git repository path: {local_repo_path}") + + try: + diff = repo.git.diff(get_git_commit_range(regression_range)) + return diff.encode("utf-8", "replace").decode("utf-8") + except Exception as e: + raise RuntimeError(f"Error retrieving changeset diff: {e}") def find_file(file_path, commit): - """Finds a file in a git commit tree.""" - # Check if the file exists in the git commit tree. - for tree in commit.tree.traverse(): - if tree.path == file_path: - return file_path - - # Check if another file with same name exists in the commit tree. - filename = os.path.basename(file_path) - for tree in commit.tree.traverse(): - if os.path.basename(tree.path) == filename: - return str(tree.path) - - # File not found. - return None - - -def get_file_content(repo_url, crash_revision, file_path): - """Fetches the content of a file in a Git repository.""" - local_repo_path = get_local_repo_path(repo_url) - local_file_path = file_path[:].removeprefix('/src/') - local_file_path = local_file_path.split('/', 1)[-1] - - try: - repo = Repo(local_repo_path) - except InvalidGitRepositoryError: - raise ValueError(f"Invalid git repository path: {local_repo_path}") - - try: - commit = repo.commit(crash_revision) - except BadName: - print(f"Error: Commit hash '{crash_revision}' not found in repository.") + """Finds a file in a git commit tree.""" + # Check if the file exists in the git commit tree. + for tree in commit.tree.traverse(): + if tree.path == file_path: + return file_path + + # Check if another file with same name exists in the commit tree. + filename = os.path.basename(file_path) + for tree in commit.tree.traverse(): + if os.path.basename(tree.path) == filename: + return str(tree.path) + + # File not found. return None - local_file_path = find_file(local_file_path, commit) - if not local_file_path: - print(f"Error: '{file_path}' not found in repository.") - return None - try: - return commit.tree[local_file_path].data_stream.read().decode('utf-8') - except Exception: - print(f"Error: '{file_path}' not found at commit '{crash_revision}'.") - return None +def get_file_content(repo_url, crash_revision, file_path): + """Fetches the content of a file in a Git repository.""" + local_repo_path = get_local_repo_path(repo_url) + local_file_path = file_path[:].removeprefix("/src/") + local_file_path = local_file_path.split("/", 1)[-1] + + try: + repo = Repo(local_repo_path) + except InvalidGitRepositoryError: + raise ValueError(f"Invalid git repository path: {local_repo_path}") + + try: + commit = repo.commit(crash_revision) + except BadName: + print(f"Error: Commit hash '{crash_revision}' not found in repository.") + return None + + local_file_path = find_file(local_file_path, commit) + if not local_file_path: + print(f"Error: '{file_path}' not found in repository.") + return None + + try: + return commit.tree[local_file_path].data_stream.read().decode("utf-8") + except Exception: + print(f"Error: '{file_path}' not found at commit '{crash_revision}'.") + return None if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Generate a prompt for a security vulnerability.") - parser.add_argument("--repo_url", help="Path to the GitHub repository") - parser.add_argument( - "--regression_range", - help="Commit range in the format 'start_commit:end_commit'") - parser.add_argument("--crash_revision", - help="Revision where the crash occurred") - parser.add_argument("--crash_stacktrace", - help="File containing the crash stacktrace") - - args = parser.parse_args() - os.makedirs(PROJECTS_DIR, exist_ok=True) - - with open(args.crash_stacktrace) as file_handle: - stacktrace = file_handle.read() - changeset_diff = get_changeset_diff(args.repo_url, args.regression_range) - source_code_content = "" - found_sanitizer_error = False - parsed_file_paths = set() - for line in stacktrace.splitlines(): - if not STACK_FRAME_START_REGEX.match(line): - continue - - match = STACK_FRAME_PATH_LINE_REGEX.search(line) - if not match: - continue - - file_path = match.group(1) - if file_path in parsed_file_paths: - continue - if any( - substring in file_path for substring in EXCLUDED_FILE_PATH_SUBSTRINGS): - continue - - file_content = get_file_content(args.repo_url, args.crash_revision, - file_path) - if not file_content: - continue - - source_code_content += ( - f'**FILE CONTENT: {file_path} **\n{file_content}\n**FILE CONTENT END**\n' + parser = argparse.ArgumentParser( + description="Generate a prompt for a security vulnerability." + ) + parser.add_argument("--repo_url", help="Path to the GitHub repository") + parser.add_argument( + "--regression_range", + help="Commit range in the format 'start_commit:end_commit'", + ) + parser.add_argument("--crash_revision", help="Revision where the crash occurred") + parser.add_argument( + "--crash_stacktrace", help="File containing the crash stacktrace" + ) + + args = parser.parse_args() + os.makedirs(PROJECTS_DIR, exist_ok=True) + + with open(args.crash_stacktrace) as file_handle: + stacktrace = file_handle.read() + changeset_diff = get_changeset_diff(args.repo_url, args.regression_range) + source_code_content = "" + found_sanitizer_error = False + parsed_file_paths = set() + for line in stacktrace.splitlines(): + if not STACK_FRAME_START_REGEX.match(line): + continue + + match = STACK_FRAME_PATH_LINE_REGEX.search(line) + if not match: + continue + + file_path = match.group(1) + if file_path in parsed_file_paths: + continue + if any(substring in file_path for substring in EXCLUDED_FILE_PATH_SUBSTRINGS): + continue + + file_content = get_file_content(args.repo_url, args.crash_revision, file_path) + if not file_content: + continue + + source_code_content += ( + f"**FILE CONTENT: {file_path} **\n{file_content}\n**FILE CONTENT END**\n" + ) + parsed_file_paths.add(file_path) + + source_code_content = source_code_content[ + : PROMPT_MAX_LENGTH + - len(PROMPT_TEMPLATE) + - len(stacktrace) + - len(changeset_diff) + ] + prompt = PROMPT_TEMPLATE.format( + stacktrace=stacktrace, + changeset_diff=changeset_diff, + source_code_content=source_code_content, ) - parsed_file_paths.add(file_path) - - source_code_content = source_code_content[:PROMPT_MAX_LENGTH - - len(PROMPT_TEMPLATE) - - len(stacktrace) - - len(changeset_diff)] - prompt = PROMPT_TEMPLATE.format(stacktrace=stacktrace, - changeset_diff=changeset_diff, - source_code_content=source_code_content) - print(prompt) + print(prompt) diff --git a/experimental/manual/prompter.py b/experimental/manual/prompter.py index 2deebb6a53..dd30fe5489 100644 --- a/experimental/manual/prompter.py +++ b/experimental/manual/prompter.py @@ -32,59 +32,62 @@ def parse_args() -> argparse.Namespace: - """Parses command line arguments.""" - parser = argparse.ArgumentParser( - description='Run all experiments that evaluates all target functions.') - parser.add_argument('-n', - '--num-samples', - type=int, - default=NUM_SAMPLES, - help='The number of samples to request from LLM.') - parser.add_argument( - '-t', - '--temperature', - type=float, - default=TEMPERATURE, - help=('A value presenting the variety of the targets generated by LLM. ' - 'It should be within [0,2] for Gemini-1.5 models and [0,1] for ' - 'Gemini-1.0 models')) - parser.add_argument('-l', - '--model', - default=models.DefaultModel.name, - help=('Models available: ' - f'{", ".join(models.LLM.all_llm_names())}')) - parser.add_argument('-p', - '--prompt', - help='Prompt file for LLM.', - required=True) - parser.add_argument('-r', - '--response-dir', - default='./responses', - help='LLM response directory.') - return parser.parse_args() + """Parses command line arguments.""" + parser = argparse.ArgumentParser( + description="Run all experiments that evaluates all target functions." + ) + parser.add_argument( + "-n", + "--num-samples", + type=int, + default=NUM_SAMPLES, + help="The number of samples to request from LLM.", + ) + parser.add_argument( + "-t", + "--temperature", + type=float, + default=TEMPERATURE, + help=( + "A value presenting the variety of the targets generated by LLM. " + "It should be within [0,2] for Gemini-1.5 models and [0,1] for " + "Gemini-1.0 models" + ), + ) + parser.add_argument( + "-l", + "--model", + default=models.DefaultModel.name, + help=("Models available: " f'{", ".join(models.LLM.all_llm_names())}'), + ) + parser.add_argument("-p", "--prompt", help="Prompt file for LLM.", required=True) + parser.add_argument( + "-r", "--response-dir", default="./responses", help="LLM response directory." + ) + return parser.parse_args() def setup_model() -> models.LLM: - return models.LLM.setup( - ai_binary='', - name=args.model, - max_tokens=MAX_TOKENS, - num_samples=args.num_samples, - temperature=args.temperature, - ) + return models.LLM.setup( + ai_binary="", + name=args.model, + max_tokens=MAX_TOKENS, + num_samples=args.num_samples, + temperature=args.temperature, + ) def construct_prompt() -> prompts.Prompt: - with open(args.prompt, 'r') as prompt_file: - content = prompt_file.read() - prompt = model.prompt_type()() - prompt.add_problem(content) - return prompt + with open(args.prompt, "r") as prompt_file: + content = prompt_file.read() + prompt = model.prompt_type()() + prompt.add_problem(content) + return prompt if __name__ == "__main__": - args = parse_args() - model = setup_model() - prompt = construct_prompt() - os.makedirs(args.response_dir, exist_ok=True) - model.query_llm(prompt, response_dir=args.response_dir) + args = parse_args() + model = setup_model() + prompt = construct_prompt() + os.makedirs(args.response_dir, exist_ok=True) + model.query_llm(prompt, response_dir=args.response_dir) diff --git a/helper/diff_target.py b/helper/diff_target.py index b78441f626..a87e9b3c84 100644 --- a/helper/diff_target.py +++ b/helper/diff_target.py @@ -25,60 +25,61 @@ def load_yaml(file_path): - """Load a YAML file and return its contents.""" - try: - with open(file_path, 'r') as f: - return yaml.safe_load(f) - except Exception as e: - print(f"Error loading YAML file {file_path}: {e}") - return None + """Load a YAML file and return its contents.""" + try: + with open(file_path, "r") as f: + return yaml.safe_load(f) + except Exception as e: + print(f"Error loading YAML file {file_path}: {e}") + return None def save_yaml(file_path, data): - """Save a dictionary to a YAML file.""" - try: - with open(file_path, 'w') as f: - yaml.safe_dump(data, f) - except Exception as e: - print(f"Error saving YAML file {file_path}: {e}") + """Save a dictionary to a YAML file.""" + try: + with open(file_path, "w") as f: + yaml.safe_dump(data, f) + except Exception as e: + print(f"Error saving YAML file {file_path}: {e}") def overwrite_target_name(dir_a, dir_b): - """Overwrite target_name in A if target_path is the same and target_name differs.""" - common_files = set(os.listdir(dir_a)) & set(os.listdir(dir_b)) - updated_files = [] + """Overwrite target_name in A if target_path is the same and target_name differs.""" + common_files = set(os.listdir(dir_a)) & set(os.listdir(dir_b)) + updated_files = [] - for file_name in common_files: - if not file_name.endswith('.yaml'): - continue + for file_name in common_files: + if not file_name.endswith(".yaml"): + continue - file_a = os.path.join(dir_a, file_name) - file_b = os.path.join(dir_b, file_name) + file_a = os.path.join(dir_a, file_name) + file_b = os.path.join(dir_b, file_name) - data_a = load_yaml(file_a) - data_b = load_yaml(file_b) + data_a = load_yaml(file_a) + data_b = load_yaml(file_b) - if data_a and data_b: - if data_a.get('target_path') == data_b.get('target_path'): - if data_a.get('target_name') != data_b.get('target_name'): - data_a['target_name'] = data_b['target_name'] - save_yaml(file_a, data_a) - updated_files.append(file_name) + if data_a and data_b: + if data_a.get("target_path") == data_b.get("target_path"): + if data_a.get("target_name") != data_b.get("target_name"): + data_a["target_name"] = data_b["target_name"] + save_yaml(file_a, data_a) + updated_files.append(file_name) - return updated_files + return updated_files if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Update target_name in A if target_path matches B.") - parser.add_argument("dir_a", help="Path to directory A (to be updated)") - parser.add_argument("dir_b", help="Path to directory B (source of updates)") - args = parser.parse_args() + parser = argparse.ArgumentParser( + description="Update target_name in A if target_path matches B." + ) + parser.add_argument("dir_a", help="Path to directory A (to be updated)") + parser.add_argument("dir_b", help="Path to directory B (source of updates)") + args = parser.parse_args() - updated_files = overwrite_target_name(args.dir_a, args.dir_b) - if updated_files: - print("The following files in A had target_name updated from B:") - for file_name in updated_files: - print(file_name) - else: - print("No updates were necessary.") + updated_files = overwrite_target_name(args.dir_a, args.dir_b) + if updated_files: + print("The following files in A had target_name updated from B:") + for file_name in updated_files: + print(file_name) + else: + print("No updates were necessary.") diff --git a/helper/result_string_search.py b/helper/result_string_search.py index b6c680202b..0385df5134 100644 --- a/helper/result_string_search.py +++ b/helper/result_string_search.py @@ -32,126 +32,144 @@ def _parse_args() -> argparse.Namespace: - """Parses arguments.""" - parser = argparse.ArgumentParser( - description= - 'Search for all benchmark that contains the in ') - parser.add_argument('-is', - '--include-strings', - type=str, - nargs='+', - default=[], - help='The string to include in result.') - parser.add_argument('-es', - '--exclude-strings', - type=str, - nargs='+', - default=[], - help='The string to exclude in result.') - parser.add_argument('-r', - '--result', - type=str, - required=True, - help='The root path to the result directory.') - parser.add_argument( - '-b', - '--sub', - type=str, - default='', - help=('The subdirectory to search in each output-* directories. ' - 'Search in all subdirectories by default')) - parser.add_argument( - '-u', - '--url', - type=str, - default='', - help='Optional url template to web report for easy access.') - - args = parser.parse_args() - assert os.path.isdir(args.result), '--result must be an existing directory.' - - output_dirs = os.listdir(args.result) - assert any( - os.path.isdir(os.path.join(args.result, d, args.sub)) for d in output_dirs - ), ('--sub must be a directory in output-* directories under \n' - 'E.g. fixed_targets, logs, raw_targets, status.') - - return args - - -def find_in_file(include_lines: list[str], exclude_lines: list[str], - file_path: str) -> bool: - """Returns True if the file_path matches the in/exclude strings.""" - with open(file_path) as f: - lines = f.readlines() - - for line in lines: - if any(exclude_line in line for exclude_line in exclude_lines): - return False - - for include_line in include_lines: - if not any(include_line in line for line in lines): - return False - - # logging.info('Matched in %s', file_path) - return True - - -def find_in_dir(include_lines: list[str], exclude_lines: list[str], - file_paths: list[str]) -> list[str]: - """Returns files in |file_paths| that contain |include_lines|.""" - # Caveat: With support for multiline search in potentially large files - # (e.g. log files), this function does not search for the exact substring - # in the file containing the new line char. Instead, it returns True when: - # other_text other_text - # other_text other_text - # can be found in the file. - return [ - file_path for file_path in file_paths - if find_in_file(include_lines, exclude_lines, file_path) - ] + """Parses arguments.""" + parser = argparse.ArgumentParser( + description="Search for all benchmark that contains the in " + ) + parser.add_argument( + "-is", + "--include-strings", + type=str, + nargs="+", + default=[], + help="The string to include in result.", + ) + parser.add_argument( + "-es", + "--exclude-strings", + type=str, + nargs="+", + default=[], + help="The string to exclude in result.", + ) + parser.add_argument( + "-r", + "--result", + type=str, + required=True, + help="The root path to the result directory.", + ) + parser.add_argument( + "-b", + "--sub", + type=str, + default="", + help=( + "The subdirectory to search in each output-* directories. " + "Search in all subdirectories by default" + ), + ) + parser.add_argument( + "-u", + "--url", + type=str, + default="", + help="Optional url template to web report for easy access.", + ) + + args = parser.parse_args() + assert os.path.isdir(args.result), "--result must be an existing directory." + + output_dirs = os.listdir(args.result) + assert any( + os.path.isdir(os.path.join(args.result, d, args.sub)) for d in output_dirs + ), ( + "--sub must be a directory in output-* directories under \n" + "E.g. fixed_targets, logs, raw_targets, status." + ) + + return args + + +def find_in_file( + include_lines: list[str], exclude_lines: list[str], file_path: str +) -> bool: + """Returns True if the file_path matches the in/exclude strings.""" + with open(file_path) as f: + lines = f.readlines() + + for line in lines: + if any(exclude_line in line for exclude_line in exclude_lines): + return False + + for include_line in include_lines: + if not any(include_line in line for line in lines): + return False + + # logging.info('Matched in %s', file_path) + return True + + +def find_in_dir( + include_lines: list[str], exclude_lines: list[str], file_paths: list[str] +) -> list[str]: + """Returns files in |file_paths| that contain |include_lines|.""" + # Caveat: With support for multiline search in potentially large files + # (e.g. log files), this function does not search for the exact substring + # in the file containing the new line char. Instead, it returns True when: + # other_text other_text + # other_text other_text + # can be found in the file. + return [ + file_path + for file_path in file_paths + if find_in_file(include_lines, exclude_lines, file_path) + ] def main(): - args = _parse_args() - result_dir = args.result - include_lines = args.include_strings - exclude_lines = args.exclude_strings - sub = args.sub - hits = [] - - logging.info( - 'Search files including string:\n\t%s\nBut exclude string:\n\t%s\n', - '\n\t'.join(include_lines), '\n\t'.join(exclude_lines)) - # Iterates through all output-*/ - for output_dir in sorted(os.listdir(result_dir)): - if not os.path.isdir(os.path.join(result_dir, output_dir)): - continue - - # Iterates through all subdirectories. - for path, sub_dir, files in os.walk( - os.path.join(result_dir, output_dir, sub)): - # Except corpora. - if 'corpora' in sub_dir: - sub_dir.remove('corpora') - - # Iterates through all files in directory. - if file_paths := find_in_dir( - include_lines, exclude_lines, - [os.path.join(path, file_name) for file_name in files]): - hits.extend(file_paths) - break - - hits.sort() - url = args.url - if url: - benchmark_report = '\n'.join([f'{url}/{hit}' for hit in hits]) - else: - benchmark_report = '\n'.join(hits) - logging.info('Found Report URLs:') - print(benchmark_report) - - -if __name__ == '__main__': - logging.basicConfig(level=logging.INFO) - main() + args = _parse_args() + result_dir = args.result + include_lines = args.include_strings + exclude_lines = args.exclude_strings + sub = args.sub + hits = [] + + logging.info( + "Search files including string:\n\t%s\nBut exclude string:\n\t%s\n", + "\n\t".join(include_lines), + "\n\t".join(exclude_lines), + ) + # Iterates through all output-*/ + for output_dir in sorted(os.listdir(result_dir)): + if not os.path.isdir(os.path.join(result_dir, output_dir)): + continue + + # Iterates through all subdirectories. + for path, sub_dir, files in os.walk(os.path.join(result_dir, output_dir, sub)): + # Except corpora. + if "corpora" in sub_dir: + sub_dir.remove("corpora") + + # Iterates through all files in directory. + if file_paths := find_in_dir( + include_lines, + exclude_lines, + [os.path.join(path, file_name) for file_name in files], + ): + hits.extend(file_paths) + break + + hits.sort() + url = args.url + if url: + benchmark_report = "\n".join([f"{url}/{hit}" for hit in hits]) + else: + benchmark_report = "\n".join(hits) + logging.info("Found Report URLs:") + print(benchmark_report) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/helper/update_comp_benchmarks.py b/helper/update_comp_benchmarks.py index 84903f1850..409bdedbfb 100644 --- a/helper/update_comp_benchmarks.py +++ b/helper/update_comp_benchmarks.py @@ -24,64 +24,67 @@ from experiment.benchmark import Benchmark -BENCHMARK_DIR = 'benchmark-sets' -SOURCE_SET = 'all' -TARGET_SET = 'comparison' +BENCHMARK_DIR = "benchmark-sets" +SOURCE_SET = "all" +TARGET_SET = "comparison" def parse_args() -> argparse.Namespace: - """parse arguments""" - parser = argparse.ArgumentParser( - description= - 'Updates all benchmark yamls in to match with .') - parser.add_argument( - '-s', - '--source', - type=str, - default=os.path.join(BENCHMARK_DIR, SOURCE_SET), - help='The source benchmark set used to update target set.') - parser.add_argument('-t', - '--target', - type=str, - default=os.path.join(BENCHMARK_DIR, TARGET_SET), - help='The target benchmark set to update.') + """parse arguments""" + parser = argparse.ArgumentParser( + description="Updates all benchmark yamls in to match with ." + ) + parser.add_argument( + "-s", + "--source", + type=str, + default=os.path.join(BENCHMARK_DIR, SOURCE_SET), + help="The source benchmark set used to update target set.", + ) + parser.add_argument( + "-t", + "--target", + type=str, + default=os.path.join(BENCHMARK_DIR, TARGET_SET), + help="The target benchmark set to update.", + ) - args = parser.parse_args() - assert os.path.isdir(args.target), '--target must be an existing directory.' - assert os.path.isdir(args.source), '--source must be an existing directory.' + args = parser.parse_args() + assert os.path.isdir(args.target), "--target must be an existing directory." + assert os.path.isdir(args.source), "--source must be an existing directory." - return args + return args def main(): - args = parse_args() - target_path = args.target - src_path = args.source + args = parse_args() + target_path = args.target + src_path = args.source - for file_name in os.listdir(target_path): - if not file_name.endswith('.yaml'): - continue + for file_name in os.listdir(target_path): + if not file_name.endswith(".yaml"): + continue - target_bms = Benchmark.from_yaml(os.path.join(target_path, file_name)) - try: - source_bms = Benchmark.from_yaml(os.path.join(src_path, file_name)) - except FileNotFoundError: - logging.error('%s is not found in %s', file_name, src_path) - continue + target_bms = Benchmark.from_yaml(os.path.join(target_path, file_name)) + try: + source_bms = Benchmark.from_yaml(os.path.join(src_path, file_name)) + except FileNotFoundError: + logging.error("%s is not found in %s", file_name, src_path) + continue - # Get raw name of the functions selected in target. - functions = [b.function_name for b in target_bms] - # Get the selected benchmarks from source. - selected_bms = [] - for b in source_bms: - if b.function_name in functions: - selected_bms.append(b) + # Get raw name of the functions selected in target. + functions = [b.function_name for b in target_bms] + # Get the selected benchmarks from source. + selected_bms = [] + for b in source_bms: + if b.function_name in functions: + selected_bms.append(b) - Benchmark.to_yaml(selected_bms, outdir=target_path) - logging.info('Updated %s', file_name) + Benchmark.to_yaml(selected_bms, outdir=target_path) + logging.info("Updated %s", file_name) -if __name__ == '__main__': - logging.basicConfig(level=logging.INFO) +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) - main() + main() diff --git a/llm_toolkit/code_fixer.py b/llm_toolkit/code_fixer.py index 6e1f307807..91cd522405 100755 --- a/llm_toolkit/code_fixer.py +++ b/llm_toolkit/code_fixer.py @@ -33,708 +33,753 @@ NO_MEMBER_ERROR_REGEX = r"error: no member named '.*' in '([^':]*):?.*'" FILE_NOT_FOUND_ERROR_REGEX = r"fatal error: '([^']*)' file not found" UNDEFINED_REF_ERROR_REGEX = r"undefined reference to `([^']*)'" -UNKNOWN_TYPE_ERROR = 'error: unknown type name' +UNKNOWN_TYPE_ERROR = "error: unknown type name" # The following strings identify errors when a C fuzz target attempts to use # FuzzedDataProvider. -FALSE_FUZZED_DATA_PROVIDER_ERROR = 'include/fuzzer/FuzzedDataProvider.h:16:10:' -FALSE_EXTERN_KEYWORD_ERROR = 'expected identifier or \'(\'\nextern "C"' -FDP_INCLUDE_STATEMENT = '#include ' +FALSE_FUZZED_DATA_PROVIDER_ERROR = "include/fuzzer/FuzzedDataProvider.h:16:10:" +FALSE_EXTERN_KEYWORD_ERROR = "expected identifier or '('\nextern \"C\"" +FDP_INCLUDE_STATEMENT = "#include " def parse_args(): - """Parses command line arguments.""" - argparser = argparse.ArgumentParser( - description='Fix the raw fuzz targets generated by LLM.') - argparser.add_argument( - '-t', - '--target-dir', - type=str, - default='./fixed_targets', - help='The directory to store all fixed LLM-generated targets.') - argparser.add_argument( - '-o', - '--intermediate-output-dir', - type=str, - default='./code_fix_output', - help=('The directory to store all intermediate output files (LLM prompt, ' - 'rawoutput).')) - argparser.add_argument('-p', - '--project', - type=str, - required=True, - help='The project name.') - argparser.add_argument('-f', - '--function', - type=str, - required=True, - help='The function name.') - argparser.add_argument('-l', - '--log', - type=str, - required=True, - help='The build log file containing the error to fix.') - - args = argparser.parse_args() - if args.target_dir and os.listdir(args.target_dir): - assert os.path.isdir( - args.target_dir - ), f'--target-dir must take an existing directory: {args.target_dir}.' - assert os.listdir( - args.target_dir - ), f'--target-dir must take a non-empty directory: {args.target_dir}.' - - os.makedirs(args.intermediate_output_dir, exist_ok=True) - - return args + """Parses command line arguments.""" + argparser = argparse.ArgumentParser( + description="Fix the raw fuzz targets generated by LLM." + ) + argparser.add_argument( + "-t", + "--target-dir", + type=str, + default="./fixed_targets", + help="The directory to store all fixed LLM-generated targets.", + ) + argparser.add_argument( + "-o", + "--intermediate-output-dir", + type=str, + default="./code_fix_output", + help=( + "The directory to store all intermediate output files (LLM prompt, " + "rawoutput)." + ), + ) + argparser.add_argument( + "-p", "--project", type=str, required=True, help="The project name." + ) + argparser.add_argument( + "-f", "--function", type=str, required=True, help="The function name." + ) + argparser.add_argument( + "-l", + "--log", + type=str, + required=True, + help="The build log file containing the error to fix.", + ) + + args = argparser.parse_args() + if args.target_dir and os.listdir(args.target_dir): + assert os.path.isdir( + args.target_dir + ), f"--target-dir must take an existing directory: {args.target_dir}." + assert os.listdir( + args.target_dir + ), f"--target-dir must take a non-empty directory: {args.target_dir}." + + os.makedirs(args.intermediate_output_dir, exist_ok=True) + + return args def get_target_files(target_dir: str) -> list[str]: - """Returns the fuzz target files in the raw target directory.""" - return [ - os.path.join(target_dir, f) - for f in os.listdir(target_dir) - if benchmarklib.is_c_file(f) or benchmarklib.is_cpp_file(f) - ] - - -def collect_specific_fixes(project: str, - file_name: str) -> list[Callable[[str], str]]: - """Returns a list code fix functions given the language and |project|.""" - required_fixes = set() - if benchmarklib.is_cpp_file(file_name): - required_fixes = required_fixes.union([ - append_extern_c, - insert_cstdint, - insert_cstdlib, - ]) - - # TODO(Dongge): Remove this. - if benchmarklib.is_c_file(file_name): - required_fixes = required_fixes.union([ - insert_stdint, - include_builtin_library, - ]) - - # TODO(Dongge): Remove this. - if project == 'libpng-proto': - required_fixes = required_fixes.union([ - remove_nonexist_png_functions, - include_pngrio, - remove_const_from_png_symbols, - ]) - - return list(required_fixes) - - -def apply_specific_fixes(content: str, - required_fixes: list[Callable[[str], str]]) -> str: - """Fixes frequent errors in |raw_content| and returns fixed content.""" - for required_fix in required_fixes: - content = required_fix(content) - - return content + """Returns the fuzz target files in the raw target directory.""" + return [ + os.path.join(target_dir, f) + for f in os.listdir(target_dir) + if benchmarklib.is_c_file(f) or benchmarklib.is_cpp_file(f) + ] + + +def collect_specific_fixes(project: str, file_name: str) -> list[Callable[[str], str]]: + """Returns a list code fix functions given the language and |project|.""" + required_fixes = set() + if benchmarklib.is_cpp_file(file_name): + required_fixes = required_fixes.union( + [ + append_extern_c, + insert_cstdint, + insert_cstdlib, + ] + ) + + # TODO(Dongge): Remove this. + if benchmarklib.is_c_file(file_name): + required_fixes = required_fixes.union( + [ + insert_stdint, + include_builtin_library, + ] + ) + + # TODO(Dongge): Remove this. + if project == "libpng-proto": + required_fixes = required_fixes.union( + [ + remove_nonexist_png_functions, + include_pngrio, + remove_const_from_png_symbols, + ] + ) + + return list(required_fixes) + + +def apply_specific_fixes( + content: str, required_fixes: list[Callable[[str], str]] +) -> str: + """Fixes frequent errors in |raw_content| and returns fixed content.""" + for required_fix in required_fixes: + content = required_fix(content) + + return content def fix_all_targets(target_dir: str, project: str): - """Reads raw content, applies fixes, and saves the fixed content.""" - for file in get_target_files(target_dir): - with open(file) as raw_file: - raw_content = raw_file.read() - specific_fixes = collect_specific_fixes(project, file) - fixed_content = apply_specific_fixes(raw_content, specific_fixes) - with open(os.path.join(target_dir, os.path.basename(file)), - 'w+') as fixed_file: - fixed_file.write(fixed_content) + """Reads raw content, applies fixes, and saves the fixed content.""" + for file in get_target_files(target_dir): + with open(file) as raw_file: + raw_content = raw_file.read() + specific_fixes = collect_specific_fixes(project, file) + fixed_content = apply_specific_fixes(raw_content, specific_fixes) + with open(os.path.join(target_dir, os.path.basename(file)), "w+") as fixed_file: + fixed_file.write(fixed_content) # ========================= Specific Fixes ========================= # def append_extern_c(raw_content: str) -> str: - """Appends `extern "C"` before fuzzer entry `LLVMFuzzerTestOneInput`.""" - pattern = r'int LLVMFuzzerTestOneInput' - replacement = f'extern "C" {pattern}' - fixed_content = re.sub(pattern, replacement, raw_content) - return fixed_content + """Appends `extern "C"` before fuzzer entry `LLVMFuzzerTestOneInput`.""" + pattern = r"int LLVMFuzzerTestOneInput" + replacement = f'extern "C" {pattern}' + fixed_content = re.sub(pattern, replacement, raw_content) + return fixed_content def insert_cstdlib(raw_content: str) -> str: - """Includes `cstdlib` library.""" - fixed_content = f'#include \n{raw_content}' - return fixed_content + """Includes `cstdlib` library.""" + fixed_content = f"#include \n{raw_content}" + return fixed_content def insert_cstdint(raw_content: str) -> str: - """Includes `cstdint` library.""" - fixed_content = f'#include \n{raw_content}' - return fixed_content + """Includes `cstdint` library.""" + fixed_content = f"#include \n{raw_content}" + return fixed_content def insert_stdint(content: str) -> str: - """Includes `stdint` library.""" - include_stdint = '#include \n' - if include_stdint not in content: - content = f'{include_stdint}{content}' - return content + """Includes `stdint` library.""" + include_stdint = "#include \n" + if include_stdint not in content: + content = f"{include_stdint}{content}" + return content def remove_nonexist_png_functions(content: str) -> str: - """Removes non-exist functions in libpng-proto.""" - non_exist_functions = [ - r'.*png_init_io.*', - r'.*png_set_write_fn.*', - r'.*png_set_compression_level.*', - r'.*.png_write_.*', - ] - for pattern in non_exist_functions: - content = re.sub(pattern, '', content) - return content + """Removes non-exist functions in libpng-proto.""" + non_exist_functions = [ + r".*png_init_io.*", + r".*png_set_write_fn.*", + r".*png_set_compression_level.*", + r".*.png_write_.*", + ] + for pattern in non_exist_functions: + content = re.sub(pattern, "", content) + return content def include_builtin_library(content: str) -> str: - """Includes builtin libraries when its function was invoked.""" - library_function_dict = { - '#include ': [ - 'malloc', - 'calloc', - 'free', - ], - '#include ': ['memcpy',] - } - for library, functions in library_function_dict.items(): - use_lib_functions = any(f in content for f in functions) - if use_lib_functions and not library in content: - content = f'{library}\n{content}' - return content + """Includes builtin libraries when its function was invoked.""" + library_function_dict = { + "#include ": [ + "malloc", + "calloc", + "free", + ], + "#include ": [ + "memcpy", + ], + } + for library, functions in library_function_dict.items(): + use_lib_functions = any(f in content for f in functions) + if use_lib_functions and not library in content: + content = f"{library}\n{content}" + return content def include_pngrio(content: str) -> str: - """Includes when using its functions.""" - functions = [ - 'png_read_data', - 'png_default_read_data', - ] - use_pngrio_funcitons = any(f in content for f in functions) - include_pngrio_stmt = '#include "pngrio.c"' + """Includes when using its functions.""" + functions = [ + "png_read_data", + "png_default_read_data", + ] + use_pngrio_funcitons = any(f in content for f in functions) + include_pngrio_stmt = '#include "pngrio.c"' - if use_pngrio_funcitons and not include_pngrio_stmt in content: - content = f'{include_pngrio}\n{content}' - return content + if use_pngrio_funcitons and not include_pngrio_stmt in content: + content = f"{include_pngrio}\n{content}" + return content def remove_const_from_png_symbols(content: str) -> str: - """Removes const from png types.""" - re.sub(r'png_const_', 'png_', content) - return content + """Removes const from png types.""" + re.sub(r"png_const_", "png_", content) + return content # ========================= LLM Fixes ========================= # -def extract_error_message(log_path: str, project_target_basename: str, - language: str) -> list[str]: - """Extracts error message and its context from the file in |log_path|.""" +def extract_error_message( + log_path: str, project_target_basename: str, language: str +) -> list[str]: + """Extracts error message and its context from the file in |log_path|.""" - with open(log_path) as log_file: - # A more accurate way to extract the error message. - log_lines = log_file.readlines() + with open(log_path) as log_file: + # A more accurate way to extract the error message. + log_lines = log_file.readlines() - errors = extract_error_from_lines(log_lines, project_target_basename, - language) - if not errors: - logger.warning('Failed to parse error message from %s.', log_path) - return errors + errors = extract_error_from_lines(log_lines, project_target_basename, language) + if not errors: + logger.warning("Failed to parse error message from %s.", log_path) + return errors -def extract_error_from_lines(log_lines: list[str], project_target_basename: str, - language: str) -> list[str]: - """Extracts error message and its context from the file in |log_path|.""" - # Error message extraction for Java projects - if language == 'jvm': - started = False +def extract_error_from_lines( + log_lines: list[str], project_target_basename: str, language: str +) -> list[str]: + """Extracts error message and its context from the file in |log_path|.""" + # Error message extraction for Java projects + if language == "jvm": + started = False + errors = [] + for log_line in log_lines: + if started: + errors.append(log_line) + if log_line == "ERROR:__main__:Building fuzzers failed.": + break + else: + if ": error:" in log_line: + errors.append(log_line) + started = True + + return errors + + # Error message extraction for Rust projects + if language == "rust": + started = False + errors = [] + for log_line in log_lines: + if started: + errors.append(log_line) + if log_line == "error: could not compile": + break + else: + if log_line.startswith(("error[E", "warning:")): + errors.append(log_line) + started = True + + return errors + + target_name, _ = os.path.splitext(project_target_basename) + + error_lines_range: list[Optional[int]] = [None, None] + temp_range: list[Optional[int]] = [None, None] + + error_start_pattern = r"\S*" + target_name + r"(\.\S*)?:\d+:\d+: .+: .+\n?" + error_include_pattern = ( + r"In file included from \S*" + target_name + r"(\.\S*)?:\d+:\n?" + ) + error_end_pattern = r".*\d+ errors? generated.\n?" + + error_keywords = [ + "multiple definition of", + "undefined reference to", + ] errors = [] - for log_line in log_lines: - if started: - errors.append(log_line) - if log_line == 'ERROR:__main__:Building fuzzers failed.': - break - else: - if ': error:' in log_line: - errors.append(log_line) - started = True + unique_symbol = set() + for i, line in enumerate(log_lines): + # Add GNU ld errors in interest. + found_keyword = False + for keyword in error_keywords: + if keyword not in line: + continue + found_keyword = True + symbol = line.split(keyword)[-1] + if symbol not in unique_symbol: + unique_symbol.add(symbol) + errors.append(line.rstrip()) + break + if found_keyword: + continue + + # Add clang/clang++ diagnostics. + if temp_range[0] is None and ( + re.fullmatch(error_include_pattern, line) + or re.fullmatch(error_start_pattern, line) + ): + temp_range[0] = i + if temp_range[0] is not None and re.fullmatch(error_end_pattern, line): + temp_range[1] = i - 1 # Exclude current line. + # In case the original fuzz target was written in C and building with + # clang failed, and building with clang++ also failed, we take the + # error from clang++, which comes after. + error_lines_range = temp_range + temp_range = [None, None] + + if error_lines_range[0] is not None and error_lines_range[1] is not None: + errors.extend( + line.rstrip() + for line in log_lines[error_lines_range[0] : error_lines_range[1] + 1] + ) + + return group_error_messages(errors) - return errors - # Error message extraction for Rust projects - if language == 'rust': - started = False - errors = [] - for log_line in log_lines: - if started: - errors.append(log_line) - if log_line == 'error: could not compile': - break - else: - if log_line.startswith(('error[E', 'warning:')): - errors.append(log_line) - started = True +def group_error_messages(error_lines: list[str]) -> list[str]: + """Groups multi-line error block into one string""" + state_unknown = "UNKNOWN" + state_include = "INCLUDE" + state_diag = "DIAG" + + diag_error_pattern = re.compile(r"(\S*):\d+:\d+: (.+): (.+)") + include_error_pattern = re.compile(r"In file included from (\S*):\d+:") + error_blocks = [] + curr_block = [] + src_file = "" + curr_state = state_unknown + for line in error_lines: + if not line: # Trim empty lines. + continue + + diag_match = diag_error_pattern.fullmatch(line) + include_match = include_error_pattern.fullmatch(line) + + if diag_match: + err_src = diag_match.group(1) + severity = diag_match.group(2) + + # Matched a note diag line under another diag, + # giving help info to fix the previous error. + if severity == "note": + curr_block.append(line) + continue + + # Matched a diag line but under an included file line, + # indicating the specific error in the included file, + if curr_state == state_include and err_src != src_file: + curr_block.append(line) + continue + + curr_state = state_diag + if curr_block: + error_blocks.append("\n".join(curr_block)) + curr_block = [] + + if include_match: + src_file = include_match.group(1) + curr_state = state_include + if curr_block: + error_blocks.append("\n".join(curr_block)) + curr_block = [] + + # Keep unknown error lines separated. + if curr_state == state_unknown and curr_block: + error_blocks.append("\n".join(curr_block)) + curr_block = [] - return errors + curr_block.append(line) - target_name, _ = os.path.splitext(project_target_basename) - - error_lines_range: list[Optional[int]] = [None, None] - temp_range: list[Optional[int]] = [None, None] - - error_start_pattern = r'\S*' + target_name + r'(\.\S*)?:\d+:\d+: .+: .+\n?' - error_include_pattern = (r'In file included from \S*' + target_name + - r'(\.\S*)?:\d+:\n?') - error_end_pattern = r'.*\d+ errors? generated.\n?' - - error_keywords = [ - 'multiple definition of', - 'undefined reference to', - ] - errors = [] - unique_symbol = set() - for i, line in enumerate(log_lines): - # Add GNU ld errors in interest. - found_keyword = False - for keyword in error_keywords: - if keyword not in line: - continue - found_keyword = True - symbol = line.split(keyword)[-1] - if symbol not in unique_symbol: - unique_symbol.add(symbol) - errors.append(line.rstrip()) - break - if found_keyword: - continue - - # Add clang/clang++ diagnostics. - if (temp_range[0] is None and (re.fullmatch(error_include_pattern, line) or - re.fullmatch(error_start_pattern, line))): - temp_range[0] = i - if temp_range[0] is not None and re.fullmatch(error_end_pattern, line): - temp_range[1] = i - 1 # Exclude current line. - # In case the original fuzz target was written in C and building with - # clang failed, and building with clang++ also failed, we take the - # error from clang++, which comes after. - error_lines_range = temp_range - temp_range = [None, None] - - if error_lines_range[0] is not None and error_lines_range[1] is not None: - errors.extend( - line.rstrip() - for line in log_lines[error_lines_range[0]:error_lines_range[1] + 1]) - - return group_error_messages(errors) + if curr_block: + error_blocks.append("\n".join(curr_block)) + return error_blocks + + +def llm_fix( + ai_binary: str, + target_path: str, + benchmark: benchmarklib.Benchmark, + llm_fix_id: int, + error_desc: Optional[str], + errors: list[str], + fixer_model_name: str, + language: str, +) -> None: + """Reads and fixes |target_path| in place with LLM based on |error_log|.""" + fuzz_target_source_code = parser.parse_code(target_path) + + _, target_ext = os.path.splitext(os.path.basename(target_path)) + response_dir = f"{os.path.splitext(target_path)[0]}-F{llm_fix_id}" + os.makedirs(response_dir, exist_ok=True) + prompt_path = os.path.join(response_dir, "prompt.txt") + + apply_llm_fix( + ai_binary, + benchmark, + fuzz_target_source_code, + error_desc, + errors, + prompt_path, + response_dir, + language, + fixer_model_name, + temperature=0.5 - llm_fix_id * 0.04, + ) + + fixed_code_candidates = [] + for file in os.listdir(response_dir): + if not parser.is_raw_output(file): + continue + fixed_code_path = os.path.join(response_dir, file) + fixed_code = parser.parse_code(fixed_code_path) + fixed_code_candidates.append([fixed_code_path, fixed_code]) + + if not fixed_code_candidates: + logger.info("LLM did not generate rawoutput for %s", prompt_path) + return + + # TODO(Dongge): Use the common vote: + # LLM gives multiple responses to one query. In many experiments, I + # found the code compartment of some of the responses are exactly the same. In + # these cases, we can use the most common code of all responses as it could be + # a safer choice. Currently, we prefer the longest code to encourage code + # complexity. + # TODO(Dongge): Exclude the candidate if it is identical to the original + # code. + preferred_fix_path, preferred_fix_code = max( + fixed_code_candidates, key=lambda x: len(x[1]) + ) + logger.info("Will use the longest fix: %s", os.path.relpath(preferred_fix_path)) + preferred_fix_name, _ = os.path.splitext(preferred_fix_path) + fixed_target_path = os.path.join(response_dir, f"{preferred_fix_name}{target_ext}") + parser.save_output(preferred_fix_code, fixed_target_path) + parser.save_output(preferred_fix_code, target_path) + + +def apply_llm_fix( + ai_binary: str, + benchmark: benchmarklib.Benchmark, + fuzz_target_source_code: str, + error_desc: Optional[str], + errors: list[str], + prompt_path: str, + response_dir: str, + language: str, + fixer_model_name: str = models.DefaultModel.name, + temperature: float = 0.4, +): + """Queries LLM to fix the code.""" + fixer_model = models.LLM.setup( + ai_binary=ai_binary, + name=fixer_model_name, + num_samples=1, + temperature=temperature, + ) + + if language == "jvm": + builder = prompt_builder.JvmFixingBuilder( + fixer_model, benchmark, fuzz_target_source_code, errors + ) + prompt = builder.build([], None, None) + prompt.save(prompt_path) + else: + builder = prompt_builder.DefaultTemplateBuilder(fixer_model) + context = collect_context(benchmark, errors) + instruction = collect_instructions(benchmark, errors, fuzz_target_source_code) + prompt = builder.build_fixer_prompt( + benchmark, + fuzz_target_source_code, + error_desc, + errors, + context=context, + instruction=instruction, + ) + prompt.save(prompt_path) -def group_error_messages(error_lines: list[str]) -> list[str]: - """Groups multi-line error block into one string""" - state_unknown = 'UNKNOWN' - state_include = 'INCLUDE' - state_diag = 'DIAG' - - diag_error_pattern = re.compile(r'(\S*):\d+:\d+: (.+): (.+)') - include_error_pattern = re.compile(r'In file included from (\S*):\d+:') - error_blocks = [] - curr_block = [] - src_file = '' - curr_state = state_unknown - for line in error_lines: - if not line: # Trim empty lines. - continue - - diag_match = diag_error_pattern.fullmatch(line) - include_match = include_error_pattern.fullmatch(line) - - if diag_match: - err_src = diag_match.group(1) - severity = diag_match.group(2) - - # Matched a note diag line under another diag, - # giving help info to fix the previous error. - if severity == 'note': - curr_block.append(line) - continue + fixer_model.query_llm(prompt, response_dir) - # Matched a diag line but under an included file line, - # indicating the specific error in the included file, - if curr_state == state_include and err_src != src_file: - curr_block.append(line) - continue - - curr_state = state_diag - if curr_block: - error_blocks.append('\n'.join(curr_block)) - curr_block = [] - - if include_match: - src_file = include_match.group(1) - curr_state = state_include - if curr_block: - error_blocks.append('\n'.join(curr_block)) - curr_block = [] - - # Keep unknown error lines separated. - if curr_state == state_unknown and curr_block: - error_blocks.append('\n'.join(curr_block)) - curr_block = [] - - curr_block.append(line) - - if curr_block: - error_blocks.append('\n'.join(curr_block)) - return error_blocks - - -def llm_fix(ai_binary: str, target_path: str, benchmark: benchmarklib.Benchmark, - llm_fix_id: int, error_desc: Optional[str], errors: list[str], - fixer_model_name: str, language: str) -> None: - """Reads and fixes |target_path| in place with LLM based on |error_log|.""" - fuzz_target_source_code = parser.parse_code(target_path) - - _, target_ext = os.path.splitext(os.path.basename(target_path)) - response_dir = f'{os.path.splitext(target_path)[0]}-F{llm_fix_id}' - os.makedirs(response_dir, exist_ok=True) - prompt_path = os.path.join(response_dir, 'prompt.txt') - - apply_llm_fix(ai_binary, - benchmark, - fuzz_target_source_code, - error_desc, - errors, - prompt_path, - response_dir, - language, - fixer_model_name, - temperature=0.5 - llm_fix_id * 0.04) - - fixed_code_candidates = [] - for file in os.listdir(response_dir): - if not parser.is_raw_output(file): - continue - fixed_code_path = os.path.join(response_dir, file) - fixed_code = parser.parse_code(fixed_code_path) - fixed_code_candidates.append([fixed_code_path, fixed_code]) - - if not fixed_code_candidates: - logger.info('LLM did not generate rawoutput for %s', prompt_path) - return - - # TODO(Dongge): Use the common vote: - # LLM gives multiple responses to one query. In many experiments, I - # found the code compartment of some of the responses are exactly the same. In - # these cases, we can use the most common code of all responses as it could be - # a safer choice. Currently, we prefer the longest code to encourage code - # complexity. - # TODO(Dongge): Exclude the candidate if it is identical to the original - # code. - preferred_fix_path, preferred_fix_code = max(fixed_code_candidates, - key=lambda x: len(x[1])) - logger.info('Will use the longest fix: %s', - os.path.relpath(preferred_fix_path)) - preferred_fix_name, _ = os.path.splitext(preferred_fix_path) - fixed_target_path = os.path.join(response_dir, - f'{preferred_fix_name}{target_ext}') - parser.save_output(preferred_fix_code, fixed_target_path) - parser.save_output(preferred_fix_code, target_path) - - -def apply_llm_fix(ai_binary: str, - benchmark: benchmarklib.Benchmark, - fuzz_target_source_code: str, - error_desc: Optional[str], - errors: list[str], - prompt_path: str, - response_dir: str, - language: str, - fixer_model_name: str = models.DefaultModel.name, - temperature: float = 0.4): - """Queries LLM to fix the code.""" - fixer_model = models.LLM.setup( - ai_binary=ai_binary, - name=fixer_model_name, - num_samples=1, - temperature=temperature, - ) - - if language == 'jvm': - builder = prompt_builder.JvmFixingBuilder(fixer_model, benchmark, - fuzz_target_source_code, errors) - prompt = builder.build([], None, None) - prompt.save(prompt_path) - else: - builder = prompt_builder.DefaultTemplateBuilder(fixer_model) - - context = collect_context(benchmark, errors) - instruction = collect_instructions(benchmark, errors, - fuzz_target_source_code) - prompt = builder.build_fixer_prompt(benchmark, - fuzz_target_source_code, - error_desc, - errors, - context=context, - instruction=instruction) - prompt.save(prompt_path) - - fixer_model.query_llm(prompt, response_dir) - - -def collect_context(benchmark: benchmarklib.Benchmark, - errors: list[str]) -> str: - """Collects the useful context to fix the errors.""" - if not errors: - return '' - - context = '' - for error in errors: - context += _collect_context_no_member(benchmark, error) - - return context - - -def _collect_context_no_member(benchmark: benchmarklib.Benchmark, - error: str) -> str: - """Collects the useful context to fix 'no member in' errors.""" - matched = re.search(NO_MEMBER_ERROR_REGEX, error) - if not matched: - return '' - target_type = matched.group(1) - ci = context_introspector.ContextRetriever(benchmark) - return ci.get_type_def(target_type) - - -def collect_instructions(benchmark: benchmarklib.Benchmark, errors: list[str], - fuzz_target_source_code: str) -> str: - """Collects the useful instructions to fix the errors.""" - if not errors: - return '' - - instruction = '' - for error in errors: - instruction += _collect_instruction_file_not_found(benchmark, error, - fuzz_target_source_code) - instruction += _collect_instruction_undefined_reference( - benchmark, error, fuzz_target_source_code) - instruction += _collect_instruction_fdp_in_c_target(benchmark, errors, - fuzz_target_source_code) - instruction += _collect_instruction_no_goto(fuzz_target_source_code) - instruction += _collect_instruction_builtin_libs_first(benchmark, errors) - instruction += _collect_instruction_extern(benchmark) - instruction += _collect_consume_buffers(fuzz_target_source_code) - - return instruction + +def collect_context(benchmark: benchmarklib.Benchmark, errors: list[str]) -> str: + """Collects the useful context to fix the errors.""" + if not errors: + return "" + + context = "" + for error in errors: + context += _collect_context_no_member(benchmark, error) + + return context + + +def _collect_context_no_member(benchmark: benchmarklib.Benchmark, error: str) -> str: + """Collects the useful context to fix 'no member in' errors.""" + matched = re.search(NO_MEMBER_ERROR_REGEX, error) + if not matched: + return "" + target_type = matched.group(1) + ci = context_introspector.ContextRetriever(benchmark) + return ci.get_type_def(target_type) + + +def collect_instructions( + benchmark: benchmarklib.Benchmark, errors: list[str], fuzz_target_source_code: str +) -> str: + """Collects the useful instructions to fix the errors.""" + if not errors: + return "" + + instruction = "" + for error in errors: + instruction += _collect_instruction_file_not_found( + benchmark, error, fuzz_target_source_code + ) + instruction += _collect_instruction_undefined_reference( + benchmark, error, fuzz_target_source_code + ) + instruction += _collect_instruction_fdp_in_c_target( + benchmark, errors, fuzz_target_source_code + ) + instruction += _collect_instruction_no_goto(fuzz_target_source_code) + instruction += _collect_instruction_builtin_libs_first(benchmark, errors) + instruction += _collect_instruction_extern(benchmark) + instruction += _collect_consume_buffers(fuzz_target_source_code) + + return instruction def _collect_instruction_undefined_reference( - benchmark: benchmarklib.Benchmark, error: str, - fuzz_target_source_code: str) -> str: - """Collects the instructions to fix the 'undefined reference' errors.""" - matched_funcs = re.findall(UNDEFINED_REF_ERROR_REGEX, error) - if not matched_funcs: - return '' - instruction = '' - for undefined_func in matched_funcs: - if undefined_func == 'LLVMFuzzerTestOneInput': - continue + benchmark: benchmarklib.Benchmark, error: str, fuzz_target_source_code: str +) -> str: + """Collects the instructions to fix the 'undefined reference' errors.""" + matched_funcs = re.findall(UNDEFINED_REF_ERROR_REGEX, error) + if not matched_funcs: + return "" + instruction = "" + for undefined_func in matched_funcs: + if undefined_func == "LLVMFuzzerTestOneInput": + continue + ci = context_introspector.ContextRetriever(benchmark) + header_file = ci.get_prefixed_header_file_by_name(undefined_func) + if header_file and header_file not in fuzz_target_source_code: + instruction += ( + "You must add the following #include statement to fix the error of " + f"undefined reference to {undefined_func}:\n\n" + f"{header_file}\n.\n" + ) + elif not header_file and benchmark.is_c_project: + instruction += ( + f"You must remove the function {undefined_func} from the" + " generated fuzz target, because the function does not exist.\n" + ) + elif not header_file or header_file in fuzz_target_source_code: + # C project: NO header file found, or + # C++: Cannot map demangled C++ function name to signature + source_file = ci.get_prefixed_source_file(undefined_func) + if not source_file and benchmark.function_name in undefined_func: + source_file = ci.get_prefixed_source_file() + if source_file: + if header_file: + # To avoid redefinition. + instruction += ( + "You must remove the following statement\n\n" + f"{header_file}\n" + ) + instruction += ( + "You must add the following #include statement to fix the error of " + f"undefined reference to `{undefined_func}':\n" + f"\n{source_file}\n.\n" + ) + else: + instruction += ( + f"To fix undefined reference to `{undefined_func}'," + "check the library documentation (e.g. README.md, comments) for " + "special instructions, such as required macros or specific inclusion " + "methods. Ensure any necessary definitions or inclusions are " + "correctly implemented in your generated fuzz target, following the " + "library's guidance." + ) + if not instruction: + instruction += ( + f"To fix undefined reference to `{undefined_func}'," + "check the library documentation (e.g. README.md, comments) for " + "special instructions, such as required macros or specific inclusion " + "methods. Ensure any necessary definitions or inclusions are " + "correctly implemented in your generated fuzz target, following the " + "library's guidance." + ) + return instruction + + +def _collect_instruction_file_not_found( + benchmark: benchmarklib.Benchmark, error: str, fuzz_target_source_code: str +) -> str: + """Collects the useful instruction to fix 'file not found' errors.""" + matched = re.search(FILE_NOT_FOUND_ERROR_REGEX, error) + if not matched: + return "" + + # Step 1: Say the file does not exist, do not include it. + wrong_file = matched.group(1) + instruction = ( + f"IMPORTANT: DO NOT include the header file {wrong_file} in the generated" + " fuzz target again, the file does not exist in the project-under-test.\n" + ) + # Step 2: Suggest the header file of the same name as the wrong one. ci = context_introspector.ContextRetriever(benchmark) - header_file = ci.get_prefixed_header_file_by_name(undefined_func) - if header_file and header_file not in fuzz_target_source_code: - instruction += ( - 'You must add the following #include statement to fix the error of ' - f'undefined reference to {undefined_func}:\n\n' - f'{header_file}\n.\n') - elif not header_file and benchmark.is_c_project: - instruction += ( - f'You must remove the function {undefined_func} from the' - ' generated fuzz target, because the function does not exist.\n') - elif not header_file or header_file in fuzz_target_source_code: - # C project: NO header file found, or - # C++: Cannot map demangled C++ function name to signature - source_file = ci.get_prefixed_source_file(undefined_func) - if not source_file and benchmark.function_name in undefined_func: - source_file = ci.get_prefixed_source_file() - if source_file: - if header_file: - # To avoid redefinition. - instruction += ('You must remove the following statement\n\n' - f'{header_file}\n') + same_name_headers = ci.get_same_header_file_paths(wrong_file) + if same_name_headers: + statements = "\n".join([f'#include "{header}"' for header in same_name_headers]) instruction += ( - 'You must add the following #include statement to fix the error of ' - f"undefined reference to `{undefined_func}':\n" - f'\n{source_file}\n.\n') - else: - instruction += ( - f"To fix undefined reference to `{undefined_func}'," - 'check the library documentation (e.g. README.md, comments) for ' - 'special instructions, such as required macros or specific inclusion ' - 'methods. Ensure any necessary definitions or inclusions are ' - 'correctly implemented in your generated fuzz target, following the ' - "library's guidance.") - if not instruction: - instruction += ( - f"To fix undefined reference to `{undefined_func}'," - 'check the library documentation (e.g. README.md, comments) for ' - 'special instructions, such as required macros or specific inclusion ' - 'methods. Ensure any necessary definitions or inclusions are ' - 'correctly implemented in your generated fuzz target, following the ' - "library's guidance.") - return instruction - - -def _collect_instruction_file_not_found(benchmark: benchmarklib.Benchmark, - error: str, - fuzz_target_source_code: str) -> str: - """Collects the useful instruction to fix 'file not found' errors.""" - matched = re.search(FILE_NOT_FOUND_ERROR_REGEX, error) - if not matched: - return '' - - # Step 1: Say the file does not exist, do not include it. - wrong_file = matched.group(1) - instruction = ( - f'IMPORTANT: DO NOT include the header file {wrong_file} in the generated' - ' fuzz target again, the file does not exist in the project-under-test.\n' - ) - # Step 2: Suggest the header file of the same name as the wrong one. - ci = context_introspector.ContextRetriever(benchmark) - same_name_headers = ci.get_same_header_file_paths(wrong_file) - if same_name_headers: - statements = '\n'.join( - [f'#include "{header}"' for header in same_name_headers]) - instruction += ( - f'Replace the non-existent {wrong_file} with the ' - 'following statement, which share the same file name but exists under ' - 'the correct path in the project-under-test:\n' - f'\n{statements}\n\n') - return instruction + f"Replace the non-existent {wrong_file} with the " + "following statement, which share the same file name but exists under " + "the correct path in the project-under-test:\n" + f"\n{statements}\n\n" + ) + return instruction + + # Step 3: Suggest the header/source file of the function under test. + function_file = ci.get_prefixed_header_file() + if function_file and f'#include "{function_file}"' in fuzz_target_source_code: + function_file_base_name = os.path.basename(function_file) - # Step 3: Suggest the header/source file of the function under test. - function_file = ci.get_prefixed_header_file() - if function_file and f'#include "{function_file}"' in fuzz_target_source_code: - function_file_base_name = os.path.basename(function_file) - - instruction += ( - 'In the generated code, ensure that the path prefix of ' - f'{function_file_base_name} is consistent with other include ' - f'statements related to the project ({benchmark.project}). For example,' - 'if another include statement is ' - f'#include <{benchmark.project}/header.h>, you must modify' - f' the path prefix in #include "{function_file}" to match ' - 'it, resulting in ' - f'#include <{benchmark.project}/{function_file_base_name}>.') + instruction += ( + "In the generated code, ensure that the path prefix of " + f"{function_file_base_name} is consistent with other include " + f"statements related to the project ({benchmark.project}). For example," + "if another include statement is " + f"#include <{benchmark.project}/header.h>, you must modify" + f' the path prefix in #include "{function_file}" to match ' + "it, resulting in " + f"#include <{benchmark.project}/{function_file_base_name}>." + ) + return instruction + + if function_file: + instruction += ( + f"If the non-existent {wrong_file} was included " + f"for the declaration of {benchmark.function_signature}, " + "you must replace it with the EXACT path of the actual file " + f"{function_file}. For example:\n" + f'\n#include "{function_file}"\n\n' + ) + + # Step 4: Suggest similar alternatives. + similar_headers = ci.get_similar_header_file_paths(wrong_file) + if similar_headers: + statements = "\n".join([f'#include "{header}"' for header in similar_headers]) + instruction += ( + "Otherwise, consider replacing it with some of the following statements" + f"that may be correct alternatives:\n\n{statements}\n\n" + ) return instruction - if function_file: - instruction += ( - f'If the non-existent {wrong_file} was included ' - f'for the declaration of {benchmark.function_signature}, ' - 'you must replace it with the EXACT path of the actual file ' - f'{function_file}. For example:\n' - f'\n#include "{function_file}"\n\n') - - # Step 4: Suggest similar alternatives. - similar_headers = ci.get_similar_header_file_paths(wrong_file) - if similar_headers: - statements = '\n'.join( - [f'#include "{header}"' for header in similar_headers]) - instruction += ( - 'Otherwise, consider replacing it with some of the following statements' - f'that may be correct alternatives:\n\n{statements}\n\n') - return instruction - - -def _collect_instruction_fdp_in_c_target(benchmark: benchmarklib.Benchmark, - errors: list[str], - fuzz_target_source_code: str) -> str: - """Collects instructions to ask LLM do not use FuzzedDataProvier in C targets - """ - has_error_from_fdp = any(FALSE_EXTERN_KEYWORD_ERROR in error or - FALSE_FUZZED_DATA_PROVIDER_ERROR in error - for error in errors) - include_fdp = FDP_INCLUDE_STATEMENT in fuzz_target_source_code - is_c = benchmark.file_type == benchmarklib.FileType.C - if (has_error_from_fdp or include_fdp) and is_c: - return ( - 'Please modify the generated C fuzz target to remove' - 'FuzzedDataProvider and replace all its functionalities ' - 'with equivalent C code, because it will cause build failure in C fuzz ' - 'targets.\nAlso, ensure the whole fuzz target must be compatible with ' - 'plain C and does not include any C++ specific code or dependencies.\n') - - return '' + +def _collect_instruction_fdp_in_c_target( + benchmark: benchmarklib.Benchmark, errors: list[str], fuzz_target_source_code: str +) -> str: + """Collects instructions to ask LLM do not use FuzzedDataProvier in C targets""" + has_error_from_fdp = any( + FALSE_EXTERN_KEYWORD_ERROR in error or FALSE_FUZZED_DATA_PROVIDER_ERROR in error + for error in errors + ) + include_fdp = FDP_INCLUDE_STATEMENT in fuzz_target_source_code + is_c = benchmark.file_type == benchmarklib.FileType.C + if (has_error_from_fdp or include_fdp) and is_c: + return ( + "Please modify the generated C fuzz target to remove" + "FuzzedDataProvider and replace all its functionalities " + "with equivalent C code, because it will cause build failure in C fuzz " + "targets.\nAlso, ensure the whole fuzz target must be compatible with " + "plain C and does not include any C++ specific code or dependencies.\n" + ) + + return "" def _collect_instruction_no_goto(fuzz_target_source_code: str) -> str: - """Collects the instruction to avoid using goto.""" - if 'goto' in fuzz_target_source_code: - return ( - 'EXTREMELY IMPORTANT: AVOID USING goto. If you have to ' - 'write code using goto, you MUST MUST also declare all ' - 'variables BEFORE the goto. Never introduce new variables ' - 'after the goto.') - return '' - - -def _collect_instruction_builtin_libs_first(benchmark: benchmarklib.Benchmark, - errors: list[str]) -> str: - """Collects the instructions to include builtin libraries first to fix - unknown type name error.""" - # Refine this, e.g., check if the symbol is builtin or from a project file. - if any(UNKNOWN_TYPE_ERROR in error for error in errors): - return ( - 'IMPORTANT: ALWAYS INCLUDE STANDARD LIBRARIES BEFORE PROJECT-SPECIFIC ' - f'({benchmark.project}) LIBRARIES. This order prevents errors like ' - '"unknown type name" for basic types. Additionally, include ' - 'project-specific libraries that contain declarations before those that' - 'use these declared symbols.') - return '' + """Collects the instruction to avoid using goto.""" + if "goto" in fuzz_target_source_code: + return ( + "EXTREMELY IMPORTANT: AVOID USING goto. If you have to " + "write code using goto, you MUST MUST also declare all " + "variables BEFORE the goto. Never introduce new variables " + "after the goto." + ) + return "" + + +def _collect_instruction_builtin_libs_first( + benchmark: benchmarklib.Benchmark, errors: list[str] +) -> str: + """Collects the instructions to include builtin libraries first to fix + unknown type name error.""" + # Refine this, e.g., check if the symbol is builtin or from a project file. + if any(UNKNOWN_TYPE_ERROR in error for error in errors): + return ( + "IMPORTANT: ALWAYS INCLUDE STANDARD LIBRARIES BEFORE PROJECT-SPECIFIC " + f"({benchmark.project}) LIBRARIES. This order prevents errors like " + '"unknown type name" for basic types. Additionally, include ' + "project-specific libraries that contain declarations before those that" + "use these declared symbols." + ) + return "" def _collect_instruction_extern(benchmark: benchmarklib.Benchmark) -> str: - """Collects the instructions to use extern "C" in C++ fuzz targets.""" - if not benchmark.needs_extern: - return '' - instruction = ( - f'IMPORTANT: The fuzz target ({benchmark.target_path}) is written in C++,' - ' whereas the project-under-test ({PROJECT_NAME}) is written in C. All ' - f'headers, functions, and code from the {benchmark.project} project must ' - 'be consistently wrapped in extern "C" to ensure error-free ' - 'compilation and linkage between C and C++:\n\nextern "C" {\n //' - 'Include necessary C headers, source files, functions, and code here.\n}' - '\n\n') - return instruction + """Collects the instructions to use extern "C" in C++ fuzz targets.""" + if not benchmark.needs_extern: + return "" + instruction = ( + f"IMPORTANT: The fuzz target ({benchmark.target_path}) is written in C++," + " whereas the project-under-test ({PROJECT_NAME}) is written in C. All " + f"headers, functions, and code from the {benchmark.project} project must " + 'be consistently wrapped in extern "C" to ensure error-free ' + 'compilation and linkage between C and C++:\n\nextern "C" {\n //' + "Include necessary C headers, source files, functions, and code here.\n}" + "\n\n" + ) + return instruction def _collect_consume_buffers(fuzz_target_source_code: str) -> str: - """Provides advice on the use of ConsumeBytes and ConsumeData""" + """Provides advice on the use of ConsumeBytes and ConsumeData""" + + instruction = "" + + for buffer_method in ["ConsumeBytes", "ConsumeData"]: + if buffer_method in fuzz_target_source_code: + instruction += ( + "IMPORTANT: the harness source code contains a call to `" + f"{buffer_method}`. Whenever this function is used, you MUST validate" + " the size of the vector returned, and make sure that the size of the" + f" vector is equal to argument given to `{buffer_method}`. If it is " + "not equal, the harness should not proceed.\n" + ) + instruction += ( + f"Furthermore, consider changing {buffer_method} to " + "`ConsumeRandomLengthString` for creating `char` buffers or strings. " + "In most cases, `ConsumeRandomLengthString` is preferred, and " + f"should be used instead of {buffer_method}\n" + ) - instruction = '' - - for buffer_method in ['ConsumeBytes', 'ConsumeData']: - if buffer_method in fuzz_target_source_code: - instruction += ( - 'IMPORTANT: the harness source code contains a call to `' - f'{buffer_method}`. Whenever this function is used, you MUST validate' - ' the size of the vector returned, and make sure that the size of the' - f' vector is equal to argument given to `{buffer_method}`. If it is ' - 'not equal, the harness should not proceed.\n') - instruction += ( - f'Furthermore, consider changing {buffer_method} to ' - '`ConsumeRandomLengthString` for creating `char` buffers or strings. ' - 'In most cases, `ConsumeRandomLengthString` is preferred, and ' - f'should be used instead of {buffer_method}\n') - - return instruction + return instruction def main(): - args = parse_args() - fix_all_targets(args.target_dir, args.project) + args = parse_args() + fix_all_targets(args.target_dir, args.project) -if __name__ == '__main__': - sys.exit(main()) +if __name__ == "__main__": + sys.exit(main()) diff --git a/llm_toolkit/corpus_generator.py b/llm_toolkit/corpus_generator.py index ebf52f24d4..dc0a5230e7 100644 --- a/llm_toolkit/corpus_generator.py +++ b/llm_toolkit/corpus_generator.py @@ -30,61 +30,68 @@ def get_script( target_harness_path: str, benchmark: Benchmark, ) -> str: - """Uses LLMs to generate a python script that will create a seed corpus for a - harness. + """Uses LLMs to generate a python script that will create a seed corpus for a + harness. - The script generated is purely generated and should be considered untrusted - in the general sense. OSS-Fuzz-gen already executes arbitrary code since - OSS-Fuzz-gen executes arbitrary open source projects with no checking on - what code is committed to the given projects.""" - corpus_model = models.LLM.setup( - ai_binary=ai_binary, - name=fixer_model_name, - ) + The script generated is purely generated and should be considered untrusted + in the general sense. OSS-Fuzz-gen already executes arbitrary code since + OSS-Fuzz-gen executes arbitrary open source projects with no checking on + what code is committed to the given projects.""" + corpus_model = models.LLM.setup( + ai_binary=ai_binary, + name=fixer_model_name, + ) - # Get the corpus generation template - with open( - os.path.join(prompt_builder.DEFAULT_TEMPLATE_DIR, - 'corpus_generation_via_python_script.txt'), 'r') as f: - prompt_to_query = f.read() - with open(target_harness_path) as target_harness_file: - target_harness_code = target_harness_file.read() + # Get the corpus generation template + with open( + os.path.join( + prompt_builder.DEFAULT_TEMPLATE_DIR, + "corpus_generation_via_python_script.txt", + ), + "r", + ) as f: + prompt_to_query = f.read() + with open(target_harness_path) as target_harness_file: + target_harness_code = target_harness_file.read() - prompt_to_query = prompt_to_query.replace('{HARNESS_SOURCE_CODE}', - target_harness_code) + prompt_to_query = prompt_to_query.replace( + "{HARNESS_SOURCE_CODE}", target_harness_code + ) - project_repository = oss_fuzz_checkout.get_project_repository( - benchmark.project) - target_source_code = introspector.query_introspector_function_source( - benchmark.project, benchmark.function_signature) + project_repository = oss_fuzz_checkout.get_project_repository(benchmark.project) + target_source_code = introspector.query_introspector_function_source( + benchmark.project, benchmark.function_signature + ) - prompt_to_query = prompt_to_query.replace('{PROJECT_NAME}', benchmark.project) - prompt_to_query = prompt_to_query.replace('{PROJECT_REPOSITORY}', - project_repository) - prompt_to_query = prompt_to_query.replace('{TARGET_FUNCTION_SOURCE}', - target_source_code) + prompt_to_query = prompt_to_query.replace("{PROJECT_NAME}", benchmark.project) + prompt_to_query = prompt_to_query.replace( + "{PROJECT_REPOSITORY}", project_repository + ) + prompt_to_query = prompt_to_query.replace( + "{TARGET_FUNCTION_SOURCE}", target_source_code + ) - prompt = corpus_model.prompt_type()() - prompt.add_priming(prompt_to_query) + prompt = corpus_model.prompt_type()() + prompt.add_priming(prompt_to_query) - response_dir = f'{os.path.splitext(target_harness_path)[0]}-corpus' - os.makedirs(response_dir, exist_ok=True) - prompt_path = os.path.join(response_dir, 'prompt.txt') - prompt.save(prompt_path) + response_dir = f"{os.path.splitext(target_harness_path)[0]}-corpus" + os.makedirs(response_dir, exist_ok=True) + prompt_path = os.path.join(response_dir, "prompt.txt") + prompt.save(prompt_path) - corpus_model.query_llm(prompt, response_dir) - for file in os.listdir(response_dir): - if not parser.is_raw_output(file): - continue - corpus_generator_path = os.path.join(response_dir, file) - with open(corpus_generator_path, 'r') as f: - corpus_generator_source = f.read() + corpus_model.query_llm(prompt, response_dir) + for file in os.listdir(response_dir): + if not parser.is_raw_output(file): + continue + corpus_generator_path = os.path.join(response_dir, file) + with open(corpus_generator_path, "r") as f: + corpus_generator_source = f.read() - corpus_generator_source = corpus_generator_source.replace('', '') - corpus_generator_source = corpus_generator_source.replace('', '') - corpus_generator_source = corpus_generator_source.replace('```python', '') - corpus_generator_source = corpus_generator_source.replace('```', '') - return corpus_generator_source + corpus_generator_source = corpus_generator_source.replace("", "") + corpus_generator_source = corpus_generator_source.replace("", "") + corpus_generator_source = corpus_generator_source.replace("```python", "") + corpus_generator_source = corpus_generator_source.replace("```", "") + return corpus_generator_source - # Return an empty Python program if generation failed. - return 'import os' + # Return an empty Python program if generation failed. + return "import os" diff --git a/llm_toolkit/crash_triager.py b/llm_toolkit/crash_triager.py index 6b4c9511ef..6c427559e5 100644 --- a/llm_toolkit/crash_triager.py +++ b/llm_toolkit/crash_triager.py @@ -24,10 +24,11 @@ class TriageResult: - """Crash triage results.""" - NOT_APPLICABLE = '-' - DRIVER = 'DRIVER' - PROJECT = 'PROJECT' + """Crash triage results.""" + + NOT_APPLICABLE = "-" + DRIVER = "DRIVER" + PROJECT = "PROJECT" # ========================= LLM Triage ========================= # @@ -39,48 +40,51 @@ def llm_triage( crash_func: dict, triage_model_name: str, ) -> str: - """Triages crash with LLM based on crash information and relevant code.""" - with open(driver_path) as target_file: - driver_code = target_file.read() - - response_dir = f'{os.path.splitext(driver_path)[0]}-triage' - os.makedirs(response_dir, exist_ok=True) - prompt_path = os.path.join(response_dir, 'prompt.txt') - - apply_llm_triage(ai_binary, - benchmark, - driver_code, - crash_info, - crash_func, - prompt_path, - response_dir, - triage_model_name, - temperature=0.5) - - triage_candidates = [] - triage_result = TriageResult.NOT_APPLICABLE - for file in os.listdir(response_dir): - if not parser.is_raw_output(file): - continue - triage_path = os.path.join(response_dir, file) - triage_result, triage = parser.parse_triage(triage_path) - triage_candidates.append([triage_path, triage]) - - if not triage_candidates: - logging.warning('LLM did not generate rawoutput for %s', prompt_path) - return TriageResult.NOT_APPLICABLE - - # TODO(maoyixie): Use the common vote - # Currently, we prefer the longest triage. - preferred_triage_path, preferred_triage = max(triage_candidates, - key=lambda x: len(x[1])) - logging.info('Will use the longest triage: %s', - os.path.relpath(preferred_triage_path)) - preferred_triage_name, _ = os.path.splitext(preferred_triage_path) - triage_report_path = os.path.join(response_dir, - f'{preferred_triage_name}.txt') - parser.save_output(preferred_triage, triage_report_path) - return triage_result + """Triages crash with LLM based on crash information and relevant code.""" + with open(driver_path) as target_file: + driver_code = target_file.read() + + response_dir = f"{os.path.splitext(driver_path)[0]}-triage" + os.makedirs(response_dir, exist_ok=True) + prompt_path = os.path.join(response_dir, "prompt.txt") + + apply_llm_triage( + ai_binary, + benchmark, + driver_code, + crash_info, + crash_func, + prompt_path, + response_dir, + triage_model_name, + temperature=0.5, + ) + + triage_candidates = [] + triage_result = TriageResult.NOT_APPLICABLE + for file in os.listdir(response_dir): + if not parser.is_raw_output(file): + continue + triage_path = os.path.join(response_dir, file) + triage_result, triage = parser.parse_triage(triage_path) + triage_candidates.append([triage_path, triage]) + + if not triage_candidates: + logging.warning("LLM did not generate rawoutput for %s", prompt_path) + return TriageResult.NOT_APPLICABLE + + # TODO(maoyixie): Use the common vote + # Currently, we prefer the longest triage. + preferred_triage_path, preferred_triage = max( + triage_candidates, key=lambda x: len(x[1]) + ) + logging.info( + "Will use the longest triage: %s", os.path.relpath(preferred_triage_path) + ) + preferred_triage_name, _ = os.path.splitext(preferred_triage_path) + triage_report_path = os.path.join(response_dir, f"{preferred_triage_name}.txt") + parser.save_output(preferred_triage, triage_report_path) + return triage_result def apply_llm_triage( @@ -94,17 +98,18 @@ def apply_llm_triage( triage_model_name: str = models.DefaultModel.name, temperature: float = 0.4, ): - """Queries LLM to triage the crash.""" - triage_model = models.LLM.setup( - ai_binary=ai_binary, - name=triage_model_name, - num_samples=1, - temperature=temperature, - ) - - builder = prompt_builder.DefaultTemplateBuilder(triage_model) - prompt = builder.build_triager_prompt(benchmark, driver_code, crash_info, - crash_func) - prompt.save(prompt_path) - - triage_model.query_llm(prompt, response_dir) + """Queries LLM to triage the crash.""" + triage_model = models.LLM.setup( + ai_binary=ai_binary, + name=triage_model_name, + num_samples=1, + temperature=temperature, + ) + + builder = prompt_builder.DefaultTemplateBuilder(triage_model) + prompt = builder.build_triager_prompt( + benchmark, driver_code, crash_info, crash_func + ) + prompt.save(prompt_path) + + triage_model.query_llm(prompt, response_dir) diff --git a/llm_toolkit/models.py b/llm_toolkit/models.py index f272b3aef0..0cf6a545f0 100644 --- a/llm_toolkit/models.py +++ b/llm_toolkit/models.py @@ -32,9 +32,14 @@ import openai import tiktoken import vertexai -from google.api_core.exceptions import (GoogleAPICallError, InternalServerError, - InvalidArgument, ResourceExhausted, - ServiceUnavailable, TooManyRequests) +from google.api_core.exceptions import ( + GoogleAPICallError, + InternalServerError, + InvalidArgument, + ResourceExhausted, + ServiceUnavailable, + TooManyRequests, +) from vertexai import generative_models from vertexai.preview.generative_models import ChatSession, GenerativeModel from vertexai.preview.language_models import CodeGenerationModel @@ -51,1021 +56,1107 @@ class LLM: - """Base LLM.""" - - # Should be set by the subclass. - name: str - # TODO(mihaimaruseac): Should this be MAX_TOKENS or a different global? - context_window: int = 2000 # Default token size. - - MAX_INPUT_TOKEN: int = sys.maxsize - - _max_attempts = 5 # Maximum number of attempts to get prediction response - - def __init__( - self, - ai_binary: str, - max_tokens: int = MAX_TOKENS, - num_samples: int = NUM_SAMPLES, - temperature: float = TEMPERATURE, - temperature_list: Optional[list[float]] = None, - ): - self.ai_binary = ai_binary - - # Model parameters. - self.max_tokens = max_tokens - self.num_samples = num_samples - self.temperature = temperature - self.temperature_list = temperature_list - - # Preserve chat history for OpenAI - self.messages = [] - - def cloud_setup(self): - """Runs Cloud specific-setup.""" - # Only a subset of models need a cloud specific set up, so - # we can pass for the remainder of the models as they don't - # need to implement specific handling of this. - - @classmethod - def setup( - cls, - ai_binary: str, - name: str, - max_tokens: int = MAX_TOKENS, - num_samples: int = NUM_SAMPLES, - temperature: float = TEMPERATURE, - temperature_list: Optional[list[float]] = None, - ): - """Prepares the LLM for fuzz target generation.""" - if ai_binary: - return AIBinaryModel(name, ai_binary, max_tokens, num_samples, - temperature) - - for subcls in cls.all_llm_subclasses(): - if getattr(subcls, 'name', None) == name: - return subcls( - ai_binary, - max_tokens, - num_samples, - temperature, - temperature_list, - ) - - raise ValueError(f'Bad model type {name}') - - @classmethod - def all_llm_subclasses(cls): - """All subclasses.""" - yield cls - for subcls in cls.__subclasses__(): - yield from subcls.all_llm_subclasses() - - @classmethod - def all_llm_names(cls): - """Returns the current model name and all child model names.""" - names = [] - for subcls in cls.all_llm_subclasses(): - if hasattr(subcls, 'name') and subcls.name != AIBinaryModel.name: - names.append(subcls.name) - return names - - @abstractmethod - def estimate_token_num(self, text) -> int: - """Estimates the number of tokens in |text|.""" - - # ============================== Generation ============================== # - @abstractmethod - def query_llm(self, prompt: prompts.Prompt, response_dir: str) -> None: - """Queries the LLM and stores responses in |response_dir|.""" - - def ask_llm(self, prompt: prompts.Prompt) -> str: - """Queries LLM a single prompt and returns its response.""" - del prompt - return '' - - @abstractmethod - def chat_llm_with_tools(self, client: Any, prompt: Optional[prompts.Prompt], - tools) -> Any: - """Queries the LLM in the given chat session with tools.""" - - @abstractmethod - def chat_llm(self, client: Any, prompt: prompts.Prompt) -> str: - """Queries the LLM in the given chat session and returns the response.""" - - @abstractmethod - def get_model(self) -> Any: - """Returns the underlying model instance.""" - - @abstractmethod - def prompt_type(self) -> type[prompts.Prompt]: - """Returns the expected prompt type.""" - - def _delay_for_retry(self, attempt_count: int) -> None: - """Sleeps for a while based on the |attempt_count|.""" - # Exponentially increase from 5 to 80 seconds + some random to jitter. - delay = 5 * 2**attempt_count + random.randint(1, 5) - logging.warning('Retry in %d seconds...', delay) - time.sleep(delay) - - def _is_retryable_error(self, err: Exception, - api_errors: list[Type[Exception]], - tb: traceback.StackSummary) -> bool: - """Validates if |err| is worth retrying.""" - if any(isinstance(err, api_error) for api_error in api_errors): - return True - - # A known case from vertex package, no content due to mismatch roles. - if (isinstance(err, ValueError) and - 'Content roles do not match' in str(err) and tb[-1].filename.endswith( - 'vertexai/generative_models/_generative_models.py')): - return True - - # A known case from vertex package, content blocked by safety filters. - if (isinstance(err, ValueError) and - 'blocked by the safety filters' in str(err) and - tb[-1].filename.endswith( - 'vertexai/generative_models/_generative_models.py')): - return True - - return False - - def with_retry_on_error(self, func: Callable, - api_errs: list[Type[Exception]]) -> Any: - """ - Retry when the function returns an expected error with exponential backoff. - """ - for attempt in range(1, self._max_attempts + 1): - try: - return func() - except Exception as err: - logging.warning('LLM API Error when responding (attempt %d): %s', - attempt, err) - tb = traceback.extract_tb(err.__traceback__) - if (not self._is_retryable_error(err, api_errs, tb) or - attempt == self._max_attempts): - logging.warning( - 'LLM API cannot fix error when responding (attempt %d) %s: %s', - attempt, err, traceback.format_exc()) - raise err - self._delay_for_retry(attempt_count=attempt) - return None - - def _save_output(self, index: int, content: str, response_dir: str) -> None: - """Saves the raw |content| from the model ouput.""" - sample_id = index + 1 - raw_output_path = os.path.join(response_dir, f'{sample_id:02}.rawoutput') - with open(raw_output_path, 'w+') as output_file: - output_file.write(content) - - def truncate_prompt(self, - raw_prompt_text: Any, - extra_text: Any = None) -> Any: - """Truncates the prompt text to fit in MAX_INPUT_TOKEN.""" - del extra_text - return raw_prompt_text - - @abstractmethod - def get_chat_client(self, model: Any) -> Any: - """Returns a new chat session.""" + """Base LLM.""" + + # Should be set by the subclass. + name: str + # TODO(mihaimaruseac): Should this be MAX_TOKENS or a different global? + context_window: int = 2000 # Default token size. + + MAX_INPUT_TOKEN: int = sys.maxsize + + _max_attempts = 5 # Maximum number of attempts to get prediction response + + def __init__( + self, + ai_binary: str, + max_tokens: int = MAX_TOKENS, + num_samples: int = NUM_SAMPLES, + temperature: float = TEMPERATURE, + temperature_list: Optional[list[float]] = None, + ): + self.ai_binary = ai_binary + + # Model parameters. + self.max_tokens = max_tokens + self.num_samples = num_samples + self.temperature = temperature + self.temperature_list = temperature_list + + # Preserve chat history for OpenAI + self.messages = [] + + def cloud_setup(self): + """Runs Cloud specific-setup.""" + # Only a subset of models need a cloud specific set up, so + # we can pass for the remainder of the models as they don't + # need to implement specific handling of this. + + @classmethod + def setup( + cls, + ai_binary: str, + name: str, + max_tokens: int = MAX_TOKENS, + num_samples: int = NUM_SAMPLES, + temperature: float = TEMPERATURE, + temperature_list: Optional[list[float]] = None, + ): + """Prepares the LLM for fuzz target generation.""" + if ai_binary: + return AIBinaryModel(name, ai_binary, max_tokens, num_samples, temperature) + + for subcls in cls.all_llm_subclasses(): + if getattr(subcls, "name", None) == name: + return subcls( + ai_binary, + max_tokens, + num_samples, + temperature, + temperature_list, + ) + + raise ValueError(f"Bad model type {name}") + + @classmethod + def all_llm_subclasses(cls): + """All subclasses.""" + yield cls + for subcls in cls.__subclasses__(): + yield from subcls.all_llm_subclasses() + + @classmethod + def all_llm_names(cls): + """Returns the current model name and all child model names.""" + names = [] + for subcls in cls.all_llm_subclasses(): + if hasattr(subcls, "name") and subcls.name != AIBinaryModel.name: + names.append(subcls.name) + return names + + @abstractmethod + def estimate_token_num(self, text) -> int: + """Estimates the number of tokens in |text|.""" + + # ============================== Generation ============================== # + @abstractmethod + def query_llm(self, prompt: prompts.Prompt, response_dir: str) -> None: + """Queries the LLM and stores responses in |response_dir|.""" + + def ask_llm(self, prompt: prompts.Prompt) -> str: + """Queries LLM a single prompt and returns its response.""" + del prompt + return "" + + @abstractmethod + def chat_llm_with_tools( + self, client: Any, prompt: Optional[prompts.Prompt], tools + ) -> Any: + """Queries the LLM in the given chat session with tools.""" + + @abstractmethod + def chat_llm(self, client: Any, prompt: prompts.Prompt) -> str: + """Queries the LLM in the given chat session and returns the response.""" + + @abstractmethod + def get_model(self) -> Any: + """Returns the underlying model instance.""" + + @abstractmethod + def prompt_type(self) -> type[prompts.Prompt]: + """Returns the expected prompt type.""" + + def _delay_for_retry(self, attempt_count: int) -> None: + """Sleeps for a while based on the |attempt_count|.""" + # Exponentially increase from 5 to 80 seconds + some random to jitter. + delay = 5 * 2**attempt_count + random.randint(1, 5) + logging.warning("Retry in %d seconds...", delay) + time.sleep(delay) + + def _is_retryable_error( + self, + err: Exception, + api_errors: list[Type[Exception]], + tb: traceback.StackSummary, + ) -> bool: + """Validates if |err| is worth retrying.""" + if any(isinstance(err, api_error) for api_error in api_errors): + return True + + # A known case from vertex package, no content due to mismatch roles. + if ( + isinstance(err, ValueError) + and "Content roles do not match" in str(err) + and tb[-1].filename.endswith( + "vertexai/generative_models/_generative_models.py" + ) + ): + return True + + # A known case from vertex package, content blocked by safety filters. + if ( + isinstance(err, ValueError) + and "blocked by the safety filters" in str(err) + and tb[-1].filename.endswith( + "vertexai/generative_models/_generative_models.py" + ) + ): + return True + + return False + + def with_retry_on_error( + self, func: Callable, api_errs: list[Type[Exception]] + ) -> Any: + """ + Retry when the function returns an expected error with exponential backoff. + """ + for attempt in range(1, self._max_attempts + 1): + try: + return func() + except Exception as err: + logging.warning( + "LLM API Error when responding (attempt %d): %s", attempt, err + ) + tb = traceback.extract_tb(err.__traceback__) + if ( + not self._is_retryable_error(err, api_errs, tb) + or attempt == self._max_attempts + ): + logging.warning( + "LLM API cannot fix error when responding (attempt %d) %s: %s", + attempt, + err, + traceback.format_exc(), + ) + raise err + self._delay_for_retry(attempt_count=attempt) + return None + + def _save_output(self, index: int, content: str, response_dir: str) -> None: + """Saves the raw |content| from the model ouput.""" + sample_id = index + 1 + raw_output_path = os.path.join(response_dir, f"{sample_id:02}.rawoutput") + with open(raw_output_path, "w+") as output_file: + output_file.write(content) + + def truncate_prompt(self, raw_prompt_text: Any, extra_text: Any = None) -> Any: + """Truncates the prompt text to fit in MAX_INPUT_TOKEN.""" + del extra_text + return raw_prompt_text + + @abstractmethod + def get_chat_client(self, model: Any) -> Any: + """Returns a new chat session.""" class GPT(LLM): - """OpenAI's GPT model encapsulator.""" - - name = 'gpt-3.5-turbo' - - def get_model(self) -> Any: - """Returns the underlying model instance.""" - # Placeholder: No suitable implementation/usage yet. - - def get_chat_client(self, model: Any) -> Any: - """Returns a new chat session.""" - return self._get_client() - - def truncate_prompt(self, - raw_prompt_text: Any, - extra_text: Any = None) -> Any: - """Truncates the prompt text to fit in MAX_INPUT_TOKEN.""" - # Obtain token counts - extra_text = extra_text or '' - extra_tokens = self.estimate_token_num(extra_text) - total_tokens = self.estimate_token_num(raw_prompt_text) - - # Assume 10 round max, with 1 portion reserved for tags and roles - remaining = math.floor(self.MAX_INPUT_TOKEN / 11 - extra_tokens) - if remaining <= 0: - logger.warning('Insufficient tokens to add any text: %d, %d', - extra_tokens, remaining) - return '' - - # Space enough, return directly - if total_tokens <= remaining: - return raw_prompt_text - - # Truncate marker - marker = '\n..(truncated)...\n' - - # Space not enough for marker, recursively truncate again - if remaining < self.estimate_token_num(marker): - logger.warning('Insufficient tokens to add marker: %d, %d', extra_tokens, - remaining) - return self.truncate_prompt(raw_prompt_text[:remaining], extra_text) - - # Add marker and return the truncated string with another recursive check - truncated_prompt = raw_prompt_text[-remaining:] + marker - logger.info('Truncated %d tokens from %d to %d chars.', - len(raw_prompt_text) - remaining, len(raw_prompt_text), - len(truncated_prompt)) - - return self.truncate_prompt(truncated_prompt, extra_text) - - def _get_tiktoken_encoding(self, model_name: str): - """Returns the tiktoken encoding for the model.""" - try: - return tiktoken.encoding_for_model(model_name) - except KeyError: - logger.info('Could not get a tiktoken encoding for %s.', model_name) - return tiktoken.get_encoding('cl100k_base') - - def _get_client(self): - """Returns the OpenAI client.""" - return openai.OpenAI(api_key=os.getenv('OPENAI_API_KEY')) - - # ================================ Prompt ================================ # - def estimate_token_num(self, text) -> int: - """Estimates the number of tokens in |text|.""" - # https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken - - encoder = self._get_tiktoken_encoding(self.name) - - if isinstance(text, str): - return len(encoder.encode(text)) - - num_tokens = 0 - for message in text: - num_tokens += 3 - for key, value in message.items(): - num_tokens += len(encoder.encode(value)) - if key == 'name': - num_tokens += 1 - num_tokens += 3 - - return num_tokens - - def prompt_type(self) -> type[prompts.Prompt]: - """Returns the expected prompt type.""" - return prompts.OpenAIPrompt - - def chat_llm(self, client: Any, prompt: prompts.Prompt) -> str: - """Queries LLM in a chat session and returns its response.""" - if self.ai_binary: - raise ValueError(f'OpenAI does not use local AI binary: {self.ai_binary}') - if self.temperature_list: - logger.info('OpenAI does not allow temperature list: %s', - self.temperature_list) - - self.messages.extend(prompt.get()) - - completion = self.with_retry_on_error( - lambda: client.chat.completions.create(messages=self.messages, - model=self.name, - n=self.num_samples, - temperature=self.temperature), - [openai.OpenAIError]) - - llm_response = completion.choices[0].message.content - self.messages.append({'role': 'assistant', 'content': llm_response}) - - return llm_response - - def chat_llm_with_tools(self, client: Any, prompt: Optional[prompts.Prompt], - tools) -> Any: - """Queries LLM in a chat session with tools.""" - if self.ai_binary: - raise ValueError(f'OpenAI does not use local AI binary: {self.ai_binary}') - if self.temperature_list: - logger.info('OpenAI does not allow temperature list: %s', - self.temperature_list) - - if prompt: - self.messages.extend(prompt.get()) - - result = self.with_retry_on_error( - lambda: client.responses.create( - model=self.name, input=self.messages, tools=tools), - [openai.OpenAIError]) - return result - - def ask_llm(self, prompt: prompts.Prompt) -> str: - """Queries LLM a single prompt and returns its response.""" - if self.ai_binary: - raise ValueError(f'OpenAI does not use local AI binary: {self.ai_binary}') - if self.temperature_list: - logger.info('OpenAI does not allow temperature list: %s', - self.temperature_list) - - client = self._get_client() - - completion = self.with_retry_on_error( - lambda: client.chat.completions.create(messages=prompt.get(), - model=self.name, - n=self.num_samples, - temperature=self.temperature), - [openai.OpenAIError]) - return completion.choices[0].message.content - - # ============================== Generation ============================== # - def query_llm(self, prompt: prompts.Prompt, response_dir: str) -> None: - """Queries OpenAI's API and stores response in |response_dir|.""" - if self.ai_binary: - raise ValueError(f'OpenAI does not use local AI binary: {self.ai_binary}') - if self.temperature_list: - logger.info('OpenAI does not allow temperature list: %s', - self.temperature_list) - - client = self._get_client() - - completion = self.with_retry_on_error( - lambda: client.chat.completions.create(messages=prompt.get(), - model=self.name, - n=self.num_samples, - temperature=self.temperature), - [openai.OpenAIError]) - for index, choice in enumerate(completion.choices): # type: ignore - content = choice.message.content - self._save_output(index, content, response_dir) + """OpenAI's GPT model encapsulator.""" + + name = "gpt-3.5-turbo" + + def get_model(self) -> Any: + """Returns the underlying model instance.""" + # Placeholder: No suitable implementation/usage yet. + + def get_chat_client(self, model: Any) -> Any: + """Returns a new chat session.""" + return self._get_client() + + def truncate_prompt(self, raw_prompt_text: Any, extra_text: Any = None) -> Any: + """Truncates the prompt text to fit in MAX_INPUT_TOKEN.""" + # Obtain token counts + extra_text = extra_text or "" + extra_tokens = self.estimate_token_num(extra_text) + total_tokens = self.estimate_token_num(raw_prompt_text) + + # Assume 10 round max, with 1 portion reserved for tags and roles + remaining = math.floor(self.MAX_INPUT_TOKEN / 11 - extra_tokens) + if remaining <= 0: + logger.warning( + "Insufficient tokens to add any text: %d, %d", extra_tokens, remaining + ) + return "" + + # Space enough, return directly + if total_tokens <= remaining: + return raw_prompt_text + + # Truncate marker + marker = "\n..(truncated)...\n" + + # Space not enough for marker, recursively truncate again + if remaining < self.estimate_token_num(marker): + logger.warning( + "Insufficient tokens to add marker: %d, %d", extra_tokens, remaining + ) + return self.truncate_prompt(raw_prompt_text[:remaining], extra_text) + + # Add marker and return the truncated string with another recursive check + truncated_prompt = raw_prompt_text[-remaining:] + marker + logger.info( + "Truncated %d tokens from %d to %d chars.", + len(raw_prompt_text) - remaining, + len(raw_prompt_text), + len(truncated_prompt), + ) + + return self.truncate_prompt(truncated_prompt, extra_text) + + def _get_tiktoken_encoding(self, model_name: str): + """Returns the tiktoken encoding for the model.""" + try: + return tiktoken.encoding_for_model(model_name) + except KeyError: + logger.info("Could not get a tiktoken encoding for %s.", model_name) + return tiktoken.get_encoding("cl100k_base") + + def _get_client(self): + """Returns the OpenAI client.""" + return openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + + # ================================ Prompt ================================ # + def estimate_token_num(self, text) -> int: + """Estimates the number of tokens in |text|.""" + # https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken + + encoder = self._get_tiktoken_encoding(self.name) + + if isinstance(text, str): + return len(encoder.encode(text)) + + num_tokens = 0 + for message in text: + num_tokens += 3 + for key, value in message.items(): + num_tokens += len(encoder.encode(value)) + if key == "name": + num_tokens += 1 + num_tokens += 3 + + return num_tokens + + def prompt_type(self) -> type[prompts.Prompt]: + """Returns the expected prompt type.""" + return prompts.OpenAIPrompt + + def chat_llm(self, client: Any, prompt: prompts.Prompt) -> str: + """Queries LLM in a chat session and returns its response.""" + if self.ai_binary: + raise ValueError(f"OpenAI does not use local AI binary: {self.ai_binary}") + if self.temperature_list: + logger.info( + "OpenAI does not allow temperature list: %s", self.temperature_list + ) + + self.messages.extend(prompt.get()) + + completion = self.with_retry_on_error( + lambda: client.chat.completions.create( + messages=self.messages, + model=self.name, + n=self.num_samples, + temperature=self.temperature, + ), + [openai.OpenAIError], + ) + + llm_response = completion.choices[0].message.content + self.messages.append({"role": "assistant", "content": llm_response}) + + return llm_response + + def chat_llm_with_tools( + self, client: Any, prompt: Optional[prompts.Prompt], tools + ) -> Any: + """Queries LLM in a chat session with tools.""" + if self.ai_binary: + raise ValueError(f"OpenAI does not use local AI binary: {self.ai_binary}") + if self.temperature_list: + logger.info( + "OpenAI does not allow temperature list: %s", self.temperature_list + ) + + if prompt: + self.messages.extend(prompt.get()) + + result = self.with_retry_on_error( + lambda: client.responses.create( + model=self.name, input=self.messages, tools=tools + ), + [openai.OpenAIError], + ) + return result + + def ask_llm(self, prompt: prompts.Prompt) -> str: + """Queries LLM a single prompt and returns its response.""" + if self.ai_binary: + raise ValueError(f"OpenAI does not use local AI binary: {self.ai_binary}") + if self.temperature_list: + logger.info( + "OpenAI does not allow temperature list: %s", self.temperature_list + ) + + client = self._get_client() + + completion = self.with_retry_on_error( + lambda: client.chat.completions.create( + messages=prompt.get(), + model=self.name, + n=self.num_samples, + temperature=self.temperature, + ), + [openai.OpenAIError], + ) + return completion.choices[0].message.content + + # ============================== Generation ============================== # + def query_llm(self, prompt: prompts.Prompt, response_dir: str) -> None: + """Queries OpenAI's API and stores response in |response_dir|.""" + if self.ai_binary: + raise ValueError(f"OpenAI does not use local AI binary: {self.ai_binary}") + if self.temperature_list: + logger.info( + "OpenAI does not allow temperature list: %s", self.temperature_list + ) + + client = self._get_client() + + completion = self.with_retry_on_error( + lambda: client.chat.completions.create( + messages=prompt.get(), + model=self.name, + n=self.num_samples, + temperature=self.temperature, + ), + [openai.OpenAIError], + ) + for index, choice in enumerate(completion.choices): # type: ignore + content = choice.message.content + self._save_output(index, content, response_dir) class GPT4(GPT): - """OpenAI's GPT-4 model.""" + """OpenAI's GPT-4 model.""" - name = 'gpt-4' + name = "gpt-4" class GPT41(GPT): - """OpenAI's GPT-4.1 model.""" + """OpenAI's GPT-4.1 model.""" - name = 'gpt-4.1' + name = "gpt-4.1" class GPT41Mini(GPT): - """OpenAI's GPT-4.1-Mini model.""" + """OpenAI's GPT-4.1-Mini model.""" - name = 'gpt-4.1-mini' + name = "gpt-4.1-mini" class GPT4o(GPT): - """OpenAI's GPT-4o model.""" + """OpenAI's GPT-4o model.""" - name = 'gpt-4o' - MAX_INPUT_TOKEN = 128000 - _gpt_ai_model = 'gpt-4o' + name = "gpt-4o" + MAX_INPUT_TOKEN = 128000 + _gpt_ai_model = "gpt-4o" class ChatGPT4oLatest(GPT): - """OpenAI's chatgpt-4o-latest model.""" + """OpenAI's chatgpt-4o-latest model.""" - name = 'chatgpt-4o-latest' - MAX_INPUT_TOKEN = 128000 - _gpt_ai_model = 'gpt-4o' + name = "chatgpt-4o-latest" + MAX_INPUT_TOKEN = 128000 + _gpt_ai_model = "gpt-4o" class GPT4oMini(GPT): - """OpenAI's GPT-4o-mini model.""" + """OpenAI's GPT-4o-mini model.""" - name = 'gpt-4o-mini' + name = "gpt-4o-mini" class GPT4Turbo(GPT): - """OpenAI's GPT-4 Turbo model.""" + """OpenAI's GPT-4 Turbo model.""" - name = 'gpt-4-turbo' + name = "gpt-4-turbo" class ChatGPT(GPT): - """OpenAI's GPT model with chat session.""" - - name = 'chatgpt-3.5-turbo' - - def __init__( - self, - ai_binary: str, - max_tokens: int = MAX_TOKENS, - num_samples: int = NUM_SAMPLES, - temperature: float = TEMPERATURE, - temperature_list: Optional[list[float]] = None, - ): - super().__init__(ai_binary, max_tokens, num_samples, temperature, - temperature_list) - self.conversation_history = [] - - def chat_llm(self, client: Any, prompt: prompts.Prompt) -> str: - """Queries the LLM in the given chat session and returns the response.""" - if self.ai_binary: - raise ValueError(f'OpenAI does not use local AI binary: {self.ai_binary}') - if self.temperature_list: - logger.info('OpenAI does not allow temperature list: %s', - self.temperature_list) - - self.conversation_history.extend(prompt.get()) - - completion = self.with_retry_on_error( - lambda: client.chat.completions.create( - messages=self.conversation_history, - model=self.name, - n=self.num_samples, - temperature=self.temperature), [openai.OpenAIError]) - - # Choose the longest response - longest_response = max( - (choice.message.content for choice in completion.choices), key=len) - self.conversation_history.append({ - 'role': 'assistant', - 'content': longest_response - }) - - return longest_response + """OpenAI's GPT model with chat session.""" + + name = "chatgpt-3.5-turbo" + + def __init__( + self, + ai_binary: str, + max_tokens: int = MAX_TOKENS, + num_samples: int = NUM_SAMPLES, + temperature: float = TEMPERATURE, + temperature_list: Optional[list[float]] = None, + ): + super().__init__( + ai_binary, max_tokens, num_samples, temperature, temperature_list + ) + self.conversation_history = [] + + def chat_llm(self, client: Any, prompt: prompts.Prompt) -> str: + """Queries the LLM in the given chat session and returns the response.""" + if self.ai_binary: + raise ValueError(f"OpenAI does not use local AI binary: {self.ai_binary}") + if self.temperature_list: + logger.info( + "OpenAI does not allow temperature list: %s", self.temperature_list + ) + + self.conversation_history.extend(prompt.get()) + + completion = self.with_retry_on_error( + lambda: client.chat.completions.create( + messages=self.conversation_history, + model=self.name, + n=self.num_samples, + temperature=self.temperature, + ), + [openai.OpenAIError], + ) + + # Choose the longest response + longest_response = max( + (choice.message.content for choice in completion.choices), key=len + ) + self.conversation_history.append( + {"role": "assistant", "content": longest_response} + ) + + return longest_response class ChatGPT4(ChatGPT): - """OpenAI's GPT4 model with chat session.""" + """OpenAI's GPT4 model with chat session.""" - name = 'chatgpt-4' + name = "chatgpt-4" class ChatGPT4o(ChatGPT): - """OpenAI's GPT-4o model with chat session.""" + """OpenAI's GPT-4o model with chat session.""" - name = 'chatgpt-4o' + name = "chatgpt-4o" class ChatGPT4oMini(ChatGPT): - """OpenAI's GPT-4o-mini model with chat session.""" + """OpenAI's GPT-4o-mini model with chat session.""" - name = 'chatgpt-4o-mini' + name = "chatgpt-4o-mini" class ChatGPT4Turbo(ChatGPT): - """OpenAI's GPT-4 Turbo model with chat session.""" + """OpenAI's GPT-4 Turbo model with chat session.""" - name = 'chatgpt-4-turbo' + name = "chatgpt-4-turbo" class AzureGPT(GPT): - """Azure's GPT model.""" + """Azure's GPT model.""" - name = 'gpt-3.5-turbo-azure' + name = "gpt-3.5-turbo-azure" - def _get_tiktoken_encoding(self, model_name: str): - """Returns the tiktoken encoding for the model.""" - return super()._get_tiktoken_encoding(model_name.replace('-azure', '')) + def _get_tiktoken_encoding(self, model_name: str): + """Returns the tiktoken encoding for the model.""" + return super()._get_tiktoken_encoding(model_name.replace("-azure", "")) - def _get_client(self): - """Returns the Azure client.""" - return openai.AzureOpenAI(azure_endpoint=os.getenv( - "AZURE_OPENAI_ENDPOINT", "https://api.openai.com"), - api_key=os.getenv("AZURE_OPENAI_API_KEY"), - api_version=os.getenv("AZURE_OPENAI_API_VERSION", - "2024-02-01")) + def _get_client(self): + """Returns the Azure client.""" + return openai.AzureOpenAI( + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-01"), + ) class AzureGPT4(AzureGPT): - """Azure's GPTi-4 model.""" + """Azure's GPTi-4 model.""" - name = 'gpt-4-azure' + name = "gpt-4-azure" class AzureGPT4o(AzureGPT): - """Azure's GPTi-4 model.""" + """Azure's GPTi-4 model.""" - name = 'gpt-4o-azure' + name = "gpt-4o-azure" class Claude(LLM): - """Anthropic's Claude model encapsulator.""" - - _max_output_tokens = 4096 - _vertex_ai_model = '' - context_window = 200000 - - # ================================ Prompt ================================ # - def estimate_token_num(self, text) -> int: - """Estimates the number of tokens in |text|.""" - client = anthropic.Client() - return client.count_tokens(text) - - def prompt_type(self) -> type[prompts.Prompt]: - """Returns the expected prompt type.""" - return prompts.ClaudePrompt - - def get_model(self) -> str: - return self._vertex_ai_model - - # ============================== Generation ============================== # - def query_llm(self, prompt: prompts.Prompt, response_dir: str) -> None: - """Queries Claude's API and stores response in |response_dir|.""" - if self.ai_binary: - raise ValueError(f'Claude does not use local AI binary: {self.ai_binary}') - if self.temperature_list: - logger.info('Claude does not allow temperature list: %s', - self.temperature_list) - - vertex_ai_locations = os.getenv('VERTEX_AI_LOCATIONS', - 'europe-west1').split(',') - project_id = os.getenv('GOOGLE_CLOUD_PROJECT', 'oss-fuzz') - region = random.sample(vertex_ai_locations, 1)[0] - client = anthropic.AnthropicVertex(region=region, project_id=project_id) - - completion = self.with_retry_on_error( - lambda: client.messages.create(max_tokens=self._max_output_tokens, - messages=prompt.get(), - model=self.get_model(), - temperature=self.temperature), - [anthropic.AnthropicError]) - for index, choice in enumerate(completion.content): - content = choice.text - self._save_output(index, content, response_dir) - - def get_chat_client(self, model: Any) -> Any: - """Returns a new chat session.""" - del model - # Placeholder: To Be Implemented. - - def chat_llm(self, client: Any, prompt: prompts.Prompt) -> Any: - """Queries the LLM in the given chat session and returns the response.""" - del client, prompt - # Placeholder: To Be Implemented. - - def chat_llm_with_tools(self, client: Any, prompt: Optional[prompts.Prompt], - tools) -> Any: - """Queries the LLM in the given chat session with tools.""" - # Placeholder: To Be Implemented. - return + """Anthropic's Claude model encapsulator.""" + + _max_output_tokens = 4096 + _vertex_ai_model = "" + context_window = 200000 + + # ================================ Prompt ================================ # + def estimate_token_num(self, text) -> int: + """Estimates the number of tokens in |text|.""" + client = anthropic.Client() + return client.count_tokens(text) + + def prompt_type(self) -> type[prompts.Prompt]: + """Returns the expected prompt type.""" + return prompts.ClaudePrompt + + def get_model(self) -> str: + return self._vertex_ai_model + + # ============================== Generation ============================== # + def query_llm(self, prompt: prompts.Prompt, response_dir: str) -> None: + """Queries Claude's API and stores response in |response_dir|.""" + if self.ai_binary: + raise ValueError(f"Claude does not use local AI binary: {self.ai_binary}") + if self.temperature_list: + logger.info( + "Claude does not allow temperature list: %s", self.temperature_list + ) + + vertex_ai_locations = os.getenv("VERTEX_AI_LOCATIONS", "europe-west1").split( + "," + ) + project_id = os.getenv("GOOGLE_CLOUD_PROJECT", "oss-fuzz") + region = random.sample(vertex_ai_locations, 1)[0] + client = anthropic.AnthropicVertex(region=region, project_id=project_id) + + completion = self.with_retry_on_error( + lambda: client.messages.create( + max_tokens=self._max_output_tokens, + messages=prompt.get(), + model=self.get_model(), + temperature=self.temperature, + ), + [anthropic.AnthropicError], + ) + for index, choice in enumerate(completion.content): + content = choice.text + self._save_output(index, content, response_dir) + + def get_chat_client(self, model: Any) -> Any: + """Returns a new chat session.""" + del model + # Placeholder: To Be Implemented. + + def chat_llm(self, client: Any, prompt: prompts.Prompt) -> Any: + """Queries the LLM in the given chat session and returns the response.""" + del client, prompt + # Placeholder: To Be Implemented. + + def chat_llm_with_tools( + self, client: Any, prompt: Optional[prompts.Prompt], tools + ) -> Any: + """Queries the LLM in the given chat session with tools.""" + # Placeholder: To Be Implemented. + return class ClaudeHaikuV3(Claude): - """Claude Haiku 3.""" + """Claude Haiku 3.""" - name = 'vertex_ai_claude-3-haiku' - _vertex_ai_model = 'claude-3-haiku@20240307' + name = "vertex_ai_claude-3-haiku" + _vertex_ai_model = "claude-3-haiku@20240307" class ClaudeOpusV3(Claude): - """Claude Opus 3.""" + """Claude Opus 3.""" - name = 'vertex_ai_claude-3-opus' - _vertex_ai_model = 'claude-3-opus@20240229' + name = "vertex_ai_claude-3-opus" + _vertex_ai_model = "claude-3-opus@20240229" class ClaudeSonnetV3D5(Claude): - """Claude Sonnet 3.5.""" + """Claude Sonnet 3.5.""" - name = 'vertex_ai_claude-3-5-sonnet' - _vertex_ai_model = 'claude-3-5-sonnet@20240620' + name = "vertex_ai_claude-3-5-sonnet" + _vertex_ai_model = "claude-3-5-sonnet@20240620" class GoogleModel(LLM): - """Generic Google model.""" - - def prompt_type(self) -> type[prompts.Prompt]: - """Returns the expected prompt type.""" - return prompts.TextPrompt - - def estimate_token_num(self, text) -> int: - """Estimates the number of tokens in |text|.""" - # A rough estimation for very large prompt: Gemini suggest 4 char per token, - # using 3 here to be safer. - text = text or '' - if len(text) // 3 > self.MAX_INPUT_TOKEN: - return len(text) // 3 - - # Otherwise, roughly 1.5 tokens per word: - return int(len(re.split('[^a-zA-Z0-9]+', text)) * 1.5 + 0.5) - - def _estimate_char_index(self, token_target: int, text: str) -> int: - """ - Estimates a character index in `text` corresponding to approximately - `token_target` tokens. It uses the total token count for `text` and - assumes a roughly linear relation between token count and character - length. - """ - total_tokens = self.estimate_token_num(text) - if not total_tokens: - return 0 - # Proportional mapping: If text has T tokens over L characters, then - # token_target corresponds to roughly (token_target / T) * L characters. - return int(len(text) * token_target / total_tokens) - - # ============================== Generation ============================== # - def query_llm(self, prompt: prompts.Prompt, response_dir: str) -> None: - """Queries a Google LLM and stores results in |response_dir|.""" - if not self.ai_binary: - logger.info('Error: This model requires a local AI binary: %s', - self.ai_binary) - sys.exit(1) - if self.temperature_list: - logger.info('AI Binary does not implement temperature list: %s', - self.temperature_list) - - with tempfile.NamedTemporaryFile(delete=False, mode='w') as f: - f.write(prompt.get()) - prompt_path = f.name - - try: - command = [ - self.ai_binary, - f'-model={self.name}', - f'-prompt={prompt_path}', - f'-response={response_dir}', - f'-max-tokens={self.max_tokens}', - f'-expected-samples={self.num_samples}', - f'-temperature={self.temperature}', - ] - - proc = subprocess.Popen( - command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - stdin=subprocess.DEVNULL, - ) - stdout, stderr = proc.communicate() - - if proc.returncode != 0: - logger.info('Failed to generate targets with prompt %s', prompt.get()) - logger.info('stdout: %s', stdout) - logger.info('stderr: %s', stderr) - finally: - os.unlink(prompt_path) - - def get_model(self) -> Any: - """Returns the underlying model instance.""" - raise NotImplementedError - - def get_chat_client(self, model: Any) -> Any: - """Returns a new chat session.""" - del model - raise NotImplementedError - - def chat_llm(self, client: Any, prompt: prompts.Prompt) -> Any: - """Queries the LLM in the given chat session and returns the response.""" - del client, prompt - raise NotImplementedError - - def chat_llm_with_tools(self, client: Any, prompt: Optional[prompts.Prompt], - tools) -> Any: - """Queries the LLM in the given chat session with tools.""" - # Placeholder: To Be Implemented. - return + """Generic Google model.""" + + def prompt_type(self) -> type[prompts.Prompt]: + """Returns the expected prompt type.""" + return prompts.TextPrompt + + def estimate_token_num(self, text) -> int: + """Estimates the number of tokens in |text|.""" + # A rough estimation for very large prompt: Gemini suggest 4 char per token, + # using 3 here to be safer. + text = text or "" + if len(text) // 3 > self.MAX_INPUT_TOKEN: + return len(text) // 3 + + # Otherwise, roughly 1.5 tokens per word: + return int(len(re.split("[^a-zA-Z0-9]+", text)) * 1.5 + 0.5) + + def _estimate_char_index(self, token_target: int, text: str) -> int: + """ + Estimates a character index in `text` corresponding to approximately + `token_target` tokens. It uses the total token count for `text` and + assumes a roughly linear relation between token count and character + length. + """ + total_tokens = self.estimate_token_num(text) + if not total_tokens: + return 0 + # Proportional mapping: If text has T tokens over L characters, then + # token_target corresponds to roughly (token_target / T) * L characters. + return int(len(text) * token_target / total_tokens) + + # ============================== Generation ============================== # + def query_llm(self, prompt: prompts.Prompt, response_dir: str) -> None: + """Queries a Google LLM and stores results in |response_dir|.""" + if not self.ai_binary: + logger.info( + "Error: This model requires a local AI binary: %s", self.ai_binary + ) + sys.exit(1) + if self.temperature_list: + logger.info( + "AI Binary does not implement temperature list: %s", + self.temperature_list, + ) + + with tempfile.NamedTemporaryFile(delete=False, mode="w") as f: + f.write(prompt.get()) + prompt_path = f.name + + try: + command = [ + self.ai_binary, + f"-model={self.name}", + f"-prompt={prompt_path}", + f"-response={response_dir}", + f"-max-tokens={self.max_tokens}", + f"-expected-samples={self.num_samples}", + f"-temperature={self.temperature}", + ] + + proc = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.DEVNULL, + ) + stdout, stderr = proc.communicate() + + if proc.returncode != 0: + logger.info("Failed to generate targets with prompt %s", prompt.get()) + logger.info("stdout: %s", stdout) + logger.info("stderr: %s", stderr) + finally: + os.unlink(prompt_path) + + def get_model(self) -> Any: + """Returns the underlying model instance.""" + raise NotImplementedError + + def get_chat_client(self, model: Any) -> Any: + """Returns a new chat session.""" + del model + raise NotImplementedError + + def chat_llm(self, client: Any, prompt: prompts.Prompt) -> Any: + """Queries the LLM in the given chat session and returns the response.""" + del client, prompt + raise NotImplementedError + + def chat_llm_with_tools( + self, client: Any, prompt: Optional[prompts.Prompt], tools + ) -> Any: + """Queries the LLM in the given chat session with tools.""" + # Placeholder: To Be Implemented. + return class VertexAIModel(GoogleModel): - """Vertex AI model.""" - - _vertex_ai_model = '' - _max_output_tokens = 2048 - - def cloud_setup(self): - """Sets Vertex AI cloud location.""" - vertex_ai_locations = os.getenv('VERTEX_AI_LOCATIONS', - 'us-central1').split(',') - location = random.sample(vertex_ai_locations, 1)[0] - - logging.info('Using location %s for Vertex AI', location) - vertexai.init(location=location,) - - def get_model(self) -> Any: - return CodeGenerationModel.from_pretrained(self._vertex_ai_model) - - def do_generate(self, model: Any, prompt: str, config: dict[str, Any]) -> Any: - return model.predict(prefix=prompt, **config).text - - def _prepare_parameters(self) -> list[dict]: - """Prepares the parameter dictionary for LLM query.""" - return [{ - 'temperature': - self.temperature_list[index % len(self.temperature_list)] - if self.temperature_list else self.temperature, - 'max_output_tokens': - self._max_output_tokens - } for index in range(self.num_samples)] - - def query_llm(self, prompt: prompts.Prompt, response_dir: str) -> None: - if self.ai_binary: - logger.info('VertexAI does not use local AI binary: %s', self.ai_binary) - - model = self.get_model() - parameters_list = self._prepare_parameters() - - for i in range(self.num_samples): - response = self.with_retry_on_error( - lambda i=i: self.do_generate(model, prompt.get(), parameters_list[i]), - [GoogleAPICallError]) or '' - self._save_output(i, response, response_dir) - - def ask_llm(self, prompt: prompts.Prompt) -> str: - if self.ai_binary: - logger.info('VertexAI does not use local AI binary: %s', self.ai_binary) - - model = self.get_model() - # TODO: Allow each trial to customize its parameters_list. - parameter = self._prepare_parameters()[0] - response = self.with_retry_on_error( - lambda: self.do_generate(model, prompt.get(), parameter), - [GoogleAPICallError]) or '' - return response + """Vertex AI model.""" + + _vertex_ai_model = "" + _max_output_tokens = 2048 + + def cloud_setup(self): + """Sets Vertex AI cloud location.""" + vertex_ai_locations = os.getenv("VERTEX_AI_LOCATIONS", "us-central1").split(",") + location = random.sample(vertex_ai_locations, 1)[0] + + logging.info("Using location %s for Vertex AI", location) + vertexai.init( + location=location, + ) + + def get_model(self) -> Any: + return CodeGenerationModel.from_pretrained(self._vertex_ai_model) + + def do_generate(self, model: Any, prompt: str, config: dict[str, Any]) -> Any: + return model.predict(prefix=prompt, **config).text + + def _prepare_parameters(self) -> list[dict]: + """Prepares the parameter dictionary for LLM query.""" + return [ + { + "temperature": ( + self.temperature_list[index % len(self.temperature_list)] + if self.temperature_list + else self.temperature + ), + "max_output_tokens": self._max_output_tokens, + } + for index in range(self.num_samples) + ] + + def query_llm(self, prompt: prompts.Prompt, response_dir: str) -> None: + if self.ai_binary: + logger.info("VertexAI does not use local AI binary: %s", self.ai_binary) + + model = self.get_model() + parameters_list = self._prepare_parameters() + + for i in range(self.num_samples): + response = ( + self.with_retry_on_error( + lambda i=i: self.do_generate( + model, prompt.get(), parameters_list[i] + ), + [GoogleAPICallError], + ) + or "" + ) + self._save_output(i, response, response_dir) + + def ask_llm(self, prompt: prompts.Prompt) -> str: + if self.ai_binary: + logger.info("VertexAI does not use local AI binary: %s", self.ai_binary) + + model = self.get_model() + # TODO: Allow each trial to customize its parameters_list. + parameter = self._prepare_parameters()[0] + response = ( + self.with_retry_on_error( + lambda: self.do_generate(model, prompt.get(), parameter), + [GoogleAPICallError], + ) + or "" + ) + return response class GeminiModel(VertexAIModel): - """Gemini models.""" - - safety_config = [ - generative_models.SafetySetting( - category=generative_models.HarmCategory. - HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, - ), - generative_models.SafetySetting( - category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT, - threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, - ), - generative_models.SafetySetting( - category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, - ), - generative_models.SafetySetting( - category=generative_models.HarmCategory. - HARM_CATEGORY_SEXUALLY_EXPLICIT, - threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, - ), - ] - - def get_model(self) -> Any: - return GenerativeModel(self._vertex_ai_model) - - def do_generate(self, model: Any, prompt: str, config: dict[str, Any]) -> Any: - # Loosen inapplicable restrictions just in case. - logger.info('%s generating response with config: %s', self.name, config) - return model.generate_content(prompt, - generation_config=config, - safety_settings=self.safety_config).text + """Gemini models.""" + + safety_config = [ + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + ] + + def get_model(self) -> Any: + return GenerativeModel(self._vertex_ai_model) + + def do_generate(self, model: Any, prompt: str, config: dict[str, Any]) -> Any: + # Loosen inapplicable restrictions just in case. + logger.info("%s generating response with config: %s", self.name, config) + return model.generate_content( + prompt, generation_config=config, safety_settings=self.safety_config + ).text class VertexAICodeBisonModel(VertexAIModel): - """code-bison.""" + """code-bison.""" - name = 'vertex_ai_code-bison' - _vertex_ai_model = 'code-bison' + name = "vertex_ai_code-bison" + _vertex_ai_model = "code-bison" class VertexAICodeBison32KModel(VertexAIModel): - """code-bison-32k.""" + """code-bison-32k.""" - _max_output_tokens = 8192 - context_window = 32000 + _max_output_tokens = 8192 + context_window = 32000 - name = 'vertex_ai_code-bison-32k' - _vertex_ai_model = 'code-bison-32k' + name = "vertex_ai_code-bison-32k" + _vertex_ai_model = "code-bison-32k" class GeminiPro(GeminiModel): - """Gemini Pro.""" + """Gemini Pro.""" - _max_output_tokens = 8192 - context_window = 32760 + _max_output_tokens = 8192 + context_window = 32760 - name = 'vertex_ai_gemini-pro' - _vertex_ai_model = 'gemini-1.0-pro' + name = "vertex_ai_gemini-pro" + _vertex_ai_model = "gemini-1.0-pro" class GeminiUltra(GeminiModel): - """Gemini Ultra.""" + """Gemini Ultra.""" - _max_output_tokens = 2048 - context_window = 32760 # TODO(dongge): Confirm this later. + _max_output_tokens = 2048 + context_window = 32760 # TODO(dongge): Confirm this later. - name = 'vertex_ai_gemini-ultra' - _vertex_ai_model = 'gemini-ultra' + name = "vertex_ai_gemini-ultra" + _vertex_ai_model = "gemini-ultra" class GeminiExperimental(GeminiModel): - """Gemini Experimental.""" + """Gemini Experimental.""" - _max_output_tokens = 8192 - context_window = 32760 # TODO(dongge): Confirm this later. + _max_output_tokens = 8192 + context_window = 32760 # TODO(dongge): Confirm this later. - name = 'vertex_ai_gemini-experimental' - _vertex_ai_model = 'gemini-experimental' + name = "vertex_ai_gemini-experimental" + _vertex_ai_model = "gemini-experimental" class GeminiV1D5(GeminiModel): - """Gemini 1.5.""" + """Gemini 1.5.""" - _max_output_tokens = 8192 - context_window = 2000000 + _max_output_tokens = 8192 + context_window = 2000000 - name = 'vertex_ai_gemini-1-5' - _vertex_ai_model = 'gemini-1.5-pro-002' + name = "vertex_ai_gemini-1-5" + _vertex_ai_model = "gemini-1.5-pro-002" class GeminiV2Flash(GeminiV1D5): - """Gemini 2 Flash.""" - name = 'vertex_ai_gemini-2-flash' - _vertex_ai_model = 'gemini-2.0-flash-001' + """Gemini 2 Flash.""" + + name = "vertex_ai_gemini-2-flash" + _vertex_ai_model = "gemini-2.0-flash-001" class GeminiV2(GeminiV1D5): - """Gemini 2.""" - name = 'vertex_ai_gemini-2' - _vertex_ai_model = 'gemini-2.0-pro-exp-02-05' + """Gemini 2.""" + + name = "vertex_ai_gemini-2" + _vertex_ai_model = "gemini-2.0-pro-exp-02-05" class GeminiV2Think(GeminiV1D5): - """Gemini 2 thinking.""" - name = 'vertex_ai_gemini-2-think' - _vertex_ai_model = 'gemini-2.0-flash-thinking-exp-01-21' + """Gemini 2 thinking.""" + + name = "vertex_ai_gemini-2-think" + _vertex_ai_model = "gemini-2.0-flash-thinking-exp-01-21" class GeminiV2D5Flash(GeminiModel): - """Gemini 2.5 flash.""" - _max_output_tokens = 65535 - context_window = 1048576 - name = 'vertex_ai_gemini-2-5-flash' - _vertex_ai_model = 'gemini-2.5-flash-preview-04-17' + """Gemini 2.5 flash.""" + + _max_output_tokens = 65535 + context_window = 1048576 + name = "vertex_ai_gemini-2-5-flash" + _vertex_ai_model = "gemini-2.5-flash-preview-04-17" class GeminiV2D5Pro(GeminiModel): - """Gemini 2.5 pro.""" - _max_output_tokens = 65535 - context_window = 1048576 - name = 'vertex_ai_gemini-2-5-pro' - _vertex_ai_model = 'gemini-2.5-pro-preview-05-06' + """Gemini 2.5 pro.""" + + _max_output_tokens = 65535 + context_window = 1048576 + name = "vertex_ai_gemini-2-5-pro" + _vertex_ai_model = "gemini-2.5-pro-preview-05-06" class GeminiV1D5Chat(GeminiV1D5): - """Gemini 1.5 for chat session.""" - name = 'vertex_ai_gemini-1-5-chat' - _vertex_ai_model = 'gemini-1.5-pro-002' - - # Avoids sending large prompts. - MAX_INPUT_TOKEN: int = 128000 # max 2000000 - - def get_chat_client(self, model: GenerativeModel) -> Any: - return model.start_chat(response_validation=False) - - @retryable( - exceptions=[ - GoogleAPICallError, - InvalidArgument, - ValueError, # TODO(dongge): Handle RECITATION specifically. - IndexError, # A known error from vertexai. - InternalServerError, - ], - other_exceptions={ - ResourceExhausted: 100, - TooManyRequests: 100, - ServiceUnavailable: 100, - }) - def _do_generate(self, client: ChatSession, prompt: str, - config: dict[str, Any]) -> Any: - """Generates chat response.""" - logger.info('%s generating response with config: %s', self.name, config) - return client.send_message( - prompt, - stream=False, - generation_config=config, - safety_settings=self.safety_config).text # type: ignore - - def truncate_prompt(self, - raw_prompt_text: Any, - extra_text: Any = None) -> Any: - """Truncates the prompt text to fit in MAX_INPUT_TOKEN.""" - extra_text = extra_text or '' - extra_tokens = self.estimate_token_num(extra_text) - total_tokens = self.estimate_token_num(raw_prompt_text) - - # Allow buffer space for potential prompts that will be appended later. - # Allocates 1/10 of MAX_INPUT_TOKEN per prompt text block, assuming up to 10 - # blocks in the final prompt. - # TODO(dongge): Move this to prompt builder (e.g., `append()`), dynamically - # reduce each prompt text block if there is no space for raw_prompt_text. - allowed_tokens = self.MAX_INPUT_TOKEN // 10 - extra_tokens - if allowed_tokens <= 0: - logger.warning('Insufficient tokens to add any text: %d, %d', - extra_tokens, allowed_tokens) - return '' - - # raw_prompt_text already fits within the allowed #tokens, return it as is. - if total_tokens <= allowed_tokens: - return raw_prompt_text - - marker = '\n...(truncated due to exceeding input token limit)...\n' - marker_tokens = self.estimate_token_num(marker) - - # extra_tokens is too large that allowed_tokens cannot include the marker, - # return just a prefix of raw_prompt_text. - if allowed_tokens < marker_tokens: - prefix_index = self._estimate_char_index(allowed_tokens, raw_prompt_text) - logger.warning('Insufficient tokens to add marker: %d, %d', extra_tokens, - allowed_tokens) - return self.truncate_prompt(raw_prompt_text[:prefix_index], extra_text) - - # Prefix of the truncated prompt, 100 tokens by default. - prefix_tokens = min(100, allowed_tokens - marker_tokens) - prefix_index = self._estimate_char_index(prefix_tokens, raw_prompt_text) - - # Extra tokens beyond the allowed limit, with a 50-token buffer. - excess_tokens = total_tokens - allowed_tokens + 50 - # Suffix keeps the last portion after removal_tokens, that is, remove a - # block and keep the last (total_tokens - removal_tokens) tokens. - tokens_before_suffix = prefix_tokens + marker_tokens + excess_tokens - suffix_index = self._estimate_char_index(tokens_before_suffix, - raw_prompt_text) - - truncated_prompt = (raw_prompt_text[:prefix_index] + marker + - raw_prompt_text[suffix_index:]) - logger.info('Truncated %d tokens from %d to %d chars.', excess_tokens, - len(raw_prompt_text), len(truncated_prompt)) - return self.truncate_prompt(truncated_prompt, extra_text) - - def chat_llm(self, client: ChatSession, prompt: prompts.Prompt) -> str: - if self.ai_binary: - logger.info('VertexAI does not use local AI binary: %s', self.ai_binary) - - # TODO(dongge): Use different values for different trials - parameters_list = self._prepare_parameters()[0] - response = self._do_generate(client, prompt.get(), parameters_list) or '' - return response - - def chat_llm_with_tools(self, client: Any, prompt: Optional[prompts.Prompt], - tools) -> Any: - """Queries the LLM in the given chat session with tools.""" - # Placeholder: To Be Implemented. - return + """Gemini 1.5 for chat session.""" + + name = "vertex_ai_gemini-1-5-chat" + _vertex_ai_model = "gemini-1.5-pro-002" + + # Avoids sending large prompts. + MAX_INPUT_TOKEN: int = 128000 # max 2000000 + + def get_chat_client(self, model: GenerativeModel) -> Any: + return model.start_chat(response_validation=False) + + @retryable( + exceptions=[ + GoogleAPICallError, + InvalidArgument, + ValueError, # TODO(dongge): Handle RECITATION specifically. + IndexError, # A known error from vertexai. + InternalServerError, + ], + other_exceptions={ + ResourceExhausted: 100, + TooManyRequests: 100, + ServiceUnavailable: 100, + }, + ) + def _do_generate( + self, client: ChatSession, prompt: str, config: dict[str, Any] + ) -> Any: + """Generates chat response.""" + logger.info("%s generating response with config: %s", self.name, config) + return client.send_message( + prompt, + stream=False, + generation_config=config, + safety_settings=self.safety_config, + ).text # type: ignore + + def truncate_prompt(self, raw_prompt_text: Any, extra_text: Any = None) -> Any: + """Truncates the prompt text to fit in MAX_INPUT_TOKEN.""" + extra_text = extra_text or "" + extra_tokens = self.estimate_token_num(extra_text) + total_tokens = self.estimate_token_num(raw_prompt_text) + + # Allow buffer space for potential prompts that will be appended later. + # Allocates 1/10 of MAX_INPUT_TOKEN per prompt text block, assuming up to 10 + # blocks in the final prompt. + # TODO(dongge): Move this to prompt builder (e.g., `append()`), dynamically + # reduce each prompt text block if there is no space for raw_prompt_text. + allowed_tokens = self.MAX_INPUT_TOKEN // 10 - extra_tokens + if allowed_tokens <= 0: + logger.warning( + "Insufficient tokens to add any text: %d, %d", + extra_tokens, + allowed_tokens, + ) + return "" + + # raw_prompt_text already fits within the allowed #tokens, return it as is. + if total_tokens <= allowed_tokens: + return raw_prompt_text + + marker = "\n...(truncated due to exceeding input token limit)...\n" + marker_tokens = self.estimate_token_num(marker) + + # extra_tokens is too large that allowed_tokens cannot include the marker, + # return just a prefix of raw_prompt_text. + if allowed_tokens < marker_tokens: + prefix_index = self._estimate_char_index(allowed_tokens, raw_prompt_text) + logger.warning( + "Insufficient tokens to add marker: %d, %d", + extra_tokens, + allowed_tokens, + ) + return self.truncate_prompt(raw_prompt_text[:prefix_index], extra_text) + + # Prefix of the truncated prompt, 100 tokens by default. + prefix_tokens = min(100, allowed_tokens - marker_tokens) + prefix_index = self._estimate_char_index(prefix_tokens, raw_prompt_text) + + # Extra tokens beyond the allowed limit, with a 50-token buffer. + excess_tokens = total_tokens - allowed_tokens + 50 + # Suffix keeps the last portion after removal_tokens, that is, remove a + # block and keep the last (total_tokens - removal_tokens) tokens. + tokens_before_suffix = prefix_tokens + marker_tokens + excess_tokens + suffix_index = self._estimate_char_index(tokens_before_suffix, raw_prompt_text) + + truncated_prompt = ( + raw_prompt_text[:prefix_index] + marker + raw_prompt_text[suffix_index:] + ) + logger.info( + "Truncated %d tokens from %d to %d chars.", + excess_tokens, + len(raw_prompt_text), + len(truncated_prompt), + ) + return self.truncate_prompt(truncated_prompt, extra_text) + def chat_llm(self, client: ChatSession, prompt: prompts.Prompt) -> str: + if self.ai_binary: + logger.info("VertexAI does not use local AI binary: %s", self.ai_binary) -class GeminiV2FlashChat(GeminiV1D5Chat): - """Gemini 2 Flash for chat session.""" - name = 'vertex_ai_gemini-2-flash-chat' - _vertex_ai_model = 'gemini-2.0-flash-001' + # TODO(dongge): Use different values for different trials + parameters_list = self._prepare_parameters()[0] + response = self._do_generate(client, prompt.get(), parameters_list) or "" + return response + def chat_llm_with_tools( + self, client: Any, prompt: Optional[prompts.Prompt], tools + ) -> Any: + """Queries the LLM in the given chat session with tools.""" + # Placeholder: To Be Implemented. + return -class GeminiV2Chat(GeminiV1D5Chat): - """Gemini 2 for chat session.""" - name = 'vertex_ai_gemini-2-chat' - _vertex_ai_model = 'gemini-2.0-pro-exp-02-05' +class GeminiV2FlashChat(GeminiV1D5Chat): + """Gemini 2 Flash for chat session.""" -class GeminiV2ThinkChat(GeminiV1D5Chat): - """Gemini 2 for chat session.""" - name = 'vertex_ai_gemini-2-think-chat' - _vertex_ai_model = 'gemini-2.0-flash-thinking-exp-01-21' + name = "vertex_ai_gemini-2-flash-chat" + _vertex_ai_model = "gemini-2.0-flash-001" -class GeminiV2D5FlashChat(GeminiV1D5Chat): - """Gemini 2.5 flash for chat session.""" - _max_output_tokens = 65535 - context_window = 1048576 - name = 'vertex_ai_gemini-2-5-flash-chat' - _vertex_ai_model = 'gemini-2.5-flash-preview-04-17' +class GeminiV2Chat(GeminiV1D5Chat): + """Gemini 2 for chat session.""" + name = "vertex_ai_gemini-2-chat" + _vertex_ai_model = "gemini-2.0-pro-exp-02-05" -class GeminiV2D5ProChat(GeminiV1D5Chat): - """Gemini 2.5 pro for chat session.""" - _max_output_tokens = 65535 - context_window = 1048576 - name = 'vertex_ai_gemini-2-5-pro-chat' - _vertex_ai_model = 'gemini-2.5-pro-preview-05-06' +class GeminiV2ThinkChat(GeminiV1D5Chat): + """Gemini 2 for chat session.""" + + name = "vertex_ai_gemini-2-think-chat" + _vertex_ai_model = "gemini-2.0-flash-thinking-exp-01-21" -class AIBinaryModel(GoogleModel): - """A customized model hosted internally.""" - name = 'ai_binary_model' +class GeminiV2D5FlashChat(GeminiV1D5Chat): + """Gemini 2.5 flash for chat session.""" + + _max_output_tokens = 65535 + context_window = 1048576 + name = "vertex_ai_gemini-2-5-flash-chat" + _vertex_ai_model = "gemini-2.5-flash-preview-04-17" - def __init__(self, name: str, *args, **kwargs): - super().__init__(*args, **kwargs) - self.name = name - def get_model(self) -> Any: - """Returns the underlying model instance.""" - # Placeholder: No suitable implementation/usage yet. +class GeminiV2D5ProChat(GeminiV1D5Chat): + """Gemini 2.5 pro for chat session.""" - def get_chat_client(self, model: Any) -> Any: - """Returns a new chat session.""" - del model - # Placeholder: To Be Implemented. + _max_output_tokens = 65535 + context_window = 1048576 + name = "vertex_ai_gemini-2-5-pro-chat" + _vertex_ai_model = "gemini-2.5-pro-preview-05-06" - def chat_llm(self, client: Any, prompt: prompts.Prompt) -> Any: - """Queries the LLM in the given chat session and returns the response.""" - del client, prompt - # Placeholder: To Be Implemented. - def chat_llm_with_tools(self, client: Any, prompt: Optional[prompts.Prompt], - tools) -> Any: - """Queries the LLM in the given chat session with tools.""" - # Placeholder: To Be Implemented. - return +class AIBinaryModel(GoogleModel): + """A customized model hosted internally.""" + + name = "ai_binary_model" + + def __init__(self, name: str, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + + def get_model(self) -> Any: + """Returns the underlying model instance.""" + # Placeholder: No suitable implementation/usage yet. + + def get_chat_client(self, model: Any) -> Any: + """Returns a new chat session.""" + del model + # Placeholder: To Be Implemented. + + def chat_llm(self, client: Any, prompt: prompts.Prompt) -> Any: + """Queries the LLM in the given chat session and returns the response.""" + del client, prompt + # Placeholder: To Be Implemented. + + def chat_llm_with_tools( + self, client: Any, prompt: Optional[prompts.Prompt], tools + ) -> Any: + """Queries the LLM in the given chat session with tools.""" + # Placeholder: To Be Implemented. + return DefaultModel = GeminiV1D5 diff --git a/llm_toolkit/output_parser.py b/llm_toolkit/output_parser.py index 708d6caa47..74ee7705ab 100755 --- a/llm_toolkit/output_parser.py +++ b/llm_toolkit/output_parser.py @@ -20,108 +20,113 @@ from llm_toolkit.crash_triager import TriageResult -RAW_OUTPUT_EXT = '.rawoutput' +RAW_OUTPUT_EXT = ".rawoutput" def is_raw_output(file: str) -> bool: - """Checks if the |file| is a raw output from LLM by its extension.""" - return file.endswith(RAW_OUTPUT_EXT) + """Checks if the |file| is a raw output from LLM by its extension.""" + return file.endswith(RAW_OUTPUT_EXT) def parse_args() -> argparse.Namespace: - """Parses command line arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument('-r', - '--llm-response-path', - type=str, - required=True, - help='A file containing the response from LLM.') - parser.add_argument('-o', - '--output-path', - type=str, - required=True, - help='A directory to save the parsed output.') - args = parser.parse_args() - - return args - - -def _parse_code_block_by_marker(lines: list[str], start_marker: str, - end_marker: str) -> list[str]: - """Parses code block lines based on markers.""" - block = [] - in_block = False - contains_api = False - - for line in lines: - if not in_block and start_marker in line.lower(): - in_block = True # Start a code block. - if not contains_api: - block = [] # Ignore previous block because it does not contain API. - elif in_block and end_marker in line: - in_block = False # Finish a code block. - if contains_api: - break # Found fuzz target. - elif in_block: - block.append(line) - contains_api = contains_api or 'LLVMFuzzerTestOneInput' in line - return block if block else lines + """Parses command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "-r", + "--llm-response-path", + type=str, + required=True, + help="A file containing the response from LLM.", + ) + parser.add_argument( + "-o", + "--output-path", + type=str, + required=True, + help="A directory to save the parsed output.", + ) + args = parser.parse_args() + + return args + + +def _parse_code_block_by_marker( + lines: list[str], start_marker: str, end_marker: str +) -> list[str]: + """Parses code block lines based on markers.""" + block = [] + in_block = False + contains_api = False + + for line in lines: + if not in_block and start_marker in line.lower(): + in_block = True # Start a code block. + if not contains_api: + block = [] # Ignore previous block because it does not contain API. + elif in_block and end_marker in line: + in_block = False # Finish a code block. + if contains_api: + break # Found fuzz target. + elif in_block: + block.append(line) + contains_api = contains_api or "LLVMFuzzerTestOneInput" in line + return block if block else lines def parse_code(response_path: str) -> str: - """Parses the expected output from the |response_path|.""" - with open(response_path) as file: - response = file.read() - return filter_code(response) + """Parses the expected output from the |response_path|.""" + with open(response_path) as file: + response = file.read() + return filter_code(response) def filter_code(response: str) -> str: - # TODO(dongge): Merge this into prompt_builder.post_process_generated_code(). - solution = response.split('')[0] - lines = solution.splitlines() - lines = _parse_code_block_by_marker(lines, '```c', '```') - lines = _parse_code_block_by_marker(lines, '```java', '```') - lines = _parse_code_block_by_marker(lines, '```python', '```') - lines = _parse_code_block_by_marker(lines, '```rust', '```') - lines = _parse_code_block_by_marker(lines, '```java_code', '```') - lines = _parse_code_block_by_marker(lines, '', '') - lines = _parse_code_block_by_marker(lines, '', '') - - # Remove leading and trailing empty lines. - while lines and not lines[0].strip(): - lines.pop(0) - while lines and not lines[-1].strip(): - lines.pop() - - return '\n'.join(lines) + # TODO(dongge): Merge this into prompt_builder.post_process_generated_code(). + solution = response.split("")[0] + lines = solution.splitlines() + lines = _parse_code_block_by_marker(lines, "```c", "```") + lines = _parse_code_block_by_marker(lines, "```java", "```") + lines = _parse_code_block_by_marker(lines, "```python", "```") + lines = _parse_code_block_by_marker(lines, "```rust", "```") + lines = _parse_code_block_by_marker(lines, "```java_code", "```") + lines = _parse_code_block_by_marker(lines, "", "") + lines = _parse_code_block_by_marker(lines, "", "") + + # Remove leading and trailing empty lines. + while lines and not lines[0].strip(): + lines.pop(0) + while lines and not lines[-1].strip(): + lines.pop() + + return "\n".join(lines) def parse_triage(triage_path: str) -> tuple[str, str]: - """Parses the triage from the |triage_path|.""" - with open(triage_path) as file: - triage = file.read() - solution = triage.split('')[0] - lines = solution.splitlines() - for line in lines: - if "Crash is caused by bug in fuzz driver" in line: - return (TriageResult.DRIVER, '\n'.join(lines)) - if "Crash is caused by bug in project" in line: - return (TriageResult.PROJECT, '\n'.join(lines)) + """Parses the triage from the |triage_path|.""" + with open(triage_path) as file: + triage = file.read() + solution = triage.split("")[0] + lines = solution.splitlines() + for line in lines: + if "Crash is caused by bug in fuzz driver" in line: + return (TriageResult.DRIVER, "\n".join(lines)) + if "Crash is caused by bug in project" in line: + return (TriageResult.PROJECT, "\n".join(lines)) - return (TriageResult.NOT_APPLICABLE, '\n'.join(lines)) + return (TriageResult.NOT_APPLICABLE, "\n".join(lines)) def save_output(content: str, output_path: str) -> None: - """Saves the parsed |content| to |output_path|.""" - with open(output_path, 'w+') as output_file: - output_file.write(content) + """Saves the parsed |content| to |output_path|.""" + with open(output_path, "w+") as output_file: + output_file.write(content) def main(): - args = parse_args() - content = parse_code(args.llm_response_path) - save_output(content, args.output_path) + args = parse_args() + content = parse_code(args.llm_response_path) + save_output(content, args.output_path) if __name__ == "__main__": - sys.exit(main()) + sys.exit(main()) diff --git a/llm_toolkit/prompt_builder.py b/llm_toolkit/prompt_builder.py index f18c633307..a6f3352e38 100644 --- a/llm_toolkit/prompt_builder.py +++ b/llm_toolkit/prompt_builder.py @@ -32,2032 +32,2286 @@ logger = logging.getLogger(__name__) -DEFAULT_TEMPLATE_DIR: str = os.path.join(os.path.dirname(__file__), - '../prompts/template_xml/') -AGENT_TEMPLATE_DIR: str = os.path.join(os.path.dirname(__file__), - '../prompts/agent/') +DEFAULT_TEMPLATE_DIR: str = os.path.join( + os.path.dirname(__file__), "../prompts/template_xml/" +) +AGENT_TEMPLATE_DIR: str = os.path.join(os.path.dirname(__file__), "../prompts/agent/") # TODO(Dongge): Refactor this tot avoid hard-coding. # Example files. -EXAMPLE_PATH = os.path.join(os.path.dirname(__file__), '..', 'prompts', - 'example') +EXAMPLE_PATH = os.path.join(os.path.dirname(__file__), "..", "prompts", "example") # Example with FuzzeDataProvider. -FDP_EXAMPLE_1_PROBLEM = os.path.join(EXAMPLE_PATH, 'gdImageString-problem.txt') -FDP_EXAMPLE_1_SOLUTION = os.path.join(EXAMPLE_PATH, 'gdImageString-solution.cc') -FDP_EXAMPLE_2_PROBLEM = os.path.join(EXAMPLE_PATH, 'mpg123_decode-problem.txt') -FDP_EXAMPLE_2_SOLUTION = os.path.join(EXAMPLE_PATH, 'mpg123_decode-solution.cc') -C_EXAMPLE_1_PROBLEM = os.path.join(EXAMPLE_PATH, 'fuzzerPolygonToCells.txt') -C_EXAMPLE_1_SOLUTION = os.path.join(EXAMPLE_PATH, 'fuzzerPolygonToCells.c') -C_EXAMPLE_2_PROBLEM = os.path.join(EXAMPLE_PATH, 'dns_message_parse.txt') -C_EXAMPLE_2_SOLUTION = os.path.join(EXAMPLE_PATH, 'dns_message_parse.c') -FDP_JVM_EXAMPLE_1_PROBLEM = os.path.join(EXAMPLE_PATH, 'joni_regex-problem.txt') -FDP_JVM_EXAMPLE_1_SOLUTION = os.path.join(EXAMPLE_PATH, - 'joni_regex-solution.java') -FDP_JVM_EXAMPLE_2_PROBLEM = os.path.join(EXAMPLE_PATH, - 'jansi_colors-problem.txt') -FDP_JVM_EXAMPLE_2_SOLUTION = os.path.join(EXAMPLE_PATH, - 'jansi_colors-solution.java') +FDP_EXAMPLE_1_PROBLEM = os.path.join(EXAMPLE_PATH, "gdImageString-problem.txt") +FDP_EXAMPLE_1_SOLUTION = os.path.join(EXAMPLE_PATH, "gdImageString-solution.cc") +FDP_EXAMPLE_2_PROBLEM = os.path.join(EXAMPLE_PATH, "mpg123_decode-problem.txt") +FDP_EXAMPLE_2_SOLUTION = os.path.join(EXAMPLE_PATH, "mpg123_decode-solution.cc") +C_EXAMPLE_1_PROBLEM = os.path.join(EXAMPLE_PATH, "fuzzerPolygonToCells.txt") +C_EXAMPLE_1_SOLUTION = os.path.join(EXAMPLE_PATH, "fuzzerPolygonToCells.c") +C_EXAMPLE_2_PROBLEM = os.path.join(EXAMPLE_PATH, "dns_message_parse.txt") +C_EXAMPLE_2_SOLUTION = os.path.join(EXAMPLE_PATH, "dns_message_parse.c") +FDP_JVM_EXAMPLE_1_PROBLEM = os.path.join(EXAMPLE_PATH, "joni_regex-problem.txt") +FDP_JVM_EXAMPLE_1_SOLUTION = os.path.join(EXAMPLE_PATH, "joni_regex-solution.java") +FDP_JVM_EXAMPLE_2_PROBLEM = os.path.join(EXAMPLE_PATH, "jansi_colors-problem.txt") +FDP_JVM_EXAMPLE_2_SOLUTION = os.path.join(EXAMPLE_PATH, "jansi_colors-solution.java") EXAMPLES = { - 'c++': [ + "c++": [ [FDP_EXAMPLE_1_PROBLEM, FDP_EXAMPLE_1_SOLUTION], [FDP_EXAMPLE_2_PROBLEM, FDP_EXAMPLE_2_SOLUTION], ], - 'c': [ + "c": [ [C_EXAMPLE_1_PROBLEM, C_EXAMPLE_1_SOLUTION], [C_EXAMPLE_2_PROBLEM, C_EXAMPLE_2_SOLUTION], ], - 'jvm': [ + "jvm": [ [FDP_JVM_EXAMPLE_1_PROBLEM, FDP_JVM_EXAMPLE_1_SOLUTION], [FDP_JVM_EXAMPLE_2_PROBLEM, FDP_JVM_EXAMPLE_2_SOLUTION], ], } -BUILD_ERROR_SUMMARY = 'The code has the following build issues:' -FUZZ_ERROR_SUMMARY = 'The code can build successfully but has a runtime issue: ' +BUILD_ERROR_SUMMARY = "The code has the following build issues:" +FUZZ_ERROR_SUMMARY = "The code can build successfully but has a runtime issue: " -C_PROMPT_HEADERS_TO_ALWAYS_INCLUDES = ['stdio.h', 'stdlib.h', 'stdint.h'] +C_PROMPT_HEADERS_TO_ALWAYS_INCLUDES = ["stdio.h", "stdlib.h", "stdint.h"] class PromptBuilder: - """Prompt builder.""" - - def __init__(self, model: models.LLM, initial=None): - self._model = model - self._prompt = model.prompt_type()(initial) - - @abstractmethod - def build(self, - example_pair: list[list[str]], - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None) -> prompts.Prompt: - """Builds a prompt.""" - - @abstractmethod - def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str, - error_desc: Optional[str], - errors: list[str]) -> prompts.Prompt: - """Builds a fixer prompt.""" - - @abstractmethod - def build_triager_prompt(self, benchmark: Benchmark, driver_code: str, - crash_info: str, crash_func: dict) -> prompts.Prompt: - """Builds a triager prompt.""" - - def post_process_generated_code(self, generated_code: str) -> str: - """Allows prompt builder to adjust the generated code.""" - # return the same by default - return generated_code + """Prompt builder.""" + + def __init__(self, model: models.LLM, initial=None): + self._model = model + self._prompt = model.prompt_type()(initial) + + @abstractmethod + def build( + self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + ) -> prompts.Prompt: + """Builds a prompt.""" + + @abstractmethod + def build_fixer_prompt( + self, + benchmark: Benchmark, + raw_code: str, + error_desc: Optional[str], + errors: list[str], + ) -> prompts.Prompt: + """Builds a fixer prompt.""" + + @abstractmethod + def build_triager_prompt( + self, benchmark: Benchmark, driver_code: str, crash_info: str, crash_func: dict + ) -> prompts.Prompt: + """Builds a triager prompt.""" + + def post_process_generated_code(self, generated_code: str) -> str: + """Allows prompt builder to adjust the generated code.""" + # return the same by default + return generated_code class DefaultTemplateBuilder(PromptBuilder): - """Default builder for C/C++.""" - - def __init__(self, - model: models.LLM, - benchmark: Optional[Benchmark] = None, - template_dir: str = DEFAULT_TEMPLATE_DIR, - initial: Any = None): - super().__init__(model, initial) - self._template_dir = template_dir - self.benchmark = benchmark - - # Load templates. - self.priming_template_file = self._find_template(template_dir, - 'priming.txt') - self.cpp_priming_filler_file = self._find_template( - template_dir, 'cpp-specific-priming-filler.txt') - self.problem_template_file = self._find_template(template_dir, - 'problem.txt') - self.solution_template_file = self._find_template(template_dir, - 'solution.txt') - self.context_template_file = self._find_template(template_dir, - 'context.txt') - self.fixer_priming_template_file = self._find_template( - template_dir, 'fixer_priming.txt') - self.fixer_problem_template_file = self._find_template( - template_dir, 'fixer_problem.txt') - self.fixer_context_template_file = self._find_template( - template_dir, 'fixer_context.txt') - self.fixer_instruction_template_file = self._find_template( - template_dir, 'fixer_instruction.txt') - self.triager_priming_template_file = self._find_template( - template_dir, 'triager_priming.txt') - self.triager_problem_template_file = self._find_template( - template_dir, 'triager_problem.txt') - - def _format_priming(self, benchmark: Benchmark) -> str: - """Formats a priming based on the prompt template.""" - priming = self._get_template(self.priming_template_file) - priming = priming.replace('{LANGUAGE}', benchmark.file_type.value) - priming = priming.replace('{FUZZ_TARGET_PATH}', benchmark.target_path) - # TODO(Dongge): Add project name and fuzz target file path. - if benchmark.needs_extern: - priming += ( - 'IMPORTANT: The fuzz target is written in C++, whereas the ' - 'project-under-test is written in C. All headers, functions, and code' - 'from the project must be consistently wrapped in ' - 'extern "C" to ensure error-free compilation and linkage' - 'between C and C++:\n\nextern "C" {\n //Include necessary C ' - 'headers, source files, functions, and code here.\n}\n\n') - if benchmark.file_type == FileType.CPP: - type_specific_priming = self._get_template(self.cpp_priming_filler_file) - else: - type_specific_priming = '' - priming = priming.replace('{TYPE_SPECIFIC_PRIMING}', type_specific_priming) - return priming - - def _find_template(self, template_dir: str, template_name: str) -> str: - """Finds template file based on |template_dir|.""" - preferred_template = os.path.join(template_dir, template_name) - # Use the preferred template if it exists. - if os.path.isfile(preferred_template): - return preferred_template - # Fall back to the default template. - default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) - return default_template - - def _get_template(self, template_file: str) -> str: - """Reads the template for prompts.""" - with open(template_file) as file: - return file.read() - - def format_problem(self, problem_content: str) -> str: - """Formats a problem based on the prompt template.""" - problem = self._get_template(self.problem_template_file) - problem = problem.replace('{PROBLEM_CONTENT}', problem_content) - return problem - - def format_solution(self, solution_content: str) -> str: - """Formats a solution based on the prompt template.""" - solution = self._get_template(self.solution_template_file) - solution = solution.replace('{SOLUTION_CONTENT}', solution_content) - return solution - - def format_context(self, context_info: dict) -> str: - context = jinja2.Template(self._get_template(self.context_template_file), - trim_blocks=True, - lstrip_blocks=True) - return context.render( - headers='\n'.join(context_info['files']), - must_insert=context_info['decl'], - func_source=context_info['func_source'], - xrefs='\n'.join(context_info['xrefs']), - include_statement=context_info['header'], - ) - - def _select_examples(self, examples: list[list], - prompt_size: int) -> list[list[str]]: - """Selects |examples| based on |prompt_size|.""" - # First remove repeated examples to avoid over fitting. - targets = set() - unique_examples = [] - for example in examples: - if example[2] in targets: - continue - targets.add(example[2]) - unique_examples.append(example) - - if (sum(example[0] for example in unique_examples) + prompt_size - < self._model.context_window): - return [[example[1], example[2]] for example in examples] - - # Then prioritize complex (i.e., long) examples. - unique_examples.sort(key=lambda x: x[0], reverse=True) - selected_examples = [] - for example in unique_examples: - if example[0] + prompt_size >= self._model.context_window: - # The estimation is inaccurate, if an example's size equals to - # the limit, it's safer to not include the example. - continue - selected_examples.append([example[1], example[2]]) - prompt_size += example[0] - - # Write the most complex examples at the end so that LLM gives them - # a higher weight. - selected_examples.sort(key=len, reverse=True) - return selected_examples - - def _add_examples(self, - example_files: list[list[str]], - final_problem: str, - example_content: Optional[list[list[str]]] = None): - """Constructs the |example_files| to be used in the prompt.""" - # Estimate prompt size so far. - prompt_size = self._model.estimate_token_num(self._prompt.get()) - # Estimate space needed for the final problem. - final_problem_prompt = self._prompt.create_prompt_piece( - final_problem, 'user') - query_size = prompt_size + self._model.estimate_token_num( - final_problem_prompt) - - # Collect all examples in a single list - examples = [] - for problem, solution in example_files: - with open(problem) as problem_file: - problem = problem_file.read()[:-1] - with open(solution) as solution_file: - solution = solution_file.read()[:-1] - solution = project_targets.filter_target_lines(solution) - examples.append((problem, solution)) - # TODO(mihaimaruseac): Should we start from these first? - if example_content: - for problem, solution in example_content: - solution = project_targets.filter_target_lines(solution) - examples.append((problem, solution)) - - # Next, we need to expand all templates and determine how much the size - # of the prompt would increase when adding each one of them: - weights = [] - for problem, solution in examples: - problem = self.format_problem(problem) - solution = self.format_solution(solution) - problem_prompt = self._prompt.create_prompt_piece(problem, 'user') - solution_prompt = self._prompt.create_prompt_piece(solution, 'assistant') - problem_weight = self._model.estimate_token_num(problem_prompt) - solution_weight = self._model.estimate_token_num(solution_prompt) - total_weight = problem_weight + solution_weight + 1 # one \n - weights.append((total_weight, problem, solution)) - - # Select examples up to context window and add them to prompt. - selected_examples = self._select_examples(weights, query_size) - for problem, solution in selected_examples: - self._prompt.add_problem(problem) - self._prompt.add_solution(solution) - - def build(self, - example_pair: list[list[str]], - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None) -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it.""" - if not self.benchmark: - return self._prompt - priming = self._format_priming(self.benchmark) - final_problem = self.format_problem(self.benchmark.function_signature) - final_problem += (f'You MUST call \n' - f'{self.benchmark.function_signature}\n' - f' in your solution!\n') - if project_context_content: - final_problem += self.format_context(project_context_content) - final_problem += '\n' - self._prepare_prompt(priming, final_problem, example_pair, - project_example_content) - return self._prompt - - def build_fixer_prompt(self, - benchmark: Benchmark, - raw_code: str, - error_desc: Optional[str], - errors: list[str], - coverage_result: Optional[CoverageResult] = None, - context: str = '', - instruction: str = '') -> prompts.Prompt: - """Prepares the code-fixing prompt.""" - priming, priming_weight = self._format_fixer_priming(benchmark) - - if error_desc and errors: - pass - elif coverage_result: - error_desc = coverage_result.insight - errors = coverage_result.suggestions.splitlines() - else: - error_desc = '' - errors = [] - problem = self._format_fixer_problem(raw_code, error_desc, errors, - priming_weight, context, instruction) - - self._prepare_prompt(priming, problem) - return self._prompt - - def _format_fixer_priming(self, benchmark: Benchmark) -> Tuple[str, int]: - """Formats a priming for code fixer based on the template.""" - with open(self.fixer_priming_template_file) as f: - priming = f.read().strip() + '\n' - priming = priming.replace('{LANGUAGE}', benchmark.file_type.value) - if benchmark.needs_extern: - priming += ('\nNote that some code may need to be wrapped with ' - 'extern "C" because the project under test is ' - 'written in C but the fuzz target is in C++.\n') - priming_prompt = self._prompt.create_prompt_piece(priming, 'system') - priming_weight = self._model.estimate_token_num(priming_prompt) - # NOTE: We need to return the priming _as text_ and the weight. Otherwise, - # in the case of structured prompts, we will create nested structures. - return priming, priming_weight - - def _format_fixer_problem(self, raw_code: str, error_desc: Optional[str], - errors: list[str], priming_weight: int, - context: str, instruction: str) -> str: - """Formats a problem for code fixer based on the template.""" - with open(self.fixer_problem_template_file) as f: - problem = f.read().strip() - problem = problem.replace('{CODE_TO_BE_FIXED}', raw_code) - if error_desc: - error_summary = FUZZ_ERROR_SUMMARY + error_desc - else: - # Build error does not pass error desc. - error_summary = BUILD_ERROR_SUMMARY - problem = problem.replace('{ERROR_SUMMARY}', error_summary) - - if context: - with open(self.fixer_context_template_file) as f: - context_template = f.read().strip() - context = context_template.replace('{CONTEXT_SOURCE_CODE}', context) - problem = problem.replace('{CONTEXT}', context) - - if instruction: - with open(self.fixer_instruction_template_file) as f: - instruction_template = f.read().strip() - instruction = instruction_template.replace('{INSTRUCTION}', instruction) - problem = problem.replace('{INSTRUCTION}', instruction) - - problem_prompt = self._prompt.create_prompt_piece(problem, 'user') - template_piece = self._prompt.create_prompt_piece('{ERROR_MESSAGES}', - 'user') - - problem_weight = self._model.estimate_token_num(problem_prompt) - template_weight = self._model.estimate_token_num(template_piece) - - # the template will be replaced later and should not be counted - prompt_size = priming_weight + problem_weight - template_weight - # Add extra 20-tokens redundancy - # TODO(mihaimaruseac): Is this needed? - prompt_size += 20 - - # We are adding errors one by one until we reach the maximum prompt size - selected_errors = [] - for error in errors: - error_prompt = self._prompt.create_prompt_piece(error, 'user') - error_token_num = self._model.estimate_token_num(error_prompt) - if prompt_size + error_token_num >= self._model.context_window: - # The estimation is inaccurate, if an example's size equals to - # the limit, it's safer to not include the example. - break - prompt_size += error_token_num - selected_errors.append(error) - - # Now, compose the problem part of the prompt - error_message = '\n'.join(selected_errors) - if error_message.strip(): - return problem.replace('{ERROR_MESSAGES}', error_message) - - # Expecting empty error message for NO_COV_INCREASE. - if SemanticCheckResult.is_no_cov_increase_err(error_desc): - return problem.replace('\n', '')\ - .replace('{ERROR_MESSAGES}\n', '')\ - .replace('\n', '') - - # Log warning for an unexpected empty error message. - logger.warning( - 'Unexpected empty error message in fix prompt for error_desc: %s', - str(error_desc)) - return problem.replace('{ERROR_MESSAGES}', error_message) - - def build_triager_prompt(self, benchmark: Benchmark, driver_code: str, - crash_info: str, crash_func: dict) -> prompts.Prompt: - """Prepares the crash-triaging prompt.""" - priming, priming_weight = self._format_triager_priming() - problem = self._format_triager_problem(benchmark, driver_code, crash_info, - crash_func, priming_weight) - - self._prepare_prompt(priming, problem) - return self._prompt - - def _format_triager_priming(self) -> Tuple[str, int]: - """Formats a priming for crash triage based on the template.""" - with open(self.triager_priming_template_file) as f: - priming = f.read().strip() + '\n' - priming_prompt = self._prompt.create_prompt_piece(priming, 'system') - priming_weight = self._model.estimate_token_num(priming_prompt) - # NOTE: We need to return the priming _as text_ and the weight. Otherwise, - # in the case of structured prompts, we will create nested structures. - return priming, priming_weight - - def _format_triager_problem(self, benchmark: Benchmark, driver_code: str, - crash_info: str, crash_func: dict, - priming_weight: int) -> str: - """Formats a problem for crash triage based on the template.""" - all_func_code = [] - for func_name, line_number in crash_func.items(): - if func_name == 'LLVMFuzzerTestOneInput': - driver_code = self._slice_driver_code(benchmark.project, driver_code, - line_number) - else: - func_code = self._slice_func_code(benchmark.project, func_name, - line_number) - all_func_code.append(func_code) - - with open(self.triager_problem_template_file) as f: - problem = f.read().strip() - problem = problem.replace('{CRASH_REPORT}', crash_info.strip())\ - .replace('{DRIVER_CODE}', driver_code.strip()) - - problem_prompt = self._prompt.create_prompt_piece(problem, 'user') - template_piece = self._prompt.create_prompt_piece('{PROJECT_FUNCTION_CODE}', - 'user') - - problem_weight = self._model.estimate_token_num(problem_prompt) - template_weight = self._model.estimate_token_num(template_piece) - - prompt_size = priming_weight + problem_weight - template_weight - # Add extra 20-tokens redundancy - prompt_size += 20 - - # Add function code one by one until we reach the maximum prompt size - selected_func_code = [] - for func_code in all_func_code: - func_code_prompt = self._prompt.create_prompt_piece(func_code, 'user') - func_code_token_num = self._model.estimate_token_num(func_code_prompt) - if prompt_size + func_code_token_num >= self._model.context_window: - # The estimation is inaccurate, if an example's size equals to - # the limit, it's safer to not include the example. - logger.warning('Breaking because adding this function code \ - would exceed context window') - break - prompt_size += func_code_token_num - selected_func_code.append(func_code) - - # Compose the problem part of the prompt - project_function_code = '\n'.join(selected_func_code) - if project_function_code.strip(): - return problem.replace('{PROJECT_FUNCTION_CODE}', - project_function_code.strip()) - - logger.warning( - 'Empty project function code in triage prompt for project: %s, \ - function name: %s', benchmark.project, benchmark.function_name) - - return problem.replace('{PROJECT_FUNCTION_CODE}', \ - 'No relevant project function code') - - def _prepare_prompt( - self, - priming: str, - final_problem: str, - example_pair: Optional[list[list[str]]] = None, - project_example_content: Optional[list[list[str]]] = None): - """Constructs a prompt using the parameters and saves it.""" - self._prompt.add_priming(priming) - - if example_pair is None: - example_pair = [] - - self._add_examples(example_pair, final_problem, project_example_content) - self._prompt.add_problem(final_problem) - - def _slice_driver_code(self, project: str, driver_code: str, - target_lines: set) -> str: - """Slice the driver code up to the target line.""" - target_line = max(target_lines) - lines = driver_code.split('\n') - - if target_line > len(lines): - logger.warning( - 'Driver target line exceed maxium limit in Project: %s, \ - try to use whole driver code in trigae prompt', project) - return driver_code - - code_snippet = '\n'.join(lines[:target_line]) - result = f'\nLine 1 - {target_line}:\n{code_snippet}' - return result - - def _slice_func_code(self, project: str, func_name: str, - target_lines: set) -> str: - """Slice target line and four preceding lines from function code.""" - func_sig = introspector.query_introspector_function_signature( - project, func_name) - func_code = introspector.query_introspector_function_source( - project, func_sig) - begin_line, end_line = introspector.query_introspector_function_line( - project, func_sig) - - if begin_line != 0 and end_line != 0 and all( - begin_line <= line <= end_line for line in target_lines): - lines = func_code.split('\n') - output_lines = set() - result = [] - for line in sorted(target_lines): - start = max(line - 4, begin_line) - end = line - if not any(l in output_lines for l in range(start, end + 1)): - code_snippet = '\n'.join(lines[(start - - begin_line):(end - begin_line) + 1]) - result.append(f'\nFunction Name:\n{func_name}\n\ - Line {start} - {end}:\n{code_snippet}') - output_lines.update(range(start, end + 1)) - return '\n'.join(result) - - logger.warning('Failed to slice Project: %s Function: %s at Lines: %s', - project, func_name, target_lines) - return '' + """Default builder for C/C++.""" + + def __init__( + self, + model: models.LLM, + benchmark: Optional[Benchmark] = None, + template_dir: str = DEFAULT_TEMPLATE_DIR, + initial: Any = None, + ): + super().__init__(model, initial) + self._template_dir = template_dir + self.benchmark = benchmark + + # Load templates. + self.priming_template_file = self._find_template(template_dir, "priming.txt") + self.cpp_priming_filler_file = self._find_template( + template_dir, "cpp-specific-priming-filler.txt" + ) + self.problem_template_file = self._find_template(template_dir, "problem.txt") + self.solution_template_file = self._find_template(template_dir, "solution.txt") + self.context_template_file = self._find_template(template_dir, "context.txt") + self.fixer_priming_template_file = self._find_template( + template_dir, "fixer_priming.txt" + ) + self.fixer_problem_template_file = self._find_template( + template_dir, "fixer_problem.txt" + ) + self.fixer_context_template_file = self._find_template( + template_dir, "fixer_context.txt" + ) + self.fixer_instruction_template_file = self._find_template( + template_dir, "fixer_instruction.txt" + ) + self.triager_priming_template_file = self._find_template( + template_dir, "triager_priming.txt" + ) + self.triager_problem_template_file = self._find_template( + template_dir, "triager_problem.txt" + ) + + def _format_priming(self, benchmark: Benchmark) -> str: + """Formats a priming based on the prompt template.""" + priming = self._get_template(self.priming_template_file) + priming = priming.replace("{LANGUAGE}", benchmark.file_type.value) + priming = priming.replace("{FUZZ_TARGET_PATH}", benchmark.target_path) + # TODO(Dongge): Add project name and fuzz target file path. + if benchmark.needs_extern: + priming += ( + "IMPORTANT: The fuzz target is written in C++, whereas the " + "project-under-test is written in C. All headers, functions, and code" + "from the project must be consistently wrapped in " + 'extern "C" to ensure error-free compilation and linkage' + 'between C and C++:\n\nextern "C" {\n //Include necessary C ' + "headers, source files, functions, and code here.\n}\n\n" + ) + if benchmark.file_type == FileType.CPP: + type_specific_priming = self._get_template(self.cpp_priming_filler_file) + else: + type_specific_priming = "" + priming = priming.replace("{TYPE_SPECIFIC_PRIMING}", type_specific_priming) + return priming + + def _find_template(self, template_dir: str, template_name: str) -> str: + """Finds template file based on |template_dir|.""" + preferred_template = os.path.join(template_dir, template_name) + # Use the preferred template if it exists. + if os.path.isfile(preferred_template): + return preferred_template + # Fall back to the default template. + default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) + return default_template + + def _get_template(self, template_file: str) -> str: + """Reads the template for prompts.""" + with open(template_file) as file: + return file.read() + + def format_problem(self, problem_content: str) -> str: + """Formats a problem based on the prompt template.""" + problem = self._get_template(self.problem_template_file) + problem = problem.replace("{PROBLEM_CONTENT}", problem_content) + return problem + + def format_solution(self, solution_content: str) -> str: + """Formats a solution based on the prompt template.""" + solution = self._get_template(self.solution_template_file) + solution = solution.replace("{SOLUTION_CONTENT}", solution_content) + return solution + + def format_context(self, context_info: dict) -> str: + context = jinja2.Template( + self._get_template(self.context_template_file), + trim_blocks=True, + lstrip_blocks=True, + ) + return context.render( + headers="\n".join(context_info["files"]), + must_insert=context_info["decl"], + func_source=context_info["func_source"], + xrefs="\n".join(context_info["xrefs"]), + include_statement=context_info["header"], + ) + + def _select_examples( + self, examples: list[list], prompt_size: int + ) -> list[list[str]]: + """Selects |examples| based on |prompt_size|.""" + # First remove repeated examples to avoid over fitting. + targets = set() + unique_examples = [] + for example in examples: + if example[2] in targets: + continue + targets.add(example[2]) + unique_examples.append(example) + + if ( + sum(example[0] for example in unique_examples) + prompt_size + < self._model.context_window + ): + return [[example[1], example[2]] for example in examples] + + # Then prioritize complex (i.e., long) examples. + unique_examples.sort(key=lambda x: x[0], reverse=True) + selected_examples = [] + for example in unique_examples: + if example[0] + prompt_size >= self._model.context_window: + # The estimation is inaccurate, if an example's size equals to + # the limit, it's safer to not include the example. + continue + selected_examples.append([example[1], example[2]]) + prompt_size += example[0] + + # Write the most complex examples at the end so that LLM gives them + # a higher weight. + selected_examples.sort(key=len, reverse=True) + return selected_examples + + def _add_examples( + self, + example_files: list[list[str]], + final_problem: str, + example_content: Optional[list[list[str]]] = None, + ): + """Constructs the |example_files| to be used in the prompt.""" + # Estimate prompt size so far. + prompt_size = self._model.estimate_token_num(self._prompt.get()) + # Estimate space needed for the final problem. + final_problem_prompt = self._prompt.create_prompt_piece(final_problem, "user") + query_size = prompt_size + self._model.estimate_token_num(final_problem_prompt) + + # Collect all examples in a single list + examples = [] + for problem, solution in example_files: + with open(problem) as problem_file: + problem = problem_file.read()[:-1] + with open(solution) as solution_file: + solution = solution_file.read()[:-1] + solution = project_targets.filter_target_lines(solution) + examples.append((problem, solution)) + # TODO(mihaimaruseac): Should we start from these first? + if example_content: + for problem, solution in example_content: + solution = project_targets.filter_target_lines(solution) + examples.append((problem, solution)) + + # Next, we need to expand all templates and determine how much the size + # of the prompt would increase when adding each one of them: + weights = [] + for problem, solution in examples: + problem = self.format_problem(problem) + solution = self.format_solution(solution) + problem_prompt = self._prompt.create_prompt_piece(problem, "user") + solution_prompt = self._prompt.create_prompt_piece(solution, "assistant") + problem_weight = self._model.estimate_token_num(problem_prompt) + solution_weight = self._model.estimate_token_num(solution_prompt) + total_weight = problem_weight + solution_weight + 1 # one \n + weights.append((total_weight, problem, solution)) + + # Select examples up to context window and add them to prompt. + selected_examples = self._select_examples(weights, query_size) + for problem, solution in selected_examples: + self._prompt.add_problem(problem) + self._prompt.add_solution(solution) + + def build( + self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + ) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it.""" + if not self.benchmark: + return self._prompt + priming = self._format_priming(self.benchmark) + final_problem = self.format_problem(self.benchmark.function_signature) + final_problem += ( + f"You MUST call \n" + f"{self.benchmark.function_signature}\n" + f" in your solution!\n" + ) + if project_context_content: + final_problem += self.format_context(project_context_content) + final_problem += "\n" + self._prepare_prompt( + priming, final_problem, example_pair, project_example_content + ) + return self._prompt + + def build_fixer_prompt( + self, + benchmark: Benchmark, + raw_code: str, + error_desc: Optional[str], + errors: list[str], + coverage_result: Optional[CoverageResult] = None, + context: str = "", + instruction: str = "", + ) -> prompts.Prompt: + """Prepares the code-fixing prompt.""" + priming, priming_weight = self._format_fixer_priming(benchmark) + + if error_desc and errors: + pass + elif coverage_result: + error_desc = coverage_result.insight + errors = coverage_result.suggestions.splitlines() + else: + error_desc = "" + errors = [] + problem = self._format_fixer_problem( + raw_code, error_desc, errors, priming_weight, context, instruction + ) + + self._prepare_prompt(priming, problem) + return self._prompt + + def _format_fixer_priming(self, benchmark: Benchmark) -> Tuple[str, int]: + """Formats a priming for code fixer based on the template.""" + with open(self.fixer_priming_template_file) as f: + priming = f.read().strip() + "\n" + priming = priming.replace("{LANGUAGE}", benchmark.file_type.value) + if benchmark.needs_extern: + priming += ( + "\nNote that some code may need to be wrapped with " + 'extern "C" because the project under test is ' + "written in C but the fuzz target is in C++.\n" + ) + priming_prompt = self._prompt.create_prompt_piece(priming, "system") + priming_weight = self._model.estimate_token_num(priming_prompt) + # NOTE: We need to return the priming _as text_ and the weight. Otherwise, + # in the case of structured prompts, we will create nested structures. + return priming, priming_weight + + def _format_fixer_problem( + self, + raw_code: str, + error_desc: Optional[str], + errors: list[str], + priming_weight: int, + context: str, + instruction: str, + ) -> str: + """Formats a problem for code fixer based on the template.""" + with open(self.fixer_problem_template_file) as f: + problem = f.read().strip() + problem = problem.replace("{CODE_TO_BE_FIXED}", raw_code) + if error_desc: + error_summary = FUZZ_ERROR_SUMMARY + error_desc + else: + # Build error does not pass error desc. + error_summary = BUILD_ERROR_SUMMARY + problem = problem.replace("{ERROR_SUMMARY}", error_summary) + + if context: + with open(self.fixer_context_template_file) as f: + context_template = f.read().strip() + context = context_template.replace("{CONTEXT_SOURCE_CODE}", context) + problem = problem.replace("{CONTEXT}", context) + + if instruction: + with open(self.fixer_instruction_template_file) as f: + instruction_template = f.read().strip() + instruction = instruction_template.replace("{INSTRUCTION}", instruction) + problem = problem.replace("{INSTRUCTION}", instruction) + + problem_prompt = self._prompt.create_prompt_piece(problem, "user") + template_piece = self._prompt.create_prompt_piece("{ERROR_MESSAGES}", "user") + + problem_weight = self._model.estimate_token_num(problem_prompt) + template_weight = self._model.estimate_token_num(template_piece) + + # the template will be replaced later and should not be counted + prompt_size = priming_weight + problem_weight - template_weight + # Add extra 20-tokens redundancy + # TODO(mihaimaruseac): Is this needed? + prompt_size += 20 + + # We are adding errors one by one until we reach the maximum prompt size + selected_errors = [] + for error in errors: + error_prompt = self._prompt.create_prompt_piece(error, "user") + error_token_num = self._model.estimate_token_num(error_prompt) + if prompt_size + error_token_num >= self._model.context_window: + # The estimation is inaccurate, if an example's size equals to + # the limit, it's safer to not include the example. + break + prompt_size += error_token_num + selected_errors.append(error) + + # Now, compose the problem part of the prompt + error_message = "\n".join(selected_errors) + if error_message.strip(): + return problem.replace("{ERROR_MESSAGES}", error_message) + + # Expecting empty error message for NO_COV_INCREASE. + if SemanticCheckResult.is_no_cov_increase_err(error_desc): + return ( + problem.replace("\n", "") + .replace("{ERROR_MESSAGES}\n", "") + .replace("\n", "") + ) + + # Log warning for an unexpected empty error message. + logger.warning( + "Unexpected empty error message in fix prompt for error_desc: %s", + str(error_desc), + ) + return problem.replace("{ERROR_MESSAGES}", error_message) + + def build_triager_prompt( + self, benchmark: Benchmark, driver_code: str, crash_info: str, crash_func: dict + ) -> prompts.Prompt: + """Prepares the crash-triaging prompt.""" + priming, priming_weight = self._format_triager_priming() + problem = self._format_triager_problem( + benchmark, driver_code, crash_info, crash_func, priming_weight + ) + + self._prepare_prompt(priming, problem) + return self._prompt + + def _format_triager_priming(self) -> Tuple[str, int]: + """Formats a priming for crash triage based on the template.""" + with open(self.triager_priming_template_file) as f: + priming = f.read().strip() + "\n" + priming_prompt = self._prompt.create_prompt_piece(priming, "system") + priming_weight = self._model.estimate_token_num(priming_prompt) + # NOTE: We need to return the priming _as text_ and the weight. Otherwise, + # in the case of structured prompts, we will create nested structures. + return priming, priming_weight + + def _format_triager_problem( + self, + benchmark: Benchmark, + driver_code: str, + crash_info: str, + crash_func: dict, + priming_weight: int, + ) -> str: + """Formats a problem for crash triage based on the template.""" + all_func_code = [] + for func_name, line_number in crash_func.items(): + if func_name == "LLVMFuzzerTestOneInput": + driver_code = self._slice_driver_code( + benchmark.project, driver_code, line_number + ) + else: + func_code = self._slice_func_code( + benchmark.project, func_name, line_number + ) + all_func_code.append(func_code) + + with open(self.triager_problem_template_file) as f: + problem = f.read().strip() + problem = problem.replace("{CRASH_REPORT}", crash_info.strip()).replace( + "{DRIVER_CODE}", driver_code.strip() + ) + + problem_prompt = self._prompt.create_prompt_piece(problem, "user") + template_piece = self._prompt.create_prompt_piece( + "{PROJECT_FUNCTION_CODE}", "user" + ) + + problem_weight = self._model.estimate_token_num(problem_prompt) + template_weight = self._model.estimate_token_num(template_piece) + + prompt_size = priming_weight + problem_weight - template_weight + # Add extra 20-tokens redundancy + prompt_size += 20 + + # Add function code one by one until we reach the maximum prompt size + selected_func_code = [] + for func_code in all_func_code: + func_code_prompt = self._prompt.create_prompt_piece(func_code, "user") + func_code_token_num = self._model.estimate_token_num(func_code_prompt) + if prompt_size + func_code_token_num >= self._model.context_window: + # The estimation is inaccurate, if an example's size equals to + # the limit, it's safer to not include the example. + logger.warning( + "Breaking because adding this function code \ + would exceed context window" + ) + break + prompt_size += func_code_token_num + selected_func_code.append(func_code) + + # Compose the problem part of the prompt + project_function_code = "\n".join(selected_func_code) + if project_function_code.strip(): + return problem.replace( + "{PROJECT_FUNCTION_CODE}", project_function_code.strip() + ) + + logger.warning( + "Empty project function code in triage prompt for project: %s, \ + function name: %s", + benchmark.project, + benchmark.function_name, + ) + + return problem.replace( + "{PROJECT_FUNCTION_CODE}", "No relevant project function code" + ) + + def _prepare_prompt( + self, + priming: str, + final_problem: str, + example_pair: Optional[list[list[str]]] = None, + project_example_content: Optional[list[list[str]]] = None, + ): + """Constructs a prompt using the parameters and saves it.""" + self._prompt.add_priming(priming) + + if example_pair is None: + example_pair = [] + + self._add_examples(example_pair, final_problem, project_example_content) + self._prompt.add_problem(final_problem) + + def _slice_driver_code( + self, project: str, driver_code: str, target_lines: set + ) -> str: + """Slice the driver code up to the target line.""" + target_line = max(target_lines) + lines = driver_code.split("\n") + + if target_line > len(lines): + logger.warning( + "Driver target line exceed maxium limit in Project: %s, \ + try to use whole driver code in trigae prompt", + project, + ) + return driver_code + + code_snippet = "\n".join(lines[:target_line]) + result = f"\nLine 1 - {target_line}:\n{code_snippet}" + return result + + def _slice_func_code(self, project: str, func_name: str, target_lines: set) -> str: + """Slice target line and four preceding lines from function code.""" + func_sig = introspector.query_introspector_function_signature( + project, func_name + ) + func_code = introspector.query_introspector_function_source(project, func_sig) + begin_line, end_line = introspector.query_introspector_function_line( + project, func_sig + ) + + if ( + begin_line != 0 + and end_line != 0 + and all(begin_line <= line <= end_line for line in target_lines) + ): + lines = func_code.split("\n") + output_lines = set() + result = [] + for line in sorted(target_lines): + start = max(line - 4, begin_line) + end = line + if not any(l in output_lines for l in range(start, end + 1)): + code_snippet = "\n".join( + lines[(start - begin_line) : (end - begin_line) + 1] + ) + result.append( + f"\nFunction Name:\n{func_name}\n\ + Line {start} - {end}:\n{code_snippet}" + ) + output_lines.update(range(start, end + 1)) + return "\n".join(result) + + logger.warning( + "Failed to slice Project: %s Function: %s at Lines: %s", + project, + func_name, + target_lines, + ) + return "" class PrototyperTemplateBuilder(DefaultTemplateBuilder): - """Builder specifically targeted C (and excluding C++).""" - - def __init__(self, - model: models.LLM, - benchmark: Benchmark, - template_dir: str = DEFAULT_TEMPLATE_DIR, - initial: Any = None): - super().__init__(model, benchmark, template_dir, initial) - self.agent_templare_dir = AGENT_TEMPLATE_DIR - - # Load templates. - if benchmark.is_c_target: - self.priming_template_file = self._find_template( - self.agent_templare_dir, 'prototyper-priming.c.txt') - elif benchmark.is_cpp_target: - self.priming_template_file = self._find_template( - self.agent_templare_dir, 'prototyper-priming.cpp.txt') - else: - self.problem_template_file = self._find_template( - self.agent_templare_dir, 'prototyper-priming.txt') - - self.cpp_priming_filler_file = self._find_template( - template_dir, 'cpp-specific-priming-filler.txt') - self.problem_template_file = self._find_template(template_dir, - 'problem.txt') - self.solution_template_file = self._find_template(template_dir, - 'solution.txt') - self.context_template_file = self._find_template(template_dir, - 'context.txt') - - def build(self, - example_pair: list[list[str]], - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None, - tool_guides: str = '', - project_dir: str = '') -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it.""" - if not self.benchmark: - return self._prompt - priming = self._format_priming(self.benchmark) - priming = priming.replace('{PROJECT_DIR}', project_dir) - final_problem = self.format_problem(self.benchmark.function_signature) - final_problem += (f'You MUST call \n' - f'{self.benchmark.function_signature}\n' - f' in your solution!\n') - if project_context_content: - final_problem += self.format_context(project_context_content) - self._prepare_prompt(priming, final_problem, example_pair, - project_example_content) - self._prompt.append(tool_guides, True) - return self._prompt + """Builder specifically targeted C (and excluding C++).""" + + def __init__( + self, + model: models.LLM, + benchmark: Benchmark, + template_dir: str = DEFAULT_TEMPLATE_DIR, + initial: Any = None, + ): + super().__init__(model, benchmark, template_dir, initial) + self.agent_templare_dir = AGENT_TEMPLATE_DIR + + # Load templates. + if benchmark.is_c_target: + self.priming_template_file = self._find_template( + self.agent_templare_dir, "prototyper-priming.c.txt" + ) + elif benchmark.is_cpp_target: + self.priming_template_file = self._find_template( + self.agent_templare_dir, "prototyper-priming.cpp.txt" + ) + else: + self.problem_template_file = self._find_template( + self.agent_templare_dir, "prototyper-priming.txt" + ) + + self.cpp_priming_filler_file = self._find_template( + template_dir, "cpp-specific-priming-filler.txt" + ) + self.problem_template_file = self._find_template(template_dir, "problem.txt") + self.solution_template_file = self._find_template(template_dir, "solution.txt") + self.context_template_file = self._find_template(template_dir, "context.txt") + + def build( + self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + tool_guides: str = "", + project_dir: str = "", + ) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it.""" + if not self.benchmark: + return self._prompt + priming = self._format_priming(self.benchmark) + priming = priming.replace("{PROJECT_DIR}", project_dir) + final_problem = self.format_problem(self.benchmark.function_signature) + final_problem += ( + f"You MUST call \n" + f"{self.benchmark.function_signature}\n" + f" in your solution!\n" + ) + if project_context_content: + final_problem += self.format_context(project_context_content) + self._prepare_prompt( + priming, final_problem, example_pair, project_example_content + ) + self._prompt.append(tool_guides, True) + return self._prompt class PrototyperFixerTemplateBuilder(PrototyperTemplateBuilder): - """Builder specifically targeted C (and excluding C++).""" - - def __init__(self, - model: models.LLM, - benchmark: Benchmark, - build_result: BuildResult, - compile_log: str, - template_dir: str = DEFAULT_TEMPLATE_DIR, - initial: Any = None): - super().__init__(model, benchmark, template_dir, initial) - # Load templates. - self.priming_template_file = self._find_template(self.agent_templare_dir, - 'prototyper-fixing.txt') - self.build_result = build_result - self.compile_log = compile_log - - def build(self, - example_pair: list[list[str]], - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None, - tool_guides: str = '', - project_dir: str = '') -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it.""" - del (example_pair, project_example_content, project_context_content, - tool_guides) - if not self.benchmark: - return self._prompt - - if self.build_result.build_script_source: - build_text = (f'\n{self.build_result.build_script_source}\n' - '') - else: - build_text = 'Build script reuses `/src/build.bk.sh`.' - - prompt = self._get_template(self.priming_template_file) - prompt = prompt.replace('{FUZZ_TARGET_SOURCE}', - self.build_result.fuzz_target_source) - prompt = prompt.replace('{BUILD_TEXT}', build_text) - prompt = prompt.replace('{COMPILE_LOG}', self.compile_log) - prompt = prompt.replace('{FUNCTION_SIGNATURE}', - self.benchmark.function_signature) - prompt = prompt.replace('{PROJECT_DIR}', project_dir) - self._prompt.append(prompt) - - return self._prompt + """Builder specifically targeted C (and excluding C++).""" + + def __init__( + self, + model: models.LLM, + benchmark: Benchmark, + build_result: BuildResult, + compile_log: str, + template_dir: str = DEFAULT_TEMPLATE_DIR, + initial: Any = None, + ): + super().__init__(model, benchmark, template_dir, initial) + # Load templates. + self.priming_template_file = self._find_template( + self.agent_templare_dir, "prototyper-fixing.txt" + ) + self.build_result = build_result + self.compile_log = compile_log + + def build( + self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + tool_guides: str = "", + project_dir: str = "", + ) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it.""" + del ( + example_pair, + project_example_content, + project_context_content, + tool_guides, + ) + if not self.benchmark: + return self._prompt + + if self.build_result.build_script_source: + build_text = ( + f"\n{self.build_result.build_script_source}\n" + "" + ) + else: + build_text = "Build script reuses `/src/build.bk.sh`." + + prompt = self._get_template(self.priming_template_file) + prompt = prompt.replace( + "{FUZZ_TARGET_SOURCE}", self.build_result.fuzz_target_source + ) + prompt = prompt.replace("{BUILD_TEXT}", build_text) + prompt = prompt.replace("{COMPILE_LOG}", self.compile_log) + prompt = prompt.replace( + "{FUNCTION_SIGNATURE}", self.benchmark.function_signature + ) + prompt = prompt.replace("{PROJECT_DIR}", project_dir) + self._prompt.append(prompt) + + return self._prompt class CoverageAnalyzerTemplateBuilder(PrototyperTemplateBuilder): - """Builder specifically targeted C (and excluding C++).""" - - def __init__(self, - model: models.LLM, - benchmark: Benchmark, - run_result: RunResult, - template_dir: str = DEFAULT_TEMPLATE_DIR, - initial: Any = None): - super().__init__(model, benchmark, template_dir, initial) - # Load templates. - self.priming_template_file = self._find_template( - self.agent_templare_dir, 'coverage-analyzer-priming.txt') - self.run_result = run_result - - def build(self, - example_pair: list[list[str]], - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None, - tool_guides: str = '', - project_dir: str = '') -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it.""" - del (example_pair, project_example_content, project_context_content) - if not self.benchmark: - return self._prompt - - prompt = self._get_template(self.priming_template_file) - prompt = prompt.replace('{LANGUAGE}', self.benchmark.file_type.value) - prompt = prompt.replace('{PROJECT}', self.benchmark.project) - prompt = prompt.replace('{PROJECT_DIR}', project_dir) - prompt = prompt.replace('{PROJECT_LANGUAGE}', self.benchmark.language) - prompt = prompt.replace('{FUNCTION_SIGNATURE}', - self.benchmark.function_signature) - prompt = prompt.replace('{FUZZ_TARGET}', self.run_result.fuzz_target_source) - prompt = prompt.replace('{TOOL_GUIDES}', tool_guides) - prompt = prompt.replace('{FUZZING_LOG}', self.run_result.run_log) - - self._prompt.append(prompt) - return self._prompt + """Builder specifically targeted C (and excluding C++).""" + + def __init__( + self, + model: models.LLM, + benchmark: Benchmark, + run_result: RunResult, + template_dir: str = DEFAULT_TEMPLATE_DIR, + initial: Any = None, + ): + super().__init__(model, benchmark, template_dir, initial) + # Load templates. + self.priming_template_file = self._find_template( + self.agent_templare_dir, "coverage-analyzer-priming.txt" + ) + self.run_result = run_result + + def build( + self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + tool_guides: str = "", + project_dir: str = "", + ) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it.""" + del (example_pair, project_example_content, project_context_content) + if not self.benchmark: + return self._prompt + + prompt = self._get_template(self.priming_template_file) + prompt = prompt.replace("{LANGUAGE}", self.benchmark.file_type.value) + prompt = prompt.replace("{PROJECT}", self.benchmark.project) + prompt = prompt.replace("{PROJECT_DIR}", project_dir) + prompt = prompt.replace("{PROJECT_LANGUAGE}", self.benchmark.language) + prompt = prompt.replace( + "{FUNCTION_SIGNATURE}", self.benchmark.function_signature + ) + prompt = prompt.replace("{FUZZ_TARGET}", self.run_result.fuzz_target_source) + prompt = prompt.replace("{TOOL_GUIDES}", tool_guides) + prompt = prompt.replace("{FUZZING_LOG}", self.run_result.run_log) + + self._prompt.append(prompt) + return self._prompt class EnhancerTemplateBuilder(PrototyperTemplateBuilder): - """Builder specifically targeted C (and excluding C++).""" - - def __init__(self, - model: models.LLM, - benchmark: Benchmark, - build_result: BuildResult, - error_desc: str = '', - errors: Optional[list[str]] = None, - coverage_result: Optional[CoverageResult] = None, - template_dir: str = DEFAULT_TEMPLATE_DIR, - initial: Any = None): - super().__init__(model, benchmark, template_dir, initial) - # Load templates. - self.priming_template_file = self._find_template(self.agent_templare_dir, - 'enhancer-priming.txt') - self.build_result = build_result - self.error_desc = error_desc - self.errors = errors - self.coverage_result = coverage_result - - def build(self, - example_pair: list[list[str]], - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None, - tool_guides: str = '', - project_dir: str = '') -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it.""" - del (example_pair, project_example_content, project_context_content) - if not self.benchmark: - return self._prompt - - priming = self._get_template(self.priming_template_file) - priming = priming.replace('{LANGUAGE}', self.benchmark.file_type.value) - priming = priming.replace('{FUNCTION_SIGNATURE}', - self.benchmark.function_signature) - priming = priming.replace('{PROJECT_DIR}', project_dir) - priming = priming.replace('{TOOL_GUIDES}', tool_guides) - if self.build_result.build_script_source: - build_text = (f'\n{self.build_result.build_script_source}\n' - '') - else: - build_text = 'Build script reuses `/src/build.bk.sh`.' - priming = priming.replace('{BUILD_TEXT}', build_text) - priming_weight = self._model.estimate_token_num(priming) - # TODO(dongge): Refine this logic. - if self.error_desc and self.errors: - error_desc = self.error_desc - errors = self.errors - elif self.coverage_result: - error_desc = self.coverage_result.insight - errors = self.coverage_result.suggestions.splitlines() - else: - error_desc = '' - errors = [] - problem = self._format_fixer_problem(self.build_result.fuzz_target_source, - error_desc, errors, priming_weight, '', - '') - - self._prepare_prompt(priming, problem) - return self._prompt + """Builder specifically targeted C (and excluding C++).""" + + def __init__( + self, + model: models.LLM, + benchmark: Benchmark, + build_result: BuildResult, + error_desc: str = "", + errors: Optional[list[str]] = None, + coverage_result: Optional[CoverageResult] = None, + template_dir: str = DEFAULT_TEMPLATE_DIR, + initial: Any = None, + ): + super().__init__(model, benchmark, template_dir, initial) + # Load templates. + self.priming_template_file = self._find_template( + self.agent_templare_dir, "enhancer-priming.txt" + ) + self.build_result = build_result + self.error_desc = error_desc + self.errors = errors + self.coverage_result = coverage_result + + def build( + self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + tool_guides: str = "", + project_dir: str = "", + ) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it.""" + del (example_pair, project_example_content, project_context_content) + if not self.benchmark: + return self._prompt + + priming = self._get_template(self.priming_template_file) + priming = priming.replace("{LANGUAGE}", self.benchmark.file_type.value) + priming = priming.replace( + "{FUNCTION_SIGNATURE}", self.benchmark.function_signature + ) + priming = priming.replace("{PROJECT_DIR}", project_dir) + priming = priming.replace("{TOOL_GUIDES}", tool_guides) + if self.build_result.build_script_source: + build_text = ( + f"\n{self.build_result.build_script_source}\n" + "" + ) + else: + build_text = "Build script reuses `/src/build.bk.sh`." + priming = priming.replace("{BUILD_TEXT}", build_text) + priming_weight = self._model.estimate_token_num(priming) + # TODO(dongge): Refine this logic. + if self.error_desc and self.errors: + error_desc = self.error_desc + errors = self.errors + elif self.coverage_result: + error_desc = self.coverage_result.insight + errors = self.coverage_result.suggestions.splitlines() + else: + error_desc = "" + errors = [] + problem = self._format_fixer_problem( + self.build_result.fuzz_target_source, + error_desc, + errors, + priming_weight, + "", + "", + ) + + self._prepare_prompt(priming, problem) + return self._prompt class CoverageEnhancerTemplateBuilder(PrototyperTemplateBuilder): - """Builder specifically targeted C (and excluding C++).""" - - def __init__(self, - model: models.LLM, - benchmark: Benchmark, - build_result: BuildResult, - coverage_result: CoverageResult, - template_dir: str = DEFAULT_TEMPLATE_DIR, - initial: Any = None): - super().__init__(model, benchmark, template_dir, initial) - # Load templates. - self.priming_template_file = self._find_template( - self.agent_templare_dir, 'enhancer-coverage-priming.txt') - self.build_result = build_result - self.coverage_result = coverage_result - - def build(self, - example_pair: list[list[str]], - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None, - tool_guides: str = '', - project_dir: str = '') -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it.""" - del (example_pair, project_example_content, project_context_content) - if not self.benchmark: - return self._prompt - - prompt = self._get_template(self.priming_template_file) - prompt = prompt.replace('{TOOL_GUIDES}', tool_guides) - prompt = prompt.replace('{LANGUAGE}', self.benchmark.file_type.value) - prompt = prompt.replace('{PROJECT}', self.benchmark.project) - prompt = prompt.replace('{PROJECT_DIR}', project_dir) - prompt = prompt.replace('{PROJECT_LANGUAGE}', self.benchmark.language) - prompt = prompt.replace('{FUZZ_TARGET}', - self.build_result.fuzz_target_source) - prompt = prompt.replace('{FUNCTION_SIGNATURE}', - self.benchmark.function_signature) - - if self.build_result.build_script_source: - build_text = (f'\n{self.build_result.build_script_source}\n' - '') - else: - build_text = 'Build script reuses `/src/build.bk.sh`.' - prompt = prompt.replace('{BUILD_TEXT}', build_text) - prompt = prompt.replace('{INSIGHTS}', self.coverage_result.insight) - prompt = prompt.replace('{SUGGESTIONS}', self.coverage_result.suggestions) - self._prompt.append(prompt) - return self._prompt + """Builder specifically targeted C (and excluding C++).""" + + def __init__( + self, + model: models.LLM, + benchmark: Benchmark, + build_result: BuildResult, + coverage_result: CoverageResult, + template_dir: str = DEFAULT_TEMPLATE_DIR, + initial: Any = None, + ): + super().__init__(model, benchmark, template_dir, initial) + # Load templates. + self.priming_template_file = self._find_template( + self.agent_templare_dir, "enhancer-coverage-priming.txt" + ) + self.build_result = build_result + self.coverage_result = coverage_result + + def build( + self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + tool_guides: str = "", + project_dir: str = "", + ) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it.""" + del (example_pair, project_example_content, project_context_content) + if not self.benchmark: + return self._prompt + + prompt = self._get_template(self.priming_template_file) + prompt = prompt.replace("{TOOL_GUIDES}", tool_guides) + prompt = prompt.replace("{LANGUAGE}", self.benchmark.file_type.value) + prompt = prompt.replace("{PROJECT}", self.benchmark.project) + prompt = prompt.replace("{PROJECT_DIR}", project_dir) + prompt = prompt.replace("{PROJECT_LANGUAGE}", self.benchmark.language) + prompt = prompt.replace("{FUZZ_TARGET}", self.build_result.fuzz_target_source) + prompt = prompt.replace( + "{FUNCTION_SIGNATURE}", self.benchmark.function_signature + ) + + if self.build_result.build_script_source: + build_text = ( + f"\n{self.build_result.build_script_source}\n" + "" + ) + else: + build_text = "Build script reuses `/src/build.bk.sh`." + prompt = prompt.replace("{BUILD_TEXT}", build_text) + prompt = prompt.replace("{INSIGHTS}", self.coverage_result.insight) + prompt = prompt.replace("{SUGGESTIONS}", self.coverage_result.suggestions) + self._prompt.append(prompt) + return self._prompt class FunctionAnalyzerTemplateBuilder(PrototyperTemplateBuilder): - """ Builder for function analyzer.""" + """Builder for function analyzer.""" - def __init__(self, - model: models.LLM, - benchmark: Benchmark, - template_dir: str = DEFAULT_TEMPLATE_DIR, - initial: Any = None): - super().__init__(model, benchmark, template_dir, initial) + def __init__( + self, + model: models.LLM, + benchmark: Benchmark, + template_dir: str = DEFAULT_TEMPLATE_DIR, + initial: Any = None, + ): + super().__init__(model, benchmark, template_dir, initial) - # Load templates. - self.function_analyzer_instruction_template_file = self._find_template( - self.agent_templare_dir, 'function-analyzer-instruction.txt') - self.function_analyzer_prompt_template_file = self._find_template( - self.agent_templare_dir, 'function-analyzer-priming.txt') + # Load templates. + self.function_analyzer_instruction_template_file = self._find_template( + self.agent_templare_dir, "function-analyzer-instruction.txt" + ) + self.function_analyzer_prompt_template_file = self._find_template( + self.agent_templare_dir, "function-analyzer-priming.txt" + ) - def build_instruction(self) -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it.""" - if not self.benchmark: - return self._prompt + def build_instruction(self) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it.""" + if not self.benchmark: + return self._prompt - prompt = self._get_template( - self.function_analyzer_instruction_template_file) + prompt = self._get_template(self.function_analyzer_instruction_template_file) - self._prompt.append(prompt) + self._prompt.append(prompt) - return self._prompt + return self._prompt - def build_prompt(self, project_name, function_signature) -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it.""" - if not self.benchmark: - return self._prompt + def build_prompt(self, project_name, function_signature) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it.""" + if not self.benchmark: + return self._prompt - prompt = self._get_template(self.function_analyzer_prompt_template_file) + prompt = self._get_template(self.function_analyzer_prompt_template_file) - prompt = prompt.replace('{PROJECT_NAME}', project_name) - prompt = prompt.replace('{FUNCTION_SIGNATURE}', function_signature) + prompt = prompt.replace("{PROJECT_NAME}", project_name) + prompt = prompt.replace("{FUNCTION_SIGNATURE}", function_signature) - self._prompt.append(prompt) + self._prompt.append(prompt) - return self._prompt + return self._prompt - def build(self, - example_pair: Optional[list[list[str]]] = None, - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None, - tool_guides: str = '', - project_dir: str = '', - project_name: str = '', - function_signature: str = '') -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it.""" - return self.build_prompt(project_name, function_signature) + def build( + self, + example_pair: Optional[list[list[str]]] = None, + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + tool_guides: str = "", + project_dir: str = "", + project_name: str = "", + function_signature: str = "", + ) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it.""" + return self.build_prompt(project_name, function_signature) class CrashAnalyzerTemplateBuilder(DefaultTemplateBuilder): - """Builder for C/C++.""" - - def __init__(self, - model: models.LLM, - benchmark: Optional[Benchmark] = None, - template_dir: str = DEFAULT_TEMPLATE_DIR, - initial: Any = None): - super().__init__(model, benchmark, template_dir, initial) - self.agent_templare_dir = AGENT_TEMPLATE_DIR - - self.crash_analyzer_priming_template_file = self._find_template( - self.agent_templare_dir, 'crash_analyzer-priming.txt') - - def _prepare_prompt( - self, - priming: str, - final_problem: str, - example_pair: Optional[list[list[str]]] = None, - project_example_content: Optional[list[list[str]]] = None): - """Constructs a prompt using the parameters and saves it.""" - self._prompt.add_priming(priming) - - def build_crash_analyzer_prompt(self, benchmark: Benchmark, driver_code: str, - crash_info: str, - crash_func: dict) -> prompts.Prompt: - """Prepares the crash analyzer prompt.""" - all_func_code = [] - for func_name, line_number in crash_func.items(): - if func_name == 'LLVMFuzzerTestOneInput': - driver_code = self._slice_driver_code(benchmark.project, driver_code, - line_number) - else: - func_code = self._slice_func_code(benchmark.project, func_name, - line_number) - all_func_code.append(func_code) - - with open(self.crash_analyzer_priming_template_file) as f: - priming = f.read().strip() - priming = priming.replace('{CRASH_REPORT}', crash_info.strip())\ - .replace('{DRIVER_CODE}', driver_code.strip()) - - priming_prompt = self._prompt.create_prompt_piece(priming, 'user') - template_piece = self._prompt.create_prompt_piece('{PROJECT_FUNCTION_CODE}', - 'user') - - priming_weight = self._model.estimate_token_num(priming_prompt) - template_weight = self._model.estimate_token_num(template_piece) - - prompt_size = priming_weight - template_weight - # Add extra 20-tokens redundancy - prompt_size += 20 - - # Add function code one by one until we reach the maximum prompt size - selected_func_code = [] - for func_code in all_func_code: - func_code_prompt = self._prompt.create_prompt_piece(func_code, 'user') - func_code_token_num = self._model.estimate_token_num(func_code_prompt) - if prompt_size + func_code_token_num >= self._model.context_window: - # The estimation is inaccurate, if an example's size equals to - # the limit, it's safer to not include the example. - logger.warning('Breaking because adding this function code \ - would exceed context window') - break - prompt_size += func_code_token_num - selected_func_code.append(func_code) - - project_function_code = '\n'.join(selected_func_code) - if project_function_code.strip(): - priming.replace('{PROJECT_FUNCTION_CODE}', project_function_code.strip()) - else: - logger.warning( - 'Empty project function code in triage prompt for project: %s, \ - function name: %s', benchmark.project, benchmark.function_name) - priming.replace('{PROJECT_FUNCTION_CODE}', \ - 'No relevant project function code') - - self._prepare_prompt(priming, '') - return self._prompt + """Builder for C/C++.""" + + def __init__( + self, + model: models.LLM, + benchmark: Optional[Benchmark] = None, + template_dir: str = DEFAULT_TEMPLATE_DIR, + initial: Any = None, + ): + super().__init__(model, benchmark, template_dir, initial) + self.agent_templare_dir = AGENT_TEMPLATE_DIR + + self.crash_analyzer_priming_template_file = self._find_template( + self.agent_templare_dir, "crash_analyzer-priming.txt" + ) + + def _prepare_prompt( + self, + priming: str, + final_problem: str, + example_pair: Optional[list[list[str]]] = None, + project_example_content: Optional[list[list[str]]] = None, + ): + """Constructs a prompt using the parameters and saves it.""" + self._prompt.add_priming(priming) + + def build_crash_analyzer_prompt( + self, benchmark: Benchmark, driver_code: str, crash_info: str, crash_func: dict + ) -> prompts.Prompt: + """Prepares the crash analyzer prompt.""" + all_func_code = [] + for func_name, line_number in crash_func.items(): + if func_name == "LLVMFuzzerTestOneInput": + driver_code = self._slice_driver_code( + benchmark.project, driver_code, line_number + ) + else: + func_code = self._slice_func_code( + benchmark.project, func_name, line_number + ) + all_func_code.append(func_code) + + with open(self.crash_analyzer_priming_template_file) as f: + priming = f.read().strip() + priming = priming.replace("{CRASH_REPORT}", crash_info.strip()).replace( + "{DRIVER_CODE}", driver_code.strip() + ) + + priming_prompt = self._prompt.create_prompt_piece(priming, "user") + template_piece = self._prompt.create_prompt_piece( + "{PROJECT_FUNCTION_CODE}", "user" + ) + + priming_weight = self._model.estimate_token_num(priming_prompt) + template_weight = self._model.estimate_token_num(template_piece) + + prompt_size = priming_weight - template_weight + # Add extra 20-tokens redundancy + prompt_size += 20 + + # Add function code one by one until we reach the maximum prompt size + selected_func_code = [] + for func_code in all_func_code: + func_code_prompt = self._prompt.create_prompt_piece(func_code, "user") + func_code_token_num = self._model.estimate_token_num(func_code_prompt) + if prompt_size + func_code_token_num >= self._model.context_window: + # The estimation is inaccurate, if an example's size equals to + # the limit, it's safer to not include the example. + logger.warning( + "Breaking because adding this function code \ + would exceed context window" + ) + break + prompt_size += func_code_token_num + selected_func_code.append(func_code) + + project_function_code = "\n".join(selected_func_code) + if project_function_code.strip(): + priming.replace("{PROJECT_FUNCTION_CODE}", project_function_code.strip()) + else: + logger.warning( + "Empty project function code in triage prompt for project: %s, \ + function name: %s", + benchmark.project, + benchmark.function_name, + ) + priming.replace( + "{PROJECT_FUNCTION_CODE}", "No relevant project function code" + ) + + self._prepare_prompt(priming, "") + return self._prompt class DefaultJvmTemplateBuilder(PromptBuilder): - """Default builder for JVM projects.""" - - def __init__(self, - model: models.LLM, - benchmark: Benchmark, - template_dir: str = DEFAULT_TEMPLATE_DIR): - super().__init__(model) - self._template_dir = template_dir - self.benchmark = benchmark - self.project_url = oss_fuzz_checkout.get_project_repository( - self.benchmark.project) - - # Retrieve additional properties for the target method - temp_properties = introspector.query_introspector_function_props( - self.benchmark.project, self.benchmark.function_signature) - self.exceptions = temp_properties.get('exceptions', []) - self.is_jvm_static = temp_properties.get('is-jvm-static', False) - self.need_close = temp_properties.get('need_close', False) - - # Load templates. - self.priming_template_file = self._find_template(template_dir, - 'jvm_priming.txt') - self.data_filler_template_file = self._find_template( - template_dir, 'jvm_specific_data_filler.txt') - self.requirement_template_file = self._find_template( - template_dir, 'jvm_requirement.txt') - self.problem_template_file = self._find_template(template_dir, - 'jvm_problem.txt') - self.target_template_file = self._find_template(template_dir, - 'jvm_target.txt') - self.arg_description_template_file = self._find_template( - template_dir, 'jvm_arg_description.txt') - self.import_template_file = self._find_template(template_dir, - 'jvm_import_mapping.txt') - - def _find_template(self, template_dir: str, template_name: str) -> str: - """Finds template file based on |template_dir|.""" - preferred_template = os.path.join(template_dir, template_name) - # Use the preferred template if it exists. - if os.path.isfile(preferred_template): - return preferred_template - # Fall back to the default template. - default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) - return default_template - - def _get_template(self, template_file: str) -> str: - """Reads the template for prompts.""" - with open(template_file) as file: - return file.read() - - def _format_exceptions(self) -> str: - """Formats the exception thrown from this method or constructor.""" - if self.exceptions: - exception_str_list = [ - f'{exp}' for exp in self.exceptions - ] - return '\n' + '\n'.join( - exception_str_list) + '\n' - - return '' - - def _format_import_mapping(self, full_class_name: str) -> str: - """Formats the import mapping row on the prompt template.""" - # full_class_name format: .$ - # For example, the inner class Inner in class Test of package - # a.b.c will have a full_class_name of a.b.c.Test$Inner - class_name = full_class_name.rsplit('.')[-1] - full_class_name = full_class_name.split('$')[0] - - mapping = self._get_template(self.import_template_file) - mapping = mapping.replace('{CLASS_NAME}', class_name) - mapping = mapping.replace('{FULL_CLASS_NAME}', full_class_name) - - return mapping - - def _format_generic_argument(self, arg_type: str) -> Tuple[str, str]: - """Formats generic argument description.""" - generic_types = arg_type.split('<', 1)[1][:-1].split(',') - - new_types = [] - generic_desc = [] - for generic_type in generic_types: - if generic_type.endswith(('T', 'K', 'V')): - generic_type = 'java.lang.Object' - - new_types.append(generic_type) - - desc = (f'For generic type of {generic_type}, you MUST use ' - '{RANDOM_METHODS} to generate the needed variable.') - - method_str = self._get_methods_for_simple_type(generic_type) - if method_str: - desc = desc.replace('{RANDOM_METHODS}', method_str) - else: - desc = desc.replace('{RANDOM_METHODS}', - 'correct constructors or static methods') - - generic_desc.append(desc) - - if not generic_desc: - return '', '' - - generic_types = ','.join(new_types) - return f' with generic types of {generic_types}', '\n'.join(generic_desc) - - def _format_argument(self, count: int, arg_type: str) -> str: - """Formats general argument description.""" - method_str = self._get_methods_for_simple_type(arg_type) - - # Simple arguments - argument = self._get_template(self.arg_description_template_file) - argument = argument.replace('{ARG_COUNT}', str(count)) - - if method_str: - type_str = '{SIMPLE_TYPE} variable.' - desc_str = f'You must use {method_str} to generate {{ARRAY_OR_NOT}}.' - else: - type_str = '{SIMPLE_TYPE} instance {GENERIC_TYPE}.' - desc_str = ('Please generate {ARRAY_OR_NOT}. You should use constructors ' - 'or static methods for the generation.\nPlease also insert ' - 'random data into the created instance.') - - argument = argument.replace('{TYPE}', type_str) - argument = argument.replace('{GENERAL_DESC}', desc_str) - - # Array handling - if '[]' in arg_type: - arg_type_no_array = arg_type.replace('[]', '').split('<')[0] - argument = argument.replace('{SIMPLE_TYPE}', - f'an array of {arg_type_no_array} ') - argument = argument.replace( - '{ARRAY_OR_NOT}', - (f'multiple {arg_type_no_array} objects and initialise an array ' - 'of {arg_type_no_array} with the generated objects.')) - else: - argument = argument.replace('{SIMPLE_TYPE}', f'a {arg_type}') - argument = argument.replace('{ARRAY_OR_NOT}', 'the needed parameter.') - - # Generic type handling - generic_type = '' - generic_desc = '' - if self._has_generic(arg_type): - generic_type, generic_desc = self._format_generic_argument(arg_type) - - argument = argument.replace('{GENERIC_TYPE}', generic_type) - argument = argument.replace('{GENERIC_DESC}', generic_desc) - - return argument - - def _format_requirement(self, signature: str) -> str: - """Formats a requirement based on the prompt template.""" - classes = [] - - class_name = signature[1:].split(']')[0] - if self._need_import(class_name): - classes.append(class_name) - - for arg_dict in self.benchmark.params: - arg_type = arg_dict['type'].split('<')[0] - if self._need_import(arg_type): - classes.append(arg_type) - - classes = list(set(classes)) - mappings = [self._format_import_mapping(type) for type in classes] - - requirement = self._get_template(self.requirement_template_file) - requirement = requirement.replace('{IMPORT_MAPPINGS}', '\n'.join(mappings)) - - harness_name = os.path.basename(self.benchmark.target_path).replace( - '.java', '') - if harness_name: - requirement = requirement.replace('{HARNESS_NAME}', harness_name) - else: - requirement = requirement.replace('{HARNESS_NAME}', 'Fuzz') - - class_name = self.benchmark.function_name[1:].split(']')[0] - if '' in self.benchmark.function_name: - creation = (f'The target method is a constructor of {class_name} ' - 'invoke it directly with new keyword.') - elif self.is_jvm_static: - creation = ('The target method is a static method, invoke it directly ' - 'without creating an object.') - else: - creation = (f'You must create the {class_name} object before calling ' - 'the target method.') - requirement = requirement.replace('{STATIC_OR_INSTANCE}', creation) - - close_statement = '' - if self.need_close: - close_statement = ( - 'You MUST invoke the close method of the ' - f'{class_name} objects in the finally block after the target method ' - 'is invoked.') - - requirement = requirement.replace('{NEED_CLOSE}', close_statement) - - return requirement - - def _format_data_filler(self) -> str: - """Formats a data_filler based on the prompt template.""" - data_filler = self._get_template(self.data_filler_template_file) - return data_filler - - def _format_arguments(self) -> str: - """Formats a list of argument descriptions.""" - argument_descriptions = [] - - for count, function_arg in enumerate(self.benchmark.params): - arg_type = function_arg['type'] - argument = self._format_argument(count, arg_type) - argument_descriptions.append(argument) - - return '' + '\n'.join(argument_descriptions) + '' - - def _format_constructors(self) -> str: - """Formats a list of functions / constructors to create the object for - invoking the target method.""" - if self.is_jvm_static: - return '' - - constructors = [] - ctrs = introspector.query_introspector_matching_function_constructor_type( - self.benchmark.project, self.benchmark.return_type, False) - for ctr in ctrs: - constructor_sig = ctr.get('function_signature', '') - if constructor_sig: - constructors.append(f'{constructor_sig}') - exceptions = introspector.query_introspector_function_props( - ctr.get('project', ''), constructor_sig).get('exceptions', []) - self.exceptions.extend(exceptions) - - if constructors: - ctr_str = '\n'.join(constructors) - return f'{ctr_str}' - - functions = [] - funcs = introspector.query_introspector_matching_function_constructor_type( - self.benchmark.project, self.benchmark.return_type, True) - for func in funcs: - is_static = func.get('is_static', False) - function_sig = func.get('function_signature', '') - if not function_sig: - continue - exceptions = introspector.query_introspector_function_props( - func.get('project', ''), function_sig).get('exceptions', []) - self.exceptions.extend(exceptions) - if is_static: - functions.append(f'{function_sig}') - else: - function_class = function_sig[1:].split(']')[0] - function_str = f'{function_sig}' - function_str = function_str + ( - 'You MUST create an ' - f'{function_class} object before calling this constructing method.' - '') - function_str = f'{function_str}' - functions.append(function_str) - if functions: - func_str = '\n'.join(functions) - return f'{func_str}' - - return '' - - def _format_source_reference(self, signature: str) -> Tuple[str, str]: - """Formats the source code reference for this target.""" - # Query for source code of the target method - source_code = introspector.query_introspector_function_source( - self.benchmark.project, signature) - - # Query for source code of target method callsites - xref_source_list = [] - for xref_source in introspector.query_introspector_cross_references( - self.benchmark.project, signature): - if xref_source: - xref_source_list.append(xref_source) - - return source_code, '\n'.join(xref_source_list) - - def _format_problem(self, signature: str) -> str: - """Formats a problem based on the prompt template.""" - is_constructor = bool('' in signature) - - problem = self._get_template(self.problem_template_file) - problem = problem.replace('{TARGET}', - self._get_template(self.target_template_file)) - problem = problem.replace('{SIGNATURE}', signature) - problem = problem.replace('{CLASS}', signature.split('].')[0][1:]) - problem = problem.replace('{REQUIREMENTS}', - self._format_requirement(signature)) - problem = problem.replace('{ARGUMENTS}', self._format_arguments()) - problem = problem.replace('{CONSTRUCTORS}', self._format_constructors()) - problem = problem.replace('{EXCEPTIONS}', self._format_exceptions()) - - self_source, cross_source = self._format_source_reference(signature) - problem = problem.replace('{SELF_SOURCE}', self_source) - problem = problem.replace('{CROSS_SOURCE}', cross_source) - problem = problem.replace("{PROJECT_NAME}", self.benchmark.project) - problem = problem.replace("{PROJECT_URL}", self.project_url) - problem = problem.replace('{DATA_MAPPING}', self._format_data_filler()) - - if is_constructor: - problem = problem.replace('{METHOD_OR_CONSTRUCTOR}', 'constructor') - else: - problem = problem.replace('{METHOD_OR_CONSTRUCTOR}', 'method') - - return problem - - def _prepare_prompt(self, prompt_str: str): - """Constructs a prompt using the parameters and saves it.""" - self._prompt.add_priming(self._get_template(self.priming_template_file)) - self._prompt.add_problem(prompt_str) - - def _has_generic(self, arg: str) -> bool: - """Determine if the argument type contains generic type.""" - return ('<' in arg and not arg.startswith('<') and arg.endswith('>') and - 'java.lang.Class' not in arg and 'java.lang.Object' not in arg) - - def _need_import(self, class_name: str) -> bool: - """Determine if the class with class_name needed to be imported.""" - return '.' in class_name and not class_name.startswith('java.lang.') - - def _get_methods_for_simple_type(self, simple_type: str) -> str: - """Retrieve string descrbing how to generate random data of - the provided simple type.""" - simple_type_mapping = { - 'int': [ - 'FuzzedDataProvider::consumeInt()', - 'FuzzedDataProvider::consumeInt(int, int)' - ], - 'boolean': [ - 'FuzzedDataProvider::consumeBoolean()', - 'FuzzedDataProvider::pickValue(boolean[])' - ], - 'byte': [ - 'FuzzedDataProvider::consumeByte()', - 'FuzzedDataProvider::consumeByte(byte,byte)' - ], - 'byte[]': [ - 'FuzzedDataProvider::consumeBytes(int)', - 'FuzzedDataProvider::consumeRemainingAsBytes()' - ], - 'short': [ - 'FuzzedDataProvider::consumeShort()', - 'FuzzedDataProvider::consumeShort(short,short)' - ], - 'long': [ - 'FuzzedDataProvider::consumeLong()', - 'FuzzedDataProvider::consumeLong(long, long)' - ], - 'float': [ - 'FuzzedDataProvider::consumeFloat()', - 'FuzzedDataProvider::consumeRegularFloat()', - 'FuzzedDataProvider::consumeRegularFloat(float,float)', - 'FuzzedDataProvider::consumeProbabilityFloat()' - ], - 'double': [ - 'FuzzedDataProvider::consumeDouble()', - 'FuzzedDataProvider::consumeRegularDouble()', - 'FuzzedDataProvider::consumeRegularDouble(double, double)', - 'FuzzedDataProvider::consumeProbabilityDouble()' - ], - 'char': [ - 'FuzzedDataProvider::consumeChar()', - 'FuzzedDataProvider::consumeCharNoSurrogates()', - 'FuzzedDataProvider::consumeChar(char, char)' - ], - 'string': [ - 'FuzzedDataProvider::consumeString(int)', - 'FuzzedDataProvider::consumeAsciiString(int)', - 'FuzzedDataProvider::consumeRemainingAsString()', - 'FuzzedDataProvider::consumeRemainingAsAsciiString()' - ], - 'class': ['Object::getClass()'] - } - - # Extract simple type - simple_type = simple_type.replace('java.lang.Integer', 'int') - simple_type = simple_type.replace('java.lang.Character', 'char') - simple_type = simple_type.split('.')[-1].lower() - - if simple_type in simple_type_mapping: - return ' or '.join(simple_type_mapping[simple_type]) - - # If the type is not found, try if it is an array of any above types - simple_type = simple_type.replace('[]', '') - return ' or '.join(simple_type_mapping.get(simple_type, [])) - - def build(self, - example_pair: list[list[str]], - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None) -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it. - Ignore target_file_type, project_example_content - and project_context_content parameters. - """ - final_problem = self._format_problem(self.benchmark.function_signature) - self._prepare_prompt(final_problem) - return self._prompt - - def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str, - error_desc: Optional[str], - errors: list[str]) -> prompts.Prompt: - """Builds a fixer prompt.""" - # Do nothing for jvm project now. - return self._prompt - - def build_triager_prompt(self, benchmark: Benchmark, driver_code: str, - crash_info: str, crash_func: dict) -> prompts.Prompt: - """Builds a triager prompt.""" - # Do nothing for jvm project now. - return self._prompt - - def post_process_generated_code(self, generated_code: str) -> str: - """Allows prompt builder to adjust the generated code.""" - # From observation, the LLM model keeps using wrong method calls including - # FuzzedDataProvider::consumeObject() or FuzzedDataProvider::getObject() or - # FuzzedDataProvider::consumeInt(int) to generate random Object / Integer - # instance. These methods are not valid in FuzzedDataProvider. - - # The fixes here change the calling of data.consumeObject() and - # data.getObject() to data.consumeString(int) - generated_code = generated_code.replace( - 'data.consumeObject()', 'data.consumeString(data.remainingBytes()/2)') - generated_code = generated_code.replace( - 'data.getObject()', 'data.consumeString(data.remainingBytes()/2)') - - # The fixes here change the calling of data.consumeInt(int) to - # data.consumeInt(0, int). For example, data.consumeInt(12345) will - # be replaced by data.consumeInt(0, 12345) - for wrong_method_call in re.findall(r'(data\.consumeInt\(([0-9]+)\))', - generated_code): - old_method_call = wrong_method_call[0] - new_method_call = f'data.consumeInt(0, {wrong_method_call[1]})' - generated_code = generated_code.replace(old_method_call, new_method_call) - - return generated_code + """Default builder for JVM projects.""" + + def __init__( + self, + model: models.LLM, + benchmark: Benchmark, + template_dir: str = DEFAULT_TEMPLATE_DIR, + ): + super().__init__(model) + self._template_dir = template_dir + self.benchmark = benchmark + self.project_url = oss_fuzz_checkout.get_project_repository( + self.benchmark.project + ) + + # Retrieve additional properties for the target method + temp_properties = introspector.query_introspector_function_props( + self.benchmark.project, self.benchmark.function_signature + ) + self.exceptions = temp_properties.get("exceptions", []) + self.is_jvm_static = temp_properties.get("is-jvm-static", False) + self.need_close = temp_properties.get("need_close", False) + + # Load templates. + self.priming_template_file = self._find_template( + template_dir, "jvm_priming.txt" + ) + self.data_filler_template_file = self._find_template( + template_dir, "jvm_specific_data_filler.txt" + ) + self.requirement_template_file = self._find_template( + template_dir, "jvm_requirement.txt" + ) + self.problem_template_file = self._find_template( + template_dir, "jvm_problem.txt" + ) + self.target_template_file = self._find_template(template_dir, "jvm_target.txt") + self.arg_description_template_file = self._find_template( + template_dir, "jvm_arg_description.txt" + ) + self.import_template_file = self._find_template( + template_dir, "jvm_import_mapping.txt" + ) + + def _find_template(self, template_dir: str, template_name: str) -> str: + """Finds template file based on |template_dir|.""" + preferred_template = os.path.join(template_dir, template_name) + # Use the preferred template if it exists. + if os.path.isfile(preferred_template): + return preferred_template + # Fall back to the default template. + default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) + return default_template + + def _get_template(self, template_file: str) -> str: + """Reads the template for prompts.""" + with open(template_file) as file: + return file.read() + + def _format_exceptions(self) -> str: + """Formats the exception thrown from this method or constructor.""" + if self.exceptions: + exception_str_list = [ + f"{exp}" for exp in self.exceptions + ] + return "\n" + "\n".join(exception_str_list) + "\n" + + return "" + + def _format_import_mapping(self, full_class_name: str) -> str: + """Formats the import mapping row on the prompt template.""" + # full_class_name format: .$ + # For example, the inner class Inner in class Test of package + # a.b.c will have a full_class_name of a.b.c.Test$Inner + class_name = full_class_name.rsplit(".")[-1] + full_class_name = full_class_name.split("$")[0] + + mapping = self._get_template(self.import_template_file) + mapping = mapping.replace("{CLASS_NAME}", class_name) + mapping = mapping.replace("{FULL_CLASS_NAME}", full_class_name) + + return mapping + + def _format_generic_argument(self, arg_type: str) -> Tuple[str, str]: + """Formats generic argument description.""" + generic_types = arg_type.split("<", 1)[1][:-1].split(",") + + new_types = [] + generic_desc = [] + for generic_type in generic_types: + if generic_type.endswith(("T", "K", "V")): + generic_type = "java.lang.Object" + + new_types.append(generic_type) + + desc = ( + f"For generic type of {generic_type}, you MUST use " + "{RANDOM_METHODS} to generate the needed variable." + ) + + method_str = self._get_methods_for_simple_type(generic_type) + if method_str: + desc = desc.replace("{RANDOM_METHODS}", method_str) + else: + desc = desc.replace( + "{RANDOM_METHODS}", "correct constructors or static methods" + ) + + generic_desc.append(desc) + + if not generic_desc: + return "", "" + + generic_types = ",".join(new_types) + return f" with generic types of {generic_types}", "\n".join(generic_desc) + + def _format_argument(self, count: int, arg_type: str) -> str: + """Formats general argument description.""" + method_str = self._get_methods_for_simple_type(arg_type) + + # Simple arguments + argument = self._get_template(self.arg_description_template_file) + argument = argument.replace("{ARG_COUNT}", str(count)) + + if method_str: + type_str = "{SIMPLE_TYPE} variable." + desc_str = f"You must use {method_str} to generate {{ARRAY_OR_NOT}}." + else: + type_str = "{SIMPLE_TYPE} instance {GENERIC_TYPE}." + desc_str = ( + "Please generate {ARRAY_OR_NOT}. You should use constructors " + "or static methods for the generation.\nPlease also insert " + "random data into the created instance." + ) + + argument = argument.replace("{TYPE}", type_str) + argument = argument.replace("{GENERAL_DESC}", desc_str) + + # Array handling + if "[]" in arg_type: + arg_type_no_array = arg_type.replace("[]", "").split("<")[0] + argument = argument.replace( + "{SIMPLE_TYPE}", f"an array of {arg_type_no_array} " + ) + argument = argument.replace( + "{ARRAY_OR_NOT}", + ( + f"multiple {arg_type_no_array} objects and initialise an array " + "of {arg_type_no_array} with the generated objects." + ), + ) + else: + argument = argument.replace("{SIMPLE_TYPE}", f"a {arg_type}") + argument = argument.replace("{ARRAY_OR_NOT}", "the needed parameter.") + + # Generic type handling + generic_type = "" + generic_desc = "" + if self._has_generic(arg_type): + generic_type, generic_desc = self._format_generic_argument(arg_type) + + argument = argument.replace("{GENERIC_TYPE}", generic_type) + argument = argument.replace("{GENERIC_DESC}", generic_desc) + + return argument + + def _format_requirement(self, signature: str) -> str: + """Formats a requirement based on the prompt template.""" + classes = [] + + class_name = signature[1:].split("]")[0] + if self._need_import(class_name): + classes.append(class_name) + + for arg_dict in self.benchmark.params: + arg_type = arg_dict["type"].split("<")[0] + if self._need_import(arg_type): + classes.append(arg_type) + + classes = list(set(classes)) + mappings = [self._format_import_mapping(type) for type in classes] + + requirement = self._get_template(self.requirement_template_file) + requirement = requirement.replace("{IMPORT_MAPPINGS}", "\n".join(mappings)) + + harness_name = os.path.basename(self.benchmark.target_path).replace(".java", "") + if harness_name: + requirement = requirement.replace("{HARNESS_NAME}", harness_name) + else: + requirement = requirement.replace("{HARNESS_NAME}", "Fuzz") + + class_name = self.benchmark.function_name[1:].split("]")[0] + if "" in self.benchmark.function_name: + creation = ( + f"The target method is a constructor of {class_name} " + "invoke it directly with new keyword." + ) + elif self.is_jvm_static: + creation = ( + "The target method is a static method, invoke it directly " + "without creating an object." + ) + else: + creation = ( + f"You must create the {class_name} object before calling " + "the target method." + ) + requirement = requirement.replace("{STATIC_OR_INSTANCE}", creation) + + close_statement = "" + if self.need_close: + close_statement = ( + "You MUST invoke the close method of the " + f"{class_name} objects in the finally block after the target method " + "is invoked." + ) + + requirement = requirement.replace("{NEED_CLOSE}", close_statement) + + return requirement + + def _format_data_filler(self) -> str: + """Formats a data_filler based on the prompt template.""" + data_filler = self._get_template(self.data_filler_template_file) + return data_filler + + def _format_arguments(self) -> str: + """Formats a list of argument descriptions.""" + argument_descriptions = [] + + for count, function_arg in enumerate(self.benchmark.params): + arg_type = function_arg["type"] + argument = self._format_argument(count, arg_type) + argument_descriptions.append(argument) + + return "" + "\n".join(argument_descriptions) + "" + + def _format_constructors(self) -> str: + """Formats a list of functions / constructors to create the object for + invoking the target method.""" + if self.is_jvm_static: + return "" + + constructors = [] + ctrs = introspector.query_introspector_matching_function_constructor_type( + self.benchmark.project, self.benchmark.return_type, False + ) + for ctr in ctrs: + constructor_sig = ctr.get("function_signature", "") + if constructor_sig: + constructors.append(f"{constructor_sig}") + exceptions = introspector.query_introspector_function_props( + ctr.get("project", ""), constructor_sig + ).get("exceptions", []) + self.exceptions.extend(exceptions) + + if constructors: + ctr_str = "\n".join(constructors) + return f"{ctr_str}" + + functions = [] + funcs = introspector.query_introspector_matching_function_constructor_type( + self.benchmark.project, self.benchmark.return_type, True + ) + for func in funcs: + is_static = func.get("is_static", False) + function_sig = func.get("function_signature", "") + if not function_sig: + continue + exceptions = introspector.query_introspector_function_props( + func.get("project", ""), function_sig + ).get("exceptions", []) + self.exceptions.extend(exceptions) + if is_static: + functions.append(f"{function_sig}") + else: + function_class = function_sig[1:].split("]")[0] + function_str = f"{function_sig}" + function_str = function_str + ( + "You MUST create an " + f"{function_class} object before calling this constructing method." + "" + ) + function_str = f"{function_str}" + functions.append(function_str) + if functions: + func_str = "\n".join(functions) + return f"{func_str}" + + return "" + + def _format_source_reference(self, signature: str) -> Tuple[str, str]: + """Formats the source code reference for this target.""" + # Query for source code of the target method + source_code = introspector.query_introspector_function_source( + self.benchmark.project, signature + ) + + # Query for source code of target method callsites + xref_source_list = [] + for xref_source in introspector.query_introspector_cross_references( + self.benchmark.project, signature + ): + if xref_source: + xref_source_list.append(xref_source) + + return source_code, "\n".join(xref_source_list) + + def _format_problem(self, signature: str) -> str: + """Formats a problem based on the prompt template.""" + is_constructor = bool("" in signature) + + problem = self._get_template(self.problem_template_file) + problem = problem.replace( + "{TARGET}", self._get_template(self.target_template_file) + ) + problem = problem.replace("{SIGNATURE}", signature) + problem = problem.replace("{CLASS}", signature.split("].")[0][1:]) + problem = problem.replace("{REQUIREMENTS}", self._format_requirement(signature)) + problem = problem.replace("{ARGUMENTS}", self._format_arguments()) + problem = problem.replace("{CONSTRUCTORS}", self._format_constructors()) + problem = problem.replace("{EXCEPTIONS}", self._format_exceptions()) + + self_source, cross_source = self._format_source_reference(signature) + problem = problem.replace("{SELF_SOURCE}", self_source) + problem = problem.replace("{CROSS_SOURCE}", cross_source) + problem = problem.replace("{PROJECT_NAME}", self.benchmark.project) + problem = problem.replace("{PROJECT_URL}", self.project_url) + problem = problem.replace("{DATA_MAPPING}", self._format_data_filler()) + + if is_constructor: + problem = problem.replace("{METHOD_OR_CONSTRUCTOR}", "constructor") + else: + problem = problem.replace("{METHOD_OR_CONSTRUCTOR}", "method") + + return problem + + def _prepare_prompt(self, prompt_str: str): + """Constructs a prompt using the parameters and saves it.""" + self._prompt.add_priming(self._get_template(self.priming_template_file)) + self._prompt.add_problem(prompt_str) + + def _has_generic(self, arg: str) -> bool: + """Determine if the argument type contains generic type.""" + return ( + "<" in arg + and not arg.startswith("<") + and arg.endswith(">") + and "java.lang.Class" not in arg + and "java.lang.Object" not in arg + ) + + def _need_import(self, class_name: str) -> bool: + """Determine if the class with class_name needed to be imported.""" + return "." in class_name and not class_name.startswith("java.lang.") + + def _get_methods_for_simple_type(self, simple_type: str) -> str: + """Retrieve string descrbing how to generate random data of + the provided simple type.""" + simple_type_mapping = { + "int": [ + "FuzzedDataProvider::consumeInt()", + "FuzzedDataProvider::consumeInt(int, int)", + ], + "boolean": [ + "FuzzedDataProvider::consumeBoolean()", + "FuzzedDataProvider::pickValue(boolean[])", + ], + "byte": [ + "FuzzedDataProvider::consumeByte()", + "FuzzedDataProvider::consumeByte(byte,byte)", + ], + "byte[]": [ + "FuzzedDataProvider::consumeBytes(int)", + "FuzzedDataProvider::consumeRemainingAsBytes()", + ], + "short": [ + "FuzzedDataProvider::consumeShort()", + "FuzzedDataProvider::consumeShort(short,short)", + ], + "long": [ + "FuzzedDataProvider::consumeLong()", + "FuzzedDataProvider::consumeLong(long, long)", + ], + "float": [ + "FuzzedDataProvider::consumeFloat()", + "FuzzedDataProvider::consumeRegularFloat()", + "FuzzedDataProvider::consumeRegularFloat(float,float)", + "FuzzedDataProvider::consumeProbabilityFloat()", + ], + "double": [ + "FuzzedDataProvider::consumeDouble()", + "FuzzedDataProvider::consumeRegularDouble()", + "FuzzedDataProvider::consumeRegularDouble(double, double)", + "FuzzedDataProvider::consumeProbabilityDouble()", + ], + "char": [ + "FuzzedDataProvider::consumeChar()", + "FuzzedDataProvider::consumeCharNoSurrogates()", + "FuzzedDataProvider::consumeChar(char, char)", + ], + "string": [ + "FuzzedDataProvider::consumeString(int)", + "FuzzedDataProvider::consumeAsciiString(int)", + "FuzzedDataProvider::consumeRemainingAsString()", + "FuzzedDataProvider::consumeRemainingAsAsciiString()", + ], + "class": ["Object::getClass()"], + } + + # Extract simple type + simple_type = simple_type.replace("java.lang.Integer", "int") + simple_type = simple_type.replace("java.lang.Character", "char") + simple_type = simple_type.split(".")[-1].lower() + + if simple_type in simple_type_mapping: + return " or ".join(simple_type_mapping[simple_type]) + + # If the type is not found, try if it is an array of any above types + simple_type = simple_type.replace("[]", "") + return " or ".join(simple_type_mapping.get(simple_type, [])) + + def build( + self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + ) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it. + Ignore target_file_type, project_example_content + and project_context_content parameters. + """ + final_problem = self._format_problem(self.benchmark.function_signature) + self._prepare_prompt(final_problem) + return self._prompt + + def build_fixer_prompt( + self, + benchmark: Benchmark, + raw_code: str, + error_desc: Optional[str], + errors: list[str], + ) -> prompts.Prompt: + """Builds a fixer prompt.""" + # Do nothing for jvm project now. + return self._prompt + + def build_triager_prompt( + self, benchmark: Benchmark, driver_code: str, crash_info: str, crash_func: dict + ) -> prompts.Prompt: + """Builds a triager prompt.""" + # Do nothing for jvm project now. + return self._prompt + + def post_process_generated_code(self, generated_code: str) -> str: + """Allows prompt builder to adjust the generated code.""" + # From observation, the LLM model keeps using wrong method calls including + # FuzzedDataProvider::consumeObject() or FuzzedDataProvider::getObject() or + # FuzzedDataProvider::consumeInt(int) to generate random Object / Integer + # instance. These methods are not valid in FuzzedDataProvider. + + # The fixes here change the calling of data.consumeObject() and + # data.getObject() to data.consumeString(int) + generated_code = generated_code.replace( + "data.consumeObject()", "data.consumeString(data.remainingBytes()/2)" + ) + generated_code = generated_code.replace( + "data.getObject()", "data.consumeString(data.remainingBytes()/2)" + ) + + # The fixes here change the calling of data.consumeInt(int) to + # data.consumeInt(0, int). For example, data.consumeInt(12345) will + # be replaced by data.consumeInt(0, 12345) + for wrong_method_call in re.findall( + r"(data\.consumeInt\(([0-9]+)\))", generated_code + ): + old_method_call = wrong_method_call[0] + new_method_call = f"data.consumeInt(0, {wrong_method_call[1]})" + generated_code = generated_code.replace(old_method_call, new_method_call) + + return generated_code class DefaultRustTemplateBuilder(PromptBuilder): - """Default builder for Rust projects.""" - - def __init__(self, - model: models.LLM, - benchmark: Benchmark, - template_dir: str = DEFAULT_TEMPLATE_DIR): - super().__init__(model) - self._template_dir = template_dir - self.benchmark = benchmark - self.project_url = oss_fuzz_checkout.get_project_repository( - self.benchmark.project) - - # Load templates. - self.priming_template_file = self._find_template(template_dir, - 'rust_priming.txt') - self.problem_template_file = self._find_template(template_dir, - 'rust_problem.txt') - - def _find_template(self, template_dir: str, template_name: str) -> str: - """Finds template file based on |template_dir|.""" - preferred_template = os.path.join(template_dir, template_name) - # Use the preferred template if it exists. - if os.path.isfile(preferred_template): - return preferred_template - - # Fall back to the default template. - default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) - return default_template - - def _get_template(self, template_file: str) -> str: - """Reads the template for prompts.""" - with open(template_file) as file: - return file.read() - - def _format_target(self, signature: str) -> str: - """Format the target function for the prompts creation.""" - target = self._get_template(self.problem_template_file) - arg_count = len(self.benchmark.params) - arg_type = [arg_dict['type'] for arg_dict in self.benchmark.params] - - target = target.replace('{FUNCTION_SIGNATURE}', signature) - target = target.replace('{ARG_COUNT}', str(arg_count)) - target = target.replace('{ARG_TYPE}', ','.join(arg_type)) - - return target - - def _format_problem(self, signature: str) -> str: - """Formats a problem based on the prompt template.""" - problem = self._format_target(signature) - - problem = problem.replace('{PROJECT_NAME}', self.benchmark.project) - problem = problem.replace('{PROJECT_URL}', self.project_url) - - return problem - - def _prepare_prompt(self, prompt_str: str): - """Constructs a prompt using the parameters and saves it.""" - self._prompt.add_priming(self._get_template(self.priming_template_file)) - self._prompt.add_problem(prompt_str) - - def build(self, - example_pair: list[list[str]], - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None) -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it. - Ignore target_file_type, project_example_content - and project_context_content parameters. - """ - final_problem = self._format_problem(self.benchmark.function_signature) - self._prepare_prompt(final_problem) - return self._prompt - - def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str, - error_desc: Optional[str], - errors: list[str]) -> prompts.Prompt: - """Builds a fixer prompt.""" - # Do nothing for rust project now. - return self._prompt - - def build_triager_prompt(self, benchmark: Benchmark, driver_code: str, - crash_info: str, crash_func: dict) -> prompts.Prompt: - """Builds a triager prompt.""" - # Do nothing for rust project now. - return self._prompt - - def post_process_generated_code(self, generated_code: str) -> str: - """Allows prompt builder to adjust the generated code.""" - # Do nothing for rust project now. - return generated_code + """Default builder for Rust projects.""" + + def __init__( + self, + model: models.LLM, + benchmark: Benchmark, + template_dir: str = DEFAULT_TEMPLATE_DIR, + ): + super().__init__(model) + self._template_dir = template_dir + self.benchmark = benchmark + self.project_url = oss_fuzz_checkout.get_project_repository( + self.benchmark.project + ) + + # Load templates. + self.priming_template_file = self._find_template( + template_dir, "rust_priming.txt" + ) + self.problem_template_file = self._find_template( + template_dir, "rust_problem.txt" + ) + + def _find_template(self, template_dir: str, template_name: str) -> str: + """Finds template file based on |template_dir|.""" + preferred_template = os.path.join(template_dir, template_name) + # Use the preferred template if it exists. + if os.path.isfile(preferred_template): + return preferred_template + + # Fall back to the default template. + default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) + return default_template + + def _get_template(self, template_file: str) -> str: + """Reads the template for prompts.""" + with open(template_file) as file: + return file.read() + + def _format_target(self, signature: str) -> str: + """Format the target function for the prompts creation.""" + target = self._get_template(self.problem_template_file) + arg_count = len(self.benchmark.params) + arg_type = [arg_dict["type"] for arg_dict in self.benchmark.params] + + target = target.replace("{FUNCTION_SIGNATURE}", signature) + target = target.replace("{ARG_COUNT}", str(arg_count)) + target = target.replace("{ARG_TYPE}", ",".join(arg_type)) + + return target + + def _format_problem(self, signature: str) -> str: + """Formats a problem based on the prompt template.""" + problem = self._format_target(signature) + + problem = problem.replace("{PROJECT_NAME}", self.benchmark.project) + problem = problem.replace("{PROJECT_URL}", self.project_url) + + return problem + + def _prepare_prompt(self, prompt_str: str): + """Constructs a prompt using the parameters and saves it.""" + self._prompt.add_priming(self._get_template(self.priming_template_file)) + self._prompt.add_problem(prompt_str) + + def build( + self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + ) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it. + Ignore target_file_type, project_example_content + and project_context_content parameters. + """ + final_problem = self._format_problem(self.benchmark.function_signature) + self._prepare_prompt(final_problem) + return self._prompt + + def build_fixer_prompt( + self, + benchmark: Benchmark, + raw_code: str, + error_desc: Optional[str], + errors: list[str], + ) -> prompts.Prompt: + """Builds a fixer prompt.""" + # Do nothing for rust project now. + return self._prompt + + def build_triager_prompt( + self, benchmark: Benchmark, driver_code: str, crash_info: str, crash_func: dict + ) -> prompts.Prompt: + """Builds a triager prompt.""" + # Do nothing for rust project now. + return self._prompt + + def post_process_generated_code(self, generated_code: str) -> str: + """Allows prompt builder to adjust the generated code.""" + # Do nothing for rust project now. + return generated_code class JvmFixingBuilder(PromptBuilder): - """Prompt builder for fixing JVM harness with complication error or - to increase code coverage.""" - - def __init__(self, - model: models.LLM, - benchmark: Benchmark, - generated_harness: str, - errors: list[str], - template_dir: str = DEFAULT_TEMPLATE_DIR): - super().__init__(model) - self._template_dir = template_dir - self.benchmark = benchmark - self.generated_harness = generated_harness - self.error_str = '\n'.join(errors) - - # Load templates. - self.template_file = self._find_template(template_dir, 'jvm_fixer.txt') - - def _find_template(self, template_dir: str, template_name: str) -> str: - """Finds template file based on |template_dir|.""" - preferred_template = os.path.join(template_dir, template_name) - # Use the preferred template if it exists. - if os.path.isfile(preferred_template): - return preferred_template - # Fall back to the default template. - default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) - return default_template - - def _get_template(self, template_file: str) -> str: - """Reads the template for prompts.""" - with open(template_file) as file: - return file.read() - - def build(self, - example_pair: list[list[str]], - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None) -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it. - Ignore target_file_type, project_example_content - and project_context_content parameters. - """ - with open(self.template_file, 'r') as f: - prompt_text = f.read() - - proj = self.benchmark.project - - # Format the repository - target_repository = oss_fuzz_checkout.get_project_repository( - self.benchmark.project) - - # Add information - prompt_text = prompt_text.replace('{TARGET_REPO}', target_repository) - prompt_text = prompt_text.replace('{HARNESS_NAME}', - self.benchmark.target_name) - prompt_text = prompt_text.replace('{GENERATED_HARNESS}', - self.generated_harness) - - # Add all public candidates to prompt - methods = introspector.query_introspector_all_public_candidates(proj) - name = [method['function_name'] for method in methods] - prompt_text = prompt_text.replace('{PUBLIC_METHODS}', ','.join(name)) - - # Add source code of all existing harnesses to prompt - source_list = [] - harnesses = introspector.query_introspector_for_harness_intrinsics(proj) - for pair in harnesses: - path = pair.get('source', '') - if path: - source = introspector.query_introspector_source_code(proj, path) - if source: - source_list.append(source) - - prompt_text = prompt_text.replace('{EXISTING_HARNESS}', - '\n---\n'.join(source_list)) - - if self.error_str: - prompt_text = prompt_text.replace('{ERRORS}', - ('There are no errors, please consider ' - 'increasing the code coverage.')) - else: - prompt_text = prompt_text.replace('{ERRORS}', self.error_str) - - self._prompt.add_problem(prompt_text) - return self._prompt - - def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str, - error_desc: Optional[str], - errors: list[str]) -> prompts.Prompt: - """Builds a fixer prompt.""" - # Do nothing for jvm project now. - return self._prompt - - def build_triager_prompt(self, benchmark: Benchmark, driver_code: str, - crash_info: str, crash_func: dict) -> prompts.Prompt: - """Builds a triager prompt.""" - # Do nothing for jvm project now. - return self._prompt - - def post_process_generated_code(self, generated_code: str) -> str: - """Allows prompt builder to adjust the generated code.""" - return generated_code + """Prompt builder for fixing JVM harness with complication error or + to increase code coverage.""" + + def __init__( + self, + model: models.LLM, + benchmark: Benchmark, + generated_harness: str, + errors: list[str], + template_dir: str = DEFAULT_TEMPLATE_DIR, + ): + super().__init__(model) + self._template_dir = template_dir + self.benchmark = benchmark + self.generated_harness = generated_harness + self.error_str = "\n".join(errors) + + # Load templates. + self.template_file = self._find_template(template_dir, "jvm_fixer.txt") + + def _find_template(self, template_dir: str, template_name: str) -> str: + """Finds template file based on |template_dir|.""" + preferred_template = os.path.join(template_dir, template_name) + # Use the preferred template if it exists. + if os.path.isfile(preferred_template): + return preferred_template + # Fall back to the default template. + default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) + return default_template + + def _get_template(self, template_file: str) -> str: + """Reads the template for prompts.""" + with open(template_file) as file: + return file.read() + + def build( + self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + ) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it. + Ignore target_file_type, project_example_content + and project_context_content parameters. + """ + with open(self.template_file, "r") as f: + prompt_text = f.read() + + proj = self.benchmark.project + + # Format the repository + target_repository = oss_fuzz_checkout.get_project_repository( + self.benchmark.project + ) + + # Add information + prompt_text = prompt_text.replace("{TARGET_REPO}", target_repository) + prompt_text = prompt_text.replace("{HARNESS_NAME}", self.benchmark.target_name) + prompt_text = prompt_text.replace("{GENERATED_HARNESS}", self.generated_harness) + + # Add all public candidates to prompt + methods = introspector.query_introspector_all_public_candidates(proj) + name = [method["function_name"] for method in methods] + prompt_text = prompt_text.replace("{PUBLIC_METHODS}", ",".join(name)) + + # Add source code of all existing harnesses to prompt + source_list = [] + harnesses = introspector.query_introspector_for_harness_intrinsics(proj) + for pair in harnesses: + path = pair.get("source", "") + if path: + source = introspector.query_introspector_source_code(proj, path) + if source: + source_list.append(source) + + prompt_text = prompt_text.replace( + "{EXISTING_HARNESS}", "\n---\n".join(source_list) + ) + + if self.error_str: + prompt_text = prompt_text.replace( + "{ERRORS}", + ( + "There are no errors, please consider " + "increasing the code coverage." + ), + ) + else: + prompt_text = prompt_text.replace("{ERRORS}", self.error_str) + + self._prompt.add_problem(prompt_text) + return self._prompt + + def build_fixer_prompt( + self, + benchmark: Benchmark, + raw_code: str, + error_desc: Optional[str], + errors: list[str], + ) -> prompts.Prompt: + """Builds a fixer prompt.""" + # Do nothing for jvm project now. + return self._prompt + + def build_triager_prompt( + self, benchmark: Benchmark, driver_code: str, crash_info: str, crash_func: dict + ) -> prompts.Prompt: + """Builds a triager prompt.""" + # Do nothing for jvm project now. + return self._prompt + + def post_process_generated_code(self, generated_code: str) -> str: + """Allows prompt builder to adjust the generated code.""" + return generated_code class DefaultPythonTemplateBuilder(PromptBuilder): - """Default builder for Python projects.""" - - def __init__(self, - model: models.LLM, - benchmark: Benchmark, - template_dir: str = DEFAULT_TEMPLATE_DIR): - super().__init__(model) - self._template_dir = template_dir - self.benchmark = benchmark - self.project_url = oss_fuzz_checkout.get_project_repository( - self.benchmark.project) - - # Load templates. - self.priming_template_file = self._find_template(template_dir, - 'python_priming.txt') - self.problem_template_file = self._find_template(template_dir, - 'python_problem.txt') - - def _find_template(self, template_dir: str, template_name: str) -> str: - """Finds template file based on |template_dir|.""" - preferred_template = os.path.join(template_dir, template_name) - # Use the preferred template if it exists. - if os.path.isfile(preferred_template): - return preferred_template - # Fall back to the default template. - default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) - return default_template - - def _get_template(self, template_file: str) -> str: - """Reads the template for prompts.""" - with open(template_file) as file: - return file.read() - - def _format_target(self, signature: str) -> str: - """Format the target function for the prompts creation.""" - target = self._get_template(self.problem_template_file) - signature_split = signature.rsplit('.', 1) - - # Determine if the target is class function of instance function - if self.benchmark.params[0].get('name', '') == 'self': - arg_count = len(self.benchmark.params) - 1 - desc = ('This is an instance function. You MUST create the needed ' - f'class {signature_split[0]} before invoking the target ' - f'function {signature_split[-1]}.') - else: - arg_count = len(self.benchmark.params) - desc = 'This is a class function. You MUST invoke it directly.' - - target = target.replace('{METHOD_SIGNATURE}', signature) - target = target.replace('{PACKAGE}', signature_split[0]) - target = target.replace('{ARG_COUNT}', str(arg_count)) - target = target.replace('{CLASS_METHOD_OR_GENERAL_METHOD}', desc) - - return target - - def _format_problem(self, signature: str) -> str: - """Formats a problem based on the prompt template.""" - problem = self._format_target(signature) - - problem = problem.replace('{PROJECT_NAME}', self.benchmark.project) - problem = problem.replace('{PROJECT_URL}', self.project_url) - - return problem - - def _prepare_prompt(self, prompt_str: str): - """Constructs a prompt using the parameters and saves it.""" - self._prompt.add_priming(self._get_template(self.priming_template_file)) - self._prompt.add_problem(prompt_str) - - def build(self, - example_pair: list[list[str]], - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None) -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it. - Ignore target_file_type, project_example_content - and project_context_content parameters. - """ - final_problem = self._format_problem(self.benchmark.function_signature) - self._prepare_prompt(final_problem) - return self._prompt - - def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str, - error_desc: Optional[str], - errors: list[str]) -> prompts.Prompt: - """Builds a fixer prompt.""" - # Do nothing for python project now. - return self._prompt - - def build_triager_prompt(self, benchmark: Benchmark, driver_code: str, - crash_info: str, crash_func: dict) -> prompts.Prompt: - """Builds a triager prompt.""" - # Do nothing for python project now. - return self._prompt - - def post_process_generated_code(self, generated_code: str) -> str: - """Allows prompt builder to adjust the generated code.""" - # Do nothing for python project now. - return generated_code + """Default builder for Python projects.""" + + def __init__( + self, + model: models.LLM, + benchmark: Benchmark, + template_dir: str = DEFAULT_TEMPLATE_DIR, + ): + super().__init__(model) + self._template_dir = template_dir + self.benchmark = benchmark + self.project_url = oss_fuzz_checkout.get_project_repository( + self.benchmark.project + ) + + # Load templates. + self.priming_template_file = self._find_template( + template_dir, "python_priming.txt" + ) + self.problem_template_file = self._find_template( + template_dir, "python_problem.txt" + ) + + def _find_template(self, template_dir: str, template_name: str) -> str: + """Finds template file based on |template_dir|.""" + preferred_template = os.path.join(template_dir, template_name) + # Use the preferred template if it exists. + if os.path.isfile(preferred_template): + return preferred_template + # Fall back to the default template. + default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) + return default_template + + def _get_template(self, template_file: str) -> str: + """Reads the template for prompts.""" + with open(template_file) as file: + return file.read() + + def _format_target(self, signature: str) -> str: + """Format the target function for the prompts creation.""" + target = self._get_template(self.problem_template_file) + signature_split = signature.rsplit(".", 1) + + # Determine if the target is class function of instance function + if self.benchmark.params[0].get("name", "") == "self": + arg_count = len(self.benchmark.params) - 1 + desc = ( + "This is an instance function. You MUST create the needed " + f"class {signature_split[0]} before invoking the target " + f"function {signature_split[-1]}." + ) + else: + arg_count = len(self.benchmark.params) + desc = "This is a class function. You MUST invoke it directly." + + target = target.replace("{METHOD_SIGNATURE}", signature) + target = target.replace("{PACKAGE}", signature_split[0]) + target = target.replace("{ARG_COUNT}", str(arg_count)) + target = target.replace("{CLASS_METHOD_OR_GENERAL_METHOD}", desc) + + return target + + def _format_problem(self, signature: str) -> str: + """Formats a problem based on the prompt template.""" + problem = self._format_target(signature) + + problem = problem.replace("{PROJECT_NAME}", self.benchmark.project) + problem = problem.replace("{PROJECT_URL}", self.project_url) + + return problem + + def _prepare_prompt(self, prompt_str: str): + """Constructs a prompt using the parameters and saves it.""" + self._prompt.add_priming(self._get_template(self.priming_template_file)) + self._prompt.add_problem(prompt_str) + + def build( + self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + ) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it. + Ignore target_file_type, project_example_content + and project_context_content parameters. + """ + final_problem = self._format_problem(self.benchmark.function_signature) + self._prepare_prompt(final_problem) + return self._prompt + + def build_fixer_prompt( + self, + benchmark: Benchmark, + raw_code: str, + error_desc: Optional[str], + errors: list[str], + ) -> prompts.Prompt: + """Builds a fixer prompt.""" + # Do nothing for python project now. + return self._prompt + + def build_triager_prompt( + self, benchmark: Benchmark, driver_code: str, crash_info: str, crash_func: dict + ) -> prompts.Prompt: + """Builds a triager prompt.""" + # Do nothing for python project now. + return self._prompt + + def post_process_generated_code(self, generated_code: str) -> str: + """Allows prompt builder to adjust the generated code.""" + # Do nothing for python project now. + return generated_code class CSpecificBuilder(PromptBuilder): - """Builder specifically targeted C (and excluding C++).""" - - def __init__(self, - model: models.LLM, - benchmark: Benchmark, - template_dir: str = DEFAULT_TEMPLATE_DIR): - super().__init__(model) - self._template_dir = template_dir - self.benchmark = benchmark - - # Load templates. - self.priming_template_file = self._find_template(template_dir, - 'c-priming.txt') - - def _find_template(self, template_dir: str, template_name: str) -> str: - """Finds template file based on |template_dir|.""" - preferred_template = os.path.join(template_dir, template_name) - # Use the preferred template if it exists. - if os.path.isfile(preferred_template): - return preferred_template - # Fall back to the default template. - default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) - return default_template - - def _get_template(self, template_file: str) -> str: - """Reads the template for prompts.""" - with open(template_file) as file: - return file.read() - - def build(self, - example_pair: list[list[str]], - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None) -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it.""" - - with open(self.priming_template_file, 'r') as f: - prompt_text = f.read() - - # Format the priming - target_repository = oss_fuzz_checkout.get_project_repository( - self.benchmark.project) - prompt_text = prompt_text.replace('{TARGET_REPO}', target_repository) - prompt_text = prompt_text.replace('{TARGET_FUNCTION}', - self.benchmark.function_signature) - function_source = introspector.query_introspector_function_source( - self.benchmark.project, self.benchmark.function_signature) - prompt_text = prompt_text.replace('{TARGET_FUNCTION_SOURCE_CODE}', - function_source) - - # Set header inclusion string if there are any headers. - headers_to_include = \ - introspector.query_introspector_header_files_to_include( - self.benchmark.project, self.benchmark.function_signature) - header_inclusion_string = '' - if headers_to_include: - header_inclusion_string = ', '.join(headers_to_include) - - # TODO: Programmatically select and refine the header. - prompt_text = prompt_text.replace('{TARGET_HEADER_FILES}', - header_inclusion_string) - - # Add function arg types - arg_types = introspector.query_introspector_function_debug_arg_types( - self.benchmark.project, self.benchmark.function_signature) - - arg_types_text = '' - if arg_types: - arg_types_text = 'The target function takes the following arguments:\n' - arg_types_text += '- ' + '- '.join(f'{arg}\n' for arg in arg_types) - - arg_types_text += ( - 'You must make sure the arguments passed to the ' - 'function match the types of the function. Do this by casting ' - 'appropriately.\n') - - prompt_text = prompt_text.replace('{FUNCTION_ARG_TYPES_MSG}', - arg_types_text) - - sample_cross_references = introspector.query_introspector_sample_xrefs( - self.benchmark.project, self.benchmark.function_signature) - if sample_cross_references: - additional_text = ( - 'The target function is used in various places of the target project.' - 'Please see the following samples of code using the target, which ' - 'you should use as inspiration for the harness to structure the code:' - '\n') - - exp_usage = 'Example usage:\n' - additional_text += exp_usage + exp_usage.join( - f'```c{elem}\n```\n' for elem in sample_cross_references) - else: - additional_text = '' - - prompt_text = prompt_text.replace('{ADDITIONAL_INFORMATION}', - additional_text) - - self._prompt.add_priming(prompt_text) - return self._prompt - - def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str, - error_desc: Optional[str], - errors: list[str]) -> prompts.Prompt: - """Prepares the code-fixing prompt.""" - return self._prompt - - def build_triager_prompt(self, benchmark: Benchmark, driver_code: str, - crash_info: str, crash_func: dict) -> prompts.Prompt: - """Builds a triager prompt.""" - return self._prompt - - def post_process_generated_code(self, generated_code: str) -> str: - """Adds specific C headers we always want in the harnesses.""" - # TODO: explore if we can make this more precise, by only adding headers - # if needed. - for header in C_PROMPT_HEADERS_TO_ALWAYS_INCLUDES: - generated_code = f'#include <{header}>\n' + generated_code - return generated_code + """Builder specifically targeted C (and excluding C++).""" + + def __init__( + self, + model: models.LLM, + benchmark: Benchmark, + template_dir: str = DEFAULT_TEMPLATE_DIR, + ): + super().__init__(model) + self._template_dir = template_dir + self.benchmark = benchmark + + # Load templates. + self.priming_template_file = self._find_template(template_dir, "c-priming.txt") + + def _find_template(self, template_dir: str, template_name: str) -> str: + """Finds template file based on |template_dir|.""" + preferred_template = os.path.join(template_dir, template_name) + # Use the preferred template if it exists. + if os.path.isfile(preferred_template): + return preferred_template + # Fall back to the default template. + default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) + return default_template + + def _get_template(self, template_file: str) -> str: + """Reads the template for prompts.""" + with open(template_file) as file: + return file.read() + + def build( + self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + ) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it.""" + + with open(self.priming_template_file, "r") as f: + prompt_text = f.read() + + # Format the priming + target_repository = oss_fuzz_checkout.get_project_repository( + self.benchmark.project + ) + prompt_text = prompt_text.replace("{TARGET_REPO}", target_repository) + prompt_text = prompt_text.replace( + "{TARGET_FUNCTION}", self.benchmark.function_signature + ) + function_source = introspector.query_introspector_function_source( + self.benchmark.project, self.benchmark.function_signature + ) + prompt_text = prompt_text.replace( + "{TARGET_FUNCTION_SOURCE_CODE}", function_source + ) + + # Set header inclusion string if there are any headers. + headers_to_include = introspector.query_introspector_header_files_to_include( + self.benchmark.project, self.benchmark.function_signature + ) + header_inclusion_string = "" + if headers_to_include: + header_inclusion_string = ", ".join(headers_to_include) + + # TODO: Programmatically select and refine the header. + prompt_text = prompt_text.replace( + "{TARGET_HEADER_FILES}", header_inclusion_string + ) + + # Add function arg types + arg_types = introspector.query_introspector_function_debug_arg_types( + self.benchmark.project, self.benchmark.function_signature + ) + + arg_types_text = "" + if arg_types: + arg_types_text = "The target function takes the following arguments:\n" + arg_types_text += "- " + "- ".join(f"{arg}\n" for arg in arg_types) + + arg_types_text += ( + "You must make sure the arguments passed to the " + "function match the types of the function. Do this by casting " + "appropriately.\n" + ) + + prompt_text = prompt_text.replace("{FUNCTION_ARG_TYPES_MSG}", arg_types_text) + + sample_cross_references = introspector.query_introspector_sample_xrefs( + self.benchmark.project, self.benchmark.function_signature + ) + if sample_cross_references: + additional_text = ( + "The target function is used in various places of the target project." + "Please see the following samples of code using the target, which " + "you should use as inspiration for the harness to structure the code:" + "\n" + ) + + exp_usage = "Example usage:\n" + additional_text += exp_usage + exp_usage.join( + f"```c{elem}\n```\n" for elem in sample_cross_references + ) + else: + additional_text = "" + + prompt_text = prompt_text.replace("{ADDITIONAL_INFORMATION}", additional_text) + + self._prompt.add_priming(prompt_text) + return self._prompt + + def build_fixer_prompt( + self, + benchmark: Benchmark, + raw_code: str, + error_desc: Optional[str], + errors: list[str], + ) -> prompts.Prompt: + """Prepares the code-fixing prompt.""" + return self._prompt + + def build_triager_prompt( + self, benchmark: Benchmark, driver_code: str, crash_info: str, crash_func: dict + ) -> prompts.Prompt: + """Builds a triager prompt.""" + return self._prompt + + def post_process_generated_code(self, generated_code: str) -> str: + """Adds specific C headers we always want in the harnesses.""" + # TODO: explore if we can make this more precise, by only adding headers + # if needed. + for header in C_PROMPT_HEADERS_TO_ALWAYS_INCLUDES: + generated_code = f"#include <{header}>\n" + generated_code + return generated_code class TestToHarnessConverter(PromptBuilder): - """Builder for test-to-harness conversion.""" - - def __init__(self, - model: models.LLM, - benchmark: Benchmark, - template_dir: str = DEFAULT_TEMPLATE_DIR): - super().__init__(model) - self._template_dir = template_dir - self.benchmark = benchmark - - self.harness_source_code = introspector.query_introspector_source_code( - self.benchmark.project, self.benchmark.target_path, 0, 10000) - - self.general_jvm_imports = [ - 'import com.code_intelligence.jazzer.api.FuzzedDataProvider;' - ] - - # Load templates. - self.priming_template_file = self._find_template( - template_dir, 'test_to_harness_priming.txt') - jvm_requirement_template_file = self._find_template( - template_dir, 'jvm_requirement_test_to_harness.txt') - - # Constant prompt description and text - self.language_prompt = { - 'c': - '''This is a C programming language so the harness + """Builder for test-to-harness conversion.""" + + def __init__( + self, + model: models.LLM, + benchmark: Benchmark, + template_dir: str = DEFAULT_TEMPLATE_DIR, + ): + super().__init__(model) + self._template_dir = template_dir + self.benchmark = benchmark + + self.harness_source_code = introspector.query_introspector_source_code( + self.benchmark.project, self.benchmark.target_path, 0, 10000 + ) + + self.general_jvm_imports = [ + "import com.code_intelligence.jazzer.api.FuzzedDataProvider;" + ] + + # Load templates. + self.priming_template_file = self._find_template( + template_dir, "test_to_harness_priming.txt" + ) + jvm_requirement_template_file = self._find_template( + template_dir, "jvm_requirement_test_to_harness.txt" + ) + + # Constant prompt description and text + self.language_prompt = { + "c": """This is a C programming language so the harness should be written in C. This means the harness should have the structure: int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {} Specifically, you should *not* include any `extern "C"` in the harness definition, and you should write the harness in pure C. - ''', - 'c++': - '''This is a CPP programming language so the harness + """, + "c++": """This is a CPP programming language so the harness should be written in CPP. This means the harness should have the structure: extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {} - ''', - 'jvm': - self._get_template(jvm_requirement_template_file).replace( - '{HARNESS_NAME}', self.benchmark.target_name) - } - - def _find_template(self, template_dir: str, template_name: str) -> str: - """Finds template file based on |template_dir|.""" - preferred_template = os.path.join(template_dir, template_name) - # Use the preferred template if it exists. - if os.path.isfile(preferred_template): - return preferred_template - # Fall back to the default template. - default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) - return default_template - - def _get_template(self, template_file: str) -> str: - """Reads the template for prompts.""" - with open(template_file) as file: - return file.read() - - def _extract_jvm_imports(self, src: str, cls: list[str]) -> list[str]: - """Extract and interpret import statements from java source.""" - - # Extract import statements - # General import statemet: import test.Test; - # Static import statement: import static test.Test.InnerTest; - # Generic import statement: import test.*; - import_pattern = r'^\s*(import\s+(static\s+)?([\w.*]+);)' - imports = re.compile(import_pattern, re.MULTILINE).findall(src) - - # Group public classes by packages - cls_map = {} - for cls_name in cls: - if '.' in cls_name: - package, name = cls_name.rsplit('.', 1) - if package not in cls_map: - cls_map[package] = [] - cls_map[package].append(name) - - # Generalise public classes import statements - results = set() - for package, cls_name_list in cls_map.items(): - if len(cls_name_list) >= 3: - # Generalise the import package if it has more than three items - results.add(f'import {package}.*;') - else: - # Import each class separately - for cls_name in cls_name_list: - results.add(f'import {package}.{cls_name};') - - # Retrieve other import statements for reference - others = set() - for full_import, _, cls_name in imports: - if cls_name.startswith('java'): - results.add(full_import) - elif '*' in cls_name: - package = cls_name.rstrip('.*') - if package not in cls_map: - others.add(full_import) - else: - others.add(full_import) - - self.general_jvm_imports = list(sorted(results)) - return list(sorted(others)) - - def _get_jvm_public_candidates(self, proj: str) -> list[str]: - """Helper function to retrieve list of public candidates for jvm.""" - method_set = set() - methods = introspector.query_introspector_all_public_candidates(proj) - for method in methods: - if "" not in method['function_name']: - method_set.add(method['function_name']) - return list(method_set) - - def extract_header_files(self, text): - # Include any weird macros defined that does not have any values. This - # was found empirically to be valuable. - includes_in_test = set() - for line in text.split('\n'): - if '#include' in line and 'test' not in line: - includes_in_test.add(line) - return includes_in_test - - def build(self, - example_pair: list[list[str]], - project_example_content: Optional[list[list[str]]] = None, - project_context_content: Optional[dict] = None, - target_repository: str = '', - test_source_code: str = '') -> prompts.Prompt: - """Constructs a prompt using the templates in |self| and saves it.""" - - with open(self.priming_template_file, 'r') as f: - prompt_text = f.read() - - # Format the priming - if not target_repository: - target_repository = oss_fuzz_checkout.get_project_repository( - self.benchmark.project) - if not test_source_code: - test_source_code = introspector.query_introspector_test_source( - self.benchmark.project, - self.benchmark.test_file_path.replace('//', '/')) - - prompt_text = prompt_text.replace("{TARGET_REPO}", target_repository) - prompt_text = prompt_text.replace("{TEST_SOURCE_CODE}", test_source_code) - - language_text = self.language_prompt.get(self.benchmark.language.lower(), - '') - prompt_text = prompt_text.replace('{PROGRAMMING_LANGUAGE_TEXT}', - language_text) - - if self.benchmark.language == 'jvm': - prompt_text = prompt_text.replace('{HEADER_FILE_LANG}', '') - prompt_text = prompt_text.replace('{HARNESS_HEADERS}', '') - - # Fuzz Introspector use JVM as it support other JVM languages in addition - # to Java. Currently, the logic in OSS-Fuzz-Gen is only working on Java. - prompt_text = prompt_text.replace('{PROG_LANG}', 'Java') - - # Provide list of public classes of this project - classes = introspector.query_introspector_public_classes( - self.benchmark.project) - prompt_text = prompt_text.replace('{PUBLIC_CLASSES}', ','.join(classes)) - - # Proivde sample harness code - harness_sample_text = ('There are already harnesses targeting this ' - 'project, and an example of this is:\n' - f'\n{self.harness_source_code}\n') - prompt_text = prompt_text.replace('{TARGET_SAMPLE_HARNESS}', - harness_sample_text) - - # Extract must included methods - methods = self._get_jvm_public_candidates(self.benchmark.project) - prompt_text = prompt_text.replace('{PUBLIC_METHODS}', ','.join(methods)) - - # Extract import list - other_import_list = self._extract_jvm_imports(test_source_code, classes) - prompt_text = prompt_text.replace('{IMPORT_STATEMENTS}', - '\n'.join(self.general_jvm_imports)) - prompt_text = prompt_text.replace('{OTHER_IMPORT_STATEMENTS}', - '\n'.join(other_import_list)) - else: - included_header_files = self.extract_header_files(test_source_code) - if included_header_files: - harness_included_header_files = ( - 'The following header files are used in the ' - 'test source code. Please make sure to include the same ones: ' - f'{included_header_files}') - else: - harness_included_header_files = '' - prompt_text = prompt_text.replace('{HARNESS_HEADERS}', - harness_included_header_files) - - headers_to_include = \ - introspector.query_introspector_header_files( - self.benchmark.project) - if headers_to_include: - header_inclusion_string = '\n' - header_inclusion_string += ''.join( - f'{h}\n' for h in headers_to_include) - - header_inclusion_string += '\n' - header_inclusion_string = ( - 'The following header files exist in the project source code. ' - 'If the harness you create needs any header files make sure ' - 'they are in the list:\n' - f'{header_inclusion_string}') - else: - header_inclusion_string = '' - prompt_text = prompt_text.replace('{HEADER_FILE_LANG}', - header_inclusion_string) - prompt_text = prompt_text.replace('{PROG_LANG}', self.benchmark.language) - - harness_sample_text = ('There are already harnesses targeting this ' - 'project, and an example of this is:\n' - f'{self.harness_source_code}') - prompt_text = prompt_text.replace('{TARGET_SAMPLE_HARNESS}', - harness_sample_text) - - self._prompt.add_priming(prompt_text) - return self._prompt - - def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str, - error_desc: Optional[str], - errors: list[str]) -> prompts.Prompt: - """Prepares the code-fixing prompt.""" - return self._prompt - - def build_triager_prompt(self, benchmark: Benchmark, driver_code: str, - crash_info: str, crash_func: dict) -> prompts.Prompt: - """Builds a triager prompt.""" - return self._prompt - - def post_process_generated_code(self, generated_code: str) -> str: - """Adds specific C headers we always want in the harnesses for C/C++. - Add general import statements and remove unnecessary statments for JVM""" - if self.benchmark.language.lower() == 'jvm': - # For JVM - # Remove assert and out statements - fixed_code = [] - prefixes = ['assert', 'System.out'] - for line in generated_code.split('\n'): - if not any(line.lstrip().startswith(prefix) for prefix in prefixes): - fixed_code.append(line) - - # Add general import statements - import_str = '\n'.join(self.general_jvm_imports) - generated_code = '\n'.join(fixed_code) - generated_code = f'{import_str}\n{generated_code}' - else: - # For C/C++ - for header in C_PROMPT_HEADERS_TO_ALWAYS_INCLUDES: - generated_code = f'#include <{header}>\n{generated_code}' - generated_code += '\n' - if self.benchmark.language.lower() == 'c': - generated_code = generated_code.replace( - 'extern "C" int LLVMFuzzerTestOneInput', - 'int LLVMFuzzerTestOneInput') - - return generated_code + """, + "jvm": self._get_template(jvm_requirement_template_file).replace( + "{HARNESS_NAME}", self.benchmark.target_name + ), + } + + def _find_template(self, template_dir: str, template_name: str) -> str: + """Finds template file based on |template_dir|.""" + preferred_template = os.path.join(template_dir, template_name) + # Use the preferred template if it exists. + if os.path.isfile(preferred_template): + return preferred_template + # Fall back to the default template. + default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) + return default_template + + def _get_template(self, template_file: str) -> str: + """Reads the template for prompts.""" + with open(template_file) as file: + return file.read() + + def _extract_jvm_imports(self, src: str, cls: list[str]) -> list[str]: + """Extract and interpret import statements from java source.""" + + # Extract import statements + # General import statemet: import test.Test; + # Static import statement: import static test.Test.InnerTest; + # Generic import statement: import test.*; + import_pattern = r"^\s*(import\s+(static\s+)?([\w.*]+);)" + imports = re.compile(import_pattern, re.MULTILINE).findall(src) + + # Group public classes by packages + cls_map = {} + for cls_name in cls: + if "." in cls_name: + package, name = cls_name.rsplit(".", 1) + if package not in cls_map: + cls_map[package] = [] + cls_map[package].append(name) + + # Generalise public classes import statements + results = set() + for package, cls_name_list in cls_map.items(): + if len(cls_name_list) >= 3: + # Generalise the import package if it has more than three items + results.add(f"import {package}.*;") + else: + # Import each class separately + for cls_name in cls_name_list: + results.add(f"import {package}.{cls_name};") + + # Retrieve other import statements for reference + others = set() + for full_import, _, cls_name in imports: + if cls_name.startswith("java"): + results.add(full_import) + elif "*" in cls_name: + package = cls_name.rstrip(".*") + if package not in cls_map: + others.add(full_import) + else: + others.add(full_import) + + self.general_jvm_imports = list(sorted(results)) + return list(sorted(others)) + + def _get_jvm_public_candidates(self, proj: str) -> list[str]: + """Helper function to retrieve list of public candidates for jvm.""" + method_set = set() + methods = introspector.query_introspector_all_public_candidates(proj) + for method in methods: + if "" not in method["function_name"]: + method_set.add(method["function_name"]) + return list(method_set) + + def extract_header_files(self, text): + # Include any weird macros defined that does not have any values. This + # was found empirically to be valuable. + includes_in_test = set() + for line in text.split("\n"): + if "#include" in line and "test" not in line: + includes_in_test.add(line) + return includes_in_test + + def build( + self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None, + target_repository: str = "", + test_source_code: str = "", + ) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it.""" + + with open(self.priming_template_file, "r") as f: + prompt_text = f.read() + + # Format the priming + if not target_repository: + target_repository = oss_fuzz_checkout.get_project_repository( + self.benchmark.project + ) + if not test_source_code: + test_source_code = introspector.query_introspector_test_source( + self.benchmark.project, self.benchmark.test_file_path.replace("//", "/") + ) + + prompt_text = prompt_text.replace("{TARGET_REPO}", target_repository) + prompt_text = prompt_text.replace("{TEST_SOURCE_CODE}", test_source_code) + + language_text = self.language_prompt.get(self.benchmark.language.lower(), "") + prompt_text = prompt_text.replace("{PROGRAMMING_LANGUAGE_TEXT}", language_text) + + if self.benchmark.language == "jvm": + prompt_text = prompt_text.replace("{HEADER_FILE_LANG}", "") + prompt_text = prompt_text.replace("{HARNESS_HEADERS}", "") + + # Fuzz Introspector use JVM as it support other JVM languages in addition + # to Java. Currently, the logic in OSS-Fuzz-Gen is only working on Java. + prompt_text = prompt_text.replace("{PROG_LANG}", "Java") + + # Provide list of public classes of this project + classes = introspector.query_introspector_public_classes( + self.benchmark.project + ) + prompt_text = prompt_text.replace("{PUBLIC_CLASSES}", ",".join(classes)) + + # Proivde sample harness code + harness_sample_text = ( + "There are already harnesses targeting this " + "project, and an example of this is:\n" + f"\n{self.harness_source_code}\n" + ) + prompt_text = prompt_text.replace( + "{TARGET_SAMPLE_HARNESS}", harness_sample_text + ) + + # Extract must included methods + methods = self._get_jvm_public_candidates(self.benchmark.project) + prompt_text = prompt_text.replace("{PUBLIC_METHODS}", ",".join(methods)) + + # Extract import list + other_import_list = self._extract_jvm_imports(test_source_code, classes) + prompt_text = prompt_text.replace( + "{IMPORT_STATEMENTS}", "\n".join(self.general_jvm_imports) + ) + prompt_text = prompt_text.replace( + "{OTHER_IMPORT_STATEMENTS}", "\n".join(other_import_list) + ) + else: + included_header_files = self.extract_header_files(test_source_code) + if included_header_files: + harness_included_header_files = ( + "The following header files are used in the " + "test source code. Please make sure to include the same ones: " + f"{included_header_files}" + ) + else: + harness_included_header_files = "" + prompt_text = prompt_text.replace( + "{HARNESS_HEADERS}", harness_included_header_files + ) + + headers_to_include = introspector.query_introspector_header_files( + self.benchmark.project + ) + if headers_to_include: + header_inclusion_string = "\n" + header_inclusion_string += "".join( + f"{h}\n" for h in headers_to_include + ) + + header_inclusion_string += "\n" + header_inclusion_string = ( + "The following header files exist in the project source code. " + "If the harness you create needs any header files make sure " + "they are in the list:\n" + f"{header_inclusion_string}" + ) + else: + header_inclusion_string = "" + prompt_text = prompt_text.replace( + "{HEADER_FILE_LANG}", header_inclusion_string + ) + prompt_text = prompt_text.replace("{PROG_LANG}", self.benchmark.language) + + harness_sample_text = ( + "There are already harnesses targeting this " + "project, and an example of this is:\n" + f"{self.harness_source_code}" + ) + prompt_text = prompt_text.replace( + "{TARGET_SAMPLE_HARNESS}", harness_sample_text + ) + + self._prompt.add_priming(prompt_text) + return self._prompt + + def build_fixer_prompt( + self, + benchmark: Benchmark, + raw_code: str, + error_desc: Optional[str], + errors: list[str], + ) -> prompts.Prompt: + """Prepares the code-fixing prompt.""" + return self._prompt + + def build_triager_prompt( + self, benchmark: Benchmark, driver_code: str, crash_info: str, crash_func: dict + ) -> prompts.Prompt: + """Builds a triager prompt.""" + return self._prompt + + def post_process_generated_code(self, generated_code: str) -> str: + """Adds specific C headers we always want in the harnesses for C/C++. + Add general import statements and remove unnecessary statments for JVM""" + if self.benchmark.language.lower() == "jvm": + # For JVM + # Remove assert and out statements + fixed_code = [] + prefixes = ["assert", "System.out"] + for line in generated_code.split("\n"): + if not any(line.lstrip().startswith(prefix) for prefix in prefixes): + fixed_code.append(line) + + # Add general import statements + import_str = "\n".join(self.general_jvm_imports) + generated_code = "\n".join(fixed_code) + generated_code = f"{import_str}\n{generated_code}" + else: + # For C/C++ + for header in C_PROMPT_HEADERS_TO_ALWAYS_INCLUDES: + generated_code = f"#include <{header}>\n{generated_code}" + generated_code += "\n" + if self.benchmark.language.lower() == "c": + generated_code = generated_code.replace( + 'extern "C" int LLVMFuzzerTestOneInput', + "int LLVMFuzzerTestOneInput", + ) + + return generated_code diff --git a/llm_toolkit/prompts.py b/llm_toolkit/prompts.py index 6459045b49..08ed8db9f6 100644 --- a/llm_toolkit/prompts.py +++ b/llm_toolkit/prompts.py @@ -23,175 +23,184 @@ class Prompt: - """Base prompt.""" + """Base prompt.""" - def __init__(self, initial=None): - """Constructor.""" + def __init__(self, initial=None): + """Constructor.""" - @abstractmethod - def append(self, text: str, to_existing: bool = False) -> None: - """Appends to the formatted prompt.""" + @abstractmethod + def append(self, text: str, to_existing: bool = False) -> None: + """Appends to the formatted prompt.""" - @abstractmethod - def get(self) -> Any: - """Gets the final formatted prompt.""" + @abstractmethod + def get(self) -> Any: + """Gets the final formatted prompt.""" - @abstractmethod - def gettext(self) -> Any: - """Gets the final formatted prompt in plain text.""" + @abstractmethod + def gettext(self) -> Any: + """Gets the final formatted prompt in plain text.""" - @abstractmethod - def create_prompt_piece(self, content: str, role: str) -> Any: - """Creates prompt based on the |content| and |role|.""" + @abstractmethod + def create_prompt_piece(self, content: str, role: str) -> Any: + """Creates prompt based on the |content| and |role|.""" - @abstractmethod - def add_priming(self, priming_content: str) -> None: - """Adds |priming_content| to prompt.""" + @abstractmethod + def add_priming(self, priming_content: str) -> None: + """Adds |priming_content| to prompt.""" - @abstractmethod - def add_problem(self, problem_content: str) -> None: - """Adds |problem_content| to prompt.""" + @abstractmethod + def add_problem(self, problem_content: str) -> None: + """Adds |problem_content| to prompt.""" - @abstractmethod - def add_solution(self, solution_content: str) -> None: - """Adds |solution_content| to prompt.""" + @abstractmethod + def add_solution(self, solution_content: str) -> None: + """Adds |solution_content| to prompt.""" - @abstractmethod - def save(self, location: str) -> None: - """Saves the prompt to a filelocation.""" + @abstractmethod + def save(self, location: str) -> None: + """Saves the prompt to a filelocation.""" class TextPrompt(Prompt): - """Text-style prompts.""" + """Text-style prompts.""" - def __init__(self, initial=None): - if not initial: - initial = '' + def __init__(self, initial=None): + if not initial: + initial = "" - self._text = initial + self._text = initial - def append(self, text: str, to_existing: bool = False) -> None: - """Appends the final formatted prompt.""" - # TextPrompt only got one text element, ignoring to_existing flag - self._text += text + def append(self, text: str, to_existing: bool = False) -> None: + """Appends the final formatted prompt.""" + # TextPrompt only got one text element, ignoring to_existing flag + self._text += text - def get(self) -> Any: - """Gets the final formatted prompt.""" - return self._text + def get(self) -> Any: + """Gets the final formatted prompt.""" + return self._text - def gettext(self) -> Any: - """Gets the final formatted prompt in plain text.""" - return self.get() + def gettext(self) -> Any: + """Gets the final formatted prompt in plain text.""" + return self.get() - def add_priming(self, priming_content: str) -> None: - """Constructs the prompt priming in the required format.""" - self._text += f'{priming_content}\n' + def add_priming(self, priming_content: str) -> None: + """Constructs the prompt priming in the required format.""" + self._text += f"{priming_content}\n" - def add_problem(self, problem_content: str) -> None: - """Constructs the prompt problem in the required format.""" - self._text += f'{problem_content}\n' + def add_problem(self, problem_content: str) -> None: + """Constructs the prompt problem in the required format.""" + self._text += f"{problem_content}\n" - def add_solution(self, solution_content: str) -> None: - """Constructs the prompt solution in the required format.""" - self._text += f'{solution_content}\n' + def add_solution(self, solution_content: str) -> None: + """Constructs the prompt solution in the required format.""" + self._text += f"{solution_content}\n" - def create_prompt_piece(self, content: str, role: str) -> Any: - """Returns a prompt piece in the format wanted by Google.""" - # Ignore role, just return content - del role - # TODO(Dongge): Use role as XML tags. - return content + def create_prompt_piece(self, content: str, role: str) -> Any: + """Returns a prompt piece in the format wanted by Google.""" + # Ignore role, just return content + del role + # TODO(Dongge): Use role as XML tags. + return content - def save(self, location: str) -> None: - """Saves the prompt to a filelocation.""" - with open(location, 'w+') as prompt_file: - prompt_file.write(self.get()) + def save(self, location: str) -> None: + """Saves the prompt to a filelocation.""" + with open(location, "w+") as prompt_file: + prompt_file.write(self.get()) class OpenAIPrompt(Prompt): - """OpenAI style structured prompt.""" - - def __init__(self, initial=None): - if not initial: - initial = [] - - self._prompt = initial - - def get(self) -> Any: - """Gets the final formatted prompt.""" - return self._prompt - - def gettext(self) -> str: - """Gets the final formatted prompt in plain text.""" - result = '' - for item in self.get(): - result = (f'{result}\n{item.get("role", "Unknown")}:' - f'\n{item.get("content", "")}') - - return result - - def add_priming(self, priming_content: str) -> None: - """Constructs the prompt priming in the required format.""" - if not priming_content: - logger.warning('Content is empty, skipping the prompt append process') - return - - self._prompt.append({ - 'role': 'system', - 'content': priming_content, - }) - - def add_problem(self, problem_content: str) -> None: - """Constructs the prompt problem in the required format.""" - if not problem_content: - logger.warning('Content is empty, skipping the prompt append process') - return - - self._prompt.append({ - 'role': 'user', - 'content': problem_content, - }) - - def add_solution(self, solution_content: str) -> None: - """Constructs the prompt solution in the required format.""" - if not solution_content: - logger.warning('Content is empty, skipping the prompt append process') - return - - self._prompt.append({ - 'role': 'assistant', - 'content': solution_content, - }) - - def create_prompt_piece(self, content: str, role: str) -> Any: - """Returns a prompt piece in the format wanted by OpenAI.""" - # TODO(mihaimaruseac): We might want to consider stripping the XML tags - # here? The roles kind of simulate them. - if not content or not role: - logger.warning('Content or role is empty, ' - 'skipping the prompt append process') - return [] - - return [{'role': role, 'content': content}] - - def save(self, location: str) -> None: - """Saves the prompt to a filelocation.""" - with open(location, 'w+') as prompt_file: - json.dump(self._prompt, prompt_file) - - def append(self, text: str, to_existing: bool = False) -> None: - """Appends to the formatted prompt.""" - if to_existing and self._prompt: - # With to_existing flag, attach the string to the original content - # of the existing prompt - self._prompt[-1]['content'] += text - elif self._prompt: - # With no to_existing flag, append a new prompt with role user - self.add_problem(text) - else: - # There are no prompt exists, append the text as priming prompt - self.add_priming(text) + """OpenAI style structured prompt.""" + + def __init__(self, initial=None): + if not initial: + initial = [] + + self._prompt = initial + + def get(self) -> Any: + """Gets the final formatted prompt.""" + return self._prompt + + def gettext(self) -> str: + """Gets the final formatted prompt in plain text.""" + result = "" + for item in self.get(): + result = ( + f'{result}\n{item.get("role", "Unknown")}:' + f'\n{item.get("content", "")}' + ) + + return result + + def add_priming(self, priming_content: str) -> None: + """Constructs the prompt priming in the required format.""" + if not priming_content: + logger.warning("Content is empty, skipping the prompt append process") + return + + self._prompt.append( + { + "role": "system", + "content": priming_content, + } + ) + + def add_problem(self, problem_content: str) -> None: + """Constructs the prompt problem in the required format.""" + if not problem_content: + logger.warning("Content is empty, skipping the prompt append process") + return + + self._prompt.append( + { + "role": "user", + "content": problem_content, + } + ) + + def add_solution(self, solution_content: str) -> None: + """Constructs the prompt solution in the required format.""" + if not solution_content: + logger.warning("Content is empty, skipping the prompt append process") + return + + self._prompt.append( + { + "role": "assistant", + "content": solution_content, + } + ) + + def create_prompt_piece(self, content: str, role: str) -> Any: + """Returns a prompt piece in the format wanted by OpenAI.""" + # TODO(mihaimaruseac): We might want to consider stripping the XML tags + # here? The roles kind of simulate them. + if not content or not role: + logger.warning( + "Content or role is empty, " "skipping the prompt append process" + ) + return [] + + return [{"role": role, "content": content}] + + def save(self, location: str) -> None: + """Saves the prompt to a filelocation.""" + with open(location, "w+") as prompt_file: + json.dump(self._prompt, prompt_file) + + def append(self, text: str, to_existing: bool = False) -> None: + """Appends to the formatted prompt.""" + if to_existing and self._prompt: + # With to_existing flag, attach the string to the original content + # of the existing prompt + self._prompt[-1]["content"] += text + elif self._prompt: + # With no to_existing flag, append a new prompt with role user + self.add_problem(text) + else: + # There are no prompt exists, append the text as priming prompt + self.add_priming(text) class ClaudePrompt(OpenAIPrompt): - """Claude style structured prompt.""" + """Claude style structured prompt.""" diff --git a/logger.py b/logger.py index 7ce2fb2dd1..6ca97bf6a7 100644 --- a/logger.py +++ b/logger.py @@ -25,185 +25,206 @@ from results import Result, RunResult, TrialResult -FINAL_RESULT_JSON = 'result.json' +FINAL_RESULT_JSON = "result.json" class CustomLoggerAdapter(logging.LoggerAdapter): - """A note-taker to log and record experiment status, key info, and final - results.""" - - def process(self, msg, kwargs): - # Combine 'extra' dictionaries and modify the message - kwargs['extra'] = {**(self.extra or {}), **(kwargs.get('extra') or {})} - return msg, kwargs - - def write_to_file(self, - file_path: str, - file_content: str, - mode: str = 'a') -> None: - """Writes the |file_content| into a local |file_path|.""" - with open(file_path, mode) as file: - file.writelines(file_content) - - def write_fuzz_target(self, result: Result) -> None: - """Writes fuzz target.""" - fuzz_target_path = os.path.join(result.work_dirs.fuzz_targets, - f'{result.trial:02d}.fuzz_target') - self.write_to_file(fuzz_target_path, result.fuzz_target_source, 'w') - - def write_build_script(self, result: Result) -> None: - """Writes build script.""" - build_script_path = os.path.join(result.work_dirs.fuzz_targets, - f'{result.trial:02d}.build_script') - self.write_to_file(build_script_path, result.build_script_source, 'w') - - def write_result(self, - result_status_dir: str, - result: TrialResult, - finished: bool = False) -> None: - """Writes the final result into JSON for report generation.""" - trial_result_dir = os.path.join(result_status_dir, f'{result.trial:02d}') - os.makedirs(trial_result_dir, exist_ok=True) - with open(os.path.join(trial_result_dir, FINAL_RESULT_JSON), 'w') as f: - json.dump(result.to_dict() | {'finished': finished}, f) - - def write_chat_history(self, result: Result) -> None: - """Writes chat history.""" - # TODO(dongge): Find a proper way to write this. - trial_result_dir = os.path.join(result.work_dirs.status, - f'{result.trial:02d}') - os.makedirs(trial_result_dir, exist_ok=True) - chat_history_path = os.path.join(trial_result_dir, 'log.txt') - chat_history = '\n'.join( - f'\n\n\n************************{agent_name}************************\n' - f'{chat_history}\n' - for agent_name, chat_history in result.chat_history.items()) - self.write_to_file(chat_history_path, chat_history) - - def download_gcs_file(self, local_path: str, gs_url: str) -> bool: - """Downloads a file from Google Cloud storage to a local file.""" - parsed_url = urlparse(gs_url) - if parsed_url.scheme != "gs": - logging.error("URL must start with 'gs://': %s", parsed_url) - - bucket_name = parsed_url.netloc - blob_name = parsed_url.path.lstrip("/") - - client = storage.Client() - bucket = client.bucket(bucket_name) - blob = bucket.blob(blob_name) - - if blob.exists(): - # Download blob to a temporary file - with tempfile.NamedTemporaryFile(delete=False) as tmp: - tmp_path = tmp.name - blob.download_to_filename(tmp_path) - # Append the temporary file's content to the local file - with open(tmp_path, 'rb') as tmp_file, open(local_path, - 'ab') as local_file: - local_file.write(tmp_file.read()) - - os.remove(tmp_path) - return True - return False - - def download_run_log(self, result: RunResult) -> None: - local_run_log_path = os.path.join(result.work_dirs.run_logs, - f'{result.trial:02d}.log') - if self.download_gcs_file(local_run_log_path, result.run_log): - info('Downloading cloud run log: %s to %s', - result.log_path, - local_run_log_path, - trial=result.trial) - else: - warning('Cloud run log gsc file does not exit: %s to %s', - result.log_path, - local_run_log_path, - trial=result.trial) - - -def debug(msg: object, - *args: object, - trial: int, - exc_info=None, - stack_info: bool = False, - stacklevel: int = 1, - extra: Mapping[str, object] | None = None, - **kwargs: object) -> None: - return get_trial_logger(trial=trial).debug(msg, - *args, - exc_info=exc_info, - stack_info=stack_info, - stacklevel=stacklevel, - extra=extra, - **kwargs) - - -def info(msg: object, - *args: object, - trial: int, - exc_info=None, - stack_info: bool = False, - stacklevel: int = 1, - extra: Mapping[str, object] | None = None, - **kwargs: object) -> None: - return get_trial_logger(trial=trial).info(msg, - *args, - exc_info=exc_info, - stack_info=stack_info, - stacklevel=stacklevel, - extra=extra, - **kwargs) - - -def warning(msg: object, - *args: object, - trial: int, - exc_info=None, - stack_info: bool = False, - stacklevel: int = 1, - extra: Mapping[str, object] | None = None, - **kwargs: object) -> None: - return get_trial_logger(trial=trial).warning(msg, - *args, - exc_info=exc_info, - stack_info=stack_info, - stacklevel=stacklevel, - extra=extra, - **kwargs) - - -def error(msg: object, - *args: object, - trial: int, - exc_info=None, - stack_info: bool = False, - stacklevel: int = 1, - extra: Mapping[str, object] | None = None, - **kwargs: object) -> None: - return get_trial_logger(trial=trial).error(msg, - *args, - exc_info=exc_info, - stack_info=stack_info, - stacklevel=stacklevel, - extra=extra, - **kwargs) - - -def get_trial_logger(name: str = __name__, - trial: int = 0, - level=logging.DEBUG) -> CustomLoggerAdapter: - """Sets up or retrieves a thread-local CustomLoggerAdapter for each thread.""" - logger = logging.getLogger(name) - if not logger.handlers: - formatter = logging.Formatter( - fmt=('%(asctime)s [Trial ID: %(trial)02d] %(levelname)s ' - '[%(module)s.%(funcName)s]: %(message)s'), - datefmt='%Y-%m-%d %H:%M:%S') - handler = logging.StreamHandler() - handler.setFormatter(formatter) - logger.addHandler(handler) - logger.setLevel(level) - logger.propagate = False - - return CustomLoggerAdapter(logger, {'trial': trial}) + """A note-taker to log and record experiment status, key info, and final + results.""" + + def process(self, msg, kwargs): + # Combine 'extra' dictionaries and modify the message + kwargs["extra"] = {**(self.extra or {}), **(kwargs.get("extra") or {})} + return msg, kwargs + + def write_to_file(self, file_path: str, file_content: str, mode: str = "a") -> None: + """Writes the |file_content| into a local |file_path|.""" + with open(file_path, mode) as file: + file.writelines(file_content) + + def write_fuzz_target(self, result: Result) -> None: + """Writes fuzz target.""" + fuzz_target_path = os.path.join( + result.work_dirs.fuzz_targets, f"{result.trial:02d}.fuzz_target" + ) + self.write_to_file(fuzz_target_path, result.fuzz_target_source, "w") + + def write_build_script(self, result: Result) -> None: + """Writes build script.""" + build_script_path = os.path.join( + result.work_dirs.fuzz_targets, f"{result.trial:02d}.build_script" + ) + self.write_to_file(build_script_path, result.build_script_source, "w") + + def write_result( + self, result_status_dir: str, result: TrialResult, finished: bool = False + ) -> None: + """Writes the final result into JSON for report generation.""" + trial_result_dir = os.path.join(result_status_dir, f"{result.trial:02d}") + os.makedirs(trial_result_dir, exist_ok=True) + with open(os.path.join(trial_result_dir, FINAL_RESULT_JSON), "w") as f: + json.dump(result.to_dict() | {"finished": finished}, f) + + def write_chat_history(self, result: Result) -> None: + """Writes chat history.""" + # TODO(dongge): Find a proper way to write this. + trial_result_dir = os.path.join(result.work_dirs.status, f"{result.trial:02d}") + os.makedirs(trial_result_dir, exist_ok=True) + chat_history_path = os.path.join(trial_result_dir, "log.txt") + chat_history = "\n".join( + f"\n\n\n************************{agent_name}************************\n" + f"{chat_history}\n" + for agent_name, chat_history in result.chat_history.items() + ) + self.write_to_file(chat_history_path, chat_history) + + def download_gcs_file(self, local_path: str, gs_url: str) -> bool: + """Downloads a file from Google Cloud storage to a local file.""" + parsed_url = urlparse(gs_url) + if parsed_url.scheme != "gs": + logging.error("URL must start with 'gs://': %s", parsed_url) + + bucket_name = parsed_url.netloc + blob_name = parsed_url.path.lstrip("/") + + client = storage.Client() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name) + + if blob.exists(): + # Download blob to a temporary file + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_path = tmp.name + blob.download_to_filename(tmp_path) + # Append the temporary file's content to the local file + with open(tmp_path, "rb") as tmp_file, open(local_path, "ab") as local_file: + local_file.write(tmp_file.read()) + + os.remove(tmp_path) + return True + return False + + def download_run_log(self, result: RunResult) -> None: + local_run_log_path = os.path.join( + result.work_dirs.run_logs, f"{result.trial:02d}.log" + ) + if self.download_gcs_file(local_run_log_path, result.run_log): + info( + "Downloading cloud run log: %s to %s", + result.log_path, + local_run_log_path, + trial=result.trial, + ) + else: + warning( + "Cloud run log gsc file does not exit: %s to %s", + result.log_path, + local_run_log_path, + trial=result.trial, + ) + + +def debug( + msg: object, + *args: object, + trial: int, + exc_info=None, + stack_info: bool = False, + stacklevel: int = 1, + extra: Mapping[str, object] | None = None, + **kwargs: object, +) -> None: + return get_trial_logger(trial=trial).debug( + msg, + *args, + exc_info=exc_info, + stack_info=stack_info, + stacklevel=stacklevel, + extra=extra, + **kwargs, + ) + + +def info( + msg: object, + *args: object, + trial: int, + exc_info=None, + stack_info: bool = False, + stacklevel: int = 1, + extra: Mapping[str, object] | None = None, + **kwargs: object, +) -> None: + return get_trial_logger(trial=trial).info( + msg, + *args, + exc_info=exc_info, + stack_info=stack_info, + stacklevel=stacklevel, + extra=extra, + **kwargs, + ) + + +def warning( + msg: object, + *args: object, + trial: int, + exc_info=None, + stack_info: bool = False, + stacklevel: int = 1, + extra: Mapping[str, object] | None = None, + **kwargs: object, +) -> None: + return get_trial_logger(trial=trial).warning( + msg, + *args, + exc_info=exc_info, + stack_info=stack_info, + stacklevel=stacklevel, + extra=extra, + **kwargs, + ) + + +def error( + msg: object, + *args: object, + trial: int, + exc_info=None, + stack_info: bool = False, + stacklevel: int = 1, + extra: Mapping[str, object] | None = None, + **kwargs: object, +) -> None: + return get_trial_logger(trial=trial).error( + msg, + *args, + exc_info=exc_info, + stack_info=stack_info, + stacklevel=stacklevel, + extra=extra, + **kwargs, + ) + + +def get_trial_logger( + name: str = __name__, trial: int = 0, level=logging.DEBUG +) -> CustomLoggerAdapter: + """Sets up or retrieves a thread-local CustomLoggerAdapter for each thread.""" + logger = logging.getLogger(name) + if not logger.handlers: + formatter = logging.Formatter( + fmt=( + "%(asctime)s [Trial ID: %(trial)02d] %(levelname)s " + "[%(module)s.%(funcName)s]: %(message)s" + ), + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(level) + logger.propagate = False + + return CustomLoggerAdapter(logger, {"trial": trial}) diff --git a/pipeline.py b/pipeline.py index 13834ebbe4..c6c781c395 100644 --- a/pipeline.py +++ b/pipeline.py @@ -23,8 +23,8 @@ from stage.writing_stage import WritingStage -class Pipeline(): - """The fuzzing main pipeline, consisting of three iterative stages: +class Pipeline: + """The fuzzing main pipeline, consisting of three iterative stages: 1. Writing stage generates or refines the fuzz target and its associated build script to improve code coverage and enhance bug-finding capabilities for the function under test. @@ -35,123 +35,157 @@ class Pipeline(): writing stage in the next iteration. """ - def __init__(self, - args: argparse.Namespace, - trial: int, - writing_stage_agents: Optional[list[BaseAgent]] = None, - execution_stage_agents: Optional[list[BaseAgent]] = None, - analysis_stage_agents: Optional[list[BaseAgent]] = None): - self.args = args - self.trial = trial - self.logger = logger.get_trial_logger(trial=trial) - self.logger.debug('Pipeline Initialized') - self.writing_stage: WritingStage = WritingStage(args, trial, - writing_stage_agents) - self.execution_stage: ExecutionStage = ExecutionStage( - args, trial, execution_stage_agents) - self.analysis_stage: AnalysisStage = AnalysisStage(args, trial, - analysis_stage_agents) - - def _terminate(self, result_history: list[Result], cycle_count: int) -> bool: - """Validates if the termination conditions have been satisfied.""" - if not cycle_count: - return False - - if cycle_count > 5: - self.logger.info('[Cycle %d] Terminate after 5 cycles: %s', cycle_count, - result_history) - return True - - last_result = result_history[-1] - if isinstance(last_result, BuildResult) and not last_result.success: - self.logger.debug('[Cycle %d] Last result is failed BuildResult: %s', - cycle_count, last_result) - return True - - if not isinstance(last_result, AnalysisResult): - self.logger.warning('[Cycle %d] Last result is not AnalysisResult: %s', - cycle_count, result_history) - return True - - if last_result.success: - self.logger.info('[Cycle %d] Generation succeeds: %s', cycle_count, - result_history) - return True - - if isinstance(last_result, AnalysisResult) and not last_result.success: - self.logger.info('[Cycle %d] Generation continues: %s', cycle_count, - result_history) - return False - - self.logger.warning('[Cycle %d] Last result is unexpected: %s', cycle_count, - last_result) - return True - - def _update_status(self, - result_history: list[Result], - finished: bool = False) -> None: - trial_result = TrialResult(benchmark=result_history[-1].benchmark, - trial=self.trial, - work_dirs=result_history[-1].work_dirs, - result_history=result_history) - self.logger.write_result( - result_status_dir=trial_result.best_result.work_dirs.status, - result=trial_result, - finished=finished) - - def _execute_one_cycle(self, result_history: list[Result], - cycle_count: int) -> None: - """Executes the stages once.""" - self.logger.info('[Cycle %d] Initial result is %s', cycle_count, - result_history[-1]) - # Writing stage. - result_history.append( - self.writing_stage.execute(result_history=result_history)) - self._update_status(result_history=result_history) - if (not isinstance(result_history[-1], BuildResult) or - not result_history[-1].success): - self.logger.warning('[Cycle %d] Build failure, skipping the rest steps', - cycle_count) - return - - # Execution stage. - result_history.append( - self.execution_stage.execute(result_history=result_history)) - self._update_status(result_history=result_history) - if (not isinstance(result_history[-1], RunResult) or - not result_history[-1].log_path): - self.logger.warning('[Cycle %d] Run failure, skipping the rest steps', - cycle_count) - return - - # Analysis stage. - result_history.append( - self.analysis_stage.execute(result_history=result_history)) - # TODO(maoyi): add the indicator for the success of analysis stage - if not isinstance(result_history[-1], AnalysisResult): - self.logger.warning( - '[Cycle %d] Analysis failure, skipping the rest steps', cycle_count) - return - self._update_status(result_history=result_history) - self.logger.info('[Cycle %d] Analysis result %s: %s', cycle_count, - result_history[-1].success, result_history[-1]) - - def execute(self, result_history: list[Result]) -> list[Result]: - """ - Runs the fuzzing pipeline iteratively to assess and refine the fuzz target. - 1. Writing Stage refines the fuzz target and its build script using insights - from the previous cycle. - 2. Execution Stage measures the performance of the revised fuzz target. - 3. Analysis Stage examines the execution results to guide the next cycle's - improvements. - The process repeats until the termination conditions are met. - """ - self.logger.debug('Pipeline starts') - cycle_count = 0 - self._update_status(result_history=result_history) - while not self._terminate(result_history=result_history, - cycle_count=cycle_count): - cycle_count += 1 - self._execute_one_cycle(result_history=result_history, - cycle_count=cycle_count) - return result_history + def __init__( + self, + args: argparse.Namespace, + trial: int, + writing_stage_agents: Optional[list[BaseAgent]] = None, + execution_stage_agents: Optional[list[BaseAgent]] = None, + analysis_stage_agents: Optional[list[BaseAgent]] = None, + ): + self.args = args + self.trial = trial + self.logger = logger.get_trial_logger(trial=trial) + self.logger.debug("Pipeline Initialized") + self.writing_stage: WritingStage = WritingStage( + args, trial, writing_stage_agents + ) + self.execution_stage: ExecutionStage = ExecutionStage( + args, trial, execution_stage_agents + ) + self.analysis_stage: AnalysisStage = AnalysisStage( + args, trial, analysis_stage_agents + ) + + def _terminate(self, result_history: list[Result], cycle_count: int) -> bool: + """Validates if the termination conditions have been satisfied.""" + if not cycle_count: + return False + + if cycle_count > 5: + self.logger.info( + "[Cycle %d] Terminate after 5 cycles: %s", cycle_count, result_history + ) + return True + + last_result = result_history[-1] + if isinstance(last_result, BuildResult) and not last_result.success: + self.logger.debug( + "[Cycle %d] Last result is failed BuildResult: %s", + cycle_count, + last_result, + ) + return True + + if not isinstance(last_result, AnalysisResult): + self.logger.warning( + "[Cycle %d] Last result is not AnalysisResult: %s", + cycle_count, + result_history, + ) + return True + + if last_result.success: + self.logger.info( + "[Cycle %d] Generation succeeds: %s", cycle_count, result_history + ) + return True + + if isinstance(last_result, AnalysisResult) and not last_result.success: + self.logger.info( + "[Cycle %d] Generation continues: %s", cycle_count, result_history + ) + return False + + self.logger.warning( + "[Cycle %d] Last result is unexpected: %s", cycle_count, last_result + ) + return True + + def _update_status( + self, result_history: list[Result], finished: bool = False + ) -> None: + trial_result = TrialResult( + benchmark=result_history[-1].benchmark, + trial=self.trial, + work_dirs=result_history[-1].work_dirs, + result_history=result_history, + ) + self.logger.write_result( + result_status_dir=trial_result.best_result.work_dirs.status, + result=trial_result, + finished=finished, + ) + + def _execute_one_cycle( + self, result_history: list[Result], cycle_count: int + ) -> None: + """Executes the stages once.""" + self.logger.info( + "[Cycle %d] Initial result is %s", cycle_count, result_history[-1] + ) + # Writing stage. + result_history.append(self.writing_stage.execute(result_history=result_history)) + self._update_status(result_history=result_history) + if ( + not isinstance(result_history[-1], BuildResult) + or not result_history[-1].success + ): + self.logger.warning( + "[Cycle %d] Build failure, skipping the rest steps", cycle_count + ) + return + + # Execution stage. + result_history.append( + self.execution_stage.execute(result_history=result_history) + ) + self._update_status(result_history=result_history) + if ( + not isinstance(result_history[-1], RunResult) + or not result_history[-1].log_path + ): + self.logger.warning( + "[Cycle %d] Run failure, skipping the rest steps", cycle_count + ) + return + + # Analysis stage. + result_history.append( + self.analysis_stage.execute(result_history=result_history) + ) + # TODO(maoyi): add the indicator for the success of analysis stage + if not isinstance(result_history[-1], AnalysisResult): + self.logger.warning( + "[Cycle %d] Analysis failure, skipping the rest steps", cycle_count + ) + return + self._update_status(result_history=result_history) + self.logger.info( + "[Cycle %d] Analysis result %s: %s", + cycle_count, + result_history[-1].success, + result_history[-1], + ) + + def execute(self, result_history: list[Result]) -> list[Result]: + """ + Runs the fuzzing pipeline iteratively to assess and refine the fuzz target. + 1. Writing Stage refines the fuzz target and its build script using insights + from the previous cycle. + 2. Execution Stage measures the performance of the revised fuzz target. + 3. Analysis Stage examines the execution results to guide the next cycle's + improvements. + The process repeats until the termination conditions are met. + """ + self.logger.debug("Pipeline starts") + cycle_count = 0 + self._update_status(result_history=result_history) + while not self._terminate( + result_history=result_history, cycle_count=cycle_count + ): + cycle_count += 1 + self._execute_one_cycle( + result_history=result_history, cycle_count=cycle_count + ) + return result_history diff --git a/prompts/agent/coverage-analyzer-priming.txt b/prompts/agent/coverage-analyzer-priming.txt index 4c665a35b3..2ddfde29c0 100644 --- a/prompts/agent/coverage-analyzer-priming.txt +++ b/prompts/agent/coverage-analyzer-priming.txt @@ -25,7 +25,7 @@ True The low coverage comes from the fact that the current fuzz target exercises only one very narrow code path—in this case, a single call to {FUNCTION_SIGNATURE} with naive argument derived directly from the input data. This approach misses many branches within the {PROJECT} because: -* Single Argument Limitation: By always providing a unprocessed and naive argument, the fuzz target never tests the handling of complex values, which likely involves additional logic (e.g., iterating over the array, handling edge cases like empty or very long tokens, and validating numeric conversions for lengths). +* Single Argument Limitation: By always providing an unprocessed and naive argument, the fuzz target never tests the handling of complex values, which likely involves additional logic (e.g., iterating over the array, handling edge cases like empty or very long tokens, and validating numeric conversions for lengths). * Lack of Input Variation: Since the fuzzer input is used verbatim as the only command argument, many conditional paths (e.g., those triggered by specific token contents or argument counts) remain untested. diff --git a/prompts/agent/enhancer-coverage-priming.txt b/prompts/agent/enhancer-coverage-priming.txt index ad35421bc0..4d170e0a65 100644 --- a/prompts/agent/enhancer-coverage-priming.txt +++ b/prompts/agent/enhancer-coverage-priming.txt @@ -47,7 +47,7 @@ For each input parameter, understand: Step 6: Plan Fuzz Target Implementation. Decide how to implement the refined fuzz target: -* The fuzz target can compile so your can reuse most of the code as a scaffold. +* The fuzz target can compile so you can reuse most of the code as a scaffold. * Only modify the parts caused the low coverage. * Prepare to output the FULL new fuzz target, do not leave out any code that is the same as before. * **Extract parameters** from the `data` and `size` variable of `LLVMFuzzerTestOneInput(const uint8_t *data, size_t size)`. diff --git a/report/aggregate_coverage_diff.py b/report/aggregate_coverage_diff.py index acafc4324f..eb95bac623 100644 --- a/report/aggregate_coverage_diff.py +++ b/report/aggregate_coverage_diff.py @@ -31,62 +31,62 @@ def compute_coverage_diff(project: str, coverage_links: list[str]): - existing_textcov = evaluator.load_existing_textcov(project) - coverage_summary = evaluator.load_existing_coverage_summary(project) - - # Can't use an anonymous client here as the coverage links may be on private - # buckets. - storage_client = storage.Client() - new_textcov = textcov.Textcov() - - for coverage_link in coverage_links: - path = coverage_link.removeprefix('gs://').split('/') - bucket = storage_client.bucket(path[0]) - textcovs_path = '/'.join(path[1:] + ['textcov_reports']) - - blobs = storage_client.list_blobs(bucket, - prefix=f'{textcovs_path}/', - delimiter='/') - for blob in blobs: - logging.info('Loading %s', blob.name) - with blob.open() as f: - # TODO: skip other functions defined the target. - new_textcov.merge(textcov.Textcov.from_file(f)) - - new_textcov.subtract_covered_lines(existing_textcov) - try: - total_lines = coverage_summary['data'][0]['totals']['lines']['count'] - except KeyError: - total_lines = 1 - - return new_textcov.covered_lines / total_lines - #print(f'{project}:', new_textcov.covered_lines / total_lines) + existing_textcov = evaluator.load_existing_textcov(project) + coverage_summary = evaluator.load_existing_coverage_summary(project) + + # Can't use an anonymous client here as the coverage links may be on private + # buckets. + storage_client = storage.Client() + new_textcov = textcov.Textcov() + + for coverage_link in coverage_links: + path = coverage_link.removeprefix("gs://").split("/") + bucket = storage_client.bucket(path[0]) + textcovs_path = "/".join(path[1:] + ["textcov_reports"]) + + blobs = storage_client.list_blobs( + bucket, prefix=f"{textcovs_path}/", delimiter="/" + ) + for blob in blobs: + logging.info("Loading %s", blob.name) + with blob.open() as f: + # TODO: skip other functions defined the target. + new_textcov.merge(textcov.Textcov.from_file(f)) + + new_textcov.subtract_covered_lines(existing_textcov) + try: + total_lines = coverage_summary["data"][0]["totals"]["lines"]["count"] + except KeyError: + total_lines = 1 + + return new_textcov.covered_lines / total_lines + # print(f'{project}:', new_textcov.covered_lines / total_lines) def main(): - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.INFO) - project_coverages = {} + project_coverages = {} - data = json.load(sys.stdin) - for benchmark in data['benchmarks']: - # TODO(ochang): Properly store the project, as projects can have '-' in the name. - project = benchmark['benchmark'].split('-')[1] - report = benchmark.get('max_line_coverage_diff_report') - if report: - project_coverages.setdefault(project, []).append(report) + data = json.load(sys.stdin) + for benchmark in data["benchmarks"]: + # TODO(ochang): Properly store the project, as projects can have '-' in the name. + project = benchmark["benchmark"].split("-")[1] + report = benchmark.get("max_line_coverage_diff_report") + if report: + project_coverages.setdefault(project, []).append(report) - diffs = {} - for project, coverage_links in project_coverages.items(): - logging.info('Computing coverage diff for %s', project) - try: - diffs[project] = compute_coverage_diff(project, coverage_links) - except Exception: - logging.error('Failed to compute coverage for %s', project) - traceback.print_exc() + diffs = {} + for project, coverage_links in project_coverages.items(): + logging.info("Computing coverage diff for %s", project) + try: + diffs[project] = compute_coverage_diff(project, coverage_links) + except Exception: + logging.error("Failed to compute coverage for %s", project) + traceback.print_exc() - print(diffs) + print(diffs) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/report/common.py b/report/common.py index 55db0cb536..f5f38b5bee 100644 --- a/report/common.py +++ b/report/common.py @@ -32,715 +32,742 @@ MAX_RUN_LOGS_LEN = 16 * 1024 -TARGET_EXTS = project_src.SEARCH_EXTS + ['.java', '.py', '.rs' - ] + ['.fuzz_target'] +TARGET_EXTS = project_src.SEARCH_EXTS + [".java", ".py", ".rs"] + [".fuzz_target"] -_CHAT_PROMPT_START_MARKER = re.compile(r'') -_CHAT_PROMPT_END_MARKER = re.compile(r'') -_CHAT_RESPONSE_START_MARKER = re.compile(r'') -_CHAT_RESPONSE_END_MARKER = re.compile(r'') +_CHAT_PROMPT_START_MARKER = re.compile(r"") +_CHAT_PROMPT_END_MARKER = re.compile(r"") +_CHAT_RESPONSE_START_MARKER = re.compile(r"") +_CHAT_RESPONSE_END_MARKER = re.compile(r"") @dataclasses.dataclass class AccumulatedResult: - """Container for storing accumulated results.""" - compiles: int = 0 - crashes: int = 0 - crash_cases: int = 0 - total_runs: int = 0 - total_coverage: float = 0.0 - total_line_coverage_diff: float = 0.0 + """Container for storing accumulated results.""" - @property - def average_coverage(self) -> float: - return self.total_coverage / float(self.total_runs) + compiles: int = 0 + crashes: int = 0 + crash_cases: int = 0 + total_runs: int = 0 + total_coverage: float = 0.0 + total_line_coverage_diff: float = 0.0 - @property - def average_line_coverage_diff(self) -> float: - return self.total_line_coverage_diff / float(self.total_runs) + @property + def average_coverage(self) -> float: + return self.total_coverage / float(self.total_runs) - @property - def build_rate(self) -> float: - return float(self.compiles) / float(self.total_runs) + @property + def average_line_coverage_diff(self) -> float: + return self.total_line_coverage_diff / float(self.total_runs) + + @property + def build_rate(self) -> float: + return float(self.compiles) / float(self.total_runs) @dataclasses.dataclass class Benchmark: - """The class of a benchmark function and its experiment results.""" - id: str - status: str - result: run_one_experiment.AggregatedResult - signature: str = '' - project: str = '' - function: str = '' - language: str = '' + """The class of a benchmark function and its experiment results.""" + + id: str + status: str + result: run_one_experiment.AggregatedResult + signature: str = "" + project: str = "" + function: str = "" + language: str = "" @dataclasses.dataclass class Project: - """Results for a project entire.""" - name: str - count: int = 0 - success: int = 0 - coverage_gain: float = 0.0 - coverage_relative_gain: float = 0.0 - coverage_ofg_total_new_covered_lines = 0 - coverage_existing_total_covered_lines = 0 - coverage_existing_total_lines = 0 - coverage_ofg_total_covered_lines = 0 + """Results for a project entire.""" + + name: str + count: int = 0 + success: int = 0 + coverage_gain: float = 0.0 + coverage_relative_gain: float = 0.0 + coverage_ofg_total_new_covered_lines = 0 + coverage_existing_total_covered_lines = 0 + coverage_existing_total_lines = 0 + coverage_ofg_total_covered_lines = 0 @dataclasses.dataclass class Sample: - """Result of a fuzz target sample of a benchmark.""" - id: str - status: str - result: evaluator.Result - - @property - def stacktrace(self) -> str: - if not self.result: - return '' - reproducer_link = self.result.reproducer_path - return f'{reproducer_link}/stacktrace' - - @property - def target_binary(self) -> str: - if not self.result: - return '' - reproducer_link = self.result.reproducer_path - return f'{reproducer_link}/target_binary' - - @property - def reproducer(self) -> str: - if not self.result: - return '' - reproducer_link = self.result.reproducer_path - return f'{reproducer_link}/artifacts' - - @property - def run_log(self) -> str: - if not self.result: - return '' - reproducer_link = self.result.reproducer_path - return reproducer_link.removesuffix('reproducer') + 'run.log' + """Result of a fuzz target sample of a benchmark.""" + + id: str + status: str + result: evaluator.Result + + @property + def stacktrace(self) -> str: + if not self.result: + return "" + reproducer_link = self.result.reproducer_path + return f"{reproducer_link}/stacktrace" + + @property + def target_binary(self) -> str: + if not self.result: + return "" + reproducer_link = self.result.reproducer_path + return f"{reproducer_link}/target_binary" + + @property + def reproducer(self) -> str: + if not self.result: + return "" + reproducer_link = self.result.reproducer_path + return f"{reproducer_link}/artifacts" + + @property + def run_log(self) -> str: + if not self.result: + return "" + reproducer_link = self.result.reproducer_path + return reproducer_link.removesuffix("reproducer") + "run.log" @dataclasses.dataclass class Target: - code: str - fixer_prompt: Optional[str] = None - build_script_code: Optional[str] = None + code: str + fixer_prompt: Optional[str] = None + build_script_code: Optional[str] = None @dataclasses.dataclass class Triage: - result: str - triager_prompt: str + result: str + triager_prompt: str @dataclasses.dataclass class LogPart: - chat_prompt: bool = False - chat_response: bool = False - content: str = '' + chat_prompt: bool = False + chat_response: bool = False + content: str = "" class FileSystem: - """ - FileSystem provides a wrapper over standard library and GCS client and - automatically chooses which to use based on the provided path. - """ - - _gcs_client = None - - @classmethod - def _get_gcs_client(cls): """ - Returns a cached storage client (a new one is created on first call. - - A new client does authentication on first call, so caching the client will - same multiple authentication round trips to GCP. - """ - if cls._gcs_client is None: - cls._gcs_client = storage.Client() - - return cls._gcs_client - - def __init__(self, path: str): - logging.debug('file operation %s', path) - self._path = path - self._gcs_bucket: Optional[storage.Bucket] = None - - if path.startswith('gs://'): - path = path.removeprefix('gs://') - self._gcs_bucket = FileSystem._get_gcs_client().bucket(path.split('/')[0]) - self._path = '/'.join(path.split('/')[1:]) - - def listdir(self) -> List[str]: - """listdir returns a list of files and directories in path.""" - if self._gcs_bucket is not None: - # Make sure the path ends with a /, otherwise GCS just returns the - # directory as a prefix and not list the contents. - prefix = self._path - if not self._path.endswith('/'): - prefix = f'{self._path}/' - - # Unfortunately GCS doesn't work like a normal file system and the client - # library doesn't even pretend there is a directory hierarchy. - # The list API does return a list of prefixes that we can join with the - # list of objects to get something close to listdir(). But client library - # is pretty weird and it stores the prefixes on the iterator... - # https://github.com/googleapis/python-storage/blob/64edbd922a605247203790a90f9536d54e3a705a/google/cloud/storage/client.py#L1356 - it = self._gcs_bucket.list_blobs(prefix=prefix, delimiter='/') - paths = [f.name for f in it] + [p.removesuffix('/') for p in it.prefixes] - r = [p.removeprefix(prefix) for p in paths] - return r - - return os.listdir(self._path) - - def exists(self) -> bool: - """exists returns true if the path is a file or directory.""" - if self._gcs_bucket is not None: - return self.isfile() or self.isdir() - - return os.path.exists(self._path) - - def isfile(self) -> bool: - """isfile returns true if the path is a file.""" - if self._gcs_bucket is not None: - return self._gcs_bucket.blob(self._path).exists() - - return os.path.isfile(self._path) - - def isdir(self) -> bool: - """isfile returns true if the path is a directory.""" - if self._gcs_bucket is not None: - return len(self.listdir()) > 0 - - return os.path.isdir(self._path) - - def makedirs(self): - """makedirs create parent(s) and directory in specified path.""" - if self._gcs_bucket is not None: - # Do nothing. GCS doesn't have directories and files can be created with - # any path. - return - - os.makedirs(self._path) - - def open(self, *args, **kwargs) -> io.IOBase: + FileSystem provides a wrapper over standard library and GCS client and + automatically chooses which to use based on the provided path. """ - open returns a file handle to the file located at the specified path. - - It has identical function signature to standard library open(). - """ - if self._gcs_bucket is not None: - return self._gcs_bucket.blob(self._path).open(*args, **kwargs) - - return open(self._path, *args, **kwargs) - def getsize(self) -> int: - """getsize returns the byte size of the file at the specified path.""" - if self._gcs_bucket is not None: - blob = self._gcs_bucket.get_blob(self._path) - if blob is None: - raise FileNotFoundError( - 'GCS blob not found gs://{self._gcs_bucket.bucket}/{self._path}.') - - # size can be None if use Bucket.blob() instead of Bucket.get_blob(). The - # type checker doesn't know this and insists we check if size is None. - return blob.size if blob.size is not None else 0 - - return os.path.getsize(self._path) + _gcs_client = None + + @classmethod + def _get_gcs_client(cls): + """ + Returns a cached storage client (a new one is created on first call. + + A new client does authentication on first call, so caching the client will + same multiple authentication round trips to GCP. + """ + if cls._gcs_client is None: + cls._gcs_client = storage.Client() + + return cls._gcs_client + + def __init__(self, path: str): + logging.debug("file operation %s", path) + self._path = path + self._gcs_bucket: Optional[storage.Bucket] = None + + if path.startswith("gs://"): + path = path.removeprefix("gs://") + self._gcs_bucket = FileSystem._get_gcs_client().bucket(path.split("/")[0]) + self._path = "/".join(path.split("/")[1:]) + + def listdir(self) -> List[str]: + """listdir returns a list of files and directories in path.""" + if self._gcs_bucket is not None: + # Make sure the path ends with a /, otherwise GCS just returns the + # directory as a prefix and not list the contents. + prefix = self._path + if not self._path.endswith("/"): + prefix = f"{self._path}/" + + # Unfortunately GCS doesn't work like a normal file system and the client + # library doesn't even pretend there is a directory hierarchy. + # The list API does return a list of prefixes that we can join with the + # list of objects to get something close to listdir(). But client library + # is pretty weird and it stores the prefixes on the iterator... + # https://github.com/googleapis/python-storage/blob/64edbd922a605247203790a90f9536d54e3a705a/google/cloud/storage/client.py#L1356 + it = self._gcs_bucket.list_blobs(prefix=prefix, delimiter="/") + paths = [f.name for f in it] + [p.removesuffix("/") for p in it.prefixes] + r = [p.removeprefix(prefix) for p in paths] + return r + + return os.listdir(self._path) + + def exists(self) -> bool: + """exists returns true if the path is a file or directory.""" + if self._gcs_bucket is not None: + return self.isfile() or self.isdir() + + return os.path.exists(self._path) + + def isfile(self) -> bool: + """isfile returns true if the path is a file.""" + if self._gcs_bucket is not None: + return self._gcs_bucket.blob(self._path).exists() + + return os.path.isfile(self._path) + + def isdir(self) -> bool: + """isfile returns true if the path is a directory.""" + if self._gcs_bucket is not None: + return len(self.listdir()) > 0 + + return os.path.isdir(self._path) + + def makedirs(self): + """makedirs create parent(s) and directory in specified path.""" + if self._gcs_bucket is not None: + # Do nothing. GCS doesn't have directories and files can be created with + # any path. + return + + os.makedirs(self._path) + + def open(self, *args, **kwargs) -> io.IOBase: + """ + open returns a file handle to the file located at the specified path. + + It has identical function signature to standard library open(). + """ + if self._gcs_bucket is not None: + return self._gcs_bucket.blob(self._path).open(*args, **kwargs) + + return open(self._path, *args, **kwargs) + + def getsize(self) -> int: + """getsize returns the byte size of the file at the specified path.""" + if self._gcs_bucket is not None: + blob = self._gcs_bucket.get_blob(self._path) + if blob is None: + raise FileNotFoundError( + "GCS blob not found gs://{self._gcs_bucket.bucket}/{self._path}." + ) + + # size can be None if use Bucket.blob() instead of Bucket.get_blob(). The + # type checker doesn't know this and insists we check if size is None. + return blob.size if blob.size is not None else 0 + + return os.path.getsize(self._path) class Results: - """Results provides functions to explore the experiment results in a - particular directory.""" - - def __init__(self, results_dir='results', benchmark_set='all'): - self._results_dir = results_dir - self._benchmark_dir = os.path.join('benchmark-sets', benchmark_set) - - def list_benchmark_ids(self) -> List[str]: - return sorted( - filter(self._is_valid_benchmark_dir, - FileSystem(self._results_dir).listdir())) - - def match_benchmark(self, benchmark_id: str, results: list[evaluator.Result], - targets: list[str]) -> Benchmark: - """Returns a benchmark class based on |benchmark_id|.""" - num_finished_trials = len([result for result in results if result.finished]) - status = 'Done' if num_finished_trials == len(results) else ( - f'Running ({num_finished_trials}/{len(results)})') - - filtered_results = [(i, stat) for i, stat in enumerate(results) if stat] - - if filtered_results: - result = run_one_experiment.aggregate_results(filtered_results, targets) - else: - result = run_one_experiment.AggregatedResult() - - return self._create_benchmark(benchmark_id, status, result) - - def get_final_target_code(self, benchmark: str, sample: str) -> str: - """Gets the targets of benchmark |benchmark| with sample ID |sample|.""" - targets_dir = os.path.join(self._results_dir, benchmark, 'fixed_targets') - # TODO(donggeliu): Make this consistent with agent output. - if not os.path.exists(targets_dir): - return '' - - for name in sorted(FileSystem(targets_dir).listdir()): - path = os.path.join(targets_dir, name) - if name.startswith(sample + '.') and FileSystem(path).isfile(): - with FileSystem(path).open() as f: - code = f.read() - code = json.dumps(code) - return code - return '' - - def get_logs(self, benchmark: str, sample: str) -> list[LogPart]: - status_dir = os.path.join(self._results_dir, benchmark, 'status') - results_path = os.path.join(status_dir, sample, 'log.txt') - if not FileSystem(results_path).exists(): - return [] - - with FileSystem(results_path).open() as f: - return _parse_log_parts(f.read()) - - def get_run_logs(self, benchmark: str, sample: str) -> str: - """Returns the content of the last run log.""" - run_logs_dir = os.path.join(self._results_dir, benchmark, 'logs', 'run') - largest_iteration, last_log_file = -1, None - for name in FileSystem(run_logs_dir).listdir(): - if name.startswith(sample + '.'): - iteration = WorkDirs.get_run_log_iteration(name) - if iteration is None: - # Be compatible with older results where there is no '-Fxx' in run - # log file name. - last_log_file = name - break - - if largest_iteration < iteration: - largest_iteration, last_log_file = iteration, name - - if not last_log_file: - return '' - - log_path = os.path.join(run_logs_dir, last_log_file) - log_size = FileSystem(log_path).getsize() - with FileSystem(log_path).open(errors='replace') as f: - if log_size <= MAX_RUN_LOGS_LEN: - return f.read() - - truncated_len = MAX_RUN_LOGS_LEN // 2 - logs_beginning = f.read(truncated_len) - f.seek(log_size - truncated_len - 1, os.SEEK_SET) - logs_ending = f.read() - - return logs_beginning + '\n...truncated...\n' + logs_ending - - return '' - - def get_triage(self, benchmark: str, sample: str) -> Triage: - """Gets the triage of benchmark |benchmark| with sample ID |sample|.""" - result = '' - triager_prompt = '' - fixed_dir = os.path.join(self._results_dir, benchmark, 'fixed_targets') - triage_dir = os.path.join(fixed_dir, f'{sample}-triage') - if not os.path.exists(triage_dir): - return Triage(result, triager_prompt) - - for name in os.listdir(triage_dir): - if name == 'prompt.txt': - with FileSystem(os.path.join(triage_dir, name)).open() as f: - triager_prompt = f.read() - - # Prepare prompt for being used in HTML. - triager_prompt = self._prepare_prompt_for_html_text(triager_prompt) - - if name.endswith('.txt') and name != 'prompt.txt': - triage_path = os.path.join(triage_dir, name) - with open(triage_path) as f: - result = f.read() - - return Triage(result, triager_prompt) - - def get_targets(self, benchmark: str, sample: str) -> list[Target]: - """Gets the targets of benchmark |benchmark| with sample ID |sample|.""" - return (self._get_targets(benchmark, sample) or - [self._get_targets_agent(benchmark, sample)]) - - def _get_targets(self, benchmark: str, sample: str) -> list[Target]: - """Gets the targets of benchmark |benchmark| with sample ID |sample| from - the OFG version 1 (single prompt).""" - targets_dir = os.path.join(self._results_dir, benchmark, 'fixed_targets') - # TODO(donggeliu): Make this consistent with agent output. - if not os.path.exists(targets_dir): - return [] - - targets = [] - - for name in sorted(FileSystem(targets_dir).listdir()): - path = os.path.join(targets_dir, name) - if name.startswith(sample + '.') and FileSystem(path).isfile(): - logging.debug(path) - with FileSystem(path).open() as f: - code = f.read() - targets.insert(0, Target(code=code)) - - if name.startswith(sample + '-F') and FileSystem(path).isdir(): - targets.append(self._get_fixed_target(path)) - - return targets - - def _get_targets_agent(self, benchmark: str, trial: str) -> Target: - """Gets the targets of benchmark |benchmark| with trial ID |trial| from - the OFG version 2 (LLM agents).""" - fuzz_target_dir = os.path.join(self._results_dir, benchmark, 'fuzz_targets') - files = sorted(FileSystem(fuzz_target_dir).listdir()) - - fuzz_target_code = '' - if f'{trial:02s}.fuzz_target' in files: - fuzz_target_path = os.path.join(fuzz_target_dir, - f'{trial:02s}.fuzz_target') - with FileSystem(fuzz_target_path).open() as f: - fuzz_target_code = f.read() - - build_script_code = '' - if f'{trial:02s}.build_script' in files: - build_script_path = os.path.join(fuzz_target_dir, - f'{trial:02s}.build_script') - with FileSystem(build_script_path).open() as f: - build_script_code = f.read() - - # TODO(dongge): Properly show build script code in reports. - return Target(code=fuzz_target_code, - fixer_prompt=None, - build_script_code=build_script_code) - - def get_samples(self, results: list[evaluator.Result], - targets: list[str]) -> list[Sample]: - """Gets the samples and their status of the given benchmark |bnmk|.""" - samples = [] - - for i, sample_id in enumerate(self._sample_ids(targets)): - status = 'Running' - result = evaluator.Result() - if results[i]: - status = 'Done' - result = results[i] - - samples.append(Sample(sample_id, status, result)) - - return samples - - def get_prompt(self, benchmark: str) -> Optional[str]: - """Gets the prompt for a given benchmark.""" - root_dir = os.path.join(self._results_dir, benchmark) - for name in FileSystem(root_dir).listdir(): - if re.match(r'^prompt.*txt$', name): - with FileSystem(os.path.join(root_dir, name)).open() as f: - content = f.read() - - # Prepare prompt text for HTML. - return self._prepare_prompt_for_html_text(content) - - return None - - def get_results(self, - benchmark: str) -> tuple[list[evaluator.Result], list[str]]: - """ - Returns results of all samples. Items can be None if they're not complete. - """ - targets = self._get_generated_targets( - benchmark) or self._get_agent_generated_targets(benchmark) - - results = [] - status_dir = os.path.join(self._results_dir, benchmark, 'status') - - for sample_id in self._sample_ids(targets): - results_path = os.path.join(status_dir, sample_id, 'result.json') - if not FileSystem(results_path).exists(): - results.append(None) - continue - - with FileSystem(results_path).open() as f: - try: - data = json.load(f) - except Exception: - return [], [] - - # TODO(dongge): Add new attributes to evaluator.Result. - valid_attributes = inspect.signature(evaluator.Result.__init__).parameters - filtered_data = { - key: value for key, value in data.items() if key in valid_attributes - } - results.append(evaluator.Result(**filtered_data)) - - return results, targets - - def get_macro_insights(self, - benchmarks: list[Benchmark]) -> AccumulatedResult: - """Returns macro insights from the aggregated benchmark results.""" - accumulated_results = AccumulatedResult() - for benchmark in benchmarks: - accumulated_results.compiles += int( - benchmark.result.build_success_rate > 0.0) - accumulated_results.crashes += int(benchmark.result.found_bug > 0) - accumulated_results.total_coverage += benchmark.result.max_coverage - accumulated_results.total_runs += 1 - accumulated_results.total_line_coverage_diff += ( - benchmark.result.max_line_coverage_diff) - return accumulated_results - - def get_coverage_language_gains(self): - """Gets report.json created by experiment runners.""" - summary_path = os.path.join(self._results_dir, 'report.json') - if FileSystem(summary_path).exists(): - with FileSystem(summary_path).open() as f: + """Results provides functions to explore the experiment results in a + particular directory.""" + + def __init__(self, results_dir="results", benchmark_set="all"): + self._results_dir = results_dir + self._benchmark_dir = os.path.join("benchmark-sets", benchmark_set) + + def list_benchmark_ids(self) -> List[str]: + return sorted( + filter( + self._is_valid_benchmark_dir, FileSystem(self._results_dir).listdir() + ) + ) + + def match_benchmark( + self, benchmark_id: str, results: list[evaluator.Result], targets: list[str] + ) -> Benchmark: + """Returns a benchmark class based on |benchmark_id|.""" + num_finished_trials = len([result for result in results if result.finished]) + status = ( + "Done" + if num_finished_trials == len(results) + else (f"Running ({num_finished_trials}/{len(results)})") + ) + + filtered_results = [(i, stat) for i, stat in enumerate(results) if stat] + + if filtered_results: + result = run_one_experiment.aggregate_results(filtered_results, targets) + else: + result = run_one_experiment.AggregatedResult() + + return self._create_benchmark(benchmark_id, status, result) + + def get_final_target_code(self, benchmark: str, sample: str) -> str: + """Gets the targets of benchmark |benchmark| with sample ID |sample|.""" + targets_dir = os.path.join(self._results_dir, benchmark, "fixed_targets") + # TODO(donggeliu): Make this consistent with agent output. + if not os.path.exists(targets_dir): + return "" + + for name in sorted(FileSystem(targets_dir).listdir()): + path = os.path.join(targets_dir, name) + if name.startswith(sample + ".") and FileSystem(path).isfile(): + with FileSystem(path).open() as f: + code = f.read() + code = json.dumps(code) + return code + return "" + + def get_logs(self, benchmark: str, sample: str) -> list[LogPart]: + status_dir = os.path.join(self._results_dir, benchmark, "status") + results_path = os.path.join(status_dir, sample, "log.txt") + if not FileSystem(results_path).exists(): + return [] + + with FileSystem(results_path).open() as f: + return _parse_log_parts(f.read()) + + def get_run_logs(self, benchmark: str, sample: str) -> str: + """Returns the content of the last run log.""" + run_logs_dir = os.path.join(self._results_dir, benchmark, "logs", "run") + largest_iteration, last_log_file = -1, None + for name in FileSystem(run_logs_dir).listdir(): + if name.startswith(sample + "."): + iteration = WorkDirs.get_run_log_iteration(name) + if iteration is None: + # Be compatible with older results where there is no '-Fxx' in run + # log file name. + last_log_file = name + break + + if largest_iteration < iteration: + largest_iteration, last_log_file = iteration, name + + if not last_log_file: + return "" + + log_path = os.path.join(run_logs_dir, last_log_file) + log_size = FileSystem(log_path).getsize() + with FileSystem(log_path).open(errors="replace") as f: + if log_size <= MAX_RUN_LOGS_LEN: + return f.read() + + truncated_len = MAX_RUN_LOGS_LEN // 2 + logs_beginning = f.read(truncated_len) + f.seek(log_size - truncated_len - 1, os.SEEK_SET) + logs_ending = f.read() + + return logs_beginning + "\n...truncated...\n" + logs_ending + + return "" + + def get_triage(self, benchmark: str, sample: str) -> Triage: + """Gets the triage of benchmark |benchmark| with sample ID |sample|.""" + result = "" + triager_prompt = "" + fixed_dir = os.path.join(self._results_dir, benchmark, "fixed_targets") + triage_dir = os.path.join(fixed_dir, f"{sample}-triage") + if not os.path.exists(triage_dir): + return Triage(result, triager_prompt) + + for name in os.listdir(triage_dir): + if name == "prompt.txt": + with FileSystem(os.path.join(triage_dir, name)).open() as f: + triager_prompt = f.read() + + # Prepare prompt for being used in HTML. + triager_prompt = self._prepare_prompt_for_html_text(triager_prompt) + + if name.endswith(".txt") and name != "prompt.txt": + triage_path = os.path.join(triage_dir, name) + with open(triage_path) as f: + result = f.read() + + return Triage(result, triager_prompt) + + def get_targets(self, benchmark: str, sample: str) -> list[Target]: + """Gets the targets of benchmark |benchmark| with sample ID |sample|.""" + return self._get_targets(benchmark, sample) or [ + self._get_targets_agent(benchmark, sample) + ] + + def _get_targets(self, benchmark: str, sample: str) -> list[Target]: + """Gets the targets of benchmark |benchmark| with sample ID |sample| from + the OFG version 1 (single prompt).""" + targets_dir = os.path.join(self._results_dir, benchmark, "fixed_targets") + # TODO(donggeliu): Make this consistent with agent output. + if not os.path.exists(targets_dir): + return [] + + targets = [] + + for name in sorted(FileSystem(targets_dir).listdir()): + path = os.path.join(targets_dir, name) + if name.startswith(sample + ".") and FileSystem(path).isfile(): + logging.debug(path) + with FileSystem(path).open() as f: + code = f.read() + targets.insert(0, Target(code=code)) + + if name.startswith(sample + "-F") and FileSystem(path).isdir(): + targets.append(self._get_fixed_target(path)) + + return targets + + def _get_targets_agent(self, benchmark: str, trial: str) -> Target: + """Gets the targets of benchmark |benchmark| with trial ID |trial| from + the OFG version 2 (LLM agents).""" + fuzz_target_dir = os.path.join(self._results_dir, benchmark, "fuzz_targets") + files = sorted(FileSystem(fuzz_target_dir).listdir()) + + fuzz_target_code = "" + if f"{trial:02s}.fuzz_target" in files: + fuzz_target_path = os.path.join(fuzz_target_dir, f"{trial:02s}.fuzz_target") + with FileSystem(fuzz_target_path).open() as f: + fuzz_target_code = f.read() + + build_script_code = "" + if f"{trial:02s}.build_script" in files: + build_script_path = os.path.join( + fuzz_target_dir, f"{trial:02s}.build_script" + ) + with FileSystem(build_script_path).open() as f: + build_script_code = f.read() + + # TODO(dongge): Properly show build script code in reports. + return Target( + code=fuzz_target_code, + fixer_prompt=None, + build_script_code=build_script_code, + ) + + def get_samples( + self, results: list[evaluator.Result], targets: list[str] + ) -> list[Sample]: + """Gets the samples and their status of the given benchmark |bnmk|.""" + samples = [] + + for i, sample_id in enumerate(self._sample_ids(targets)): + status = "Running" + result = evaluator.Result() + if results[i]: + status = "Done" + result = results[i] + + samples.append(Sample(sample_id, status, result)) + + return samples + + def get_prompt(self, benchmark: str) -> Optional[str]: + """Gets the prompt for a given benchmark.""" + root_dir = os.path.join(self._results_dir, benchmark) + for name in FileSystem(root_dir).listdir(): + if re.match(r"^prompt.*txt$", name): + with FileSystem(os.path.join(root_dir, name)).open() as f: + content = f.read() + + # Prepare prompt text for HTML. + return self._prepare_prompt_for_html_text(content) + + return None + + def get_results(self, benchmark: str) -> tuple[list[evaluator.Result], list[str]]: + """ + Returns results of all samples. Items can be None if they're not complete. + """ + targets = self._get_generated_targets( + benchmark + ) or self._get_agent_generated_targets(benchmark) + + results = [] + status_dir = os.path.join(self._results_dir, benchmark, "status") + + for sample_id in self._sample_ids(targets): + results_path = os.path.join(status_dir, sample_id, "result.json") + if not FileSystem(results_path).exists(): + results.append(None) + continue + + with FileSystem(results_path).open() as f: + try: + data = json.load(f) + except Exception: + return [], [] + + # TODO(dongge): Add new attributes to evaluator.Result. + valid_attributes = inspect.signature(evaluator.Result.__init__).parameters + filtered_data = { + key: value for key, value in data.items() if key in valid_attributes + } + results.append(evaluator.Result(**filtered_data)) + + return results, targets + + def get_macro_insights(self, benchmarks: list[Benchmark]) -> AccumulatedResult: + """Returns macro insights from the aggregated benchmark results.""" + accumulated_results = AccumulatedResult() + for benchmark in benchmarks: + accumulated_results.compiles += int( + benchmark.result.build_success_rate > 0.0 + ) + accumulated_results.crashes += int(benchmark.result.found_bug > 0) + accumulated_results.total_coverage += benchmark.result.max_coverage + accumulated_results.total_runs += 1 + accumulated_results.total_line_coverage_diff += ( + benchmark.result.max_line_coverage_diff + ) + return accumulated_results + + def get_coverage_language_gains(self): + """Gets report.json created by experiment runners.""" + summary_path = os.path.join(self._results_dir, "report.json") + if FileSystem(summary_path).exists(): + with FileSystem(summary_path).open() as f: + try: + return json.load(f) + except ValueError: + # Skip if error + logging.debug("Failed to decode project_coverage_gain.json") + return {} + + def get_project_summary(self, benchmarks: list[Benchmark]) -> list[Project]: + """Returns a list of project summary.""" + project_summary_dict = {} + for benchmark in benchmarks: + if benchmark.project not in project_summary_dict: + project_summary_dict[benchmark.project] = Project(benchmark.project) + project_summary_dict[benchmark.project].count += 1 + project_summary_dict[benchmark.project].success += ( + benchmark.result.build_success_count > 0 + ) + + # Retrieve coverage gain information + coverage_dict = {} + summary_path = os.path.join(self._results_dir, "report.json") + if FileSystem(summary_path).exists(): + with FileSystem(summary_path).open() as f: + try: + coverage_dict = json.load(f).get("project_summary", {}) + except ValueError: + # Skip if error + logging.debug("Failed to decode project_coverage_gain.json") + + # Update project summary with coverage gain information + project_summary_list = list(project_summary_dict.values()) + if coverage_dict: + for project in project_summary_list: + if project.name in coverage_dict: + project.coverage_gain = coverage_dict[project.name].get( + "coverage_diff", 0.0 + ) + project.coverage_relative_gain = coverage_dict[project.name].get( + "coverage_relative_gain", 0.0 + ) + project.coverage_ofg_total_new_covered_lines = coverage_dict[ + project.name + ].get("coverage_ofg_total_new_covered_lines", 0) + project.coverage_existing_total_covered_lines = coverage_dict[ + project.name + ].get("coverage_existing_total_covered_lines", 0) + project.coverage_existing_total_lines = coverage_dict[ + project.name + ].get("coverage_existing_total_lines", 0) + project.coverage_ofg_total_covered_lines = coverage_dict[ + project.name + ].get("coverage_ofg_total_covered_lines", 0) + + return project_summary_list + + def _prepare_prompt_for_html_text(self, raw_prompt_content: str) -> str: + """Converts a raw prompt file into presentable HTML text.""" try: - return json.load(f) - except ValueError: - # Skip if error - logging.debug('Failed to decode project_coverage_gain.json') - return {} - - def get_project_summary(self, benchmarks: list[Benchmark]) -> list[Project]: - """Returns a list of project summary.""" - project_summary_dict = {} - for benchmark in benchmarks: - if benchmark.project not in project_summary_dict: - project_summary_dict[benchmark.project] = Project(benchmark.project) - project_summary_dict[benchmark.project].count += 1 - project_summary_dict[benchmark.project].success += ( - benchmark.result.build_success_count > 0) - - # Retrieve coverage gain information - coverage_dict = {} - summary_path = os.path.join(self._results_dir, 'report.json') - if FileSystem(summary_path).exists(): - with FileSystem(summary_path).open() as f: + structured_prompt = json.loads(raw_prompt_content) + if isinstance(structured_prompt, list) and structured_prompt: + html_presentable_content = "" + for elem in structured_prompt: + if isinstance(elem, dict) and "content" in elem: + html_presentable_content += f'\n{elem["content"]}' + logging.debug("Converted structured prompt to raw text.") + return html_presentable_content + except json.decoder.JSONDecodeError: + logging.debug("Using raw prompt text.") + + # If execution goes here it the input was not a structured prompt but just + # raw text, which is then returned. + return raw_prompt_content + + def _is_valid_benchmark_dir(self, cur_dir: str) -> bool: + """Checks if |cur_dir| is a valid benchmark directory (e.g., no lost+found).""" + # Check prefix. + if not cur_dir.startswith("output-"): + return False + + # Skip checking sub-directories in GCS. It's a lot of filesystem operations + # to go over the network. + if cur_dir.startswith("gs://"): + return True + + # Check sub-directories. + # TODO(donggeliu): Make this consistent with agent output. + # We used to expect 'fixed_targets' and 'raw_targets' here, but the agent + # workflow doesn't populate them. As a result, these directories don't get + # uploaded to GCS. + expected_dirs = ["status"] + return all( + FileSystem(os.path.join(self._results_dir, cur_dir, expected_dir)).isdir() + for expected_dir in expected_dirs + ) + + # TODO(dongge): Deprecate this. + def _get_generated_targets(self, benchmark: str) -> list[str]: + """Gets the targets of benchmark |benchmark| from the OFG version 1 (single + prompt).""" + targets = [] + raw_targets_dir = os.path.join(self._results_dir, benchmark, "raw_targets") + # TODO(donggeliu): Make this consistent with agent output. + if not os.path.exists(raw_targets_dir): + return [] + + for filename in sorted(FileSystem(raw_targets_dir).listdir()): + if os.path.splitext(filename)[1] in TARGET_EXTS: + targets.append(os.path.join(raw_targets_dir, filename)) + + return targets + + def _get_agent_generated_targets(self, benchmark: str) -> list[str]: + """Gets the targets of benchmark |benchmark| from the OFG version 2 (LLM + agent).""" + targets = [] + fuzz_targets_dir = os.path.join(self._results_dir, benchmark, "fuzz_targets") + for filename in sorted(FileSystem(fuzz_targets_dir).listdir()): + if os.path.splitext(filename)[1] in TARGET_EXTS: + targets.append(os.path.join(fuzz_targets_dir, filename)) + + return targets + + def _get_fixed_target(self, path: str) -> Target: + """Gets the fixed fuzz target from the benchmark's result |path|.""" + code = "" + fixer_prompt = "" + for name in FileSystem(path).listdir(): + if name.endswith(".txt"): + with FileSystem(os.path.join(path, name)).open() as f: + fixer_prompt = f.read() + + # Prepare prompt for being used in HTML. + fixer_prompt = self._prepare_prompt_for_html_text(fixer_prompt) + + if name.endswith(".rawoutput"): + with FileSystem(os.path.join(path, name)).open() as f: + code = f.read() + + return Target(code, fixer_prompt) + + def _sample_ids(self, target_paths: list[str]): + for target in target_paths: + yield os.path.splitext(os.path.basename(target))[0] + + def _create_benchmark( + self, + benchmark_id: str, + status: str, + result: run_one_experiment.AggregatedResult, + ) -> Benchmark: + project = "-".join(benchmark_id.split("-")[1:-1]) + function = benchmark_id.split("-")[-1] + signature = self._find_benchmark_signature(project, function) or benchmark_id + language = self._find_benchmark_language(project) + return Benchmark( + benchmark_id, status, result, signature, project, function, language + ) + + def _find_benchmark_signature(self, project: str, target_function: str) -> str: + """Finds the function signature by searching for its |benchmark_id|.""" + project_path = os.path.join(self._benchmark_dir, f"{project}.yaml") + if not FileSystem(project_path).isfile(): + return "" + + matched_prefix_signature = "" + with FileSystem(project_path).open() as project_yaml_file: + functions = yaml.safe_load(project_yaml_file).get("functions", []) + for function in functions: + function_name = function.get("name", "") + function_signature = function.get("signature", "") + + # Best match is a full match, but sometimes the result directory only + # has the first n characters of a long function name so a full match is + # not possible. + # To avoid returning early on a prefix match when there is a full match + # farther down the list, we only return the prefix match at the end. + if function_name.lower() == target_function.lower(): + return function_signature + if function_name.lower().startswith(target_function.lower()): + if matched_prefix_signature: + logging.warning( + "Multiple substring matches found when looking for function " + "name %s", + function_name, + ) + matched_prefix_signature = function_signature + + return matched_prefix_signature + + def _find_benchmark_language(self, project: str) -> str: + """Finds the programming language of the benchmark.""" + if not self._benchmark_dir: + return "" + + project_path = os.path.join(self._benchmark_dir, f"{project}.yaml") + if not FileSystem(project_path).isfile(): + return "" + try: - coverage_dict = json.load(f).get('project_summary', {}) - except ValueError: - # Skip if error - logging.debug('Failed to decode project_coverage_gain.json') - - # Update project summary with coverage gain information - project_summary_list = list(project_summary_dict.values()) - if coverage_dict: - for project in project_summary_list: - if project.name in coverage_dict: - project.coverage_gain = coverage_dict[project.name].get( - 'coverage_diff', 0.0) - project.coverage_relative_gain = coverage_dict[project.name].get( - 'coverage_relative_gain', 0.0) - project.coverage_ofg_total_new_covered_lines = coverage_dict[ - project.name].get('coverage_ofg_total_new_covered_lines', 0) - project.coverage_existing_total_covered_lines = coverage_dict[ - project.name].get('coverage_existing_total_covered_lines', 0) - project.coverage_existing_total_lines = coverage_dict[ - project.name].get('coverage_existing_total_lines', 0) - project.coverage_ofg_total_covered_lines = coverage_dict[ - project.name].get('coverage_ofg_total_covered_lines', 0) - - return project_summary_list - - def _prepare_prompt_for_html_text(self, raw_prompt_content: str) -> str: - """Converts a raw prompt file into presentable HTML text.""" - try: - structured_prompt = json.loads(raw_prompt_content) - if isinstance(structured_prompt, list) and structured_prompt: - html_presentable_content = '' - for elem in structured_prompt: - if isinstance(elem, dict) and 'content' in elem: - html_presentable_content += f'\n{elem["content"]}' - logging.debug('Converted structured prompt to raw text.') - return html_presentable_content - except json.decoder.JSONDecodeError: - logging.debug('Using raw prompt text.') - - # If execution goes here it the input was not a structured prompt but just - # raw text, which is then returned. - return raw_prompt_content - - def _is_valid_benchmark_dir(self, cur_dir: str) -> bool: - """Checks if |cur_dir| is a valid benchmark directory (e.g., no lost+found). - """ - # Check prefix. - if not cur_dir.startswith('output-'): - return False - - # Skip checking sub-directories in GCS. It's a lot of filesystem operations - # to go over the network. - if cur_dir.startswith('gs://'): - return True - - # Check sub-directories. - # TODO(donggeliu): Make this consistent with agent output. - # We used to expect 'fixed_targets' and 'raw_targets' here, but the agent - # workflow doesn't populate them. As a result, these directories don't get - # uploaded to GCS. - expected_dirs = ['status'] - return all( - FileSystem(os.path.join(self._results_dir, cur_dir, - expected_dir)).isdir() - for expected_dir in expected_dirs) - - # TODO(dongge): Deprecate this. - def _get_generated_targets(self, benchmark: str) -> list[str]: - """Gets the targets of benchmark |benchmark| from the OFG version 1 (single - prompt).""" - targets = [] - raw_targets_dir = os.path.join(self._results_dir, benchmark, 'raw_targets') - # TODO(donggeliu): Make this consistent with agent output. - if not os.path.exists(raw_targets_dir): - return [] - - for filename in sorted(FileSystem(raw_targets_dir).listdir()): - if os.path.splitext(filename)[1] in TARGET_EXTS: - targets.append(os.path.join(raw_targets_dir, filename)) - - return targets - - def _get_agent_generated_targets(self, benchmark: str) -> list[str]: - """Gets the targets of benchmark |benchmark| from the OFG version 2 (LLM - agent).""" - targets = [] - fuzz_targets_dir = os.path.join(self._results_dir, benchmark, - 'fuzz_targets') - for filename in sorted(FileSystem(fuzz_targets_dir).listdir()): - if os.path.splitext(filename)[1] in TARGET_EXTS: - targets.append(os.path.join(fuzz_targets_dir, filename)) - - return targets - - def _get_fixed_target(self, path: str) -> Target: - """Gets the fixed fuzz target from the benchmark's result |path|.""" - code = '' - fixer_prompt = '' - for name in FileSystem(path).listdir(): - if name.endswith('.txt'): - with FileSystem(os.path.join(path, name)).open() as f: - fixer_prompt = f.read() - - # Prepare prompt for being used in HTML. - fixer_prompt = self._prepare_prompt_for_html_text(fixer_prompt) - - if name.endswith('.rawoutput'): - with FileSystem(os.path.join(path, name)).open() as f: - code = f.read() - - return Target(code, fixer_prompt) - - def _sample_ids(self, target_paths: list[str]): - for target in target_paths: - yield os.path.splitext(os.path.basename(target))[0] - - def _create_benchmark( - self, benchmark_id: str, status: str, - result: run_one_experiment.AggregatedResult) -> Benchmark: - project = '-'.join(benchmark_id.split('-')[1:-1]) - function = benchmark_id.split('-')[-1] - signature = self._find_benchmark_signature(project, - function) or benchmark_id - language = self._find_benchmark_language(project) - return Benchmark(benchmark_id, status, result, signature, project, function, - language) - - def _find_benchmark_signature(self, project: str, - target_function: str) -> str: - """Finds the function signature by searching for its |benchmark_id|.""" - project_path = os.path.join(self._benchmark_dir, f'{project}.yaml') - if not FileSystem(project_path).isfile(): - return '' - - matched_prefix_signature = '' - with FileSystem(project_path).open() as project_yaml_file: - functions = yaml.safe_load(project_yaml_file).get('functions', []) - for function in functions: - function_name = function.get('name', '') - function_signature = function.get('signature', '') - - # Best match is a full match, but sometimes the result directory only - # has the first n characters of a long function name so a full match is - # not possible. - # To avoid returning early on a prefix match when there is a full match - # farther down the list, we only return the prefix match at the end. - if function_name.lower() == target_function.lower(): - return function_signature - if function_name.lower().startswith(target_function.lower()): - if matched_prefix_signature: - logging.warning( - 'Multiple substring matches found when looking for function ' - 'name %s', function_name) - matched_prefix_signature = function_signature - - return matched_prefix_signature - - def _find_benchmark_language(self, project: str) -> str: - """Finds the programming language of the benchmark.""" - if not self._benchmark_dir: - return '' - - project_path = os.path.join(self._benchmark_dir, f'{project}.yaml') - if not FileSystem(project_path).isfile(): - return '' - - try: - with FileSystem(project_path).open() as f: - benchmark_data = yaml.safe_load(f) - return benchmark_data.get('language', '') - except Exception as e: - logging.error('Failed to read benchmark file %s: %s', project_path, e) - return '' + with FileSystem(project_path).open() as f: + benchmark_data = yaml.safe_load(f) + return benchmark_data.get("language", "") + except Exception as e: + logging.error("Failed to read benchmark file %s: %s", project_path, e) + return "" def _parse_log_parts(log: str) -> list[LogPart]: - """Parse log into parts.""" - parts = [] - idx = 0 - next_marker = _CHAT_PROMPT_START_MARKER - - while idx < len(log): - match = next_marker.search(log, idx) - if not match: - parts.append(LogPart(content=log[idx:])) - break - - if match.start() > idx: - # Log content in between chat logs. - parts.append(LogPart(content=log[idx:match.start()])) - - # Read up to the start of the corresponding end marker. - end_idx = len(log) - - chat_prompt = False - chat_response = False - if next_marker == _CHAT_PROMPT_START_MARKER: - end = _CHAT_PROMPT_END_MARKER.search(log, match.end()) - chat_prompt = True - next_marker = _CHAT_RESPONSE_START_MARKER - else: - assert next_marker == _CHAT_RESPONSE_START_MARKER - end = _CHAT_RESPONSE_END_MARKER.search(log, match.end()) - chat_response = True - next_marker = _CHAT_PROMPT_START_MARKER - - if end: - end_idx = end.start() - # Skip past the end tag. - idx = end.end() - else: - # No corresponding end tag, just read till the end of the log. - end_idx = len(log) - idx = end_idx - - parts.append( - LogPart(chat_prompt=chat_prompt, + """Parse log into parts.""" + parts = [] + idx = 0 + next_marker = _CHAT_PROMPT_START_MARKER + + while idx < len(log): + match = next_marker.search(log, idx) + if not match: + parts.append(LogPart(content=log[idx:])) + break + + if match.start() > idx: + # Log content in between chat logs. + parts.append(LogPart(content=log[idx : match.start()])) + + # Read up to the start of the corresponding end marker. + end_idx = len(log) + + chat_prompt = False + chat_response = False + if next_marker == _CHAT_PROMPT_START_MARKER: + end = _CHAT_PROMPT_END_MARKER.search(log, match.end()) + chat_prompt = True + next_marker = _CHAT_RESPONSE_START_MARKER + else: + assert next_marker == _CHAT_RESPONSE_START_MARKER + end = _CHAT_RESPONSE_END_MARKER.search(log, match.end()) + chat_response = True + next_marker = _CHAT_PROMPT_START_MARKER + + if end: + end_idx = end.start() + # Skip past the end tag. + idx = end.end() + else: + # No corresponding end tag, just read till the end of the log. + end_idx = len(log) + idx = end_idx + + parts.append( + LogPart( + chat_prompt=chat_prompt, chat_response=chat_response, - content=log[match.end():end_idx])) + content=log[match.end() : end_idx], + ) + ) - return parts + return parts diff --git a/report/compare_results.py b/report/compare_results.py index 52bc1e52c2..a50c08f47a 100644 --- a/report/compare_results.py +++ b/report/compare_results.py @@ -23,111 +23,160 @@ def extract_basename_from_filename(filename): - """ - Extract the basename from the filename. + """ + Extract the basename from the filename. - Args: - - filename (str): The name of the file. + Args: + - filename (str): The name of the file. - Returns: - - str: The extracted basename. - """ - return os.path.basename(os.path.splitext(filename)[0]) + Returns: + - str: The extracted basename. + """ + return os.path.basename(os.path.splitext(filename)[0]) def merge_tables(file1, file2): - """ - Merge, compare, and sort two CSV tables based on the benchmark name and build - rate differences. - - Args: - - file1 (str): Path to the first CSV file. - - file2 (str): Path to the second CSV file. - - Returns: - - DataFrame: The merged, compared, and sorted table. - """ - basename1 = extract_basename_from_filename(file1) - basename2 = extract_basename_from_filename(file2) - - df1 = pd.concat((chunk for chunk in pd.read_csv(file1, chunksize=5000))) - df2 = pd.concat((chunk for chunk in pd.read_csv(file2, chunksize=5000))) - - merged_df = df1.merge(df2, - on='Benchmark', - how='outer', - suffixes=(f'_{basename1}', f'_{basename2}')) - - # Fill NaN values with '-' and convert columns to string type - for col in merged_df.columns: - merged_df[col] = merged_df[col].fillna('-').astype(str) - - # Calculate build rate differences for sorting - merged_df['build_rate_diff'] = merged_df.apply( - lambda row: abs( - float(row[f'Build rate_{basename1}']) - float(row[ - f'Build rate_{basename2}'])) - if row[f'Build rate_{basename1}'] != '-' and row[f'Build rate_{basename2}' - ] != '-' else 0, - axis=1) - - # Sorting criteria - merged_df['sort_diff'] = merged_df['build_rate_diff'].apply(lambda x: 0 - if x == 0 else 1) - merged_df['sort_non_zero'] = merged_df.apply( - lambda row: 1 if (row[f'Build rate_{basename1}'] != '0' or row[ - f'Build rate_{basename2}'] != '0') and row[f'Build rate_{basename1}'] - != '-' and row[f'Build rate_{basename2}'] != '-' else 0, - axis=1) - merged_df['sort_zero'] = merged_df.apply( - lambda row: 1 if row[f'Build rate_{basename1}'] == '0' and row[ - f'Build rate_{basename2}'] == '0' else 0, - axis=1) - merged_df['sort_missing'] = merged_df.apply( - lambda row: 1 if row[f'Build rate_{basename1}'] == '-' or row[ - f'Build rate_{basename2}'] == '-' else 0, - axis=1) - merged_df['sort_basename'] = merged_df.apply( - lambda row: 1 if row[f'Build rate_{basename1}'] == '-' and row[ - f'Build rate_{basename2}'] != '-' else 0, - axis=1) - - merged_df.sort_values(by=[ - 'sort_diff', 'build_rate_diff', 'sort_non_zero', 'sort_zero', - 'sort_missing', 'sort_basename' - ], - ascending=[False, False, False, False, True, True], - inplace=True) - merged_df.drop(columns=[ - 'sort_diff', 'sort_non_zero', 'sort_zero', 'sort_missing', - 'sort_basename', 'build_rate_diff' - ], - inplace=True) - - columns_order = [ - 'Benchmark', 'Status_' + basename1, 'Status_' + basename2, - 'Build rate_' + basename1, 'Build rate_' + basename2, - 'Crash rate_' + basename1, 'Crash rate_' + basename2, - 'Coverage_' + basename1, 'Coverage_' + basename2, - 'Line coverage diff_' + basename1, 'Line coverage diff_' + basename2 - ] - merged_df = merged_df[columns_order] - - return merged_df - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description=( - 'Merge, compare, and sort two CSV tables based on the benchmark name and ' - 'build rate differences.')) - parser.add_argument('file1', type=str, help='Path to the first CSV file.') - parser.add_argument('file2', type=str, help='Path to the second CSV file.') - - args = parser.parse_args() - - output_df = merge_tables(args.file1, args.file2) - - input_basename2 = extract_basename_from_filename(args.file2) - output_filename = f'{input_basename2}_merged.csv' - output_df.to_csv(output_filename, index=False) - print(f'Output saved to {output_filename}') + """ + Merge, compare, and sort two CSV tables based on the benchmark name and build + rate differences. + + Args: + - file1 (str): Path to the first CSV file. + - file2 (str): Path to the second CSV file. + + Returns: + - DataFrame: The merged, compared, and sorted table. + """ + basename1 = extract_basename_from_filename(file1) + basename2 = extract_basename_from_filename(file2) + + df1 = pd.concat((chunk for chunk in pd.read_csv(file1, chunksize=5000))) + df2 = pd.concat((chunk for chunk in pd.read_csv(file2, chunksize=5000))) + + merged_df = df1.merge( + df2, on="Benchmark", how="outer", suffixes=(f"_{basename1}", f"_{basename2}") + ) + + # Fill NaN values with '-' and convert columns to string type + for col in merged_df.columns: + merged_df[col] = merged_df[col].fillna("-").astype(str) + + # Calculate build rate differences for sorting + merged_df["build_rate_diff"] = merged_df.apply( + lambda row: ( + abs( + float(row[f"Build rate_{basename1}"]) + - float(row[f"Build rate_{basename2}"]) + ) + if row[f"Build rate_{basename1}"] != "-" + and row[f"Build rate_{basename2}"] != "-" + else 0 + ), + axis=1, + ) + + # Sorting criteria + merged_df["sort_diff"] = merged_df["build_rate_diff"].apply( + lambda x: 0 if x == 0 else 1 + ) + merged_df["sort_non_zero"] = merged_df.apply( + lambda row: ( + 1 + if ( + row[f"Build rate_{basename1}"] != "0" + or row[f"Build rate_{basename2}"] != "0" + ) + and row[f"Build rate_{basename1}"] != "-" + and row[f"Build rate_{basename2}"] != "-" + else 0 + ), + axis=1, + ) + merged_df["sort_zero"] = merged_df.apply( + lambda row: ( + 1 + if row[f"Build rate_{basename1}"] == "0" + and row[f"Build rate_{basename2}"] == "0" + else 0 + ), + axis=1, + ) + merged_df["sort_missing"] = merged_df.apply( + lambda row: ( + 1 + if row[f"Build rate_{basename1}"] == "-" + or row[f"Build rate_{basename2}"] == "-" + else 0 + ), + axis=1, + ) + merged_df["sort_basename"] = merged_df.apply( + lambda row: ( + 1 + if row[f"Build rate_{basename1}"] == "-" + and row[f"Build rate_{basename2}"] != "-" + else 0 + ), + axis=1, + ) + + merged_df.sort_values( + by=[ + "sort_diff", + "build_rate_diff", + "sort_non_zero", + "sort_zero", + "sort_missing", + "sort_basename", + ], + ascending=[False, False, False, False, True, True], + inplace=True, + ) + merged_df.drop( + columns=[ + "sort_diff", + "sort_non_zero", + "sort_zero", + "sort_missing", + "sort_basename", + "build_rate_diff", + ], + inplace=True, + ) + + columns_order = [ + "Benchmark", + "Status_" + basename1, + "Status_" + basename2, + "Build rate_" + basename1, + "Build rate_" + basename2, + "Crash rate_" + basename1, + "Crash rate_" + basename2, + "Coverage_" + basename1, + "Coverage_" + basename2, + "Line coverage diff_" + basename1, + "Line coverage diff_" + basename2, + ] + merged_df = merged_df[columns_order] + + return merged_df + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=( + "Merge, compare, and sort two CSV tables based on the benchmark name and " + "build rate differences." + ) + ) + parser.add_argument("file1", type=str, help="Path to the first CSV file.") + parser.add_argument("file2", type=str, help="Path to the second CSV file.") + + args = parser.parse_args() + + output_df = merge_tables(args.file1, args.file2) + + input_basename2 = extract_basename_from_filename(args.file2) + output_filename = f"{input_basename2}_merged.csv" + output_df.to_csv(output_filename, index=False) + print(f"Output saved to {output_filename}") diff --git a/report/docker_run.py b/report/docker_run.py index 15b19bd8a6..dc3f37f2d4 100755 --- a/report/docker_run.py +++ b/report/docker_run.py @@ -26,438 +26,538 @@ # Configure logging to display all messages at or above INFO level logging.basicConfig(level=logging.INFO) -BENCHMARK_SET = 'comparison' -FREQUENCY_LABEL = 'daily' +BENCHMARK_SET = "comparison" +FREQUENCY_LABEL = "daily" RUN_TIMEOUT = 300 -SUB_DIR = 'default' -MODEL = 'vertex_ai_gemini-1-5' +SUB_DIR = "default" +MODEL = "vertex_ai_gemini-1-5" DELAY = 0 NUM_SAMPLES = 10 LLM_FIX_LIMIT = 5 MAX_ROUND = 100 -DATA_DIR = '/experiment/data-dir/' +DATA_DIR = "/experiment/data-dir/" def _parse_args(cmd) -> argparse.Namespace: - """Parses the command line arguments.""" - parser = argparse.ArgumentParser(description='Run experiments') - parser.add_argument( - '-b', - '--benchmark-set', - type=str, - default=BENCHMARK_SET, - help=f'Experiment benchmark set, default: {BENCHMARK_SET}.') - parser.add_argument( - '-l', - '--frequency-label', - type=str, - default=FREQUENCY_LABEL, - help=(f'Used as part of Cloud Build tags and GCS report directory, ' - f'default: {FREQUENCY_LABEL}.')) - parser.add_argument( - '-to', - '--run-timeout', - type=int, - default=RUN_TIMEOUT, - help=f'Fuzzing timeout in seconds, default: {RUN_TIMEOUT} seconds.') - parser.add_argument( - '-sd', - '--sub-dir', - type=str, - default=SUB_DIR, - help= - f'The subdirectory for the generated report in GCS, default: {SUB_DIR}.') - parser.add_argument('-m', - '--model', - type=str, - default=MODEL, - help=f'Large Language Model name, default: {MODEL}.') - parser.add_argument( - '-d', - '--delay', - type=int, - default=DELAY, - help=f'Delay each benchmark experiment by N seconds, default: {DELAY}.') - parser.add_argument( - '-i', - '--local-introspector', - type=str, - default="false", - help= - 'If set to "true" will use a local version of fuzz introspector\'s webapp' - ) - parser.add_argument( - '-ns', - '--num-samples', - type=int, - default=NUM_SAMPLES, - help=f'The number of samples to request from LLM, default: {NUM_SAMPLES}') - parser.add_argument( - '-nf', - '--llm-fix-limit', - type=int, - default=LLM_FIX_LIMIT, - help=f'The number of fixes to request from LLM, default: {LLM_FIX_LIMIT}') - parser.add_argument( - '-vt', - '--vary-temperature', - type=str, - default="true", - help= - 'Use different temperatures for each sample. Set to "false" to disable.') - parser.add_argument( - '-ag', - '--agent', - type=str, - default="false", - help='Enables agent enhancement. Set to "true" to enable.') - parser.add_argument('-mr', - '--max-round', - type=int, - default=MAX_ROUND, - help=f'Max trial round for agents, default: {MAX_ROUND}.') - parser.add_argument( - '-rd', - '--redirect-outs', - type=str, - default="false", - help= - 'Redirects experiments stdout/stderr to file. Set to "true" to enable.') - - args, additional_args = parser.parse_known_args(cmd) - - # Arguments after the first element ("--") separator. - args.additional_args = additional_args[1:] - - # Parse boolean arguments - args.local_introspector = args.local_introspector.lower() == "true" - args.vary_temperature = args.vary_temperature.lower() == "true" - args.agent = args.agent.lower() == "true" - args.redirect_outs = args.redirect_outs.lower() == "true" - - return args + """Parses the command line arguments.""" + parser = argparse.ArgumentParser(description="Run experiments") + parser.add_argument( + "-b", + "--benchmark-set", + type=str, + default=BENCHMARK_SET, + help=f"Experiment benchmark set, default: {BENCHMARK_SET}.", + ) + parser.add_argument( + "-l", + "--frequency-label", + type=str, + default=FREQUENCY_LABEL, + help=( + f"Used as part of Cloud Build tags and GCS report directory, " + f"default: {FREQUENCY_LABEL}." + ), + ) + parser.add_argument( + "-to", + "--run-timeout", + type=int, + default=RUN_TIMEOUT, + help=f"Fuzzing timeout in seconds, default: {RUN_TIMEOUT} seconds.", + ) + parser.add_argument( + "-sd", + "--sub-dir", + type=str, + default=SUB_DIR, + help=f"The subdirectory for the generated report in GCS, default: {SUB_DIR}.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default=MODEL, + help=f"Large Language Model name, default: {MODEL}.", + ) + parser.add_argument( + "-d", + "--delay", + type=int, + default=DELAY, + help=f"Delay each benchmark experiment by N seconds, default: {DELAY}.", + ) + parser.add_argument( + "-i", + "--local-introspector", + type=str, + default="false", + help='If set to "true" will use a local version of fuzz introspector\'s webapp', + ) + parser.add_argument( + "-ns", + "--num-samples", + type=int, + default=NUM_SAMPLES, + help=f"The number of samples to request from LLM, default: {NUM_SAMPLES}", + ) + parser.add_argument( + "-nf", + "--llm-fix-limit", + type=int, + default=LLM_FIX_LIMIT, + help=f"The number of fixes to request from LLM, default: {LLM_FIX_LIMIT}", + ) + parser.add_argument( + "-vt", + "--vary-temperature", + type=str, + default="true", + help='Use different temperatures for each sample. Set to "false" to disable.', + ) + parser.add_argument( + "-ag", + "--agent", + type=str, + default="false", + help='Enables agent enhancement. Set to "true" to enable.', + ) + parser.add_argument( + "-mr", + "--max-round", + type=int, + default=MAX_ROUND, + help=f"Max trial round for agents, default: {MAX_ROUND}.", + ) + parser.add_argument( + "-rd", + "--redirect-outs", + type=str, + default="false", + help='Redirects experiments stdout/stderr to file. Set to "true" to enable.', + ) + + args, additional_args = parser.parse_known_args(cmd) + + # Arguments after the first element ("--") separator. + args.additional_args = additional_args[1:] + + # Parse boolean arguments + args.local_introspector = args.local_introspector.lower() == "true" + args.vary_temperature = args.vary_temperature.lower() == "true" + args.agent = args.agent.lower() == "true" + args.redirect_outs = args.redirect_outs.lower() == "true" + + return args def _run_command(command: list[str], shell=False): - """Runs a command and return its exit code.""" - process = subprocess.run(command, shell=shell, check=False) - return process.returncode + """Runs a command and return its exit code.""" + process = subprocess.run(command, shell=shell, check=False) + return process.returncode def _authorize_gcloud(): - """Authorizes to gcloud""" - # When running the docker container locally we need to activate the service - # account from the env variable. - # When running on GCP this step is unnecessary. - google_creds = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS', '') - if google_creds: - logging.info("GOOGLE APPLICATION CREDENTIALS set: %s.", google_creds) - _run_command([ - 'gcloud', 'auth', 'activate-service-account', - 'LLM-EVAL@oss-fuzz.iam.gserviceaccount.com', '--key-file', google_creds - ]) - else: - # TODO: Set GOOGLE_APPLICATION_CREDENTIALS and ensure cloud build uses it. - logging.info("GOOGLE APPLICATION CREDENTIALS is not set.") + """Authorizes to gcloud""" + # When running the docker container locally we need to activate the service + # account from the env variable. + # When running on GCP this step is unnecessary. + google_creds = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", "") + if google_creds: + logging.info("GOOGLE APPLICATION CREDENTIALS set: %s.", google_creds) + _run_command( + [ + "gcloud", + "auth", + "activate-service-account", + "LLM-EVAL@oss-fuzz.iam.gserviceaccount.com", + "--key-file", + google_creds, + ] + ) + else: + # TODO: Set GOOGLE_APPLICATION_CREDENTIALS and ensure cloud build uses it. + logging.info("GOOGLE APPLICATION CREDENTIALS is not set.") def _log_common_args(args): - """Prints args useful for logging""" - logging.info("Benchmark set is %s.", args.benchmark_set) - logging.info("Frequency label is %s.", args.frequency_label) - logging.info("Run timeout is %s.", args.run_timeout) - logging.info( - "Sub-directory is %s. Please consider using sub-directory to classify" - " your experiment.", args.sub_dir) - logging.info("LLM is %s.", args.model) - logging.info("DELAY is %s.", args.delay) + """Prints args useful for logging""" + logging.info("Benchmark set is %s.", args.benchmark_set) + logging.info("Frequency label is %s.", args.frequency_label) + logging.info("Run timeout is %s.", args.run_timeout) + logging.info( + "Sub-directory is %s. Please consider using sub-directory to classify" + " your experiment.", + args.sub_dir, + ) + logging.info("LLM is %s.", args.model) + logging.info("DELAY is %s.", args.delay) def main(cmd=None): - """Main entrypoint""" - if os.path.isfile('/experiment/data-dir.zip'): - subprocess.check_call( - 'apt-get install -y zip && zip -s0 data-dir.zip --out newd.zip && unzip newd.zip && rm ./data-dir.z*', - shell=True, - cwd='/experiment') - if os.path.isdir(DATA_DIR): - run_on_data_from_scratch(cmd) - else: - run_standard(cmd) + """Main entrypoint""" + if os.path.isfile("/experiment/data-dir.zip"): + subprocess.check_call( + "apt-get install -y zip && zip -s0 data-dir.zip --out newd.zip && unzip newd.zip && rm ./data-dir.z*", + shell=True, + cwd="/experiment", + ) + if os.path.isdir(DATA_DIR): + run_on_data_from_scratch(cmd) + else: + run_standard(cmd) def run_on_data_from_scratch(cmd=None): - """Creates experiment for projects that are not in OSS-Fuzz upstream""" - args = _parse_args(cmd) - - # Uses python3 by default and /venv/bin/python3 for Docker containers. - python_path = "/venv/bin/python3" if os.path.exists( - "/venv/bin/python3") else "python3" - os.environ["PYTHON"] = python_path - - _authorize_gcloud() - _log_common_args(args) - - # Launch starter, which set ups a Fuzz Introspector instance, which - # will be used for creating benchmarks and extract context. - logging.info('Running starter script') - subprocess.check_call('/experiment/report/custom_oss_fuzz_fi_starter.sh', - shell=True) - - date = datetime.datetime.now().strftime('%Y-%m-%d') - - # Experiment name is used to label the Cloud Builds and as part of the - # GCS directory that build logs are stored in. - # - # Example directory: 2023-12-02-daily-comparison - experiment_name = f"{date}-{args.frequency_label}-{args.benchmark_set}" - - # Report directory uses the same name as experiment. - # See upload_report.sh on how this is used. - gcs_report_dir = f"{args.sub_dir}/{experiment_name}" - - # Trends report use a similarly named path. - gcs_trend_report_path = f"{args.sub_dir}/{experiment_name}.json" - - local_results_dir = 'results' - - # Generate a report and upload it to GCS - report_process = subprocess.Popen([ - "bash", "report/upload_report.sh", local_results_dir, gcs_report_dir, - args.benchmark_set, args.model - ]) - - # Launch run_all_experiments.py - # some notes: - # - we will generate benchmarks using the local FI running - # - we will use the oss-fuzz project of our workdir, which is - # the only one that has the projets. - environ = os.environ.copy() - - # We need to make sure that we use our version of OSS-Fuzz - environ['OSS_FUZZ_DATA_DIR'] = os.path.join(DATA_DIR, 'oss-fuzz2') - - # Get project names to analyse - project_in_oss_fuzz = [] - for project_name in os.listdir( - os.path.join(DATA_DIR, 'oss-fuzz2', 'build', 'out')): - project_path = os.path.join(DATA_DIR, 'oss-fuzz2', 'build', 'out', - project_name) - if not os.path.isdir(project_path): - continue - project_in_oss_fuzz.append(project_name) - project_names = ','.join(project_in_oss_fuzz) - - introspector_endpoint = "http://127.0.0.1:8080/api" - - cmd = [python_path, 'run_all_experiments.py'] - cmd.append('-g') - cmd.append( - 'far-reach-low-coverage,low-cov-with-fuzz-keyword,easy-params-far-reach') - cmd.append('-gp') - cmd.append(project_names) - cmd.append('-gm') - cmd.append(str(8)) - cmd.append('-e') - cmd.append(introspector_endpoint) - cmd.append('-mr') - cmd.append(str(args.max_round)) - - vary_temperature = [0.5, 0.6, 0.7, 0.8, 0.9] if args.vary_temperature else [] - cmd += [ - "--run-timeout", - str(args.run_timeout), "--cloud-experiment-name", experiment_name, - "--cloud-experiment-bucket", "oss-fuzz-gcb-experiment-run-logs", - "--template-directory", "prompts/template_xml", "--work-dir", - local_results_dir, "--num-samples", - str(args.num_samples), "--delay", - str(args.delay), "--context", "--temperature-list", - *[str(temp) for temp in vary_temperature], "--model", args.model - ] - if args.agent: - cmd.append("--agent") - - # Run the experiment and redirect to file if indicated. - if args.redirect_outs: - with open(f"{local_results_dir}/logs-from-run.txt", "w") as outfile: - process = subprocess.run(cmd, - stdout=outfile, - stderr=outfile, - env=environ, - check=False) - ret_val = process.returncode - else: - process = subprocess.run(cmd, env=environ, check=False) - ret_val = process.returncode - - os.environ["ret_val"] = str(ret_val) - - with open("/experiment_ended", "w"): - pass - - logging.info("Shutting down introspector") - try: - subprocess.run(["curl", "--silent", "http://localhost:8080/api/shutdown"], - check=False, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL) - except Exception: - pass - - # Wait for the report process to finish uploading. - report_process.wait() - - trends_cmd = [ - python_path, "-m", "report.trends_report.upload_summary", "--results-dir", - local_results_dir, "--output-path", - f"gs://oss-fuzz-gcb-experiment-run-logs/trend-reports/" - f"{gcs_trend_report_path}", "--name", experiment_name, "--date", date, - "--url", f"https://llm-exp.oss-fuzz.com/Result-reports/{gcs_report_dir}", - "--run-timeout", - str(args.run_timeout), "--num-samples", - str(args.num_samples), "--llm-fix-limit", - str(args.llm_fix_limit), "--model", args.model, "--tags", - args.frequency_label, args.sub_dir, "--commit-hash", - subprocess.check_output(["git", "rev-parse", - "HEAD"]).decode().strip(), "--commit-date", - subprocess.check_output(["git", "show", "--no-patch", "--format=%cs" - ]).decode().strip(), "--git-branch", - subprocess.check_output(["git", "branch", "--show"]).decode().strip() - ] - - subprocess.run(trends_cmd, check=False) - - # Exit with the return value of `./run_all_experiments`. - return ret_val - + """Creates experiment for projects that are not in OSS-Fuzz upstream""" + args = _parse_args(cmd) -def run_standard(cmd=None): - """The main function.""" - args = _parse_args(cmd) + # Uses python3 by default and /venv/bin/python3 for Docker containers. + python_path = ( + "/venv/bin/python3" if os.path.exists("/venv/bin/python3") else "python3" + ) + os.environ["PYTHON"] = python_path - # Uses python3 by default and /venv/bin/python3 for Docker containers. - python_path = "/venv/bin/python3" if os.path.exists( - "/venv/bin/python3") else "python3" - os.environ["PYTHON"] = python_path + _authorize_gcloud() + _log_common_args(args) - _authorize_gcloud() - _log_common_args(args) + # Launch starter, which set ups a Fuzz Introspector instance, which + # will be used for creating benchmarks and extract context. + logging.info("Running starter script") + subprocess.check_call( + "/experiment/report/custom_oss_fuzz_fi_starter.sh", shell=True + ) + + date = datetime.datetime.now().strftime("%Y-%m-%d") + + # Experiment name is used to label the Cloud Builds and as part of the + # GCS directory that build logs are stored in. + # + # Example directory: 2023-12-02-daily-comparison + experiment_name = f"{date}-{args.frequency_label}-{args.benchmark_set}" + + # Report directory uses the same name as experiment. + # See upload_report.sh on how this is used. + gcs_report_dir = f"{args.sub_dir}/{experiment_name}" + + # Trends report use a similarly named path. + gcs_trend_report_path = f"{args.sub_dir}/{experiment_name}.json" + + local_results_dir = "results" + + # Generate a report and upload it to GCS + report_process = subprocess.Popen( + [ + "bash", + "report/upload_report.sh", + local_results_dir, + gcs_report_dir, + args.benchmark_set, + args.model, + ] + ) + + # Launch run_all_experiments.py + # some notes: + # - we will generate benchmarks using the local FI running + # - we will use the oss-fuzz project of our workdir, which is + # the only one that has the projets. + environ = os.environ.copy() + + # We need to make sure that we use our version of OSS-Fuzz + environ["OSS_FUZZ_DATA_DIR"] = os.path.join(DATA_DIR, "oss-fuzz2") + + # Get project names to analyse + project_in_oss_fuzz = [] + for project_name in os.listdir(os.path.join(DATA_DIR, "oss-fuzz2", "build", "out")): + project_path = os.path.join(DATA_DIR, "oss-fuzz2", "build", "out", project_name) + if not os.path.isdir(project_path): + continue + project_in_oss_fuzz.append(project_name) + project_names = ",".join(project_in_oss_fuzz) - if args.local_introspector: - os.environ["BENCHMARK_SET"] = args.benchmark_set introspector_endpoint = "http://127.0.0.1:8080/api" - logging.info("LOCAL_INTROSPECTOR is enabled: %s", introspector_endpoint) - _run_command(['bash', 'report/launch_local_introspector.sh'], shell=True) - else: - introspector_endpoint = "https://introspector.oss-fuzz.com/api" - logging.info("LOCAL_INTROSPECTOR was not specified. Defaulting to %s.", - introspector_endpoint) - - logging.info("NUM_SAMPLES is %s.", args.num_samples) - - if args.llm_fix_limit: - os.environ["LLM_FIX_LIMIT"] = str(args.llm_fix_limit) - logging.info("LLM_FIX_LIMIT is set to %s.", args.llm_fix_limit) - - vary_temperature = [0.5, 0.6, 0.7, 0.8, 0.9] if args.vary_temperature else [] - - date = datetime.datetime.now().strftime('%Y-%m-%d') - local_results_dir = 'results' - - # Experiment name is used to label the Cloud Builds and as part of the - # GCS directory that build logs are stored in. - # - # Example directory: 2023-12-02-daily-comparison - experiment_name = f"{date}-{args.frequency_label}-{args.benchmark_set}" - - # Report directory uses the same name as experiment. - # See upload_report.sh on how this is used. - gcs_report_dir = f"{args.sub_dir}/{experiment_name}" - - # Trends report use a similarly named path. - gcs_trend_report_path = f"{args.sub_dir}/{experiment_name}.json" - - # Generate a report and upload it to GCS - report_process = subprocess.Popen([ - "bash", "report/upload_report.sh", local_results_dir, gcs_report_dir, - args.benchmark_set, args.model - ]) - - # Prepare the command to run experiments - run_cmd = [ - python_path, "run_all_experiments.py", "--benchmarks-directory", - f"benchmark-sets/{args.benchmark_set}", "--run-timeout", - str(args.run_timeout), "--cloud-experiment-name", experiment_name, - "--cloud-experiment-bucket", "oss-fuzz-gcb-experiment-run-logs", - "--template-directory", "prompts/template_xml", "--work-dir", - local_results_dir, "--num-samples", - str(args.num_samples), "--delay", - str(args.delay), "--context", "--introspector-endpoint", - introspector_endpoint, "--temperature-list", - *[str(temp) for temp in vary_temperature], "--model", args.model, - "--max-round", - str(args.max_round) - ] - - if args.agent: - run_cmd.append("--agent") - - if args.additional_args: - run_cmd.extend(args.additional_args) - - # Run the experiment and redirect to file if indicated. - if args.redirect_outs: - with open(f"{local_results_dir}/logs-from-run.txt", "w") as outfile: - process = subprocess.run(run_cmd, - stdout=outfile, - stderr=outfile, - check=False) - ret_val = process.returncode - else: - process = subprocess.run(run_cmd, check=False) - ret_val = process.returncode - - os.environ["ret_val"] = str(ret_val) - - with open("/experiment_ended", "w"): - pass - - if args.local_introspector: + + cmd = [python_path, "run_all_experiments.py"] + cmd.append("-g") + cmd.append("far-reach-low-coverage,low-cov-with-fuzz-keyword,easy-params-far-reach") + cmd.append("-gp") + cmd.append(project_names) + cmd.append("-gm") + cmd.append(str(8)) + cmd.append("-e") + cmd.append(introspector_endpoint) + cmd.append("-mr") + cmd.append(str(args.max_round)) + + vary_temperature = [0.5, 0.6, 0.7, 0.8, 0.9] if args.vary_temperature else [] + cmd += [ + "--run-timeout", + str(args.run_timeout), + "--cloud-experiment-name", + experiment_name, + "--cloud-experiment-bucket", + "oss-fuzz-gcb-experiment-run-logs", + "--template-directory", + "prompts/template_xml", + "--work-dir", + local_results_dir, + "--num-samples", + str(args.num_samples), + "--delay", + str(args.delay), + "--context", + "--temperature-list", + *[str(temp) for temp in vary_temperature], + "--model", + args.model, + ] + if args.agent: + cmd.append("--agent") + + # Run the experiment and redirect to file if indicated. + if args.redirect_outs: + with open(f"{local_results_dir}/logs-from-run.txt", "w") as outfile: + process = subprocess.run( + cmd, stdout=outfile, stderr=outfile, env=environ, check=False + ) + ret_val = process.returncode + else: + process = subprocess.run(cmd, env=environ, check=False) + ret_val = process.returncode + + os.environ["ret_val"] = str(ret_val) + + with open("/experiment_ended", "w"): + pass + logging.info("Shutting down introspector") try: - subprocess.run(["curl", "--silent", "http://localhost:8080/api/shutdown"], - check=False, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL) + subprocess.run( + ["curl", "--silent", "http://localhost:8080/api/shutdown"], + check=False, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) except Exception: - pass - - # Wait for the report process to finish uploading. - report_process.wait() - - trends_cmd = [ - python_path, "-m", "report.trends_report.upload_summary", "--results-dir", - local_results_dir, "--output-path", - f"gs://oss-fuzz-gcb-experiment-run-logs/trend-reports/" - f"{gcs_trend_report_path}", "--name", experiment_name, "--date", date, - "--url", f"https://llm-exp.oss-fuzz.com/Result-reports/{gcs_report_dir}", - "--benchmark-set", args.benchmark_set, "--run-timeout", - str(args.run_timeout), "--num-samples", - str(args.num_samples), "--llm-fix-limit", - str(args.llm_fix_limit), "--model", args.model, "--tags", - args.frequency_label, args.sub_dir, "--commit-hash", - subprocess.check_output(["git", "rev-parse", - "HEAD"]).decode().strip(), "--commit-date", - subprocess.check_output(["git", "show", "--no-patch", "--format=%cs" - ]).decode().strip(), "--git-branch", - subprocess.check_output(["git", "branch", "--show"]).decode().strip() - ] - - subprocess.run(trends_cmd, check=False) - - # Exit with the return value of `./run_all_experiments`. - return ret_val + pass + + # Wait for the report process to finish uploading. + report_process.wait() + + trends_cmd = [ + python_path, + "-m", + "report.trends_report.upload_summary", + "--results-dir", + local_results_dir, + "--output-path", + f"gs://oss-fuzz-gcb-experiment-run-logs/trend-reports/" + f"{gcs_trend_report_path}", + "--name", + experiment_name, + "--date", + date, + "--url", + f"https://llm-exp.oss-fuzz.com/Result-reports/{gcs_report_dir}", + "--run-timeout", + str(args.run_timeout), + "--num-samples", + str(args.num_samples), + "--llm-fix-limit", + str(args.llm_fix_limit), + "--model", + args.model, + "--tags", + args.frequency_label, + args.sub_dir, + "--commit-hash", + subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip(), + "--commit-date", + subprocess.check_output(["git", "show", "--no-patch", "--format=%cs"]) + .decode() + .strip(), + "--git-branch", + subprocess.check_output(["git", "branch", "--show"]).decode().strip(), + ] + + subprocess.run(trends_cmd, check=False) + + # Exit with the return value of `./run_all_experiments`. + return ret_val + + +def run_standard(cmd=None): + """The main function.""" + args = _parse_args(cmd) + + # Uses python3 by default and /venv/bin/python3 for Docker containers. + python_path = ( + "/venv/bin/python3" if os.path.exists("/venv/bin/python3") else "python3" + ) + os.environ["PYTHON"] = python_path + + _authorize_gcloud() + _log_common_args(args) + + if args.local_introspector: + os.environ["BENCHMARK_SET"] = args.benchmark_set + introspector_endpoint = "http://127.0.0.1:8080/api" + logging.info("LOCAL_INTROSPECTOR is enabled: %s", introspector_endpoint) + _run_command(["bash", "report/launch_local_introspector.sh"], shell=True) + else: + introspector_endpoint = "https://introspector.oss-fuzz.com/api" + logging.info( + "LOCAL_INTROSPECTOR was not specified. Defaulting to %s.", + introspector_endpoint, + ) + + logging.info("NUM_SAMPLES is %s.", args.num_samples) + + if args.llm_fix_limit: + os.environ["LLM_FIX_LIMIT"] = str(args.llm_fix_limit) + logging.info("LLM_FIX_LIMIT is set to %s.", args.llm_fix_limit) + + vary_temperature = [0.5, 0.6, 0.7, 0.8, 0.9] if args.vary_temperature else [] + + date = datetime.datetime.now().strftime("%Y-%m-%d") + local_results_dir = "results" + + # Experiment name is used to label the Cloud Builds and as part of the + # GCS directory that build logs are stored in. + # + # Example directory: 2023-12-02-daily-comparison + experiment_name = f"{date}-{args.frequency_label}-{args.benchmark_set}" + + # Report directory uses the same name as experiment. + # See upload_report.sh on how this is used. + gcs_report_dir = f"{args.sub_dir}/{experiment_name}" + + # Trends report use a similarly named path. + gcs_trend_report_path = f"{args.sub_dir}/{experiment_name}.json" + + # Generate a report and upload it to GCS + report_process = subprocess.Popen( + [ + "bash", + "report/upload_report.sh", + local_results_dir, + gcs_report_dir, + args.benchmark_set, + args.model, + ] + ) + + # Prepare the command to run experiments + run_cmd = [ + python_path, + "run_all_experiments.py", + "--benchmarks-directory", + f"benchmark-sets/{args.benchmark_set}", + "--run-timeout", + str(args.run_timeout), + "--cloud-experiment-name", + experiment_name, + "--cloud-experiment-bucket", + "oss-fuzz-gcb-experiment-run-logs", + "--template-directory", + "prompts/template_xml", + "--work-dir", + local_results_dir, + "--num-samples", + str(args.num_samples), + "--delay", + str(args.delay), + "--context", + "--introspector-endpoint", + introspector_endpoint, + "--temperature-list", + *[str(temp) for temp in vary_temperature], + "--model", + args.model, + "--max-round", + str(args.max_round), + ] + + if args.agent: + run_cmd.append("--agent") + + if args.additional_args: + run_cmd.extend(args.additional_args) + + # Run the experiment and redirect to file if indicated. + if args.redirect_outs: + with open(f"{local_results_dir}/logs-from-run.txt", "w") as outfile: + process = subprocess.run( + run_cmd, stdout=outfile, stderr=outfile, check=False + ) + ret_val = process.returncode + else: + process = subprocess.run(run_cmd, check=False) + ret_val = process.returncode + + os.environ["ret_val"] = str(ret_val) + + with open("/experiment_ended", "w"): + pass + + if args.local_introspector: + logging.info("Shutting down introspector") + try: + subprocess.run( + ["curl", "--silent", "http://localhost:8080/api/shutdown"], + check=False, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + except Exception: + pass + + # Wait for the report process to finish uploading. + report_process.wait() + + trends_cmd = [ + python_path, + "-m", + "report.trends_report.upload_summary", + "--results-dir", + local_results_dir, + "--output-path", + f"gs://oss-fuzz-gcb-experiment-run-logs/trend-reports/" + f"{gcs_trend_report_path}", + "--name", + experiment_name, + "--date", + date, + "--url", + f"https://llm-exp.oss-fuzz.com/Result-reports/{gcs_report_dir}", + "--benchmark-set", + args.benchmark_set, + "--run-timeout", + str(args.run_timeout), + "--num-samples", + str(args.num_samples), + "--llm-fix-limit", + str(args.llm_fix_limit), + "--model", + args.model, + "--tags", + args.frequency_label, + args.sub_dir, + "--commit-hash", + subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip(), + "--commit-date", + subprocess.check_output(["git", "show", "--no-patch", "--format=%cs"]) + .decode() + .strip(), + "--git-branch", + subprocess.check_output(["git", "branch", "--show"]).decode().strip(), + ] + + subprocess.run(trends_cmd, check=False) + + # Exit with the return value of `./run_all_experiments`. + return ret_val if __name__ == "__main__": - sys.exit(main()) + sys.exit(main()) diff --git a/report/trends_report/update_index.py b/report/trends_report/update_index.py index d98cf7e772..95fd7bb643 100644 --- a/report/trends_report/update_index.py +++ b/report/trends_report/update_index.py @@ -20,38 +20,39 @@ def trends_report_index(event, context): - """Read all the trends reports in GCS and write an index at the root.""" - # Don't trigger on changes to index.json or other top level files - if len(event['attributes']['objectId'].split('/')) < 3: - return '' - - index = {} - bucket = storage.Client().bucket('oss-fuzz-gcb-experiment-run-logs') - for b in bucket.list_blobs(prefix='trend-reports/'): - # Skip reading index.json or other top level files - if len(b.name.split('/')) < 3: - continue - - print(f'Reading {b.name}') - try: - # e.g. trend-reports/scheduled/2024-11-02-weekly-all.json -> scheduled - directory = b.name.split('/')[1] - report = json.loads(b.download_as_text()) - index[report['name']] = { - 'directory': directory, - 'name': report['name'], - 'url': report['url'], - 'date': report['date'], - 'benchmark_set': report['benchmark_set'], - 'llm_model': report['llm_model'], - 'tags': report['tags'], - } - except: - print('****************************', file=sys.stderr) - print(f'Issue when reading {b.name}', file=sys.stderr) - print('****************************', file=sys.stderr) - - bucket.blob('trend-reports/index.json').upload_from_string( - json.dumps(index), content_type='application/json') - - return '' + """Read all the trends reports in GCS and write an index at the root.""" + # Don't trigger on changes to index.json or other top level files + if len(event["attributes"]["objectId"].split("/")) < 3: + return "" + + index = {} + bucket = storage.Client().bucket("oss-fuzz-gcb-experiment-run-logs") + for b in bucket.list_blobs(prefix="trend-reports/"): + # Skip reading index.json or other top level files + if len(b.name.split("/")) < 3: + continue + + print(f"Reading {b.name}") + try: + # e.g. trend-reports/scheduled/2024-11-02-weekly-all.json -> scheduled + directory = b.name.split("/")[1] + report = json.loads(b.download_as_text()) + index[report["name"]] = { + "directory": directory, + "name": report["name"], + "url": report["url"], + "date": report["date"], + "benchmark_set": report["benchmark_set"], + "llm_model": report["llm_model"], + "tags": report["tags"], + } + except: + print("****************************", file=sys.stderr) + print(f"Issue when reading {b.name}", file=sys.stderr) + print("****************************", file=sys.stderr) + + bucket.blob("trend-reports/index.json").upload_from_string( + json.dumps(index), content_type="application/json" + ) + + return "" diff --git a/report/trends_report/update_web.py b/report/trends_report/update_web.py index c751ce6ed0..7438c3d9ba 100644 --- a/report/trends_report/update_web.py +++ b/report/trends_report/update_web.py @@ -21,30 +21,30 @@ from google.cloud import storage -REPO_ZIP_LINK = 'https://github.com/google/oss-fuzz-gen/archive/refs/heads/main.zip' -ZIP_DIR = 'oss-fuzz-gen-trends-report' +REPO_ZIP_LINK = "https://github.com/google/oss-fuzz-gen/archive/refs/heads/main.zip" +ZIP_DIR = "oss-fuzz-gen-trends-report" def trends_report_web(event, context): - """Update trends report web page files from GitHub.""" - bucket = storage.Client().bucket('oss-fuzz-gcb-experiment-run-logs') + """Update trends report web page files from GitHub.""" + bucket = storage.Client().bucket("oss-fuzz-gcb-experiment-run-logs") - with urllib.request.urlopen(REPO_ZIP_LINK) as response: - zip_contents = response.read() + with urllib.request.urlopen(REPO_ZIP_LINK) as response: + zip_contents = response.read() - with tempfile.TemporaryDirectory() as temp: - with zipfile.ZipFile(io.BytesIO(zip_contents)) as zip_file: - zip_file.extractall(temp) - for path in zip_file.namelist(): - parts = path.split('/report/trends_report_web/') + with tempfile.TemporaryDirectory() as temp: + with zipfile.ZipFile(io.BytesIO(zip_contents)) as zip_file: + zip_file.extractall(temp) + for path in zip_file.namelist(): + parts = path.split("/report/trends_report_web/") - # Upload files under report/trends_report_web/ - if len(parts) > 1 and parts[1] != '': - fname = parts[1] - print(f'uploading {path} to trend-reports/{fname}') - blob = bucket.blob(f'trend-reports/{fname}') - blob.upload_from_filename(os.path.join(temp, path)) + # Upload files under report/trends_report_web/ + if len(parts) > 1 and parts[1] != "": + fname = parts[1] + print(f"uploading {path} to trend-reports/{fname}") + blob = bucket.blob(f"trend-reports/{fname}") + blob.upload_from_filename(os.path.join(temp, path)) if __name__ == "__main__": - trends_report_web(None, None) + trends_report_web(None, None) diff --git a/report/trends_report/upload_summary.py b/report/trends_report/upload_summary.py index 9dd06d96e8..ce22dd1d0a 100644 --- a/report/trends_report/upload_summary.py +++ b/report/trends_report/upload_summary.py @@ -26,130 +26,143 @@ @dataclasses.dataclass class Summary: - """Summary of the experiment for trends report.""" - benchmarks: List[Dict[str, Any]] - accumulated_results: Dict[str, Any] - projects: List[Dict[str, Any]] + """Summary of the experiment for trends report.""" + + benchmarks: List[Dict[str, Any]] + accumulated_results: Dict[str, Any] + projects: List[Dict[str, Any]] def generate_summary(results_util: Results) -> Summary: - """Returns a summary object from the experiment results.""" - benchmarks = [] - benchmark_summaries = [] - for benchmark_id in results_util.list_benchmark_ids(): - results, targets = results_util.get_results(benchmark_id) - benchmark = results_util.match_benchmark(benchmark_id, results, targets) - benchmarks.append(benchmark) - benchmark_summaries.append({ - 'id': benchmark.id, - 'project': benchmark.project, - 'function': benchmark.function, - 'signature': benchmark.signature, - 'build_success_rate': benchmark.result.build_success_rate, - 'crash_rate': benchmark.result.crash_rate, - 'found_bug': benchmark.result.found_bug, - 'max_coverage': benchmark.result.max_coverage, - 'max_line_coverage_diff': benchmark.result.max_line_coverage_diff, - }) - - accumulated_results = dataclasses.asdict( - results_util.get_macro_insights(benchmarks)) - projects = list( - map(dataclasses.asdict, results_util.get_project_summary(benchmarks))) - return Summary(benchmark_summaries, accumulated_results, projects) + """Returns a summary object from the experiment results.""" + benchmarks = [] + benchmark_summaries = [] + for benchmark_id in results_util.list_benchmark_ids(): + results, targets = results_util.get_results(benchmark_id) + benchmark = results_util.match_benchmark(benchmark_id, results, targets) + benchmarks.append(benchmark) + benchmark_summaries.append( + { + "id": benchmark.id, + "project": benchmark.project, + "function": benchmark.function, + "signature": benchmark.signature, + "build_success_rate": benchmark.result.build_success_rate, + "crash_rate": benchmark.result.crash_rate, + "found_bug": benchmark.result.found_bug, + "max_coverage": benchmark.result.max_coverage, + "max_line_coverage_diff": benchmark.result.max_line_coverage_diff, + } + ) + + accumulated_results = dataclasses.asdict( + results_util.get_macro_insights(benchmarks) + ) + projects = list( + map(dataclasses.asdict, results_util.get_project_summary(benchmarks)) + ) + return Summary(benchmark_summaries, accumulated_results, projects) def _parse_arguments() -> argparse.Namespace: - """Parses command line args.""" - parser = argparse.ArgumentParser(description=( - 'Report generation tool reads raw experiment output files and ' - 'generates a summary json file used for trends report.')) - - parser.add_argument('--results-dir', - help='Directory with results from OSS-Fuzz-gen.', - required=True) - parser.add_argument( - '--output-path', - help='Full path to store the summary json for trends report.', - required=True) - parser.add_argument('--date', help='Date of the experiment.', required=True) - parser.add_argument('--name', - help='Name used for the benchmark results.', - required=True) - parser.add_argument('--url', - help='Name used for the benchmark results.', - required=True) - parser.add_argument('--benchmark-set', - help='Directory with benchmarks used for the experiment.', - required=True) - parser.add_argument('--run-timeout', - help='Timeout the experiment uses for each fuzz test.', - required=True, - type=int) - parser.add_argument( - '--num-samples', - help='Number of samples the experiment requests from the LLM.', - required=True, - type=int) - parser.add_argument( - '--llm-fix-limit', - help='How many times the experiment asks the LLM to fix broken tests.', - required=True, - type=int) - parser.add_argument('--model', - help='Model used for the experiment.', - required=True) - parser.add_argument('--commit-hash', - help='Commit hash of the currect git checkout.', - required=True) - parser.add_argument('--commit-date', - help='Commit date of the currect git checkout.', - required=True) - parser.add_argument('--git-branch', - help='Git branch of the currect checkout.', - required=True) - parser.add_argument('--tags', - help='Additional tags for this experiment.', - nargs="*", - type=str) - - return parser.parse_args() + """Parses command line args.""" + parser = argparse.ArgumentParser( + description=( + "Report generation tool reads raw experiment output files and " + "generates a summary json file used for trends report." + ) + ) + + parser.add_argument( + "--results-dir", help="Directory with results from OSS-Fuzz-gen.", required=True + ) + parser.add_argument( + "--output-path", + help="Full path to store the summary json for trends report.", + required=True, + ) + parser.add_argument("--date", help="Date of the experiment.", required=True) + parser.add_argument( + "--name", help="Name used for the benchmark results.", required=True + ) + parser.add_argument( + "--url", help="Name used for the benchmark results.", required=True + ) + parser.add_argument( + "--benchmark-set", + help="Directory with benchmarks used for the experiment.", + required=True, + ) + parser.add_argument( + "--run-timeout", + help="Timeout the experiment uses for each fuzz test.", + required=True, + type=int, + ) + parser.add_argument( + "--num-samples", + help="Number of samples the experiment requests from the LLM.", + required=True, + type=int, + ) + parser.add_argument( + "--llm-fix-limit", + help="How many times the experiment asks the LLM to fix broken tests.", + required=True, + type=int, + ) + parser.add_argument("--model", help="Model used for the experiment.", required=True) + parser.add_argument( + "--commit-hash", help="Commit hash of the currect git checkout.", required=True + ) + parser.add_argument( + "--commit-date", help="Commit date of the currect git checkout.", required=True + ) + parser.add_argument( + "--git-branch", help="Git branch of the currect checkout.", required=True + ) + parser.add_argument( + "--tags", help="Additional tags for this experiment.", nargs="*", type=str + ) + + return parser.parse_args() def main(): - args = _parse_arguments() - summary = dataclasses.asdict( - generate_summary( - Results(results_dir=args.results_dir, - benchmark_set=args.benchmark_set))) - tags = [args.model, args.benchmark_set] - if args.tags: - tags.extend(args.tags) - build_info = { - 'branch': args.git_branch, - 'commit_hash': args.commit_hash, - 'commit_date': args.commit_date, - } - summary_json = { - 'name': args.name, - 'date': args.date, - 'benchmark_set': args.benchmark_set, - 'llm_model': args.model, - 'url': args.url, - 'run_parameters': { - 'run_timeout': args.run_timeout, - 'num_samples': args.num_samples, - 'llm_fix_limit': args.llm_fix_limit, - }, - 'build_info': build_info, - 'tags': tags, - **summary, - } - - with FileSystem(args.output_path).open('w', encoding='utf-8') as f: - json.dump(summary_json, f) - - -if __name__ == '__main__': - logging.getLogger().setLevel(os.environ.get('LOGLEVEL', 'WARN').upper()) - main() + args = _parse_arguments() + summary = dataclasses.asdict( + generate_summary( + Results(results_dir=args.results_dir, benchmark_set=args.benchmark_set) + ) + ) + tags = [args.model, args.benchmark_set] + if args.tags: + tags.extend(args.tags) + build_info = { + "branch": args.git_branch, + "commit_hash": args.commit_hash, + "commit_date": args.commit_date, + } + summary_json = { + "name": args.name, + "date": args.date, + "benchmark_set": args.benchmark_set, + "llm_model": args.model, + "url": args.url, + "run_parameters": { + "run_timeout": args.run_timeout, + "num_samples": args.num_samples, + "llm_fix_limit": args.llm_fix_limit, + }, + "build_info": build_info, + "tags": tags, + **summary, + } + + with FileSystem(args.output_path).open("w", encoding="utf-8") as f: + json.dump(summary_json, f) + + +if __name__ == "__main__": + logging.getLogger().setLevel(os.environ.get("LOGLEVEL", "WARN").upper()) + main() diff --git a/report/web.py b/report/web.py index 7ab2173786..7bb5dcadce 100644 --- a/report/web.py +++ b/report/web.py @@ -29,311 +29,349 @@ import jinja2 -from report.common import (AccumulatedResult, Benchmark, FileSystem, Project, - Results, Sample, Target) +from report.common import ( + AccumulatedResult, + Benchmark, + FileSystem, + Project, + Results, + Sample, + Target, +) -LOCAL_HOST = '127.0.0.1' +LOCAL_HOST = "127.0.0.1" class JinjaEnv: - """JinjaEnv wraps the set up of a jinja2 environment.""" - - @staticmethod - def _urlencode_filter(s): - return urllib.parse.quote(s, safe='') - - @staticmethod - def _percent(num: float): - return f'{num*100:.2f}' - - @staticmethod - def _cov_report_link(link: str): - """Get URL to coverage report""" - if not link: - return '#' - - if 'gcb-experiment' not in link: - # In local rusn we don't overwrite the path - link_path = link - else: - path = link.removeprefix('gs://oss-fuzz-gcb-experiment-run-logs/') - link_path = f'https://llm-exp.oss-fuzz.com/{path}/report/linux/' - - # Check if this is a java benchmark, which will always have a period in - # the path, where C/C++ wont. - # TODO(David) refactor to have paths for links more controlled. - if '.' in link_path: - return link_path + 'index.html' - return link_path + 'report.html' - - @staticmethod - def _remove_trailing_empty_lines(code: str) -> str: - """Remove trailing empty lines from code.""" - if not code: - return "" - - try: - lines = code.splitlines() - while lines and not lines[-1].strip(): - lines.pop() - - return '\n'.join(lines) - except Exception as e: - logging.warning("Error in remove_trailing_empty_lines filter: %s", e) - return code # Return original code on error - - @staticmethod - def _splitlines(text: str) -> list: - """Split text into lines, similar to Python's splitlines().""" - if not text: - return [] - try: - return text.splitlines() - except Exception as e: - logging.warning("Error in splitlines filter: %s", e) - return [text] # Return original as single line on error - - def __init__(self, template_globals: Optional[Dict[str, Any]] = None): - self._env = jinja2.Environment( - loader=jinja2.FileSystemLoader("report/templates"), - autoescape=jinja2.select_autoescape()) - - self._env.filters['urlencode_filter'] = self._urlencode_filter - self._env.filters['percent'] = self._percent - self._env.filters['cov_report_link'] = self._cov_report_link - self._env.filters[ - 'remove_trailing_empty_lines'] = self._remove_trailing_empty_lines - self._env.filters['splitlines'] = self._splitlines - - if template_globals: - for key, val in template_globals.items(): - self._env.globals[key] = val - - def get_template_search_path(self) -> Optional[List[str]]: - """Returns the search path of the Jinja2 FileSystemLoader.""" - if isinstance(self._env.loader, jinja2.FileSystemLoader): - return self._env.loader.searchpath - return None - - def render(self, template_name: str, **kwargs): - """Render a template with variables provides through kwargs.""" - return self._env.get_template(template_name).render(**kwargs) + """JinjaEnv wraps the set up of a jinja2 environment.""" + + @staticmethod + def _urlencode_filter(s): + return urllib.parse.quote(s, safe="") + + @staticmethod + def _percent(num: float): + return f"{num*100:.2f}" + + @staticmethod + def _cov_report_link(link: str): + """Get URL to coverage report""" + if not link: + return "#" + + if "gcb-experiment" not in link: + # In local rusn we don't overwrite the path + link_path = link + else: + path = link.removeprefix("gs://oss-fuzz-gcb-experiment-run-logs/") + link_path = f"https://llm-exp.oss-fuzz.com/{path}/report/linux/" + + # Check if this is a java benchmark, which will always have a period in + # the path, where C/C++ wont. + # TODO(David) refactor to have paths for links more controlled. + if "." in link_path: + return link_path + "index.html" + return link_path + "report.html" + + @staticmethod + def _remove_trailing_empty_lines(code: str) -> str: + """Remove trailing empty lines from code.""" + if not code: + return "" + + try: + lines = code.splitlines() + while lines and not lines[-1].strip(): + lines.pop() + + return "\n".join(lines) + except Exception as e: + logging.warning("Error in remove_trailing_empty_lines filter: %s", e) + return code # Return original code on error + + @staticmethod + def _splitlines(text: str) -> list: + """Split text into lines, similar to Python's splitlines().""" + if not text: + return [] + try: + return text.splitlines() + except Exception as e: + logging.warning("Error in splitlines filter: %s", e) + return [text] # Return original as single line on error + + def __init__(self, template_globals: Optional[Dict[str, Any]] = None): + self._env = jinja2.Environment( + loader=jinja2.FileSystemLoader("report/templates"), + autoescape=jinja2.select_autoescape(), + ) + + self._env.filters["urlencode_filter"] = self._urlencode_filter + self._env.filters["percent"] = self._percent + self._env.filters["cov_report_link"] = self._cov_report_link + self._env.filters["remove_trailing_empty_lines"] = ( + self._remove_trailing_empty_lines + ) + self._env.filters["splitlines"] = self._splitlines + + if template_globals: + for key, val in template_globals.items(): + self._env.globals[key] = val + + def get_template_search_path(self) -> Optional[List[str]]: + """Returns the search path of the Jinja2 FileSystemLoader.""" + if isinstance(self._env.loader, jinja2.FileSystemLoader): + return self._env.loader.searchpath + return None + + def render(self, template_name: str, **kwargs): + """Render a template with variables provides through kwargs.""" + return self._env.get_template(template_name).render(**kwargs) class GenerateReport: - """ - GenerateReport helps generate an HTML report of experiment results. - - Args: - results: A Results object which takes care of reading results files and - processing them for the HTML templates. - output_dir: The directory to for the HTML report files. - jinja_env: A JinjaEnv object which provides the template render function. - """ - - def __init__(self, - results: Results, - jinja_env: JinjaEnv, - results_dir: str, - output_dir: str = 'results-report'): - self._results = results - self._output_dir = output_dir - self._jinja = jinja_env - self.results_dir = results_dir - - def read_timings(self): - with open(os.path.join(self.results_dir, 'report.json'), 'r') as f: - timings_dict = json.loads(f.read()) - return timings_dict - - def _copy_and_set_coverage_report(self, benchmark, sample): - """Prepares coverage reports in local runs.""" - coverage_path = os.path.join(self.results_dir, benchmark.id, - 'code-coverage-reports') - if not os.path.isdir(coverage_path): - return - - coverage_report = '' - for l in os.listdir(coverage_path): - if l.split('.')[0] == sample.id: - coverage_report = os.path.join(coverage_path, l) - - # On cloud runs there are two folders in code coverage reports, (report, - # textcov). If we have three files/dirs (linux, style.cssand textcov), then - # it's a local run. In that case copy over the code coverage reports so - # they are visible in the HTML page. - if coverage_report and os.path.isdir(coverage_report) and len( - os.listdir(coverage_report)) > 2: - # Copy coverage to reports out - dst = os.path.join(self._output_dir, 'sample', benchmark.id, 'coverage') - os.makedirs(dst, exist_ok=True) - dst = os.path.join(dst, sample.id) - - shutil.copytree(coverage_report, dst, dirs_exist_ok=True) - sample.result.coverage_report_path = \ - f'/sample/{benchmark.id}/coverage/{sample.id}/linux/' - - def _read_static_file(self, file_path_in_templates_subdir: str) -> str: - """Reads a static file from the templates directory.""" - - search_path = self._jinja.get_template_search_path() - - if not search_path: - logging.error( - 'Jinja FileSystemLoader\'s searchpath is empty, or loader is not ' - 'FileSystemLoader.') - return '' - templates_base_dir = search_path[0] - - full_file_path = os.path.join(templates_base_dir, - file_path_in_templates_subdir) - - try: - with open(full_file_path, 'r', encoding='utf-8') as f: - return f.read() - except FileNotFoundError: - logging.warning('Static file not found: %s', full_file_path) - return '' - except Exception as e: - logging.error('Error reading static file %s: %s', full_file_path, e) - return '' - - def generate(self): - """Generate and write every report file.""" - benchmarks = [] - samples_with_bugs = [] - for benchmark_id in self._results.list_benchmark_ids(): - results, targets = self._results.get_results(benchmark_id) - benchmark = self._results.match_benchmark(benchmark_id, results, targets) - benchmarks.append(benchmark) - samples = self._results.get_samples(results, targets) - prompt = self._results.get_prompt(benchmark.id) - - for sample in samples: - # If this is a local run then we need to set up coverage reports. - self._copy_and_set_coverage_report(benchmark, sample) - - self._write_benchmark_index(benchmark, samples, prompt) - self._write_benchmark_crash(benchmark, samples) - - for sample in samples: - if sample.result.crashes: - samples_with_bugs.append({'benchmark': benchmark, 'sample': sample}) - sample_targets = self._results.get_targets(benchmark.id, sample.id) - self._write_benchmark_sample(benchmark, sample, sample_targets) - - accumulated_results = self._results.get_macro_insights(benchmarks) - projects = self._results.get_project_summary(benchmarks) - coverage_language_gains = self._results.get_coverage_language_gains() - - time_results = self.read_timings() - - self._write_index_html(benchmarks, accumulated_results, time_results, - projects, samples_with_bugs, coverage_language_gains) - self._write_index_json(benchmarks) - - def _write(self, output_path: str, content: str): - """Utility write to filesystem function.""" - full_path = os.path.join(self._output_dir, output_path) - - parent_dir = os.path.dirname(full_path) - if not FileSystem(parent_dir).exists(): - FileSystem(parent_dir).makedirs() - - if not FileSystem(parent_dir).isdir(): - raise Exception( - f'Writing to {full_path} but {parent_dir} is not a directory!') - - with FileSystem(full_path).open('w', encoding='utf-8') as f: - f.write(content) - - def _write_index_html(self, benchmarks: List[Benchmark], - accumulated_results: AccumulatedResult, - time_results: dict[str, Any], projects: list[Project], - samples_with_bugs: list[dict[str, Any]], - coverage_language_gains: dict[str, Any]): - """Generate the report index.html and write to filesystem.""" - index_css_content = self._read_static_file('index/index.css') - index_js_content = self._read_static_file('index/index.js') - - rendered = self._jinja.render( - 'index/index.html', - benchmarks=benchmarks, - accumulated_results=accumulated_results, - time_results=time_results, - projects=projects, - samples_with_bugs=samples_with_bugs, - coverage_language_gains=coverage_language_gains, - index_css_content=index_css_content, - index_js_content=index_js_content) - self._write('index.html', rendered) - - def _write_index_json(self, benchmarks: List[Benchmark]): - """Generate the report index.json and write to filesystem.""" - rendered = self._jinja.render('index.json', benchmarks=benchmarks) - self._write('index.json', rendered) - - def _write_benchmark_index(self, benchmark: Benchmark, samples: List[Sample], - prompt: Optional[str]): - """Generate the benchmark index.html and write to filesystem.""" - benchmark_css_content = self._read_static_file('benchmark/benchmark.css') - benchmark_js_content = self._read_static_file('benchmark/benchmark.js') - - rendered = self._jinja.render('benchmark/benchmark.html', - benchmark=benchmark.id, - samples=samples, - prompt=prompt, - benchmark_css_content=benchmark_css_content, - benchmark_js_content=benchmark_js_content) - self._write(f'benchmark/{benchmark.id}/index.html', rendered) - - def _write_benchmark_crash(self, benchmark: Benchmark, samples: List[Sample]): - """Generate the benchmark crash.json and write to filesystem.""" - try: - rendered = self._jinja.render('crash.json', - benchmark=benchmark.signature, - samples=samples, - get_benchmark_final_target_code=partial( - self._results.get_final_target_code, - benchmark.id)) - self._write(f'benchmark/{benchmark.id}/crash.json', rendered) - except Exception as e: - logging.error('Failed to write benchmark/%s/crash.json:\n%s', - benchmark.id, e) - - def _write_benchmark_sample(self, benchmark: Benchmark, sample: Sample, - sample_targets: List[Target]): - """Generate the sample page and write to filesystem.""" - try: - # Ensure all required variables are available - logs = self._results.get_logs(benchmark.id, sample.id) or [] - run_logs = self._results.get_run_logs(benchmark.id, sample.id) or "" - triage = self._results.get_triage(benchmark.id, sample.id) or { - "result": "", - "triager_prompt": "" - } - - sample_css_content = self._read_static_file('sample/sample.css') - sample_js_content = self._read_static_file('sample/sample.js') - - rendered = self._jinja.render('sample/sample.html', - benchmark=benchmark, - benchmark_id=benchmark.id, - sample=sample, - logs=logs, - run_logs=run_logs, - triage=triage, - targets=sample_targets, - sample_css_content=sample_css_content, - sample_js_content=sample_js_content) - self._write(f'sample/{benchmark.id}/{sample.id}.html', rendered) - except Exception as e: - logging.error('Failed to write sample/%s/%s:\n%s', benchmark.id, - sample.id, e) - logging.error('Exception details: %s', traceback.format_exc()) - # Create a simple error page so users see something - try: - error_html = f""" + """ + GenerateReport helps generate an HTML report of experiment results. + + Args: + results: A Results object which takes care of reading results files and + processing them for the HTML templates. + output_dir: The directory to for the HTML report files. + jinja_env: A JinjaEnv object which provides the template render function. + """ + + def __init__( + self, + results: Results, + jinja_env: JinjaEnv, + results_dir: str, + output_dir: str = "results-report", + ): + self._results = results + self._output_dir = output_dir + self._jinja = jinja_env + self.results_dir = results_dir + + def read_timings(self): + with open(os.path.join(self.results_dir, "report.json"), "r") as f: + timings_dict = json.loads(f.read()) + return timings_dict + + def _copy_and_set_coverage_report(self, benchmark, sample): + """Prepares coverage reports in local runs.""" + coverage_path = os.path.join( + self.results_dir, benchmark.id, "code-coverage-reports" + ) + if not os.path.isdir(coverage_path): + return + + coverage_report = "" + for l in os.listdir(coverage_path): + if l.split(".")[0] == sample.id: + coverage_report = os.path.join(coverage_path, l) + + # On cloud runs there are two folders in code coverage reports, (report, + # textcov). If we have three files/dirs (linux, style.cssand textcov), then + # it's a local run. In that case copy over the code coverage reports so + # they are visible in the HTML page. + if ( + coverage_report + and os.path.isdir(coverage_report) + and len(os.listdir(coverage_report)) > 2 + ): + # Copy coverage to reports out + dst = os.path.join(self._output_dir, "sample", benchmark.id, "coverage") + os.makedirs(dst, exist_ok=True) + dst = os.path.join(dst, sample.id) + + shutil.copytree(coverage_report, dst, dirs_exist_ok=True) + sample.result.coverage_report_path = ( + f"/sample/{benchmark.id}/coverage/{sample.id}/linux/" + ) + + def _read_static_file(self, file_path_in_templates_subdir: str) -> str: + """Reads a static file from the templates directory.""" + + search_path = self._jinja.get_template_search_path() + + if not search_path: + logging.error( + "Jinja FileSystemLoader's searchpath is empty, or loader is not " + "FileSystemLoader." + ) + return "" + templates_base_dir = search_path[0] + + full_file_path = os.path.join(templates_base_dir, file_path_in_templates_subdir) + + try: + with open(full_file_path, "r", encoding="utf-8") as f: + return f.read() + except FileNotFoundError: + logging.warning("Static file not found: %s", full_file_path) + return "" + except Exception as e: + logging.error("Error reading static file %s: %s", full_file_path, e) + return "" + + def generate(self): + """Generate and write every report file.""" + benchmarks = [] + samples_with_bugs = [] + for benchmark_id in self._results.list_benchmark_ids(): + results, targets = self._results.get_results(benchmark_id) + benchmark = self._results.match_benchmark(benchmark_id, results, targets) + benchmarks.append(benchmark) + samples = self._results.get_samples(results, targets) + prompt = self._results.get_prompt(benchmark.id) + + for sample in samples: + # If this is a local run then we need to set up coverage reports. + self._copy_and_set_coverage_report(benchmark, sample) + + self._write_benchmark_index(benchmark, samples, prompt) + self._write_benchmark_crash(benchmark, samples) + + for sample in samples: + if sample.result.crashes: + samples_with_bugs.append({"benchmark": benchmark, "sample": sample}) + sample_targets = self._results.get_targets(benchmark.id, sample.id) + self._write_benchmark_sample(benchmark, sample, sample_targets) + + accumulated_results = self._results.get_macro_insights(benchmarks) + projects = self._results.get_project_summary(benchmarks) + coverage_language_gains = self._results.get_coverage_language_gains() + + time_results = self.read_timings() + + self._write_index_html( + benchmarks, + accumulated_results, + time_results, + projects, + samples_with_bugs, + coverage_language_gains, + ) + self._write_index_json(benchmarks) + + def _write(self, output_path: str, content: str): + """Utility write to filesystem function.""" + full_path = os.path.join(self._output_dir, output_path) + + parent_dir = os.path.dirname(full_path) + if not FileSystem(parent_dir).exists(): + FileSystem(parent_dir).makedirs() + + if not FileSystem(parent_dir).isdir(): + raise Exception( + f"Writing to {full_path} but {parent_dir} is not a directory!" + ) + + with FileSystem(full_path).open("w", encoding="utf-8") as f: + f.write(content) + + def _write_index_html( + self, + benchmarks: List[Benchmark], + accumulated_results: AccumulatedResult, + time_results: dict[str, Any], + projects: list[Project], + samples_with_bugs: list[dict[str, Any]], + coverage_language_gains: dict[str, Any], + ): + """Generate the report index.html and write to filesystem.""" + index_css_content = self._read_static_file("index/index.css") + index_js_content = self._read_static_file("index/index.js") + + rendered = self._jinja.render( + "index/index.html", + benchmarks=benchmarks, + accumulated_results=accumulated_results, + time_results=time_results, + projects=projects, + samples_with_bugs=samples_with_bugs, + coverage_language_gains=coverage_language_gains, + index_css_content=index_css_content, + index_js_content=index_js_content, + ) + self._write("index.html", rendered) + + def _write_index_json(self, benchmarks: List[Benchmark]): + """Generate the report index.json and write to filesystem.""" + rendered = self._jinja.render("index.json", benchmarks=benchmarks) + self._write("index.json", rendered) + + def _write_benchmark_index( + self, benchmark: Benchmark, samples: List[Sample], prompt: Optional[str] + ): + """Generate the benchmark index.html and write to filesystem.""" + benchmark_css_content = self._read_static_file("benchmark/benchmark.css") + benchmark_js_content = self._read_static_file("benchmark/benchmark.js") + + rendered = self._jinja.render( + "benchmark/benchmark.html", + benchmark=benchmark.id, + samples=samples, + prompt=prompt, + benchmark_css_content=benchmark_css_content, + benchmark_js_content=benchmark_js_content, + ) + self._write(f"benchmark/{benchmark.id}/index.html", rendered) + + def _write_benchmark_crash(self, benchmark: Benchmark, samples: List[Sample]): + """Generate the benchmark crash.json and write to filesystem.""" + try: + rendered = self._jinja.render( + "crash.json", + benchmark=benchmark.signature, + samples=samples, + get_benchmark_final_target_code=partial( + self._results.get_final_target_code, benchmark.id + ), + ) + self._write(f"benchmark/{benchmark.id}/crash.json", rendered) + except Exception as e: + logging.error( + "Failed to write benchmark/%s/crash.json:\n%s", benchmark.id, e + ) + + def _write_benchmark_sample( + self, benchmark: Benchmark, sample: Sample, sample_targets: List[Target] + ): + """Generate the sample page and write to filesystem.""" + try: + # Ensure all required variables are available + logs = self._results.get_logs(benchmark.id, sample.id) or [] + run_logs = self._results.get_run_logs(benchmark.id, sample.id) or "" + triage = self._results.get_triage(benchmark.id, sample.id) or { + "result": "", + "triager_prompt": "", + } + + sample_css_content = self._read_static_file("sample/sample.css") + sample_js_content = self._read_static_file("sample/sample.js") + + rendered = self._jinja.render( + "sample/sample.html", + benchmark=benchmark, + benchmark_id=benchmark.id, + sample=sample, + logs=logs, + run_logs=run_logs, + triage=triage, + targets=sample_targets, + sample_css_content=sample_css_content, + sample_js_content=sample_js_content, + ) + self._write(f"sample/{benchmark.id}/{sample.id}.html", rendered) + except Exception as e: + logging.error( + "Failed to write sample/%s/%s:\n%s", benchmark.id, sample.id, e + ) + logging.error("Exception details: %s", traceback.format_exc()) + # Create a simple error page so users see something + try: + error_html = f""" Error rendering {benchmark.id}/{sample.id} @@ -343,90 +381,96 @@ def _write_benchmark_sample(self, benchmark: Benchmark, sample: Sample, """ - self._write(f'sample/{benchmark.id}/{sample.id}.html', error_html) - except Exception: - pass # Ignore errors in error handling + self._write(f"sample/{benchmark.id}/{sample.id}.html", error_html) + except Exception: + pass # Ignore errors in error handling def generate_report(args: argparse.Namespace) -> None: - """Generates static web server files.""" - logging.info('Generating web page files in %s', args.output_dir) - results = Results(results_dir=args.results_dir, - benchmark_set=args.benchmark_set) - jinja_env = JinjaEnv(template_globals={'model': args.model}) - gr = GenerateReport(results=results, - jinja_env=jinja_env, - results_dir=args.results_dir, - output_dir=args.output_dir) - gr.generate() + """Generates static web server files.""" + logging.info("Generating web page files in %s", args.output_dir) + results = Results(results_dir=args.results_dir, benchmark_set=args.benchmark_set) + jinja_env = JinjaEnv(template_globals={"model": args.model}) + gr = GenerateReport( + results=results, + jinja_env=jinja_env, + results_dir=args.results_dir, + output_dir=args.output_dir, + ) + gr.generate() def launch_webserver(args): - """Launches a local web server to browse results.""" - logging.info('Launching webserver at %s:%d', LOCAL_HOST, args.port) - server = ThreadingHTTPServer((LOCAL_HOST, args.port), - partial(SimpleHTTPRequestHandler, - directory=args.output_dir)) - server.serve_forever() + """Launches a local web server to browse results.""" + logging.info("Launching webserver at %s:%d", LOCAL_HOST, args.port) + server = ThreadingHTTPServer( + (LOCAL_HOST, args.port), + partial(SimpleHTTPRequestHandler, directory=args.output_dir), + ) + server.serve_forever() def _parse_arguments() -> argparse.Namespace: - """Parses command line args.""" - parser = argparse.ArgumentParser(description=( - 'Report generation tool reads raw experiment output files and ' - 'generates a report in the form of HTML files in a directory hierarchy.')) - - parser.add_argument('--results-dir', - '-r', - help='Directory with results from OSS-Fuzz-gen.', - required=True) - parser.add_argument( - '--output-dir', - '-o', - help='Directory to store statically generated web report.', - default='results-report') - parser.add_argument('--benchmark-set', - '-b', - help='Directory with benchmarks used for the experiment.', - default='') - parser.add_argument('--model', - '-m', - help='Model used for the experiment.', - default='') - parser.add_argument('--serve', - '-s', - help='Will launch a web server if set.', - action='store_true') - parser.add_argument('--port', - '-p', - help='Port to launch webserver on.', - type=int, - default=8012) - - return parser.parse_args() + """Parses command line args.""" + parser = argparse.ArgumentParser( + description=( + "Report generation tool reads raw experiment output files and " + "generates a report in the form of HTML files in a directory hierarchy." + ) + ) + + parser.add_argument( + "--results-dir", + "-r", + help="Directory with results from OSS-Fuzz-gen.", + required=True, + ) + parser.add_argument( + "--output-dir", + "-o", + help="Directory to store statically generated web report.", + default="results-report", + ) + parser.add_argument( + "--benchmark-set", + "-b", + help="Directory with benchmarks used for the experiment.", + default="", + ) + parser.add_argument( + "--model", "-m", help="Model used for the experiment.", default="" + ) + parser.add_argument( + "--serve", "-s", help="Will launch a web server if set.", action="store_true" + ) + parser.add_argument( + "--port", "-p", help="Port to launch webserver on.", type=int, default=8012 + ) + + return parser.parse_args() def main(): - args = _parse_arguments() - - if not args.serve: - generate_report(args) - else: - logging.getLogger().setLevel(os.environ.get('LOGLEVEL', 'INFO').upper()) - # Launch web server - thread = threading.Thread(target=launch_webserver, args=(args,)) - thread.start() - - # Generate results continuously while the process runs. - while True: - generate_report(args) - try: - time.sleep(90) - except KeyboardInterrupt: - logging.info('Exiting.') - os._exit(0) - - -if __name__ == '__main__': - logging.getLogger().setLevel(os.environ.get('LOGLEVEL', 'WARN').upper()) - main() + args = _parse_arguments() + + if not args.serve: + generate_report(args) + else: + logging.getLogger().setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) + # Launch web server + thread = threading.Thread(target=launch_webserver, args=(args,)) + thread.start() + + # Generate results continuously while the process runs. + while True: + generate_report(args) + try: + time.sleep(90) + except KeyboardInterrupt: + logging.info("Exiting.") + os._exit(0) + + +if __name__ == "__main__": + logging.getLogger().setLevel(os.environ.get("LOGLEVEL", "WARN").upper()) + main() diff --git a/results.py b/results.py index 84b14d5535..f2b13385a8 100644 --- a/results.py +++ b/results.py @@ -23,716 +23,760 @@ class Result: - """A benchmark generation result.""" - benchmark: Benchmark - trial: int - work_dirs: WorkDirs - fuzz_target_source: str - build_script_source: str - author: Any - chat_history: dict - _repr_exclude = {'_repr_exclude', 'chat_history'} - - def __init__(self, - benchmark: Benchmark, - trial: int, - work_dirs: WorkDirs, - fuzz_target_source: str = '', - build_script_source: str = '', - author: Any = None, - chat_history: Optional[dict] = None, - default_success: bool = False) -> None: - self.benchmark = benchmark - self.trial = trial - self.work_dirs = work_dirs - self.fuzz_target_source = fuzz_target_source - self.build_script_source = build_script_source - self.author = author - self.chat_history = chat_history or {} - self.default_success = default_success - - def __repr__(self) -> str: - attributes = [ - f'{k}={v!r}' for k, v in vars(self).items() - if k not in self._repr_exclude - ] - return f'{self.__class__.__name__}({", ".join(attributes)})' - - @property - def success(self): - return self.default_success - - def to_dict(self) -> dict: - return { - 'function_signature': self.benchmark.function_signature, - 'project': self.benchmark.project, - 'project_commit': self.benchmark.commit, - 'project_language': self.benchmark.language, - 'trial': self.trial, - 'fuzz_target_source': self.fuzz_target_source, - 'build_script_source': self.build_script_source, - 'author': self.author.name if self.author else '', - 'chat_history': self.chat_history, - } + """A benchmark generation result.""" + + benchmark: Benchmark + trial: int + work_dirs: WorkDirs + fuzz_target_source: str + build_script_source: str + author: Any + chat_history: dict + _repr_exclude = {"_repr_exclude", "chat_history"} + + def __init__( + self, + benchmark: Benchmark, + trial: int, + work_dirs: WorkDirs, + fuzz_target_source: str = "", + build_script_source: str = "", + author: Any = None, + chat_history: Optional[dict] = None, + default_success: bool = False, + ) -> None: + self.benchmark = benchmark + self.trial = trial + self.work_dirs = work_dirs + self.fuzz_target_source = fuzz_target_source + self.build_script_source = build_script_source + self.author = author + self.chat_history = chat_history or {} + self.default_success = default_success + + def __repr__(self) -> str: + attributes = [ + f"{k}={v!r}" for k, v in vars(self).items() if k not in self._repr_exclude + ] + return f'{self.__class__.__name__}({", ".join(attributes)})' + + @property + def success(self): + return self.default_success + + def to_dict(self) -> dict: + return { + "function_signature": self.benchmark.function_signature, + "project": self.benchmark.project, + "project_commit": self.benchmark.commit, + "project_language": self.benchmark.language, + "trial": self.trial, + "fuzz_target_source": self.fuzz_target_source, + "build_script_source": self.build_script_source, + "author": self.author.name if self.author else "", + "chat_history": self.chat_history, + } # TODO: Make this class an attribute of Result, avoid too many attributes in one # class. class BuildResult(Result): - """A benchmark generation result with build info.""" - compiles: bool # Build success/failure. - compile_error: str # Build error message. - compile_log: str # Build full output. - binary_exists: bool # Fuzz target binary generated successfully. - is_function_referenced: bool # Fuzz target references function-under-test. - _repr_exclude = Result._repr_exclude | {'compile_log', 'compile_error'} - - def __init__(self, - benchmark: Benchmark, - trial: int, - work_dirs: WorkDirs, - compiles: bool = False, - compile_error: str = '', - compile_log: str = '', - binary_exists: bool = False, - is_function_referenced: bool = False, - fuzz_target_source: str = '', - build_script_source: str = '', - author: Any = None, - chat_history: Optional[dict] = None) -> None: - super().__init__(benchmark, trial, work_dirs, fuzz_target_source, - build_script_source, author, chat_history) - self.compiles = compiles - self.compile_error = compile_error - self.compile_log = compile_log - self.binary_exists = binary_exists - self.is_function_referenced = is_function_referenced - - def to_dict(self) -> dict: - return super().to_dict() | { - 'compiles': self.success, - 'compile_error': self.compile_error, - 'compile_log': self.compile_log, - 'binary_exists': self.binary_exists, - 'is_function_referenced': self.is_function_referenced, - } - - @property - def success(self): - return self.compiles and self.binary_exists and self.is_function_referenced + """A benchmark generation result with build info.""" + + compiles: bool # Build success/failure. + compile_error: str # Build error message. + compile_log: str # Build full output. + binary_exists: bool # Fuzz target binary generated successfully. + is_function_referenced: bool # Fuzz target references function-under-test. + _repr_exclude = Result._repr_exclude | {"compile_log", "compile_error"} + + def __init__( + self, + benchmark: Benchmark, + trial: int, + work_dirs: WorkDirs, + compiles: bool = False, + compile_error: str = "", + compile_log: str = "", + binary_exists: bool = False, + is_function_referenced: bool = False, + fuzz_target_source: str = "", + build_script_source: str = "", + author: Any = None, + chat_history: Optional[dict] = None, + ) -> None: + super().__init__( + benchmark, + trial, + work_dirs, + fuzz_target_source, + build_script_source, + author, + chat_history, + ) + self.compiles = compiles + self.compile_error = compile_error + self.compile_log = compile_log + self.binary_exists = binary_exists + self.is_function_referenced = is_function_referenced + + def to_dict(self) -> dict: + return super().to_dict() | { + "compiles": self.success, + "compile_error": self.compile_error, + "compile_log": self.compile_log, + "binary_exists": self.binary_exists, + "is_function_referenced": self.is_function_referenced, + } + + @property + def success(self): + return self.compiles and self.binary_exists and self.is_function_referenced # TODO: Make this class an attribute of Result, avoid too many attributes in one # class. class RunResult(BuildResult): - """The fuzzing run-time result info.""" - crashes: bool - run_error: str - crash_func: dict - run_log: str - coverage_summary: dict - coverage: float - line_coverage_diff: float - reproducer_path: str - artifact_path: str - sanitizer: str - textcov_diff: Optional[textcov.Textcov] - log_path: str - corpus_path: str - coverage_report_path: str - cov_pcs: int - total_pcs: int - _repr_exclude = BuildResult._repr_exclude | {'textcov_diff'} - err_type: str - crash_sypmtom: str - crash_stacks: Optional[list[list[str]]] - - def __init__( - self, - benchmark: Benchmark, - trial: int, - work_dirs: WorkDirs, - compiles: bool = False, - compile_error: str = '', - compile_log: str = '', - binary_exists: bool = False, - is_function_referenced: bool = False, - crashes: bool = False, # Runtime crash. - run_error: str = '', # Runtime crash error message. - crash_func: Optional[dict] = None, - run_log: str = '', # Full fuzzing output. - coverage_summary: Optional[dict] = None, - coverage: float = 0.0, - line_coverage_diff: float = 0.0, - textcov_diff: Optional[textcov.Textcov] = None, - reproducer_path: str = '', - artifact_path: str = '', - sanitizer: str = '', - log_path: str = '', - corpus_path: str = '', - coverage_report_path: str = '', - cov_pcs: int = 0, - total_pcs: int = 0, - err_type: str = SemanticCheckResult.NOT_APPLICABLE, - crash_sypmtom: str = '', - crash_stacks: Optional[list[list[str]]] = None, - fuzz_target_source: str = '', - build_script_source: str = '', - author: Any = None, - chat_history: Optional[dict] = None) -> None: - super().__init__(benchmark, trial, work_dirs, compiles, compile_error, - compile_log, binary_exists, is_function_referenced, - fuzz_target_source, build_script_source, author, - chat_history) - self.crashes = crashes - self.run_error = run_error - self.crash_func = crash_func or {} - self.run_log = run_log - self.coverage_summary = coverage_summary or {} - self.coverage = coverage - self.line_coverage_diff = line_coverage_diff - self.reproducer_path = reproducer_path - self.artifact_path = artifact_path - self.sanitizer = sanitizer - self.textcov_diff = textcov_diff - self.log_path = log_path - self.corpus_path = corpus_path - self.coverage_report_path = coverage_report_path - self.cov_pcs = cov_pcs - self.total_pcs = total_pcs - self.err_type = err_type - self.crash_sypmtom = crash_sypmtom - self.crash_stacks = crash_stacks or [] - - @property - def artifact_name(self) -> str: - return os.path.basename(self.artifact_path) - - def to_dict(self) -> dict: - return super().to_dict() | { - 'crashes': - self.crashes, - 'run_error': - self.run_error, - 'crash_func': - self.crash_func or {}, - 'run_log': - self.run_log, - 'coverage_summary': - self.coverage_summary or {}, - 'coverage': - self.coverage, - 'line_coverage_diff': - self.line_coverage_diff, - 'reproducer_path': - self.reproducer_path, - 'artifact_path': - self.artifact_path, - 'artifact_name': - self.artifact_name, - 'sanitizer': - self.sanitizer, - 'textcov_diff': - dataclasses.asdict(self.textcov_diff) if self.textcov_diff else '', - 'log_path': - self.log_path, - 'corpus_path': - self.corpus_path, - 'coverage_report_path': - self.coverage_report_path, - 'cov_pcs': - self.cov_pcs, - 'total_pcs': - self.total_pcs, - 'err_type': - self.err_type, - 'crash_sypmtom': - self.crash_sypmtom, - 'crash_stacks': - self.crash_stacks, - } - - # TODO(dongge): Define success property to show if the fuzz target was run. + """The fuzzing run-time result info.""" + + crashes: bool + run_error: str + crash_func: dict + run_log: str + coverage_summary: dict + coverage: float + line_coverage_diff: float + reproducer_path: str + artifact_path: str + sanitizer: str + textcov_diff: Optional[textcov.Textcov] + log_path: str + corpus_path: str + coverage_report_path: str + cov_pcs: int + total_pcs: int + _repr_exclude = BuildResult._repr_exclude | {"textcov_diff"} + err_type: str + crash_sypmtom: str + crash_stacks: Optional[list[list[str]]] + + def __init__( + self, + benchmark: Benchmark, + trial: int, + work_dirs: WorkDirs, + compiles: bool = False, + compile_error: str = "", + compile_log: str = "", + binary_exists: bool = False, + is_function_referenced: bool = False, + crashes: bool = False, # Runtime crash. + run_error: str = "", # Runtime crash error message. + crash_func: Optional[dict] = None, + run_log: str = "", # Full fuzzing output. + coverage_summary: Optional[dict] = None, + coverage: float = 0.0, + line_coverage_diff: float = 0.0, + textcov_diff: Optional[textcov.Textcov] = None, + reproducer_path: str = "", + artifact_path: str = "", + sanitizer: str = "", + log_path: str = "", + corpus_path: str = "", + coverage_report_path: str = "", + cov_pcs: int = 0, + total_pcs: int = 0, + err_type: str = SemanticCheckResult.NOT_APPLICABLE, + crash_sypmtom: str = "", + crash_stacks: Optional[list[list[str]]] = None, + fuzz_target_source: str = "", + build_script_source: str = "", + author: Any = None, + chat_history: Optional[dict] = None, + ) -> None: + super().__init__( + benchmark, + trial, + work_dirs, + compiles, + compile_error, + compile_log, + binary_exists, + is_function_referenced, + fuzz_target_source, + build_script_source, + author, + chat_history, + ) + self.crashes = crashes + self.run_error = run_error + self.crash_func = crash_func or {} + self.run_log = run_log + self.coverage_summary = coverage_summary or {} + self.coverage = coverage + self.line_coverage_diff = line_coverage_diff + self.reproducer_path = reproducer_path + self.artifact_path = artifact_path + self.sanitizer = sanitizer + self.textcov_diff = textcov_diff + self.log_path = log_path + self.corpus_path = corpus_path + self.coverage_report_path = coverage_report_path + self.cov_pcs = cov_pcs + self.total_pcs = total_pcs + self.err_type = err_type + self.crash_sypmtom = crash_sypmtom + self.crash_stacks = crash_stacks or [] + + @property + def artifact_name(self) -> str: + return os.path.basename(self.artifact_path) + + def to_dict(self) -> dict: + return super().to_dict() | { + "crashes": self.crashes, + "run_error": self.run_error, + "crash_func": self.crash_func or {}, + "run_log": self.run_log, + "coverage_summary": self.coverage_summary or {}, + "coverage": self.coverage, + "line_coverage_diff": self.line_coverage_diff, + "reproducer_path": self.reproducer_path, + "artifact_path": self.artifact_path, + "artifact_name": self.artifact_name, + "sanitizer": self.sanitizer, + "textcov_diff": ( + dataclasses.asdict(self.textcov_diff) if self.textcov_diff else "" + ), + "log_path": self.log_path, + "corpus_path": self.corpus_path, + "coverage_report_path": self.coverage_report_path, + "cov_pcs": self.cov_pcs, + "total_pcs": self.total_pcs, + "err_type": self.err_type, + "crash_sypmtom": self.crash_sypmtom, + "crash_stacks": self.crash_stacks, + } + + # TODO(dongge): Define success property to show if the fuzz target was run. class CrashResult(Result): - """The fuzzing run-time result with crash info.""" - stacktrace: str - true_bug: bool # True/False positive crash - insight: str # Reason and fixes for crashes - - def __init__(self, - *args, - stacktrace: str = '', - true_bug: bool = False, - insight: str = '', - **kwargs): - super().__init__(*args, **kwargs) - self.stacktrace = stacktrace - self.true_bug = true_bug - self.insight = insight - - def to_dict(self) -> dict: - return { - 'stacktrace': self.stacktrace, - 'true_bug': self.true_bug, - 'insight': self.insight, - } - - -class CoverageResult(): - """The fuzzing run-time result with code coverage info.""" - improve_required: bool = False - insight: str = '' # Reason and fixes for low coverage - suggestions: str = '' # Suggestions to fix fuzz target. - _repr_exclude = set() - - def to_dict(self) -> dict: - return { - 'improve_required': self.improve_required, - 'insights': self.insight, - 'suggestions': self.suggestions - } - - def __repr__(self) -> str: - attributes = [ - f'{k}={v!r}' for k, v in vars(self).items() - if k not in self._repr_exclude - ] - return f'{self.__class__.__name__}({", ".join(attributes)})' + """The fuzzing run-time result with crash info.""" + + stacktrace: str + true_bug: bool # True/False positive crash + insight: str # Reason and fixes for crashes + + def __init__( + self, + *args, + stacktrace: str = "", + true_bug: bool = False, + insight: str = "", + **kwargs, + ): + super().__init__(*args, **kwargs) + self.stacktrace = stacktrace + self.true_bug = true_bug + self.insight = insight + + def to_dict(self) -> dict: + return { + "stacktrace": self.stacktrace, + "true_bug": self.true_bug, + "insight": self.insight, + } + + +class CoverageResult: + """The fuzzing run-time result with code coverage info.""" + + improve_required: bool = False + insight: str = "" # Reason and fixes for low coverage + suggestions: str = "" # Suggestions to fix fuzz target. + _repr_exclude = set() + + def to_dict(self) -> dict: + return { + "improve_required": self.improve_required, + "insights": self.insight, + "suggestions": self.suggestions, + } + + def __repr__(self) -> str: + attributes = [ + f"{k}={v!r}" for k, v in vars(self).items() if k not in self._repr_exclude + ] + return f'{self.__class__.__name__}({", ".join(attributes)})' # TODO: Make this class an attribute of Result, avoid too many attributes in one # class. class AnalysisResult(Result): - """Analysis of the fuzzing run-time result.""" - run_result: RunResult - semantic_result: Optional[SemanticCheckResult] - crash_result: Optional[CrashResult] - coverage_result: Optional[CoverageResult] - - def __init__(self, - author: Any, - run_result: RunResult, - semantic_result: Optional[SemanticCheckResult] = None, - crash_result: Optional[CrashResult] = None, - coverage_result: Optional[CoverageResult] = None, - chat_history: Optional[dict] = None) -> None: - super().__init__(run_result.benchmark, run_result.trial, - run_result.work_dirs, run_result.fuzz_target_source, - run_result.build_script_source, author, chat_history) - self.run_result = run_result - self.semantic_result = semantic_result - self.crash_result = crash_result - self.coverage_result = coverage_result - - def to_dict(self) -> dict: - return self.run_result.to_dict() | { - 'semantic_result': - self.semantic_result.to_dict() if self.semantic_result else {}, - 'crash_result': - self.crash_result.to_dict() if self.crash_result else {}, - 'coverage_result': - self.coverage_result.to_dict() if self.coverage_result else {}, - } - - # TODO(maoyi): maybe we should redefine success property or - # rename the property - @property - def success(self) -> bool: - if self.semantic_result: - return not self.semantic_result.has_err - if self.coverage_result: - return not self.coverage_result.improve_required - return False - - @property - def crashes(self) -> bool: - return self.run_result.crashes - - @property - def coverage(self) -> float: - return self.run_result.coverage - - @property - def line_coverage_diff(self) -> float: - return self.run_result.line_coverage_diff - - @property - def run_log(self) -> str: - return self.run_result.run_log - - @property - def log_path(self) -> str: - return self.run_result.log_path + """Analysis of the fuzzing run-time result.""" + + run_result: RunResult + semantic_result: Optional[SemanticCheckResult] + crash_result: Optional[CrashResult] + coverage_result: Optional[CoverageResult] + + def __init__( + self, + author: Any, + run_result: RunResult, + semantic_result: Optional[SemanticCheckResult] = None, + crash_result: Optional[CrashResult] = None, + coverage_result: Optional[CoverageResult] = None, + chat_history: Optional[dict] = None, + ) -> None: + super().__init__( + run_result.benchmark, + run_result.trial, + run_result.work_dirs, + run_result.fuzz_target_source, + run_result.build_script_source, + author, + chat_history, + ) + self.run_result = run_result + self.semantic_result = semantic_result + self.crash_result = crash_result + self.coverage_result = coverage_result + + def to_dict(self) -> dict: + return self.run_result.to_dict() | { + "semantic_result": ( + self.semantic_result.to_dict() if self.semantic_result else {} + ), + "crash_result": self.crash_result.to_dict() if self.crash_result else {}, + "coverage_result": ( + self.coverage_result.to_dict() if self.coverage_result else {} + ), + } + + # TODO(maoyi): maybe we should redefine success property or + # rename the property + @property + def success(self) -> bool: + if self.semantic_result: + return not self.semantic_result.has_err + if self.coverage_result: + return not self.coverage_result.improve_required + return False + + @property + def crashes(self) -> bool: + return self.run_result.crashes + + @property + def coverage(self) -> float: + return self.run_result.coverage + + @property + def line_coverage_diff(self) -> float: + return self.run_result.line_coverage_diff + + @property + def run_log(self) -> str: + return self.run_result.run_log + + @property + def log_path(self) -> str: + return self.run_result.log_path class TrialResult: - """All history results for a trial of a benchmark in an experiment.""" - benchmark: Benchmark - trial: int - work_dirs: WorkDirs - result_history: list[Result] - - def __init__(self, - benchmark: Benchmark, - trial: int, - work_dirs: WorkDirs, - result_history: Optional[list[Result]] = None) -> None: - self.benchmark = benchmark - self.trial = trial - self.work_dirs = work_dirs - self.result_history = result_history or [] - - @property - def function_signature(self) -> str: - """Function signature of the benchmark.""" - return self.benchmark.function_signature - - @property - def project(self) -> str: - """Project name of the benchmark.""" - return self.benchmark.project - - @property - def project_commit(self) -> str: - """Project commit of the benchmark.""" - return self.benchmark.commit or '' - - @property - def project_language(self) -> str: - """Project language of the benchmark.""" - return self.benchmark.language - - @property - def best_analysis_result(self) -> Optional[AnalysisResult]: - """Last AnalysisResult in trial, prefer crashed and a non-semantic error.""" - # 1. Crashed for a non-semantic error - for result in self.result_history[::-1]: - #TODO(dongge): Refine this logic for coverage - if (isinstance(result, AnalysisResult) and result.crashes and - result.semantic_result and not result.semantic_result.has_err): - return result - - # 2. Crashed - for result in self.result_history[::-1]: - if isinstance(result, AnalysisResult) and result.crashes: - return result - - # 3. AnalysisResult - for result in self.result_history[::-1]: - if isinstance(result, AnalysisResult): - return result - return None - - @property - def best_result(self) -> Result: - """Best result in trial based on coverage.""" - # Preference order: - # 1. Highest coverage diff (AnalysisResult) - # 2. Highest coverage diff (RunResult) - # 3. Highest coverage (AnalysisResult) - # 3. Highest coverage (RunResult) - # 4. Last Build success (BuildResult) - # 5. Last Result - best_result = None - - max_cov_diff = 0 - for result in self.result_history: - if (isinstance(result, (RunResult, AnalysisResult)) and - result.line_coverage_diff >= max_cov_diff): - max_cov_diff = result.line_coverage_diff - best_result = result - if best_result: - return best_result - - max_cov = 0 - for result in self.result_history: - if (isinstance(result, (RunResult, AnalysisResult)) and - result.coverage >= max_cov): - max_cov = result.coverage - best_result = result - if best_result: - return best_result - - for result in self.result_history[::-1]: - if isinstance(result, BuildResult) and result.success: - return result - - return self.result_history[-1] - - @property - def fuzz_target_source(self) -> str: - """The best fuzz target source code.""" - result = self.best_result - if isinstance(result, AnalysisResult): - return result.run_result.fuzz_target_source - return self.best_result.fuzz_target_source - - @property - def build_script_source(self) -> str: - """The best build script source code.""" - result = self.best_result - if isinstance(result, AnalysisResult): - return result.run_result.build_script_source - return self.best_result.build_script_source - - @property - def author(self) -> Any: - """The author of the best result.""" - return self.best_result.author - - @property - def chat_history(self) -> dict: - """The chat history of the best result.""" - return self.best_result.chat_history - - @property - def build_success(self) -> bool: - """True if there is any build success.""" - return any(result.success - for result in self.result_history - if isinstance(result, BuildResult)) - - @property - def crashes(self) -> bool: - """True if there is any runtime crash.""" - return any(result.crashes - for result in self.result_history - if isinstance(result, RunResult)) - - @property - def coverage(self) -> float: - """Max line coverage diff.""" - return max((result.coverage + """All history results for a trial of a benchmark in an experiment.""" + + benchmark: Benchmark + trial: int + work_dirs: WorkDirs + result_history: list[Result] + + def __init__( + self, + benchmark: Benchmark, + trial: int, + work_dirs: WorkDirs, + result_history: Optional[list[Result]] = None, + ) -> None: + self.benchmark = benchmark + self.trial = trial + self.work_dirs = work_dirs + self.result_history = result_history or [] + + @property + def function_signature(self) -> str: + """Function signature of the benchmark.""" + return self.benchmark.function_signature + + @property + def project(self) -> str: + """Project name of the benchmark.""" + return self.benchmark.project + + @property + def project_commit(self) -> str: + """Project commit of the benchmark.""" + return self.benchmark.commit or "" + + @property + def project_language(self) -> str: + """Project language of the benchmark.""" + return self.benchmark.language + + @property + def best_analysis_result(self) -> Optional[AnalysisResult]: + """Last AnalysisResult in trial, prefer crashed and a non-semantic error.""" + # 1. Crashed for a non-semantic error + for result in self.result_history[::-1]: + # TODO(dongge): Refine this logic for coverage + if ( + isinstance(result, AnalysisResult) + and result.crashes + and result.semantic_result + and not result.semantic_result.has_err + ): + return result + + # 2. Crashed + for result in self.result_history[::-1]: + if isinstance(result, AnalysisResult) and result.crashes: + return result + + # 3. AnalysisResult + for result in self.result_history[::-1]: + if isinstance(result, AnalysisResult): + return result + return None + + @property + def best_result(self) -> Result: + """Best result in trial based on coverage.""" + # Preference order: + # 1. Highest coverage diff (AnalysisResult) + # 2. Highest coverage diff (RunResult) + # 3. Highest coverage (AnalysisResult) + # 3. Highest coverage (RunResult) + # 4. Last Build success (BuildResult) + # 5. Last Result + best_result = None + + max_cov_diff = 0 + for result in self.result_history: + if ( + isinstance(result, (RunResult, AnalysisResult)) + and result.line_coverage_diff >= max_cov_diff + ): + max_cov_diff = result.line_coverage_diff + best_result = result + if best_result: + return best_result + + max_cov = 0 + for result in self.result_history: + if ( + isinstance(result, (RunResult, AnalysisResult)) + and result.coverage >= max_cov + ): + max_cov = result.coverage + best_result = result + if best_result: + return best_result + + for result in self.result_history[::-1]: + if isinstance(result, BuildResult) and result.success: + return result + + return self.result_history[-1] + + @property + def fuzz_target_source(self) -> str: + """The best fuzz target source code.""" + result = self.best_result + if isinstance(result, AnalysisResult): + return result.run_result.fuzz_target_source + return self.best_result.fuzz_target_source + + @property + def build_script_source(self) -> str: + """The best build script source code.""" + result = self.best_result + if isinstance(result, AnalysisResult): + return result.run_result.build_script_source + return self.best_result.build_script_source + + @property + def author(self) -> Any: + """The author of the best result.""" + return self.best_result.author + + @property + def chat_history(self) -> dict: + """The chat history of the best result.""" + return self.best_result.chat_history + + @property + def build_success(self) -> bool: + """True if there is any build success.""" + return any( + result.success + for result in self.result_history + if isinstance(result, BuildResult) + ) + + @property + def crashes(self) -> bool: + """True if there is any runtime crash.""" + return any( + result.crashes + for result in self.result_history + if isinstance(result, RunResult) + ) + + @property + def coverage(self) -> float: + """Max line coverage diff.""" + return max( + ( + result.coverage for result in self.result_history - if isinstance(result, RunResult)), - default=0) - - @property - def line_coverage_diff(self) -> float: - """Max line coverage diff.""" - return max((result.line_coverage_diff + if isinstance(result, RunResult) + ), + default=0, + ) + + @property + def line_coverage_diff(self) -> float: + """Max line coverage diff.""" + return max( + ( + result.line_coverage_diff for result in self.result_history - if isinstance(result, RunResult)), - default=0) - - @property - def cov_pcs(self) -> int: - """Log path of the best result if it is RunResult.""" - return max((result.cov_pcs + if isinstance(result, RunResult) + ), + default=0, + ) + + @property + def cov_pcs(self) -> int: + """Log path of the best result if it is RunResult.""" + return max( + ( + result.cov_pcs for result in self.result_history - if isinstance(result, RunResult)), - default=0) - - @property - def total_pcs(self) -> int: - """Log path of the best result if it is RunResult.""" - return max((result.total_pcs + if isinstance(result, RunResult) + ), + default=0, + ) + + @property + def total_pcs(self) -> int: + """Log path of the best result if it is RunResult.""" + return max( + ( + result.total_pcs for result in self.result_history - if isinstance(result, RunResult)), - default=0) - - @property - def line_coverage_report(self) -> str: - """Max line coverage diff report.""" - for result in self.result_history: - if not isinstance(result, RunResult): - continue - if result.line_coverage_diff == self.line_coverage_diff: - return result.coverage_report_path - return '' - - @property - def textcov_diff(self) -> textcov.Textcov: - """Sum textcov diff.""" - all_textcov = textcov.Textcov() - for result in self.result_history: - if isinstance(result, RunResult) and result.textcov_diff: - all_textcov.merge(result.textcov_diff) - return all_textcov - - @property - def run_error(self) -> str: - """Run error of the best result if it is RunResult.""" - result = self.best_result - if isinstance(result, RunResult): - return result.run_error - if isinstance(result, AnalysisResult): - return result.run_result.run_error - return '' - - @property - def run_log(self) -> str: - """Run log of the best result if it is RunResult.""" - result = self.best_result - if isinstance(result, (RunResult, AnalysisResult)): - return result.run_log - return '' - - @property - def log_path(self) -> str: - """Log path of the best result if it is RunResult.""" - result = self.best_result - if isinstance(result, (RunResult, AnalysisResult)): - return result.log_path - return '' - - @property - def is_semantic_error(self) -> bool: - """Validates if the best AnalysisResult has semantic error.""" - result = self.best_analysis_result - if result and result.semantic_result: - return result.semantic_result.has_err - return False - - @property - def semantic_error(self) -> str: - """Semantic error type of the best AnalysisResult.""" - result = self.best_analysis_result - if result and result.semantic_result: - return result.semantic_result.type - return '-' - - def to_dict(self) -> dict: - return { - 'trial': - self.trial, - 'function_signature': - self.function_signature, - 'project': - self.project, - 'project_commit': - self.project_commit, - 'project_language': - self.project_language, - 'fuzz_target_source': - self.fuzz_target_source, - 'build_script_source': - self.build_script_source, - 'author': - self.author.name if self.author else '', - 'chat_history': - self.chat_history, - 'compiles': - self.build_success, - 'crashes': - self.crashes, - 'coverage': - self.coverage, - 'line_coverage_diff': - self.line_coverage_diff, - 'cov_pcs': - self.cov_pcs, - 'total_pcs': - self.total_pcs, - 'line_coverage_report': - self.line_coverage_report, - 'textcov_diff': - dataclasses.asdict(self.textcov_diff) if self.textcov_diff else '', - 'run_error': - self.run_error, - 'run_log': - self.run_log, - 'log_path': - self.log_path, - 'is_semantic_error': - self.is_semantic_error, - 'semantic_error': - self.semantic_error, - } + if isinstance(result, RunResult) + ), + default=0, + ) + + @property + def line_coverage_report(self) -> str: + """Max line coverage diff report.""" + for result in self.result_history: + if not isinstance(result, RunResult): + continue + if result.line_coverage_diff == self.line_coverage_diff: + return result.coverage_report_path + return "" + + @property + def textcov_diff(self) -> textcov.Textcov: + """Sum textcov diff.""" + all_textcov = textcov.Textcov() + for result in self.result_history: + if isinstance(result, RunResult) and result.textcov_diff: + all_textcov.merge(result.textcov_diff) + return all_textcov + + @property + def run_error(self) -> str: + """Run error of the best result if it is RunResult.""" + result = self.best_result + if isinstance(result, RunResult): + return result.run_error + if isinstance(result, AnalysisResult): + return result.run_result.run_error + return "" + + @property + def run_log(self) -> str: + """Run log of the best result if it is RunResult.""" + result = self.best_result + if isinstance(result, (RunResult, AnalysisResult)): + return result.run_log + return "" + + @property + def log_path(self) -> str: + """Log path of the best result if it is RunResult.""" + result = self.best_result + if isinstance(result, (RunResult, AnalysisResult)): + return result.log_path + return "" + + @property + def is_semantic_error(self) -> bool: + """Validates if the best AnalysisResult has semantic error.""" + result = self.best_analysis_result + if result and result.semantic_result: + return result.semantic_result.has_err + return False + + @property + def semantic_error(self) -> str: + """Semantic error type of the best AnalysisResult.""" + result = self.best_analysis_result + if result and result.semantic_result: + return result.semantic_result.type + return "-" + + def to_dict(self) -> dict: + return { + "trial": self.trial, + "function_signature": self.function_signature, + "project": self.project, + "project_commit": self.project_commit, + "project_language": self.project_language, + "fuzz_target_source": self.fuzz_target_source, + "build_script_source": self.build_script_source, + "author": self.author.name if self.author else "", + "chat_history": self.chat_history, + "compiles": self.build_success, + "crashes": self.crashes, + "coverage": self.coverage, + "line_coverage_diff": self.line_coverage_diff, + "cov_pcs": self.cov_pcs, + "total_pcs": self.total_pcs, + "line_coverage_report": self.line_coverage_report, + "textcov_diff": ( + dataclasses.asdict(self.textcov_diff) if self.textcov_diff else "" + ), + "run_error": self.run_error, + "run_log": self.run_log, + "log_path": self.log_path, + "is_semantic_error": self.is_semantic_error, + "semantic_error": self.semantic_error, + } class PreWritingResult(Result): - """ The result of the function analyzer.""" - result_available: bool - requirements: list[str] - explanation: str - - def __init__(self, - benchmark: Benchmark, - trial: int, - work_dirs: WorkDirs, - result_available: bool, - requirements: Optional[list[str]] = None, - explanation: str = '', - fuzz_target_source: str = '', - build_script_source: str = '', - author: Any = None, - chat_history: Optional[dict] = None, - default_success: bool = False) -> None: - - super().__init__(benchmark, trial, work_dirs, fuzz_target_source, - build_script_source, author, chat_history, default_success) - - self.result_available = result_available - if result_available and requirements is not None: - self.requirements = requirements - self.explanation = explanation + """The result of the function analyzer.""" + + result_available: bool + requirements: list[str] + explanation: str + + def __init__( + self, + benchmark: Benchmark, + trial: int, + work_dirs: WorkDirs, + result_available: bool, + requirements: Optional[list[str]] = None, + explanation: str = "", + fuzz_target_source: str = "", + build_script_source: str = "", + author: Any = None, + chat_history: Optional[dict] = None, + default_success: bool = False, + ) -> None: + + super().__init__( + benchmark, + trial, + work_dirs, + fuzz_target_source, + build_script_source, + author, + chat_history, + default_success, + ) + + self.result_available = result_available + if result_available and requirements is not None: + self.requirements = requirements + self.explanation = explanation class BenchmarkResult: - """All trial results for a benchmark in an experiment.""" - benchmark: Benchmark - work_dirs: WorkDirs - trial_results: list[TrialResult] - - def __init__(self, - benchmark: Benchmark, - work_dirs: WorkDirs, - trial_results: Optional[list[TrialResult]] = None) -> None: - self.benchmark = benchmark - self.work_dirs = work_dirs - self.trial_results = trial_results or [] - - @property - def trial_count(self) -> int: - """Total number of trials.""" - return len(self.trial_results) - - @property - def build_success_count(self) -> int: - """Build success count.""" - return sum(result.build_success for result in self.trial_results) - - @property - def build_success_rate(self) -> float: - """Build success Ratio.""" - if self.trial_count == 0: - return 0 - return self.build_success_count / self.trial_count - - @property - def crash_rate(self) -> float: - """True if there is any run crash not caused by semantic error.""" - if self.trial_count == 0: - return 0 - return sum( - result.crashes for result in self.trial_results) / self.trial_count - - @property - def coverage(self) -> float: - """Max line coverage diff.""" - return max((result.coverage for result in self.trial_results), default=0) - - @property - def line_coverage_diff(self) -> float: - """Max line coverage diff.""" - return max((result.line_coverage_diff for result in self.trial_results), - default=0) - - @property - def line_coverage_report(self) -> str: - """Max line coverage diff report.""" - for result in self.trial_results: - if result.line_coverage_diff == self.line_coverage_diff: - return result.line_coverage_report - return '' - - @property - def textcov_diff(self) -> textcov.Textcov: - """Sum textcov diff.""" - all_textcov = textcov.Textcov() - for result in self.trial_results: - all_textcov.merge(result.textcov_diff) - return all_textcov + """All trial results for a benchmark in an experiment.""" + + benchmark: Benchmark + work_dirs: WorkDirs + trial_results: list[TrialResult] + + def __init__( + self, + benchmark: Benchmark, + work_dirs: WorkDirs, + trial_results: Optional[list[TrialResult]] = None, + ) -> None: + self.benchmark = benchmark + self.work_dirs = work_dirs + self.trial_results = trial_results or [] + + @property + def trial_count(self) -> int: + """Total number of trials.""" + return len(self.trial_results) + + @property + def build_success_count(self) -> int: + """Build success count.""" + return sum(result.build_success for result in self.trial_results) + + @property + def build_success_rate(self) -> float: + """Build success Ratio.""" + if self.trial_count == 0: + return 0 + return self.build_success_count / self.trial_count + + @property + def crash_rate(self) -> float: + """True if there is any run crash not caused by semantic error.""" + if self.trial_count == 0: + return 0 + return sum(result.crashes for result in self.trial_results) / self.trial_count + + @property + def coverage(self) -> float: + """Max line coverage diff.""" + return max((result.coverage for result in self.trial_results), default=0) + + @property + def line_coverage_diff(self) -> float: + """Max line coverage diff.""" + return max( + (result.line_coverage_diff for result in self.trial_results), default=0 + ) + + @property + def line_coverage_report(self) -> str: + """Max line coverage diff report.""" + for result in self.trial_results: + if result.line_coverage_diff == self.line_coverage_diff: + return result.line_coverage_report + return "" + + @property + def textcov_diff(self) -> textcov.Textcov: + """Sum textcov diff.""" + all_textcov = textcov.Textcov() + for result in self.trial_results: + all_textcov.merge(result.textcov_diff) + return all_textcov diff --git a/run_all_experiments.py b/run_all_experiments.py index 03710a4af4..23abdb9a04 100755 --- a/run_all_experiments.py +++ b/run_all_experiments.py @@ -41,7 +41,7 @@ # NUM_EXP controls the number of experiments in parallel, while each experiment # will evaluate {run_one_experiment.NUM_EVA, default 3} fuzz targets in # parallel. -NUM_EXP = int(os.getenv('LLM_NUM_EXP', '2')) +NUM_EXP = int(os.getenv("LLM_NUM_EXP", "2")) # Default LLM hyper-parameters. MAX_TOKENS: int = run_one_experiment.MAX_TOKENS @@ -50,551 +50,606 @@ TEMPERATURE: float = run_one_experiment.TEMPERATURE RESULTS_DIR: str = run_one_experiment.RESULTS_DIR -JSON_REPORT = 'report.json' -TIME_STAMP_FMT = '%Y-%m-%d %H:%M:%S' +JSON_REPORT = "report.json" +TIME_STAMP_FMT = "%Y-%m-%d %H:%M:%S" -WORK_DIR = '' +WORK_DIR = "" -LOG_LEVELS = ['debug', 'info'] -LOG_FMT = ('%(asctime)s.%(msecs)03d %(levelname)s ' - '%(module)s - %(funcName)s: %(message)s') +LOG_LEVELS = ["debug", "info"] +LOG_FMT = ( + "%(asctime)s.%(msecs)03d %(levelname)s " "%(module)s - %(funcName)s: %(message)s" +) class Result: - benchmark: benchmarklib.Benchmark - result: run_one_experiment.AggregatedResult | str + benchmark: benchmarklib.Benchmark + result: run_one_experiment.AggregatedResult | str - def __init__(self, benchmark, result): - self.benchmark = benchmark - self.result = result + def __init__(self, benchmark, result): + self.benchmark = benchmark + self.result = result def generate_benchmarks(args: argparse.Namespace) -> None: - """Generates benchmarks, write to filesystem and set args benchmark dir.""" - logger.info('Generating benchmarks.') - benchmark_dir = introspector.get_next_generated_benchmarks_dir() - logger.info('Setting benchmark directory to %s.', benchmark_dir) - os.makedirs(benchmark_dir) - args.benchmarks_directory = benchmark_dir - benchmark_oracles = [ - heuristic.strip() for heuristic in args.generate_benchmarks.split(',') - ] - projects_to_target = [ - project.strip() - for project in args.generate_benchmarks_projects.split(',') - ] - for project in projects_to_target: - project_lang = oss_fuzz_checkout.get_project_language(project) - benchmarks = introspector.populate_benchmarks_using_introspector( - project, project_lang, args.generate_benchmarks_max, benchmark_oracles) - if benchmarks: - benchmarklib.Benchmark.to_yaml(benchmarks, outdir=benchmark_dir) + """Generates benchmarks, write to filesystem and set args benchmark dir.""" + logger.info("Generating benchmarks.") + benchmark_dir = introspector.get_next_generated_benchmarks_dir() + logger.info("Setting benchmark directory to %s.", benchmark_dir) + os.makedirs(benchmark_dir) + args.benchmarks_directory = benchmark_dir + benchmark_oracles = [ + heuristic.strip() for heuristic in args.generate_benchmarks.split(",") + ] + projects_to_target = [ + project.strip() for project in args.generate_benchmarks_projects.split(",") + ] + for project in projects_to_target: + project_lang = oss_fuzz_checkout.get_project_language(project) + benchmarks = introspector.populate_benchmarks_using_introspector( + project, project_lang, args.generate_benchmarks_max, benchmark_oracles + ) + if benchmarks: + benchmarklib.Benchmark.to_yaml(benchmarks, outdir=benchmark_dir) def prepare_experiment_targets( - args: argparse.Namespace) -> list[benchmarklib.Benchmark]: - """Constructs a list of experiment configs based on the |BENCHMARK_DIR| and + args: argparse.Namespace, +) -> list[benchmarklib.Benchmark]: + """Constructs a list of experiment configs based on the |BENCHMARK_DIR| and |args| setting.""" - benchmark_yamls = [] - if args.benchmark_yaml: - logger.info( - 'A benchmark yaml file %s is provided. Will use it and ignore ' - 'the files in %s.', args.benchmark_yaml, args.benchmarks_directory) - benchmark_yamls = [args.benchmark_yaml] - else: - if args.generate_benchmarks: - generate_benchmarks(args) - - benchmark_yamls = [ - os.path.join(args.benchmarks_directory, file) - for file in os.listdir(args.benchmarks_directory) - if file.endswith('.yaml') or file.endswith('yml') - ] - experiment_configs = [] - for benchmark_file in benchmark_yamls: - experiment_configs.extend(benchmarklib.Benchmark.from_yaml(benchmark_file)) + benchmark_yamls = [] + if args.benchmark_yaml: + logger.info( + "A benchmark yaml file %s is provided. Will use it and ignore " + "the files in %s.", + args.benchmark_yaml, + args.benchmarks_directory, + ) + benchmark_yamls = [args.benchmark_yaml] + else: + if args.generate_benchmarks: + generate_benchmarks(args) - return experiment_configs + benchmark_yamls = [ + os.path.join(args.benchmarks_directory, file) + for file in os.listdir(args.benchmarks_directory) + if file.endswith(".yaml") or file.endswith("yml") + ] + experiment_configs = [] + for benchmark_file in benchmark_yamls: + experiment_configs.extend(benchmarklib.Benchmark.from_yaml(benchmark_file)) + + return experiment_configs def run_experiments(benchmark: benchmarklib.Benchmark, args) -> Result: - """Runs an experiment based on the |benchmark| config.""" - try: - work_dirs = WorkDirs(os.path.join(args.work_dir, f'output-{benchmark.id}')) - args.work_dirs = work_dirs - model = models.LLM.setup( - ai_binary=args.ai_binary, - name=args.model, - max_tokens=MAX_TOKENS, - num_samples=args.num_samples, - temperature=args.temperature, - temperature_list=args.temperature_list, + """Runs an experiment based on the |benchmark| config.""" + try: + work_dirs = WorkDirs(os.path.join(args.work_dir, f"output-{benchmark.id}")) + args.work_dirs = work_dirs + model = models.LLM.setup( + ai_binary=args.ai_binary, + name=args.model, + max_tokens=MAX_TOKENS, + num_samples=args.num_samples, + temperature=args.temperature, + temperature_list=args.temperature_list, + ) + + result = run_one_experiment.run( + benchmark=benchmark, model=model, args=args, work_dirs=work_dirs + ) + return Result(benchmark, result) + except Exception as e: + logger.error("Exception while running experiment: %s", str(e)) + traceback.print_exc() + return Result(benchmark, f"Exception while running experiment: {e}") + + +def parse_args() -> argparse.Namespace: + """Parses command line arguments.""" + parser = argparse.ArgumentParser( + description="Run all experiments that evaluates all target functions." + ) + parser.add_argument( + "-n", + "--num-samples", + type=int, + default=NUM_SAMPLES, + help="The number of samples to request from LLM.", + ) + parser.add_argument( + "-t", + "--temperature", + type=float, + default=TEMPERATURE, + help=( + "A value between 0 and 1 representing the variety of the targets " + "generated by LLM." + ), + ) + parser.add_argument( + "-tr", + "--temperature-list", + nargs="*", + type=float, + default=[], + help=( + "A list of values representing the temperatures will be used by " + "each sample LLM query." + ), + ) + parser.add_argument( + "-c", + "--cloud-experiment-name", + type=str, + default="", + help="The name of the cloud experiment.", + ) + parser.add_argument( + "-cb", + "--cloud-experiment-bucket", + type=str, + default="", + help="A gcloud bucket to store experiment files.", + ) + parser.add_argument("-b", "--benchmarks-directory", type=str) + parser.add_argument( + "-y", "--benchmark-yaml", type=str, help="A benchmark YAML file." + ) + parser.add_argument("-to", "--run-timeout", type=int, default=RUN_TIMEOUT) + parser.add_argument( + "-a", + "--ai-binary", + required=False, + nargs="?", + const=os.getenv("AI_BINARY", ""), + default="", + type=str, + ) + parser.add_argument( + "-l", + "--model", + default=models.DefaultModel.name, + help=("Models available: " f'{", ".join(models.LLM.all_llm_names())}.'), + ) + parser.add_argument( + "-td", + "--template-directory", + type=str, + default=prompt_builder.DEFAULT_TEMPLATE_DIR, + ) + parser.add_argument("-w", "--work-dir", default=RESULTS_DIR) + parser.add_argument( + "--context", + action="store_true", + default=False, + help="Add context to function under test.", + ) + parser.add_argument( + "-e", + "--introspector-endpoint", + type=str, + default=introspector.DEFAULT_INTROSPECTOR_ENDPOINT, + ) + parser.add_argument( + "-lo", + "--log-level", + help=f'Sets the logging level. Options available: {", ".join(LOG_LEVELS)}.', + default="info", + ) + parser.add_argument( + "-of", + "--oss-fuzz-dir", + help="OSS-Fuzz dir path to use. Create temporary directory by default.", + default="", + ) + parser.add_argument( + "-g", + "--generate-benchmarks", + help=( + "Generate benchmarks and use those for analysis. This is a string " + "of comma-separated heuristics to use when identifying benchmark " + "targets. Options available: " + f'{", ".join(introspector.get_oracle_dict().keys())}.' + ), + type=str, + ) + parser.add_argument( + "-gp", + "--generate-benchmarks-projects", + help="Projects to generate benchmarks for in a comma separated string.", + type=str, + ) + parser.add_argument( + "-gm", + "--generate-benchmarks-max", + help="Max targets to generate per benchmark heuristic.", + type=int, + default=5, + ) + parser.add_argument( + "--delay", + type=int, + default=0, + help=( + "Delay each experiment by certain seconds (e.g., 10s) to avoid " + "exceeding quota limit in large scale experiments." + ), + ) + parser.add_argument( + "-p", + "--prompt-builder", + help="The prompt builder to use for harness generation.", + default="DEFAULT", + ) + parser.add_argument( + "-ag", + "--agent", + action="store_true", + default=False, + help="Enables agent enhancement.", + ) + parser.add_argument( + "-mr", "--max-round", type=int, default=100, help="Max trial round for agents." ) - result = run_one_experiment.run(benchmark=benchmark, - model=model, - args=args, - work_dirs=work_dirs) - return Result(benchmark, result) - except Exception as e: - logger.error('Exception while running experiment: %s', str(e)) - traceback.print_exc() - return Result(benchmark, f'Exception while running experiment: {e}') + args = parser.parse_args() + if args.num_samples: + assert args.num_samples > 0, "--num-samples must take a positive integer." + + if args.temperature: + assert 2 >= args.temperature >= 0, "--temperature must be within 0 and 2." + + benchmark_yaml = args.benchmark_yaml + if benchmark_yaml: + assert benchmark_yaml.endswith(".yaml") or benchmark_yaml.endswith( + "yml" + ), "--benchmark-yaml needs to take an YAML file." + + bench_yml = bool(benchmark_yaml) + bench_dir = bool(args.benchmarks_directory) + bench_gen = bool(args.generate_benchmarks) + num_options = int(bench_yml) + int(bench_dir) + int(bench_gen) + assert num_options == 1, ( + "One and only one of --benchmark-yaml, --benchmarks-directory and " + "--generate-benchmarks. --benchmark-yaml takes one benchmark YAML file, " + "--benchmarks-directory takes: a directory of them and " + "--generate-benchmarks generates them during analysis." + ) + # Validate templates. + assert os.path.isdir( + args.template_directory + ), "--template-directory must be an existing directory." -def parse_args() -> argparse.Namespace: - """Parses command line arguments.""" - parser = argparse.ArgumentParser( - description='Run all experiments that evaluates all target functions.') - parser.add_argument('-n', - '--num-samples', - type=int, - default=NUM_SAMPLES, - help='The number of samples to request from LLM.') - parser.add_argument( - '-t', - '--temperature', - type=float, - default=TEMPERATURE, - help=('A value between 0 and 1 representing the variety of the targets ' - 'generated by LLM.')) - parser.add_argument( - '-tr', - '--temperature-list', - nargs='*', - type=float, - default=[], - help=('A list of values representing the temperatures will be used by ' - 'each sample LLM query.')) - parser.add_argument('-c', - '--cloud-experiment-name', - type=str, - default='', - help='The name of the cloud experiment.') - parser.add_argument('-cb', - '--cloud-experiment-bucket', - type=str, - default='', - help='A gcloud bucket to store experiment files.') - parser.add_argument('-b', '--benchmarks-directory', type=str) - parser.add_argument('-y', - '--benchmark-yaml', - type=str, - help='A benchmark YAML file.') - parser.add_argument('-to', '--run-timeout', type=int, default=RUN_TIMEOUT) - parser.add_argument('-a', - '--ai-binary', - required=False, - nargs='?', - const=os.getenv('AI_BINARY', ''), - default='', - type=str) - parser.add_argument('-l', - '--model', - default=models.DefaultModel.name, - help=('Models available: ' - f'{", ".join(models.LLM.all_llm_names())}.')) - parser.add_argument('-td', - '--template-directory', - type=str, - default=prompt_builder.DEFAULT_TEMPLATE_DIR) - parser.add_argument('-w', '--work-dir', default=RESULTS_DIR) - parser.add_argument('--context', - action='store_true', - default=False, - help='Add context to function under test.') - parser.add_argument('-e', - '--introspector-endpoint', - type=str, - default=introspector.DEFAULT_INTROSPECTOR_ENDPOINT) - parser.add_argument( - '-lo', - '--log-level', - help= - f'Sets the logging level. Options available: {", ".join(LOG_LEVELS)}.', - default='info') - parser.add_argument( - '-of', - '--oss-fuzz-dir', - help='OSS-Fuzz dir path to use. Create temporary directory by default.', - default='') - parser.add_argument( - '-g', - '--generate-benchmarks', - help=('Generate benchmarks and use those for analysis. This is a string ' - 'of comma-separated heuristics to use when identifying benchmark ' - 'targets. Options available: ' - f'{", ".join(introspector.get_oracle_dict().keys())}.'), - type=str) - parser.add_argument( - '-gp', - '--generate-benchmarks-projects', - help='Projects to generate benchmarks for in a comma separated string.', - type=str) - parser.add_argument('-gm', - '--generate-benchmarks-max', - help='Max targets to generate per benchmark heuristic.', - type=int, - default=5) - parser.add_argument( - '--delay', - type=int, - default=0, - help=('Delay each experiment by certain seconds (e.g., 10s) to avoid ' - 'exceeding quota limit in large scale experiments.')) - parser.add_argument('-p', - '--prompt-builder', - help='The prompt builder to use for harness generation.', - default='DEFAULT') - parser.add_argument('-ag', - '--agent', - action='store_true', - default=False, - help='Enables agent enhancement.') - parser.add_argument('-mr', - '--max-round', - type=int, - default=100, - help='Max trial round for agents.') - - args = parser.parse_args() - if args.num_samples: - assert args.num_samples > 0, '--num-samples must take a positive integer.' - - if args.temperature: - assert 2 >= args.temperature >= 0, '--temperature must be within 0 and 2.' - - benchmark_yaml = args.benchmark_yaml - if benchmark_yaml: - assert (benchmark_yaml.endswith('.yaml') or - benchmark_yaml.endswith('yml')), ( - "--benchmark-yaml needs to take an YAML file.") - - bench_yml = bool(benchmark_yaml) - bench_dir = bool(args.benchmarks_directory) - bench_gen = bool(args.generate_benchmarks) - num_options = int(bench_yml) + int(bench_dir) + int(bench_gen) - assert num_options == 1, ( - 'One and only one of --benchmark-yaml, --benchmarks-directory and ' - '--generate-benchmarks. --benchmark-yaml takes one benchmark YAML file, ' - '--benchmarks-directory takes: a directory of them and ' - '--generate-benchmarks generates them during analysis.') - - # Validate templates. - assert os.path.isdir(args.template_directory), ( - '--template-directory must be an existing directory.') - - # Validate cloud experiment configs. - assert ( - bool(args.cloud_experiment_name) == bool(args.cloud_experiment_bucket) - ), ('Cannot accept exactly one of --args.cloud-experiment-name and ' - '--args.cloud-experiment-bucket: Local experiment requires neither of ' - 'them, cloud experiment needs both.') - return args + # Validate cloud experiment configs. + assert bool(args.cloud_experiment_name) == bool(args.cloud_experiment_bucket), ( + "Cannot accept exactly one of --args.cloud-experiment-name and " + "--args.cloud-experiment-bucket: Local experiment requires neither of " + "them, cloud experiment needs both." + ) + return args def extend_report_with_coverage_gains() -> None: - """Process total gain from all generated harnesses for each projects and - update summary report. This makes it possible to view per-project stats - as experiments complete rather than only after all experiments run.""" - coverage_gain_dict = _process_total_coverage_gain() - existing_oss_fuzz_cov = introspector.query_introspector_language_stats() - - total_new_covgains = {} - for project_dict in coverage_gain_dict.values(): - lang_gains = total_new_covgains.get(project_dict.get('language', 'c'), 0) - lang_gains += project_dict.get('coverage_ofg_total_new_covered_lines', 0) - total_new_covgains[project_dict.get('language', 'c')] = lang_gains - - comparative_cov_gains = {} - for language, lang_cov_gain in total_new_covgains.items(): - try: - total_coverage_increase = round( - (lang_cov_gain / existing_oss_fuzz_cov[language]['total']) * 100.0, - 10) - except (KeyError, ZeroDivisionError): - total_coverage_increase = 0 - - try: - relative_coverage_increase = round( - (lang_cov_gain / existing_oss_fuzz_cov[language]['covered']) * 100.0, - 10) - except (KeyError, ZeroDivisionError): - relative_coverage_increase = 0 - comparative_cov_gains[language] = { - 'total_coverage_increase': total_coverage_increase, - 'relative_coverage_increase': relative_coverage_increase, - } - add_to_json_report(WORK_DIR, 'coverage_gains_per_language', - total_new_covgains) - add_to_json_report(WORK_DIR, 'project_summary', coverage_gain_dict) - add_to_json_report(WORK_DIR, 'oss_fuzz_language_status', - existing_oss_fuzz_cov) - add_to_json_report(WORK_DIR, 'comperative_coverage_gains', - comparative_cov_gains) + """Process total gain from all generated harnesses for each projects and + update summary report. This makes it possible to view per-project stats + as experiments complete rather than only after all experiments run.""" + coverage_gain_dict = _process_total_coverage_gain() + existing_oss_fuzz_cov = introspector.query_introspector_language_stats() + + total_new_covgains = {} + for project_dict in coverage_gain_dict.values(): + lang_gains = total_new_covgains.get(project_dict.get("language", "c"), 0) + lang_gains += project_dict.get("coverage_ofg_total_new_covered_lines", 0) + total_new_covgains[project_dict.get("language", "c")] = lang_gains + + comparative_cov_gains = {} + for language, lang_cov_gain in total_new_covgains.items(): + try: + total_coverage_increase = round( + (lang_cov_gain / existing_oss_fuzz_cov[language]["total"]) * 100.0, 10 + ) + except (KeyError, ZeroDivisionError): + total_coverage_increase = 0 + + try: + relative_coverage_increase = round( + (lang_cov_gain / existing_oss_fuzz_cov[language]["covered"]) * 100.0, 10 + ) + except (KeyError, ZeroDivisionError): + relative_coverage_increase = 0 + comparative_cov_gains[language] = { + "total_coverage_increase": total_coverage_increase, + "relative_coverage_increase": relative_coverage_increase, + } + add_to_json_report(WORK_DIR, "coverage_gains_per_language", total_new_covgains) + add_to_json_report(WORK_DIR, "project_summary", coverage_gain_dict) + add_to_json_report(WORK_DIR, "oss_fuzz_language_status", existing_oss_fuzz_cov) + add_to_json_report(WORK_DIR, "comperative_coverage_gains", comparative_cov_gains) def extend_report_with_coverage_gains_process(): - """A process that continuously runs to update coverage gains in the - background.""" - while True: - time.sleep(300) # 5 minutes. - try: - extend_report_with_coverage_gains() - except Exception: - logger.error('Failed to extend report with coverage gains') - traceback.print_exc() + """A process that continuously runs to update coverage gains in the + background.""" + while True: + time.sleep(300) # 5 minutes. + try: + extend_report_with_coverage_gains() + except Exception: + logger.error("Failed to extend report with coverage gains") + traceback.print_exc() def _print_experiment_result(result: Result): - """Prints the |result| of a single experiment.""" - logger.info('\n**** Finished benchmark %s, %s ****\n%s', - result.benchmark.project, result.benchmark.function_signature, - result.result) + """Prints the |result| of a single experiment.""" + logger.info( + "\n**** Finished benchmark %s, %s ****\n%s", + result.benchmark.project, + result.benchmark.function_signature, + result.result, + ) + +def _print_experiment_results( + results: list[Result], cov_gain: dict[str, dict[str, Any]] +): + """Prints the |results| of multiple experiments.""" + logger.info("\n\n**** FINAL RESULTS: ****\n\n") + for result in results: + logger.info( + "%s\n*%s, %s*\n%s\n", + "=" * 80, + result.benchmark.project, + result.benchmark.function_signature, + result.result, + ) + + logger.info("**** TOTAL COVERAGE GAIN: ****") + for project in cov_gain: + logger.info("*%s: %s", project, cov_gain[project]["coverage_diff"]) + + +def _setup_logging(verbose: str = "info", is_cloud: bool = False) -> None: + """Set up logging level.""" + + if is_cloud: + try: + client = cloud_logging.Client() + client.setup_logging() + except Exception as e: + # For local runs we continue + logger.warning("Error setting up cloud logging client: %s", e) + + if verbose == "debug": + log_level = logging.DEBUG + else: + log_level = logging.INFO + logging.basicConfig( + level=log_level, + format=LOG_FMT, + datefmt="%Y-%m-%d %H:%M:%S", + ) + # Set the base logger level + logging.getLogger("").setLevel(log_level) -def _print_experiment_results(results: list[Result], - cov_gain: dict[str, dict[str, Any]]): - """Prints the |results| of multiple experiments.""" - logger.info('\n\n**** FINAL RESULTS: ****\n\n') - for result in results: - logger.info('%s\n*%s, %s*\n%s\n', '=' * 80, result.benchmark.project, - result.benchmark.function_signature, result.result) - logger.info('**** TOTAL COVERAGE GAIN: ****') - for project in cov_gain: - logger.info('*%s: %s', project, cov_gain[project]["coverage_diff"]) +def add_to_json_report(outdir: str, key: str, value: Any) -> None: + """Adds a key/value pair to JSON report.""" + os.makedirs(outdir, exist_ok=True) + json_report_path = os.path.join(outdir, JSON_REPORT) + if os.path.isfile(json_report_path): + with open(json_report_path, "r") as f: + json_report = json.load(f) + else: + json_report = {} + json_report[key] = value -def _setup_logging(verbose: str = 'info', is_cloud: bool = False) -> None: - """Set up logging level.""" + # Overwrite the new json file + with open(json_report_path, "w") as f: + f.write(json.dumps(json_report)) - if is_cloud: - try: - client = cloud_logging.Client() - client.setup_logging() - except Exception as e: - # For local runs we continue - logger.warning('Error setting up cloud logging client: %s', e) - - if verbose == "debug": - log_level = logging.DEBUG - else: - log_level = logging.INFO - logging.basicConfig( - level=log_level, - format=LOG_FMT, - datefmt='%Y-%m-%d %H:%M:%S', - ) - # Set the base logger level - logging.getLogger('').setLevel(log_level) +def _process_total_coverage_gain() -> dict[str, dict[str, Any]]: + """Processes and calculates the total coverage gain for each project.""" + textcov_dict: dict[str, list[textcov.Textcov]] = {} + + # Load all the textcov dirs + for benchmark_dir in os.listdir(WORK_DIR): + if not os.path.isdir(os.path.join(WORK_DIR, benchmark_dir)): + continue + + result_benchmark_used_path = os.path.join( + os.path.join(WORK_DIR, benchmark_dir, "benchmark.yaml") + ) + if not os.path.isfile(result_benchmark_used_path): + continue + + project_name = "" + ignore_patterns = [] + + benchmark_used = benchmarklib.Benchmark.from_yaml(result_benchmark_used_path) + if not benchmark_used: + logger.info("Did not find benchmark for %s", benchmark_dir) + try: + project_name = "-".join(benchmark_dir.split("-")[1:-1]) + except: + continue + else: + logger.info("Found benchmark for %s", benchmark_dir) + project_name = benchmark_used[0].project + target_basename = os.path.basename(benchmark_used[0].target_path) + ignore_patterns = [re.compile(r"^" + re.escape(target_basename) + ":")] + + coverage_reports = os.path.join( + WORK_DIR, benchmark_dir, "code-coverage-reports" + ) + if not os.path.isdir(coverage_reports): + continue + + if project_name not in textcov_dict: + textcov_dict[project_name] = [] + for sample in os.listdir(coverage_reports): + summary = os.path.join(coverage_reports, sample, "textcov") + if not os.path.isdir(summary): + continue + + for textcov_file in os.listdir(summary): + if textcov_file.endswith(".covreport"): + with open(os.path.join(summary, textcov_file), "rb") as f: + if benchmark_used[0].language != "rust": + textcov_dict[project_name].append( + textcov.Textcov.from_file(f) + ) + else: + textcov_dict[project_name].append( + textcov.Textcov.from_rust_file( + f, ignore_function_patterns=ignore_patterns + ) + ) + elif textcov_file == "all_cov.json": + with open(os.path.join(summary, textcov_file)) as f: + textcov_dict[project_name].append( + textcov.Textcov.from_python_file(f) + ) + elif textcov_file == "jacoco.xml": + with open(os.path.join(summary, textcov_file)) as f: + textcov_dict[project_name].append( + textcov.Textcov.from_jvm_file(f) + ) + + if not textcov_dict: + return {} + + coverage_gain: dict[str, dict[str, Any]] = {} + for project, cov_list in textcov_dict.items(): + total_cov = textcov.Textcov() + for cov in cov_list: + total_cov.merge(cov) + existing_textcov = evaluator.load_existing_textcov(project) + coverage_summary = evaluator.load_existing_coverage_summary(project) + + try: + coverage_summary_files = coverage_summary["data"][0]["files"] + lines = [f["summary"]["lines"]["count"] for f in coverage_summary_files] + except (KeyError, TypeError): + lines = [] + + total_existing_lines = sum(lines) + total_cov_covered_lines_before_subtraction = total_cov.covered_lines + total_cov.subtract_covered_lines(existing_textcov) + try: + cov_relative_gain = total_cov.covered_lines / existing_textcov.covered_lines + except ZeroDivisionError: + cov_relative_gain = 0.0 + + total_lines = max(total_cov.total_lines, total_existing_lines) + + if total_lines: + coverage_gain[project] = { + "language": oss_fuzz_checkout.get_project_language(project), + "coverage_diff": total_cov.covered_lines / total_lines, + "coverage_relative_gain": cov_relative_gain, + "coverage_ofg_total_covered_lines": total_cov_covered_lines_before_subtraction, + "coverage_ofg_total_new_covered_lines": total_cov.covered_lines, + "coverage_existing_total_covered_lines": existing_textcov.covered_lines, + "coverage_existing_total_lines": total_existing_lines, + } + else: + # Fail safe when total_lines is 0 because of invalid coverage report + logger.warning( + "Line coverage information missing from the coverage report." + ) + coverage_gain[project] = {"coverage_diff": 0.0} + + return coverage_gain -def add_to_json_report(outdir: str, key: str, value: Any) -> None: - """Adds a key/value pair to JSON report.""" - os.makedirs(outdir, exist_ok=True) - json_report_path = os.path.join(outdir, JSON_REPORT) - if os.path.isfile(json_report_path): - with open(json_report_path, 'r') as f: - json_report = json.load(f) - else: - json_report = {} - json_report[key] = value +def main(): + global WORK_DIR - # Overwrite the new json file - with open(json_report_path, 'w') as f: - f.write(json.dumps(json_report)) + args = parse_args() + _setup_logging(args.log_level, is_cloud=args.cloud_experiment_name != "") + logger.info("Starting experiments on PR branch") + # Capture time at start + start = time.time() + add_to_json_report( + args.work_dir, "start_time", time.strftime(TIME_STAMP_FMT, time.gmtime(start)) + ) + # Add num_samples to report.json + add_to_json_report(args.work_dir, "num_samples", args.num_samples) -def _process_total_coverage_gain() -> dict[str, dict[str, Any]]: - """Processes and calculates the total coverage gain for each project.""" - textcov_dict: dict[str, list[textcov.Textcov]] = {} - - # Load all the textcov dirs - for benchmark_dir in os.listdir(WORK_DIR): - if not os.path.isdir(os.path.join(WORK_DIR, benchmark_dir)): - continue - - result_benchmark_used_path = os.path.join( - os.path.join(WORK_DIR, benchmark_dir, 'benchmark.yaml')) - if not os.path.isfile(result_benchmark_used_path): - continue - - project_name = '' - ignore_patterns = [] - - benchmark_used = benchmarklib.Benchmark.from_yaml( - result_benchmark_used_path) - if not benchmark_used: - logger.info('Did not find benchmark for %s', benchmark_dir) - try: - project_name = '-'.join(benchmark_dir.split('-')[1:-1]) - except: - continue - else: - logger.info('Found benchmark for %s', benchmark_dir) - project_name = benchmark_used[0].project - target_basename = os.path.basename(benchmark_used[0].target_path) - ignore_patterns = [re.compile(r'^' + re.escape(target_basename) + ':')] - - coverage_reports = os.path.join(WORK_DIR, benchmark_dir, - 'code-coverage-reports') - if not os.path.isdir(coverage_reports): - continue - - if project_name not in textcov_dict: - textcov_dict[project_name] = [] - for sample in os.listdir(coverage_reports): - summary = os.path.join(coverage_reports, sample, 'textcov') - if not os.path.isdir(summary): - continue - - for textcov_file in os.listdir(summary): - if textcov_file.endswith('.covreport'): - with open(os.path.join(summary, textcov_file), 'rb') as f: - if benchmark_used[0].language != 'rust': - textcov_dict[project_name].append(textcov.Textcov.from_file(f)) - else: - textcov_dict[project_name].append( - textcov.Textcov.from_rust_file( - f, ignore_function_patterns=ignore_patterns)) - elif textcov_file == 'all_cov.json': - with open(os.path.join(summary, textcov_file)) as f: - textcov_dict[project_name].append( - textcov.Textcov.from_python_file(f)) - elif textcov_file == 'jacoco.xml': - with open(os.path.join(summary, textcov_file)) as f: - textcov_dict[project_name].append(textcov.Textcov.from_jvm_file(f)) - - if not textcov_dict: - return {} - - coverage_gain: dict[str, dict[str, Any]] = {} - for project, cov_list in textcov_dict.items(): - total_cov = textcov.Textcov() - for cov in cov_list: - total_cov.merge(cov) - existing_textcov = evaluator.load_existing_textcov(project) - coverage_summary = evaluator.load_existing_coverage_summary(project) + # Set introspector endpoint before performing any operations to ensure the + # right API endpoint is used throughout. + introspector.set_introspector_endpoints(args.introspector_endpoint) - try: - coverage_summary_files = coverage_summary['data'][0]['files'] - lines = [f['summary']['lines']['count'] for f in coverage_summary_files] - except (KeyError, TypeError): - lines = [] - - total_existing_lines = sum(lines) - total_cov_covered_lines_before_subtraction = total_cov.covered_lines - total_cov.subtract_covered_lines(existing_textcov) - try: - cov_relative_gain = (total_cov.covered_lines / - existing_textcov.covered_lines) - except ZeroDivisionError: - cov_relative_gain = 0.0 - - total_lines = max(total_cov.total_lines, total_existing_lines) - - if total_lines: - coverage_gain[project] = { - 'language': - oss_fuzz_checkout.get_project_language(project), - 'coverage_diff': - total_cov.covered_lines / total_lines, - 'coverage_relative_gain': - cov_relative_gain, - 'coverage_ofg_total_covered_lines': - total_cov_covered_lines_before_subtraction, - 'coverage_ofg_total_new_covered_lines': - total_cov.covered_lines, - 'coverage_existing_total_covered_lines': - existing_textcov.covered_lines, - 'coverage_existing_total_lines': - total_existing_lines, - } + run_one_experiment.prepare(args.oss_fuzz_dir) + + experiment_targets = prepare_experiment_targets(args) + if oss_fuzz_checkout.ENABLE_CACHING: + oss_fuzz_checkout.prepare_cached_images(experiment_targets) + + logger.info( + "Running %s experiment(s) in parallels of %s.", + len(experiment_targets), + str(NUM_EXP), + ) + + # Set global variables that are updated throughout experiment runs. + WORK_DIR = args.work_dir + + # Start parallel coverage aggregate analysis + coverage_gains_process = Process(target=extend_report_with_coverage_gains_process) + coverage_gains_process.start() + + experiment_results = [] + if NUM_EXP == 1: + for target_benchmark in experiment_targets: + result = run_experiments(target_benchmark, args) + _print_experiment_result(result) + experiment_results.append(result) else: - # Fail safe when total_lines is 0 because of invalid coverage report - logger.warning( - 'Line coverage information missing from the coverage report.') - coverage_gain[project] = {'coverage_diff': 0.0} + experiment_tasks = [] + with Pool(NUM_EXP, maxtasksperchild=1) as p: + for target_benchmark in experiment_targets: + experiment_task = p.apply_async( + run_experiments, + (target_benchmark, args), + callback=_print_experiment_result, + ) + experiment_tasks.append(experiment_task) + time.sleep(args.delay) + + experiment_results = [task.get() for task in experiment_tasks] + + # Signal that no more work will be submitte to the pool. + p.close() + + # Wait for all workers to complete. + p.join() + + if coverage_gains_process: + # Do a final coverage aggregation. + coverage_gains_process.kill() + extend_report_with_coverage_gains() + + # Capture time at end + end = time.time() + add_to_json_report( + args.work_dir, + "completion_time", + time.strftime(TIME_STAMP_FMT, time.gmtime(end)), + ) + add_to_json_report( + args.work_dir, "total_run_time", str(timedelta(seconds=end - start)) + ) - return coverage_gain + coverage_gain_dict = _process_total_coverage_gain() + _print_experiment_results(experiment_results, coverage_gain_dict) -def main(): - global WORK_DIR - - args = parse_args() - _setup_logging(args.log_level, is_cloud=args.cloud_experiment_name != '') - logger.info('Starting experiments on PR branch') - - # Capture time at start - start = time.time() - add_to_json_report(args.work_dir, 'start_time', - time.strftime(TIME_STAMP_FMT, time.gmtime(start))) - # Add num_samples to report.json - add_to_json_report(args.work_dir, 'num_samples', args.num_samples) - - # Set introspector endpoint before performing any operations to ensure the - # right API endpoint is used throughout. - introspector.set_introspector_endpoints(args.introspector_endpoint) - - run_one_experiment.prepare(args.oss_fuzz_dir) - - experiment_targets = prepare_experiment_targets(args) - if oss_fuzz_checkout.ENABLE_CACHING: - oss_fuzz_checkout.prepare_cached_images(experiment_targets) - - logger.info('Running %s experiment(s) in parallels of %s.', - len(experiment_targets), str(NUM_EXP)) - - # Set global variables that are updated throughout experiment runs. - WORK_DIR = args.work_dir - - # Start parallel coverage aggregate analysis - coverage_gains_process = Process( - target=extend_report_with_coverage_gains_process) - coverage_gains_process.start() - - experiment_results = [] - if NUM_EXP == 1: - for target_benchmark in experiment_targets: - result = run_experiments(target_benchmark, args) - _print_experiment_result(result) - experiment_results.append(result) - else: - experiment_tasks = [] - with Pool(NUM_EXP, maxtasksperchild=1) as p: - for target_benchmark in experiment_targets: - experiment_task = p.apply_async(run_experiments, - (target_benchmark, args), - callback=_print_experiment_result) - experiment_tasks.append(experiment_task) - time.sleep(args.delay) - - experiment_results = [task.get() for task in experiment_tasks] - - # Signal that no more work will be submitte to the pool. - p.close() - - # Wait for all workers to complete. - p.join() - - if coverage_gains_process: - # Do a final coverage aggregation. - coverage_gains_process.kill() - extend_report_with_coverage_gains() - - # Capture time at end - end = time.time() - add_to_json_report(args.work_dir, 'completion_time', - time.strftime(TIME_STAMP_FMT, time.gmtime(end))) - add_to_json_report(args.work_dir, 'total_run_time', - str(timedelta(seconds=end - start))) - - coverage_gain_dict = _process_total_coverage_gain() - _print_experiment_results(experiment_results, coverage_gain_dict) - - -if __name__ == '__main__': - sys.exit(main()) +if __name__ == "__main__": + sys.exit(main()) diff --git a/run_one_experiment.py b/run_one_experiment.py index cb1e9a0464..e4e83505ef 100644 --- a/run_one_experiment.py +++ b/run_one_experiment.py @@ -43,7 +43,7 @@ # NUM_EVA controls the number of fuzz targets to evaluate in parallel by each # experiment, while {run_all_experiments.NUM_EXP, default 2} experiments will # run in parallel. -NUM_EVA = int(os.getenv('LLM_NUM_EVA', '3')) +NUM_EVA = int(os.getenv("LLM_NUM_EVA", "3")) # Default LLM hyper-parameters. # #182 shows Gemini returns NUM_SAMPLES independent responses via repeated @@ -58,130 +58,151 @@ RUN_TIMEOUT: int = 30 TEMPERATURE: float = 0.4 -RESULTS_DIR = './results' +RESULTS_DIR = "./results" # TODO(dongge): Move this to results.py @dataclasses.dataclass class AggregatedResult: - """Aggregated evaluation result.""" - build_success_count: int = 0 - build_success_rate: float = 0.0 - crash_rate: float = 0.0 - found_bug: int = 0 - max_coverage: float = 0.0 - max_line_coverage_diff: float = 0.0 - max_coverage_sample: str = '' - max_coverage_diff_sample: str = '' - max_coverage_diff_report: str = '' - full_textcov_diff: textcov.Textcov = dataclasses.field( - default_factory=textcov.Textcov) - - def __str__(self): - return ( - f'build success rate: {self.build_success_rate}, ' - f'crash rate: {self.crash_rate}, ' - f'found bug: {self.found_bug}, ' - f'max coverage: {self.max_coverage}, ' - f'max line coverage diff: {self.max_line_coverage_diff}\n' - f'max coverage sample: {self.max_coverage_sample}\n' - f'max coverage diff sample: {self.max_coverage_diff_sample}\n' - f'max coverage diff report: {self.max_coverage_diff_report or "None"}') - - @classmethod - def from_benchmark_result( - cls, benchmark_result: BenchmarkResult) -> 'AggregatedResult': - """Aggregates experiment history results of all samples.""" + """Aggregated evaluation result.""" + + build_success_count: int = 0 + build_success_rate: float = 0.0 + crash_rate: float = 0.0 + found_bug: int = 0 + max_coverage: float = 0.0 + max_line_coverage_diff: float = 0.0 + max_coverage_sample: str = "" + max_coverage_diff_sample: str = "" + max_coverage_diff_report: str = "" + full_textcov_diff: textcov.Textcov = dataclasses.field( + default_factory=textcov.Textcov + ) - return AggregatedResult( - build_success_count=benchmark_result.build_success_count, - build_success_rate=benchmark_result.build_success_rate, - crash_rate=benchmark_result.crash_rate, - max_coverage=benchmark_result.coverage, - max_line_coverage_diff=benchmark_result.line_coverage_diff, - max_coverage_diff_report=benchmark_result.line_coverage_report, - full_textcov_diff=benchmark_result.textcov_diff) - - -def generate_targets(benchmark: Benchmark, model: models.LLM, - prompt: prompts.Prompt, work_dirs: WorkDirs, - builder: prompt_builder.PromptBuilder) -> list[str]: - """Generates fuzz target with LLM.""" - logging.info('Generating targets for %s %s using %s..', benchmark.project, - benchmark.function_signature, model.name) - model.query_llm(prompt, response_dir=work_dirs.raw_targets) - - _, target_ext = os.path.splitext(benchmark.target_path) - generated_targets = [] - for file in os.listdir(work_dirs.raw_targets): - if not output_parser.is_raw_output(file): - continue - raw_output = os.path.join(work_dirs.raw_targets, file) - target_code = output_parser.parse_code(raw_output) - target_code = builder.post_process_generated_code(target_code) - target_id, _ = os.path.splitext(raw_output) - target_file = f'{target_id}{target_ext}' - target_path = os.path.join(work_dirs.raw_targets, target_file) - output_parser.save_output(target_code, target_path) - generated_targets.append(target_path) - - if generated_targets: - targets_relpath = map(os.path.relpath, generated_targets) - targets_relpath_str = '\n '.join(targets_relpath) - logging.info('Generated:\n %s', targets_relpath_str) - else: - logging.info('Failed to generate targets: %s', generated_targets) - return generated_targets + def __str__(self): + return ( + f"build success rate: {self.build_success_rate}, " + f"crash rate: {self.crash_rate}, " + f"found bug: {self.found_bug}, " + f"max coverage: {self.max_coverage}, " + f"max line coverage diff: {self.max_line_coverage_diff}\n" + f"max coverage sample: {self.max_coverage_sample}\n" + f"max coverage diff sample: {self.max_coverage_diff_sample}\n" + f'max coverage diff report: {self.max_coverage_diff_report or "None"}' + ) + + @classmethod + def from_benchmark_result( + cls, benchmark_result: BenchmarkResult + ) -> "AggregatedResult": + """Aggregates experiment history results of all samples.""" + + return AggregatedResult( + build_success_count=benchmark_result.build_success_count, + build_success_rate=benchmark_result.build_success_rate, + crash_rate=benchmark_result.crash_rate, + max_coverage=benchmark_result.coverage, + max_line_coverage_diff=benchmark_result.line_coverage_diff, + max_coverage_diff_report=benchmark_result.line_coverage_report, + full_textcov_diff=benchmark_result.textcov_diff, + ) + + +def generate_targets( + benchmark: Benchmark, + model: models.LLM, + prompt: prompts.Prompt, + work_dirs: WorkDirs, + builder: prompt_builder.PromptBuilder, +) -> list[str]: + """Generates fuzz target with LLM.""" + logging.info( + "Generating targets for %s %s using %s..", + benchmark.project, + benchmark.function_signature, + model.name, + ) + model.query_llm(prompt, response_dir=work_dirs.raw_targets) + + _, target_ext = os.path.splitext(benchmark.target_path) + generated_targets = [] + for file in os.listdir(work_dirs.raw_targets): + if not output_parser.is_raw_output(file): + continue + raw_output = os.path.join(work_dirs.raw_targets, file) + target_code = output_parser.parse_code(raw_output) + target_code = builder.post_process_generated_code(target_code) + target_id, _ = os.path.splitext(raw_output) + target_file = f"{target_id}{target_ext}" + target_path = os.path.join(work_dirs.raw_targets, target_file) + output_parser.save_output(target_code, target_path) + generated_targets.append(target_path) + + if generated_targets: + targets_relpath = map(os.path.relpath, generated_targets) + targets_relpath_str = "\n ".join(targets_relpath) + logging.info("Generated:\n %s", targets_relpath_str) + else: + logging.info("Failed to generate targets: %s", generated_targets) + return generated_targets def fix_code(work_dirs: WorkDirs, generated_targets: List[str]) -> List[str]: - """Copies the generated target to the fixed target directory for simple + """Copies the generated target to the fixed target directory for simple code fixes.""" - fixed_targets = [] - # Prepare all LLM-generated targets for code fixes. - for file in generated_targets: - fixed_target = os.path.join(work_dirs.fixed_targets, os.path.basename(file)) - shutil.copyfile(file, fixed_target) - fixed_targets.append(fixed_target) - - return fixed_targets - - -def aggregate_results(target_stats: list[tuple[int, exp_evaluator.Result]], - generated_targets: list[str]) -> AggregatedResult: - """Aggregates experiment status and results of a targets.""" - build_success_count = sum([int(stat.compiles) for _, stat in target_stats]) - build_success_rate = build_success_count / len(target_stats) - crash_rate = sum([int(stat.crashes) for _, stat in target_stats - ]) / len(target_stats) - found_bug = sum([ - int(stat.crashes and not stat.is_semantic_error) - for _, stat in target_stats - ]) - max_coverage = max([stat.coverage for _, stat in target_stats]) - max_line_coverage_diff = max( - [stat.line_coverage_diff for _, stat in target_stats]) - - max_coverage_sample = '' - max_coverage_diff_sample = '' - max_coverage_diff_report = '' - - all_textcov = textcov.Textcov() - for i, stat in target_stats: - if stat.coverage == max_coverage: - max_coverage_sample = generated_targets[i] - - if stat.line_coverage_diff == max_line_coverage_diff: - max_coverage_diff_sample = generated_targets[i] - max_coverage_diff_report = stat.coverage_report_path - - if isinstance(stat.textcov_diff, textcov.Textcov): - all_textcov.merge(stat.textcov_diff) - - return AggregatedResult(build_success_count, build_success_rate, crash_rate, - found_bug, max_coverage, max_line_coverage_diff, - max_coverage_sample, max_coverage_diff_sample, - max_coverage_diff_report, all_textcov) + fixed_targets = [] + # Prepare all LLM-generated targets for code fixes. + for file in generated_targets: + fixed_target = os.path.join(work_dirs.fixed_targets, os.path.basename(file)) + shutil.copyfile(file, fixed_target) + fixed_targets.append(fixed_target) + + return fixed_targets + + +def aggregate_results( + target_stats: list[tuple[int, exp_evaluator.Result]], generated_targets: list[str] +) -> AggregatedResult: + """Aggregates experiment status and results of a targets.""" + build_success_count = sum([int(stat.compiles) for _, stat in target_stats]) + build_success_rate = build_success_count / len(target_stats) + crash_rate = sum([int(stat.crashes) for _, stat in target_stats]) / len( + target_stats + ) + found_bug = sum( + [int(stat.crashes and not stat.is_semantic_error) for _, stat in target_stats] + ) + max_coverage = max([stat.coverage for _, stat in target_stats]) + max_line_coverage_diff = max([stat.line_coverage_diff for _, stat in target_stats]) + + max_coverage_sample = "" + max_coverage_diff_sample = "" + max_coverage_diff_report = "" + + all_textcov = textcov.Textcov() + for i, stat in target_stats: + if stat.coverage == max_coverage: + max_coverage_sample = generated_targets[i] + + if stat.line_coverage_diff == max_line_coverage_diff: + max_coverage_diff_sample = generated_targets[i] + max_coverage_diff_report = stat.coverage_report_path + + if isinstance(stat.textcov_diff, textcov.Textcov): + all_textcov.merge(stat.textcov_diff) + + return AggregatedResult( + build_success_count, + build_success_rate, + crash_rate, + found_bug, + max_coverage, + max_line_coverage_diff, + max_coverage_sample, + max_coverage_diff_sample, + max_coverage_diff_report, + all_textcov, + ) def check_targets( @@ -189,134 +210,143 @@ def check_targets( benchmark: Benchmark, work_dirs: WorkDirs, generated_targets: List[str], - cloud_experiment_name: str = '', - cloud_experiment_bucket: str = '', + cloud_experiment_name: str = "", + cloud_experiment_bucket: str = "", run_timeout: int = RUN_TIMEOUT, fixer_model_name: str = models.DefaultModel.name, ) -> Optional[AggregatedResult]: - """Builds all targets in the fixed target directory.""" - target_stats = [] - - if cloud_experiment_name: - builder_runner = builder_runner_lib.CloudBuilderRunner( - benchmark, - work_dirs, - run_timeout, - fixer_model_name, - experiment_name=cloud_experiment_name, - experiment_bucket=cloud_experiment_bucket, + """Builds all targets in the fixed target directory.""" + target_stats = [] + + if cloud_experiment_name: + builder_runner = builder_runner_lib.CloudBuilderRunner( + benchmark, + work_dirs, + run_timeout, + fixer_model_name, + experiment_name=cloud_experiment_name, + experiment_bucket=cloud_experiment_bucket, + ) + else: + builder_runner = builder_runner_lib.BuilderRunner( + benchmark, work_dirs, run_timeout, fixer_model_name + ) + + evaluator = exp_evaluator.Evaluator(builder_runner, benchmark, work_dirs) + + ai_target_pairs = [(ai_binary, target) for target in generated_targets] + with pool.ThreadPool(NUM_EVA) as p: + for i, target_stat in enumerate( + p.starmap(evaluator.check_target, ai_target_pairs) + ): + if target_stat is None: + logging.error( + "This should never happen: Error evaluating target: %s", + generated_targets[i], + ) + target_stat = exp_evaluator.Result() + + target_stats.append((i, target_stat)) + + if len(target_stats) > 0: + return aggregate_results(target_stats, generated_targets) + + logging.info("No targets to check.") + return None + + +def prepare(oss_fuzz_dir: str) -> None: + """Prepares the experiment environment.""" + oss_fuzz_checkout.clone_oss_fuzz(oss_fuzz_dir) + oss_fuzz_checkout.postprocess_oss_fuzz() + + +def _fuzzing_pipeline( + benchmark: Benchmark, + model: models.LLM, + args: argparse.Namespace, + work_dirs: WorkDirs, + trial: int, +) -> TrialResult: + """Runs the predefined 3-stage pipeline for one trial.""" + trial_logger = logger.get_trial_logger(trial=trial, level=logging.DEBUG) + trial_logger.info("Trial Starts") + if args.agent: + p = pipeline.Pipeline( + args=args, + trial=trial, + writing_stage_agents=[ + Prototyper(trial=trial, llm=model, args=args), + Enhancer(trial=trial, llm=model, args=args), + ], + analysis_stage_agents=[ + SemanticAnalyzer(trial=trial, llm=model, args=args), + CoverageAnalyzer(trial=trial, llm=model, args=args), + CrashAnalyzer(trial=trial, llm=model, args=args), + ], + ) + else: + p = pipeline.Pipeline( + args=args, + trial=trial, + writing_stage_agents=[ + OnePromptPrototyper(trial=trial, llm=model, args=args), + OnePromptEnhancer(trial=trial, llm=model, args=args), + ], + analysis_stage_agents=[ + SemanticAnalyzer(trial=trial, llm=model, args=args), + ], + ) + + results = p.execute( + result_history=[Result(benchmark=benchmark, trial=trial, work_dirs=work_dirs)] ) - else: - builder_runner = builder_runner_lib.BuilderRunner(benchmark, work_dirs, - run_timeout, - fixer_model_name) - evaluator = exp_evaluator.Evaluator(builder_runner, benchmark, work_dirs) + trial_result = TrialResult( + benchmark=benchmark, trial=trial, work_dirs=work_dirs, result_history=results + ) + trial_logger.write_result( + result_status_dir=trial_result.best_result.work_dirs.status, + result=trial_result, + finished=True, + ) + return trial_result - ai_target_pairs = [(ai_binary, target) for target in generated_targets] - with pool.ThreadPool(NUM_EVA) as p: - for i, target_stat in enumerate( - p.starmap(evaluator.check_target, ai_target_pairs)): - if target_stat is None: - logging.error('This should never happen: Error evaluating target: %s', - generated_targets[i]) - target_stat = exp_evaluator.Result() - target_stats.append((i, target_stat)) +def _fuzzing_pipelines( + benchmark: Benchmark, + model: models.LLM, + args: argparse.Namespace, + work_dirs: WorkDirs, +) -> BenchmarkResult: + """Runs all trial experiments in their pipelines.""" + # Create a pool of worker processes + with pool.ThreadPool(processes=NUM_EVA) as p: + # Initialize thread-local storage in each worker before processing + task_args = [ + (benchmark, model, args, work_dirs, trial) + for trial in range(1, args.num_samples + 1) + ] + trial_results = p.starmap(_fuzzing_pipeline, task_args) + return BenchmarkResult( + benchmark=benchmark, work_dirs=work_dirs, trial_results=trial_results + ) - if len(target_stats) > 0: - return aggregate_results(target_stats, generated_targets) - logging.info('No targets to check.') - return None +def run( + benchmark: Benchmark, + model: models.LLM, + args: argparse.Namespace, + work_dirs: WorkDirs, +) -> Optional[AggregatedResult]: + """Generates code via LLM, and evaluates them.""" + model.cloud_setup() + # Save the benchmark in the WorkDir base. This is saved to the working + # directory, and should not be deleted in future executions. As such, + # from here on, do not erase all WorkDir contents. + Benchmark.to_yaml([benchmark], outdir=work_dirs.base, out_basename="benchmark.yaml") -def prepare(oss_fuzz_dir: str) -> None: - """Prepares the experiment environment.""" - oss_fuzz_checkout.clone_oss_fuzz(oss_fuzz_dir) - oss_fuzz_checkout.postprocess_oss_fuzz() - - -def _fuzzing_pipeline(benchmark: Benchmark, model: models.LLM, - args: argparse.Namespace, work_dirs: WorkDirs, - trial: int) -> TrialResult: - """Runs the predefined 3-stage pipeline for one trial.""" - trial_logger = logger.get_trial_logger(trial=trial, level=logging.DEBUG) - trial_logger.info('Trial Starts') - if args.agent: - p = pipeline.Pipeline(args=args, - trial=trial, - writing_stage_agents=[ - Prototyper(trial=trial, llm=model, args=args), - Enhancer(trial=trial, llm=model, args=args), - ], - analysis_stage_agents=[ - SemanticAnalyzer(trial=trial, - llm=model, - args=args), - CoverageAnalyzer(trial=trial, - llm=model, - args=args), - CrashAnalyzer(trial=trial, llm=model, args=args), - ]) - else: - p = pipeline.Pipeline(args=args, - trial=trial, - writing_stage_agents=[ - OnePromptPrototyper(trial=trial, - llm=model, - args=args), - OnePromptEnhancer(trial=trial, - llm=model, - args=args), - ], - analysis_stage_agents=[ - SemanticAnalyzer(trial=trial, - llm=model, - args=args), - ]) - - results = p.execute(result_history=[ - Result(benchmark=benchmark, trial=trial, work_dirs=work_dirs) - ]) - - trial_result = TrialResult(benchmark=benchmark, - trial=trial, - work_dirs=work_dirs, - result_history=results) - trial_logger.write_result( - result_status_dir=trial_result.best_result.work_dirs.status, - result=trial_result, - finished=True) - return trial_result - - -def _fuzzing_pipelines(benchmark: Benchmark, model: models.LLM, - args: argparse.Namespace, - work_dirs: WorkDirs) -> BenchmarkResult: - """Runs all trial experiments in their pipelines.""" - # Create a pool of worker processes - with pool.ThreadPool(processes=NUM_EVA) as p: - # Initialize thread-local storage in each worker before processing - task_args = [(benchmark, model, args, work_dirs, trial) - for trial in range(1, args.num_samples + 1)] - trial_results = p.starmap(_fuzzing_pipeline, task_args) - return BenchmarkResult(benchmark=benchmark, - work_dirs=work_dirs, - trial_results=trial_results) - - -def run(benchmark: Benchmark, model: models.LLM, args: argparse.Namespace, - work_dirs: WorkDirs) -> Optional[AggregatedResult]: - """Generates code via LLM, and evaluates them.""" - model.cloud_setup() - - # Save the benchmark in the WorkDir base. This is saved to the working - # directory, and should not be deleted in future executions. As such, - # from here on, do not erase all WorkDir contents. - Benchmark.to_yaml([benchmark], - outdir=work_dirs.base, - out_basename='benchmark.yaml') - - return AggregatedResult.from_benchmark_result( - _fuzzing_pipelines(benchmark, model, args, work_dirs)) + return AggregatedResult.from_benchmark_result( + _fuzzing_pipelines(benchmark, model, args, work_dirs) + ) diff --git a/scripts/run-new-oss-fuzz-project/helper.py b/scripts/run-new-oss-fuzz-project/helper.py index 1b0b37ea34..8c1469726e 100644 --- a/scripts/run-new-oss-fuzz-project/helper.py +++ b/scripts/run-new-oss-fuzz-project/helper.py @@ -20,19 +20,19 @@ source_dir = sys.argv[1] -TARGET_OSS_FUZZ = 'work/oss-fuzz' +TARGET_OSS_FUZZ = "work/oss-fuzz" -projects_to_exec = '' +projects_to_exec = "" for empty_oss_fuzz in os.listdir(source_dir): - dst = os.path.join(TARGET_OSS_FUZZ, 'projects', empty_oss_fuzz) - if os.path.isdir(dst): - shutil.rmtree(dst) + dst = os.path.join(TARGET_OSS_FUZZ, "projects", empty_oss_fuzz) + if os.path.isdir(dst): + shutil.rmtree(dst) - shutil.copytree(os.path.join(source_dir, empty_oss_fuzz), dst) - projects_to_exec += empty_oss_fuzz + ' ' + shutil.copytree(os.path.join(source_dir, empty_oss_fuzz), dst) + projects_to_exec += empty_oss_fuzz + " " # Launch the runner -cmd = f'scripts/run-new-oss-fuzz-project/run-project.sh {projects_to_exec}' +cmd = f"scripts/run-new-oss-fuzz-project/run-project.sh {projects_to_exec}" # Call with shell to ensure we have the right environment. subprocess.check_call(cmd, shell=True) diff --git a/stage/analysis_stage.py b/stage/analysis_stage.py index a57eafec17..5e8861d4b5 100644 --- a/stage/analysis_stage.py +++ b/stage/analysis_stage.py @@ -20,32 +20,33 @@ class AnalysisStage(BaseStage): - """Analyzes the runtime performance of fuzz targets and suggests improvements. - This stage examines whether crashes are due to bugs in the fuzz target or - if there are significant code blocks left uncovered. Based on this analysis, - it provides recommendations for refining the fuzz target in subsequent stages. - Additionally, it prepares to terminate the experiment if the fuzz target - crashes due to a bug in the project under test or if all major code paths have - been sufficiently covered.""" + """Analyzes the runtime performance of fuzz targets and suggests improvements. + This stage examines whether crashes are due to bugs in the fuzz target or + if there are significant code blocks left uncovered. Based on this analysis, + it provides recommendations for refining the fuzz target in subsequent stages. + Additionally, it prepares to terminate the experiment if the fuzz target + crashes due to a bug in the project under test or if all major code paths have + been sufficiently covered.""" - def execute(self, result_history: list[Result]) -> Result: - """Selects agent based on run result and executes it.""" - self.logger.info('Analysis Stage') - last_result = result_history[-1] - assert isinstance(last_result, RunResult) + def execute(self, result_history: list[Result]) -> Result: + """Selects agent based on run result and executes it.""" + self.logger.info("Analysis Stage") + last_result = result_history[-1] + assert isinstance(last_result, RunResult) - if last_result.crashes: - agent = self.get_agent(agent_name='CrashAnalyzer') - else: - try: - agent = self.get_agent(agent_name='CoverageAnalyzer') - except RuntimeError: - agent = self.get_agent(agent_name='SemanticAnalyzer') - analysis_result = self._execute_agent(agent, result_history) + if last_result.crashes: + agent = self.get_agent(agent_name="CrashAnalyzer") + else: + try: + agent = self.get_agent(agent_name="CoverageAnalyzer") + except RuntimeError: + agent = self.get_agent(agent_name="SemanticAnalyzer") + analysis_result = self._execute_agent(agent, result_history) - # TODO(dongge): Save logs and more info into workdir. - self.logger.write_chat_history(analysis_result) - self.logger.debug('Analysis stage completed with with result:\n%s', - analysis_result) + # TODO(dongge): Save logs and more info into workdir. + self.logger.write_chat_history(analysis_result) + self.logger.debug( + "Analysis stage completed with with result:\n%s", analysis_result + ) - return analysis_result + return analysis_result diff --git a/stage/base_stage.py b/stage/base_stage.py index 8e82239e4d..50b601d6ef 100644 --- a/stage/base_stage.py +++ b/stage/base_stage.py @@ -23,51 +23,53 @@ class BaseStage(ABC): - """The abstract base class for stages in fuzzing pipeline.""" + """The abstract base class for stages in fuzzing pipeline.""" - def __init__(self, - args: argparse.Namespace, - trail: int, - agents: Optional[list[BaseAgent]] = None, - name: str = '') -> None: - self.args = args - self.trial = trail - self.agents: list[BaseAgent] = agents or [] - self.logger = logger.get_trial_logger(trial=trail) - self.name: str = name or self.__class__.__name__ + def __init__( + self, + args: argparse.Namespace, + trail: int, + agents: Optional[list[BaseAgent]] = None, + name: str = "", + ) -> None: + self.args = args + self.trial = trail + self.agents: list[BaseAgent] = agents or [] + self.logger = logger.get_trial_logger(trial=trail) + self.name: str = name or self.__class__.__name__ - def __repr__(self) -> str: - return self.__class__.__name__ + def __repr__(self) -> str: + return self.__class__.__name__ - def add_agent(self, agent: BaseAgent) -> 'BaseStage': - """Adds an agent for the stage.""" - agent.args = agent.args or self.args - self.agents.append(agent) - return self + def add_agent(self, agent: BaseAgent) -> "BaseStage": + """Adds an agent for the stage.""" + agent.args = agent.args or self.args + self.agents.append(agent) + return self - def get_agent(self, index: int = 0, agent_name: str = '') -> BaseAgent: - """Finds the agent by its name.""" - if not agent_name: - return self.agents[index] - for agent in self.agents: - if agent.name == agent_name: - return agent - raise RuntimeError(f'Agent {agent_name} is undefined') + def get_agent(self, index: int = 0, agent_name: str = "") -> BaseAgent: + """Finds the agent by its name.""" + if not agent_name: + return self.agents[index] + for agent in self.agents: + if agent.name == agent_name: + return agent + raise RuntimeError(f"Agent {agent_name} is undefined") - def _execute_agent_cloud(self, agent: BaseAgent, - result_history: list[Result]) -> Result: - """Executes agent in cloud build.""" - cloud_builder = CloudBuilder(self.args) - dill_dir = result_history[-1].work_dirs.dills - result = cloud_builder.run(agent, result_history, dill_dir) - return result + def _execute_agent_cloud( + self, agent: BaseAgent, result_history: list[Result] + ) -> Result: + """Executes agent in cloud build.""" + cloud_builder = CloudBuilder(self.args) + dill_dir = result_history[-1].work_dirs.dills + result = cloud_builder.run(agent, result_history, dill_dir) + return result - def _execute_agent(self, agent: BaseAgent, - result_history: list[Result]) -> Result: - if self.args.cloud_experiment_name: - return self._execute_agent_cloud(agent, result_history) - return agent.execute(result_history) + def _execute_agent(self, agent: BaseAgent, result_history: list[Result]) -> Result: + if self.args.cloud_experiment_name: + return self._execute_agent_cloud(agent, result_history) + return agent.execute(result_history) - @abstractmethod - def execute(self, result_history: list[Result]) -> Result: - """Executes the stage-specific actions using agents.""" + @abstractmethod + def execute(self, result_history: list[Result]) -> Result: + """Executes the stage-specific actions using agents.""" diff --git a/stage/execution_stage.py b/stage/execution_stage.py index a7666728ae..fa0297c66f 100644 --- a/stage/execution_stage.py +++ b/stage/execution_stage.py @@ -25,162 +25,181 @@ class ExecutionStage(BaseStage): - """Executes fuzz targets and build scripts. This stage takes a fuzz target - and its build script, runs them locally or on the cloud with OSS-Fuzz infra, - and outputs code coverage report and run-time crash information for later - stages to analyze and improve on. It uses OSS-Fuzz infra to perform these - tasks.""" - - def execute(self, result_history: list[Result]) -> Result: - """Executes the fuzz target and build script in the latest result.""" - last_result = result_history[-1] - benchmark = last_result.benchmark - if self.args.cloud_experiment_name: - builder_runner = builder_runner_lib.CloudBuilderRunner( - benchmark=benchmark, - work_dirs=last_result.work_dirs, - run_timeout=self.args.run_timeout, - experiment_name=self.args.cloud_experiment_name, - experiment_bucket=self.args.cloud_experiment_bucket, - ) - else: - builder_runner = builder_runner_lib.BuilderRunner( - benchmark=benchmark, - work_dirs=last_result.work_dirs, - run_timeout=self.args.run_timeout, - ) - - evaluator = Evaluator(builder_runner, benchmark, last_result.work_dirs) - generated_target_name = os.path.basename(benchmark.target_path) - generated_oss_fuzz_project = f'{benchmark.id}-{last_result.trial}' - generated_oss_fuzz_project = oss_fuzz_checkout.rectify_docker_tag( - generated_oss_fuzz_project) - - fuzz_target_path = os.path.join(last_result.work_dirs.fuzz_targets, - f'{self.trial:02d}.fuzz_target') - build_script_path = os.path.join(last_result.work_dirs.fuzz_targets, - f'{self.trial:02d}.build_script') - evaluator.create_ossfuzz_project(benchmark, generated_oss_fuzz_project, - fuzz_target_path, build_script_path) - - status_path = os.path.join(last_result.work_dirs.status, - f'{self.trial:02d}') - os.makedirs(status_path, exist_ok=True) - - # Try building and running the new target. - - # TODO: Log build failure. - # TODO: Log run success/failure. - - # 1. Evaluating generated driver. - if not isinstance(last_result, BuildResult): - self.logger.error('RunResult must follow a BuildResult') - raise TypeError - - try: - build_result, run_result = evaluator.builder_runner.build_and_run( - generated_oss_fuzz_project, - fuzz_target_path, - 0, - benchmark.language, - cloud_build_tags=[ - str(self.trial), - 'Execution', - 'ofg', - # TODO(dongge): Tag function name, compatible with tag format. - last_result.benchmark.project, - ], - trial=self.trial) - if not run_result: - raise Exception('No RunResult received from build_and_run') - if run_result.coverage_summary is None or run_result.coverage is None: - self.logger.warning('No cov info in run result of %s', - generated_oss_fuzz_project) - raise Exception(f'No Coverage or Coverage Summary in {run_result}') - - if run_result.coverage_summary: - total_lines = evaluator_lib.compute_total_lines_without_fuzz_targets( - run_result.coverage_summary, generated_target_name) - else: - total_lines = 0 - - if run_result.total_pcs: - coverage_percent = run_result.cov_pcs / run_result.total_pcs - self.logger.info('coverage percent == %s in %s.', coverage_percent, - generated_oss_fuzz_project) - else: - self.logger.warning('total_pcs == 0 in %s.', generated_oss_fuzz_project) - coverage_percent = 0.0 - - existing_textcov = evaluator.load_existing_textcov() - run_result.coverage.subtract_covered_lines(existing_textcov) - - if total_lines: - coverage_diff = run_result.coverage.covered_lines / total_lines - self.logger.info('coverage diff == %s in %s.', coverage_diff, - generated_oss_fuzz_project) - else: - self.logger.warning('total_lines == 0 in %s', - generated_oss_fuzz_project) - coverage_diff = 0.0 - - if run_result.log_path and os.path.isfile(run_result.log_path): - with open(run_result.log_path, 'r') as f: - run_log_lines = f.readlines() - if len(run_log_lines) > 30: - run_log_lines = (run_log_lines[:20] + [ - f'...({len(run_log_lines) - 30} lines of fuzzing log truncated)' - '...' - ] + run_log_lines[-10:]) - run_log_content = ''.join(run_log_lines) - else: - run_log_content = '' - - runresult = RunResult( - benchmark=benchmark, - trial=self.trial, - work_dirs=last_result.work_dirs, - fuzz_target_source=last_result.fuzz_target_source, - build_script_source=last_result.build_script_source, - author=self, - compiles=last_result.compiles, - compile_error=last_result.compile_error, - compile_log=last_result.compile_log, - binary_exists=last_result.binary_exists, - is_function_referenced=last_result.is_function_referenced, - crashes=run_result.crashes, - run_error=run_result.crash_info, - crash_func=run_result.semantic_check.crash_func, - # TODO: This should be the content of log_path. - run_log=run_log_content, - coverage_summary=run_result.coverage_summary, - coverage=coverage_percent, - line_coverage_diff=coverage_diff, - reproducer_path=run_result.reproducer_path, - artifact_path=run_result.artifact_path, - sanitizer=run_result.sanitizer, - textcov_diff=run_result.coverage, - log_path=run_result.log_path, - corpus_path=run_result.corpus_path, - coverage_report_path=run_result.coverage_report_path, - cov_pcs=run_result.cov_pcs, - total_pcs=run_result.total_pcs, - chat_history={self.name: run_log_content}) - except Exception as e: - self.logger.error('Exception %s occurred on %s', e, last_result) - runresult = RunResult( - benchmark=benchmark, - trial=self.trial, - work_dirs=last_result.work_dirs, - fuzz_target_source=last_result.fuzz_target_source, - build_script_source=last_result.build_script_source, - chat_history={self.name: 'Exuection Failed'}, - author=self, - compiles=last_result.compiles, - compile_error=last_result.compile_error, - compile_log=last_result.compile_log, - binary_exists=last_result.binary_exists, - is_function_referenced=last_result.is_function_referenced) - - self.logger.write_chat_history(runresult) - return runresult + """Executes fuzz targets and build scripts. This stage takes a fuzz target + and its build script, runs them locally or on the cloud with OSS-Fuzz infra, + and outputs code coverage report and run-time crash information for later + stages to analyze and improve on. It uses OSS-Fuzz infra to perform these + tasks.""" + + def execute(self, result_history: list[Result]) -> Result: + """Executes the fuzz target and build script in the latest result.""" + last_result = result_history[-1] + benchmark = last_result.benchmark + if self.args.cloud_experiment_name: + builder_runner = builder_runner_lib.CloudBuilderRunner( + benchmark=benchmark, + work_dirs=last_result.work_dirs, + run_timeout=self.args.run_timeout, + experiment_name=self.args.cloud_experiment_name, + experiment_bucket=self.args.cloud_experiment_bucket, + ) + else: + builder_runner = builder_runner_lib.BuilderRunner( + benchmark=benchmark, + work_dirs=last_result.work_dirs, + run_timeout=self.args.run_timeout, + ) + + evaluator = Evaluator(builder_runner, benchmark, last_result.work_dirs) + generated_target_name = os.path.basename(benchmark.target_path) + generated_oss_fuzz_project = f"{benchmark.id}-{last_result.trial}" + generated_oss_fuzz_project = oss_fuzz_checkout.rectify_docker_tag( + generated_oss_fuzz_project + ) + + fuzz_target_path = os.path.join( + last_result.work_dirs.fuzz_targets, f"{self.trial:02d}.fuzz_target" + ) + build_script_path = os.path.join( + last_result.work_dirs.fuzz_targets, f"{self.trial:02d}.build_script" + ) + evaluator.create_ossfuzz_project( + benchmark, generated_oss_fuzz_project, fuzz_target_path, build_script_path + ) + + status_path = os.path.join(last_result.work_dirs.status, f"{self.trial:02d}") + os.makedirs(status_path, exist_ok=True) + + # Try building and running the new target. + + # TODO: Log build failure. + # TODO: Log run success/failure. + + # 1. Evaluating generated driver. + if not isinstance(last_result, BuildResult): + self.logger.error("RunResult must follow a BuildResult") + raise TypeError + + try: + build_result, run_result = evaluator.builder_runner.build_and_run( + generated_oss_fuzz_project, + fuzz_target_path, + 0, + benchmark.language, + cloud_build_tags=[ + str(self.trial), + "Execution", + "ofg", + # TODO(dongge): Tag function name, compatible with tag format. + last_result.benchmark.project, + ], + trial=self.trial, + ) + if not run_result: + raise Exception("No RunResult received from build_and_run") + if run_result.coverage_summary is None or run_result.coverage is None: + self.logger.warning( + "No cov info in run result of %s", generated_oss_fuzz_project + ) + raise Exception(f"No Coverage or Coverage Summary in {run_result}") + + if run_result.coverage_summary: + total_lines = evaluator_lib.compute_total_lines_without_fuzz_targets( + run_result.coverage_summary, generated_target_name + ) + else: + total_lines = 0 + + if run_result.total_pcs: + coverage_percent = run_result.cov_pcs / run_result.total_pcs + self.logger.info( + "coverage percent == %s in %s.", + coverage_percent, + generated_oss_fuzz_project, + ) + else: + self.logger.warning("total_pcs == 0 in %s.", generated_oss_fuzz_project) + coverage_percent = 0.0 + + existing_textcov = evaluator.load_existing_textcov() + run_result.coverage.subtract_covered_lines(existing_textcov) + + if total_lines: + coverage_diff = run_result.coverage.covered_lines / total_lines + self.logger.info( + "coverage diff == %s in %s.", + coverage_diff, + generated_oss_fuzz_project, + ) + else: + self.logger.warning( + "total_lines == 0 in %s", generated_oss_fuzz_project + ) + coverage_diff = 0.0 + + if run_result.log_path and os.path.isfile(run_result.log_path): + with open(run_result.log_path, "r") as f: + run_log_lines = f.readlines() + if len(run_log_lines) > 30: + run_log_lines = ( + run_log_lines[:20] + + [ + f"...({len(run_log_lines) - 30} lines of fuzzing log truncated)" + "..." + ] + + run_log_lines[-10:] + ) + run_log_content = "".join(run_log_lines) + else: + run_log_content = "" + + runresult = RunResult( + benchmark=benchmark, + trial=self.trial, + work_dirs=last_result.work_dirs, + fuzz_target_source=last_result.fuzz_target_source, + build_script_source=last_result.build_script_source, + author=self, + compiles=last_result.compiles, + compile_error=last_result.compile_error, + compile_log=last_result.compile_log, + binary_exists=last_result.binary_exists, + is_function_referenced=last_result.is_function_referenced, + crashes=run_result.crashes, + run_error=run_result.crash_info, + crash_func=run_result.semantic_check.crash_func, + # TODO: This should be the content of log_path. + run_log=run_log_content, + coverage_summary=run_result.coverage_summary, + coverage=coverage_percent, + line_coverage_diff=coverage_diff, + reproducer_path=run_result.reproducer_path, + artifact_path=run_result.artifact_path, + sanitizer=run_result.sanitizer, + textcov_diff=run_result.coverage, + log_path=run_result.log_path, + corpus_path=run_result.corpus_path, + coverage_report_path=run_result.coverage_report_path, + cov_pcs=run_result.cov_pcs, + total_pcs=run_result.total_pcs, + chat_history={self.name: run_log_content}, + ) + except Exception as e: + self.logger.error("Exception %s occurred on %s", e, last_result) + runresult = RunResult( + benchmark=benchmark, + trial=self.trial, + work_dirs=last_result.work_dirs, + fuzz_target_source=last_result.fuzz_target_source, + build_script_source=last_result.build_script_source, + chat_history={self.name: "Exuection Failed"}, + author=self, + compiles=last_result.compiles, + compile_error=last_result.compile_error, + compile_log=last_result.compile_log, + binary_exists=last_result.binary_exists, + is_function_referenced=last_result.is_function_referenced, + ) + + self.logger.write_chat_history(runresult) + return runresult diff --git a/stage/writing_stage.py b/stage/writing_stage.py index 69799af8e5..90c23367ce 100644 --- a/stage/writing_stage.py +++ b/stage/writing_stage.py @@ -23,39 +23,39 @@ class WritingStage(BaseStage): - """Handles the creation and refinement of fuzz targets and build scripts. - Initially, this stage outputs a new fuzz target and its build script for a - function under test. In later cycles, it uses run-time results and insights - from previous iterations to produce a revised fuzz target. - It leverages LLM agents to perform these tasks.""" - - def _write_new_fuzz_target(self, result_history: list[Result]) -> Result: - """Writes a new fuzz target.""" - agent = self.get_agent() - - if self.args.cloud_experiment_name: - return self._execute_agent_cloud(agent, result_history) - return agent.execute(result_history) - - def _refine_given_fuzz_targets(self, result_history: list[Result]) -> Result: - """Writes a new fuzz target.""" - agent = self.get_agent(index=1) - if self.args.cloud_experiment_name: - return self._execute_agent_cloud(agent, result_history) - return agent.execute(result_history) - - def execute(self, result_history: list[Result]) -> Result: - """Executes the writing stage.""" - if result_history and result_history[-1].fuzz_target_source: - agent = self.get_agent(index=1) - else: - agent = self.get_agent() - agent_result = self._execute_agent(agent, result_history) - build_result = cast(BuildResult, agent_result) - - # TODO(dongge): Save logs and more info into workdir. - self.logger.write_fuzz_target(build_result) - self.logger.write_build_script(build_result) - self.logger.write_chat_history(build_result) - self.logger.debug('Writing stage completed with result:\n%s', build_result) - return build_result + """Handles the creation and refinement of fuzz targets and build scripts. + Initially, this stage outputs a new fuzz target and its build script for a + function under test. In later cycles, it uses run-time results and insights + from previous iterations to produce a revised fuzz target. + It leverages LLM agents to perform these tasks.""" + + def _write_new_fuzz_target(self, result_history: list[Result]) -> Result: + """Writes a new fuzz target.""" + agent = self.get_agent() + + if self.args.cloud_experiment_name: + return self._execute_agent_cloud(agent, result_history) + return agent.execute(result_history) + + def _refine_given_fuzz_targets(self, result_history: list[Result]) -> Result: + """Writes a new fuzz target.""" + agent = self.get_agent(index=1) + if self.args.cloud_experiment_name: + return self._execute_agent_cloud(agent, result_history) + return agent.execute(result_history) + + def execute(self, result_history: list[Result]) -> Result: + """Executes the writing stage.""" + if result_history and result_history[-1].fuzz_target_source: + agent = self.get_agent(index=1) + else: + agent = self.get_agent() + agent_result = self._execute_agent(agent, result_history) + build_result = cast(BuildResult, agent_result) + + # TODO(dongge): Save logs and more info into workdir. + self.logger.write_fuzz_target(build_result) + self.logger.write_build_script(build_result) + self.logger.write_chat_history(build_result) + self.logger.debug("Writing stage completed with result:\n%s", build_result) + return build_result diff --git a/tool/base_tool.py b/tool/base_tool.py index b3a8a977fa..89c02422b0 100644 --- a/tool/base_tool.py +++ b/tool/base_tool.py @@ -19,29 +19,28 @@ from experiment.benchmark import Benchmark -TOOL_TUTORIAL_DIR = os.path.join(os.path.dirname(__file__), '../prompts', - 'tool') +TOOL_TUTORIAL_DIR = os.path.join(os.path.dirname(__file__), "../prompts", "tool") class BaseTool(ABC): - """Abstract base class for tools used by LLM agents to interact with various - environments or perform actions. Provides a common interface for creating - tool-specific guides and executing commands.""" - - def __init__(self, benchmark: Benchmark, name: str = '') -> None: - self.benchmark = benchmark - # The name of the tool. - self.name: str = name or self.__class__.__name__ - - def _get_tutorial_file_content(self, filename: str) -> str: - tutorial_path = os.path.join(TOOL_TUTORIAL_DIR, filename) - with open(tutorial_path) as tool_tutorial_path: - return tool_tutorial_path.read() - - @abstractmethod - def tutorial(self) -> str: - """Constructs a guide for LLM, e.g., based on self.command_usages.""" - - @abstractmethod - def execute(self, command: str) -> Any: - """Executes tool based on the command.""" + """Abstract base class for tools used by LLM agents to interact with various + environments or perform actions. Provides a common interface for creating + tool-specific guides and executing commands.""" + + def __init__(self, benchmark: Benchmark, name: str = "") -> None: + self.benchmark = benchmark + # The name of the tool. + self.name: str = name or self.__class__.__name__ + + def _get_tutorial_file_content(self, filename: str) -> str: + tutorial_path = os.path.join(TOOL_TUTORIAL_DIR, filename) + with open(tutorial_path) as tool_tutorial_path: + return tool_tutorial_path.read() + + @abstractmethod + def tutorial(self) -> str: + """Constructs a guide for LLM, e.g., based on self.command_usages.""" + + @abstractmethod + def execute(self, command: str) -> Any: + """Executes tool based on the command.""" diff --git a/tool/bash_tool.py b/tool/bash_tool.py index 0188be3c4d..aa8f74c585 100644 --- a/tool/bash_tool.py +++ b/tool/bash_tool.py @@ -16,4 +16,4 @@ class BashTool(BaseTool): - pass + pass diff --git a/tool/container_tool.py b/tool/container_tool.py index 5c75897e4a..a37432dcd8 100644 --- a/tool/container_tool.py +++ b/tool/container_tool.py @@ -23,132 +23,159 @@ class ProjectContainerTool(BaseTool): - """A tool for LLM agents to interact within a project's docker container.""" - - def __init__(self, - benchmark: Benchmark, - name: str = '', - project_name: str = '') -> None: - super().__init__(benchmark, name) - self.project_name = project_name or benchmark.project - self.image_name = self._prepare_project_image(self.project_name) - self.container_id = self._start_docker_container() - self.build_script_path = '/src/build.sh' - self._backup_default_build_script() - self.project_dir = self._get_project_dir() - - def tutorial(self) -> str: - """Constructs a tool guide tutorial for LLM agents.""" - return self._get_tutorial_file_content('container_tool.txt').replace( - '{FUZZ_TARGET_PATH}', self.benchmark.target_path) - - def _prepare_project_image(self, project_name: str) -> str: - """Prepares the project's OSS-Fuzz docker image and returns the image name. - """ - image_name = oss_fuzz_checkout.prepare_project_image(self.benchmark) - if image_name: - return image_name - raise Exception(f'Failed to build image for {project_name}') - - def _execute_command_in_container(self, - command: list[str]) -> sp.CompletedProcess: - """Executes the |command| in subprocess and log output.""" - try: - result = sp.run(command, - stdout=sp.PIPE, - stderr=sp.PIPE, - check=False, - text=True, - encoding='utf-8', - errors='ignore') - - logger.debug( - 'Executing command (%s) in container %s: Return code %d. STDOUT: %s, ' - 'STDERR: %s', command, self.container_id, result.returncode, - result.stdout, result.stderr) - return result - except Exception as e: - logger.error( - 'Executing command (%s) in container failed with Exception: %s', - command, e) - return sp.CompletedProcess(command, returncode=1, stdout='', stderr='') - - def _execute_command(self, command: list[str]) -> sp.CompletedProcess: - """Executes the |command| in subprocess and log output.""" - try: - result = sp.run(command, - stdout=sp.PIPE, - stderr=sp.PIPE, - check=False, - text=True, - encoding='utf-8', - errors='ignore') - - logger.debug( - 'Executing command (%s): Return code %d. STDOUT: %s, STDERR: %s', - command, result.returncode, result.stdout, result.stderr) - return result - except Exception as e: - logger.error('Executing command (%s) failed with Exception: %s', command, - e) - return sp.CompletedProcess(command, returncode=1, stdout='', stderr='') - - def _backup_default_build_script(self) -> None: - """Creates a copy of the human-written /src/build.sh for LLM to use.""" - backup_command = f'cp {self.build_script_path} /src/build.bk.sh' - process = self.execute(backup_command) - if process.returncode: - logger.error('Failed to create a backup of %s: %s', - self.build_script_path, self.image_name) - - def _get_project_dir(self) -> str: - """Returns the project-under-test's source code directory.""" - pwd_command = 'pwd' - process = self.execute(pwd_command) - if process.returncode: - logger.error('Failed to get the WORKDIR: %s', self.image_name) - return '' - return process.stdout.strip() - - def _start_docker_container(self) -> str: - """Runs the project's OSS-Fuzz image as a background container and returns - the container ID.""" - run_container_command = [ - 'docker', 'run', '-d', '-t', '--entrypoint=/bin/bash', '-e', - f'FUZZING_LANGUAGE={self.benchmark.language}', self.image_name - ] - result = self._execute_command(run_container_command) - if result.returncode: - logger.error('Failed to start container of image: %s', self.image_name) - container_id = result.stdout.strip() - return container_id - - def execute(self, command: str) -> sp.CompletedProcess: - """Executes the |command| in the container and returns the output.""" - logger.debug('Executing command (%s) in %s: ', command, self.container_id) - execute_command_in_container = [ - 'docker', 'exec', self.container_id, '/bin/bash', '-c', command - ] - process = self._execute_command_in_container(execute_command_in_container) - process.args = command - return process - - def compile(self, extra_commands: str = '') -> sp.CompletedProcess: - """Compiles the fuzz target.""" - command = 'compile > /dev/null' + extra_commands - compile_process = self.execute(command) - # Hide Compilation command so that LLM won't reuse it in the inspection tool - # and be distracted by irrelevant errors, e.g., `build/ already exits`. - compile_process.args = '# Compiles the fuzz target.' - return compile_process - - def terminate(self) -> bool: - """Terminates the container.""" - terminate_container_command = ['docker', 'stop', self.container_id] - result = self._execute_command(terminate_container_command) - return result.returncode == 0 - - def write_to_file(self, content: str, file_path: str) -> None: - replace_file_content_command = ( - f'cat << "OFG_EOF" > {file_path}\n{content}\nOFG_EOF') - self.execute(replace_file_content_command) + """A tool for LLM agents to interact within a project's docker container.""" + + def __init__( + self, benchmark: Benchmark, name: str = "", project_name: str = "" + ) -> None: + super().__init__(benchmark, name) + self.project_name = project_name or benchmark.project + self.image_name = self._prepare_project_image(self.project_name) + self.container_id = self._start_docker_container() + self.build_script_path = "/src/build.sh" + self._backup_default_build_script() + self.project_dir = self._get_project_dir() + + def tutorial(self) -> str: + """Constructs a tool guide tutorial for LLM agents.""" + return self._get_tutorial_file_content("container_tool.txt").replace( + "{FUZZ_TARGET_PATH}", self.benchmark.target_path + ) + + def _prepare_project_image(self, project_name: str) -> str: + """Prepares the project's OSS-Fuzz docker image and returns the image name.""" + image_name = oss_fuzz_checkout.prepare_project_image(self.benchmark) + if image_name: + return image_name + raise Exception(f"Failed to build image for {project_name}") + + def _execute_command_in_container(self, command: list[str]) -> sp.CompletedProcess: + """Executes the |command| in subprocess and log output.""" + try: + result = sp.run( + command, + stdout=sp.PIPE, + stderr=sp.PIPE, + check=False, + text=True, + encoding="utf-8", + errors="ignore", + ) + + logger.debug( + "Executing command (%s) in container %s: Return code %d. STDOUT: %s, " + "STDERR: %s", + command, + self.container_id, + result.returncode, + result.stdout, + result.stderr, + ) + return result + except Exception as e: + logger.error( + "Executing command (%s) in container failed with Exception: %s", + command, + e, + ) + return sp.CompletedProcess(command, returncode=1, stdout="", stderr="") + + def _execute_command(self, command: list[str]) -> sp.CompletedProcess: + """Executes the |command| in subprocess and log output.""" + try: + result = sp.run( + command, + stdout=sp.PIPE, + stderr=sp.PIPE, + check=False, + text=True, + encoding="utf-8", + errors="ignore", + ) + + logger.debug( + "Executing command (%s): Return code %d. STDOUT: %s, STDERR: %s", + command, + result.returncode, + result.stdout, + result.stderr, + ) + return result + except Exception as e: + logger.error("Executing command (%s) failed with Exception: %s", command, e) + return sp.CompletedProcess(command, returncode=1, stdout="", stderr="") + + def _backup_default_build_script(self) -> None: + """Creates a copy of the human-written /src/build.sh for LLM to use.""" + backup_command = f"cp {self.build_script_path} /src/build.bk.sh" + process = self.execute(backup_command) + if process.returncode: + logger.error( + "Failed to create a backup of %s: %s", + self.build_script_path, + self.image_name, + ) + + def _get_project_dir(self) -> str: + """Returns the project-under-test's source code directory.""" + pwd_command = "pwd" + process = self.execute(pwd_command) + if process.returncode: + logger.error("Failed to get the WORKDIR: %s", self.image_name) + return "" + return process.stdout.strip() + + def _start_docker_container(self) -> str: + """Runs the project's OSS-Fuzz image as a background container and returns + the container ID.""" + run_container_command = [ + "docker", + "run", + "-d", + "-t", + "--entrypoint=/bin/bash", + "-e", + f"FUZZING_LANGUAGE={self.benchmark.language}", + self.image_name, + ] + result = self._execute_command(run_container_command) + if result.returncode: + logger.error("Failed to start container of image: %s", self.image_name) + container_id = result.stdout.strip() + return container_id + + def execute(self, command: str) -> sp.CompletedProcess: + """Executes the |command| in the container and returns the output.""" + logger.debug("Executing command (%s) in %s: ", command, self.container_id) + execute_command_in_container = [ + "docker", + "exec", + self.container_id, + "/bin/bash", + "-c", + command, + ] + process = self._execute_command_in_container(execute_command_in_container) + process.args = command + return process + + def compile(self, extra_commands: str = "") -> sp.CompletedProcess: + """Compiles the fuzz target.""" + command = "compile > /dev/null" + extra_commands + compile_process = self.execute(command) + # Hide Compilation command so that LLM won't reuse it in the inspection tool + # and be distracted by irrelevant errors, e.g., `build/ already exits`. + compile_process.args = "# Compiles the fuzz target." + return compile_process + + def terminate(self) -> bool: + """Terminates the container.""" + terminate_container_command = ["docker", "stop", self.container_id] + result = self._execute_command(terminate_container_command) + return result.returncode == 0 + + def write_to_file(self, content: str, file_path: str) -> None: + replace_file_content_command = ( + f'cat << "OFG_EOF" > {file_path}\n{content}\nOFG_EOF' + ) + self.execute(replace_file_content_command) diff --git a/tool/fuzz_introspector_tool.py b/tool/fuzz_introspector_tool.py index e36e45f7fc..b41302676c 100644 --- a/tool/fuzz_introspector_tool.py +++ b/tool/fuzz_introspector_tool.py @@ -17,24 +17,24 @@ class FuzzIntrospectorTool(BaseTool): - """Calls FI API with params.""" + """Calls FI API with params.""" - def _source_code(self, filename: str, start_line: int, end_line: int) -> str: - """Calls the source code API of the Fuzz Introspector.""" - # A placeholder - raise NotImplementedError + def _source_code(self, filename: str, start_line: int, end_line: int) -> str: + """Calls the source code API of the Fuzz Introspector.""" + # A placeholder + raise NotImplementedError - def _xrefs(self, function_signature: str) -> list[str]: - """Calls the xrefs API of the Fuzz Introspector.""" - # A placeholder - raise NotImplementedError + def _xrefs(self, function_signature: str) -> list[str]: + """Calls the xrefs API of the Fuzz Introspector.""" + # A placeholder + raise NotImplementedError - def _types_def(self, function_signature: str) -> list[str]: - """Calls the type API of the Fuzz Introspector.""" - # A placeholder - raise NotImplementedError + def _types_def(self, function_signature: str) -> list[str]: + """Calls the type API of the Fuzz Introspector.""" + # A placeholder + raise NotImplementedError - def _function_signature(self, function_name: str) -> str: - """Calls the function signature API of the Fuzz Introspector.""" - # A placeholder - raise NotImplementedError + def _function_signature(self, function_name: str) -> str: + """Calls the function signature API of the Fuzz Introspector.""" + # A placeholder + raise NotImplementedError diff --git a/tool/gbucket_tool.py b/tool/gbucket_tool.py index ae59a66545..52cd34fdfd 100644 --- a/tool/gbucket_tool.py +++ b/tool/gbucket_tool.py @@ -16,14 +16,14 @@ class GBucketTool(BaseTool): - """Fetches file content from GBucket.""" + """Fetches file content from GBucket.""" - def human_targets(self, project: str) -> list[str]: - """Human written fuzz targets of |project|.""" - # A placeholder. - raise NotImplementedError + def human_targets(self, project: str) -> list[str]: + """Human written fuzz targets of |project|.""" + # A placeholder. + raise NotImplementedError - def llm_targets(self, project: str) -> list[str]: - """LLM generated fuzz targets of |project|.""" - # A placeholder. - raise NotImplementedError + def llm_targets(self, project: str) -> list[str]: + """LLM generated fuzz targets of |project|.""" + # A placeholder. + raise NotImplementedError diff --git a/tool/lldb_tool.py b/tool/lldb_tool.py index 710d9751e1..962ad946fc 100644 --- a/tool/lldb_tool.py +++ b/tool/lldb_tool.py @@ -24,40 +24,49 @@ class LLDBTool(ProjectContainerTool): - """A tool for LLM agents to interact within a LLDB.""" - - def __init__(self, - benchmark: Benchmark, - result: RunResult, - name: str = '', - project_name: str = '') -> None: - super().__init__(benchmark, name, project_name) - self.result = result - - def tutorial(self) -> str: - """Constructs a tool guide tutorial for LLM agents.""" - return self._get_tutorial_file_content('lldb_tool.txt')\ - .replace('{AFTIFACT_NAME}', self.result.artifact_name)\ - .replace('{TARGET_NAME}', self.benchmark.target_name) - - def execute(self, command: str) -> sp.CompletedProcess: - """Executes the |command| in the container and returns the output.""" - logger.debug('Executing command (%s) in %s: ', command, self.container_id) - execute_command_in_container = [ - 'docker', 'exec', self.container_id, '/bin/bash', '-c', command - ] - process = self._execute_command_in_container(execute_command_in_container) - process.args = command - return process - - def execute_in_screen(self, lldb_command: str) -> sp.CompletedProcess: - """Sends a command to the lldb_session screen and returns LLDB output.""" - self.execute('screen -S lldb_session -X logfile flush 0') - self.execute('truncate -s 0 /tmp/lldb_log.txt') - - safe_cmd = lldb_command.replace('"', '\\"') + '\r' - self.execute(f'screen -S lldb_session -X stuff "{safe_cmd}"') - - time.sleep(1.0) - self.execute('screen -S lldb_session -X logfile flush 0') - return self.execute('cat /tmp/lldb_log.txt') + """A tool for LLM agents to interact within a LLDB.""" + + def __init__( + self, + benchmark: Benchmark, + result: RunResult, + name: str = "", + project_name: str = "", + ) -> None: + super().__init__(benchmark, name, project_name) + self.result = result + + def tutorial(self) -> str: + """Constructs a tool guide tutorial for LLM agents.""" + return ( + self._get_tutorial_file_content("lldb_tool.txt") + .replace("{AFTIFACT_NAME}", self.result.artifact_name) + .replace("{TARGET_NAME}", self.benchmark.target_name) + ) + + def execute(self, command: str) -> sp.CompletedProcess: + """Executes the |command| in the container and returns the output.""" + logger.debug("Executing command (%s) in %s: ", command, self.container_id) + execute_command_in_container = [ + "docker", + "exec", + self.container_id, + "/bin/bash", + "-c", + command, + ] + process = self._execute_command_in_container(execute_command_in_container) + process.args = command + return process + + def execute_in_screen(self, lldb_command: str) -> sp.CompletedProcess: + """Sends a command to the lldb_session screen and returns LLDB output.""" + self.execute("screen -S lldb_session -X logfile flush 0") + self.execute("truncate -s 0 /tmp/lldb_log.txt") + + safe_cmd = lldb_command.replace('"', '\\"') + "\r" + self.execute(f'screen -S lldb_session -X stuff "{safe_cmd}"') + + time.sleep(1.0) + self.execute("screen -S lldb_session -X logfile flush 0") + return self.execute("cat /tmp/lldb_log.txt") diff --git a/utils.py b/utils.py index c71e720f70..acd7eeb40a 100644 --- a/utils.py +++ b/utils.py @@ -22,42 +22,43 @@ import dill -def serialize_to_dill(variable: Any, dill_path: str = '') -> str: - """Serializes |variable| to a dill file under |path_prefix| and returns - the file path.""" - path_prefix = os.path.dirname(dill_path) - os.makedirs(path_prefix, exist_ok=True) - with open(dill_path, 'wb') as f: - dill.dump(variable, f) - logging.info('Serialized %s to %s', variable, dill_path) - return dill_path +def serialize_to_dill(variable: Any, dill_path: str = "") -> str: + """Serializes |variable| to a dill file under |path_prefix| and returns + the file path.""" + path_prefix = os.path.dirname(dill_path) + os.makedirs(path_prefix, exist_ok=True) + with open(dill_path, "wb") as f: + dill.dump(variable, f) + logging.info("Serialized %s to %s", variable, dill_path) + return dill_path def deserialize_from_dill(dill_path: Any) -> Any: - """Serializes |variable| to a dill file under |path_prefix| and returns - the file path.""" - try: - with open(dill_path, 'rb') as f: - obj = dill.load(f) - logging.info('Deserialized %s to %s', dill_path, obj) - return obj - except FileNotFoundError as e: - logging.error('Failed to deserialize %s: File does not exist: %s', - dill_path, e) - return None + """Serializes |variable| to a dill file under |path_prefix| and returns + the file path.""" + try: + with open(dill_path, "rb") as f: + obj = dill.load(f) + logging.info("Deserialized %s to %s", dill_path, obj) + return obj + except FileNotFoundError as e: + logging.error("Failed to deserialize %s: File does not exist: %s", dill_path, e) + return None def _default_retry_delay_fn(e: Exception, n: int): - """Delays retry by a random seconds between 1 to 2 minutes.""" - del e, n - return random.uniform(60, 120) + """Delays retry by a random seconds between 1 to 2 minutes.""" + del e, n + return random.uniform(60, 120) -def retryable(exceptions=None, - default_attempts=5, - delay_fn=_default_retry_delay_fn, - other_exceptions=None): - """ +def retryable( + exceptions=None, + default_attempts=5, + delay_fn=_default_retry_delay_fn, + other_exceptions=None, +): + """ Decorator that retries the function on specified exceptions. :param exceptions: List/Set of exceptions or a dictionary of exceptions with custom retry counts. @@ -65,45 +66,70 @@ def retryable(exceptions=None, :param delay_fn: Function to determine the delay between retries. Default is random between 0-60 seconds. """ - exception_config = { - exc: default_attempts for exc in exceptions or {} - } | (other_exceptions or {}) - - def decorator(func): - - @wraps(func) - def wrapper(*args, **kwargs): - attempt = 1 # TODO(dongge): A separate counter for each exception. - while True: - try: - return func(*args, **kwargs) - except Exception as e: - # Expected exceptions and their subclass. - num_attempts = next( - (attempts for exc_type, attempts in exception_config.items() - if type(e) is exc_type), 1) # pylint: disable=unidiomatic-typecheck - - logging.error( - 'Exception %s (%s) on function %s(args=%s, kwargs=%s), attempt ' - '%d/%d', type(e), e, func.__name__, args, kwargs, attempt, - num_attempts) - - if attempt >= num_attempts: - logging.error( - 'Max attempts %d/%d reached for %s(args=%s, kwargs=%s) due to ' - '%s (%s)', attempt, num_attempts, func.__name__, args, kwargs, - type(e), e) - raise - - attempt += 1 - - delay_time = delay_fn(e, attempt) - logging.warning( - 'Delay %d seconds before re-attempting (%d/%d) function call %s(' - 'args=%s, kwargs=%s)', delay_time, attempt, num_attempts, - func.__name__, args, kwargs) - time.sleep(delay_time) - - return wrapper - - return decorator + exception_config = {exc: default_attempts for exc in exceptions or {}} | ( + other_exceptions or {} + ) + + def decorator(func): + + @wraps(func) + def wrapper(*args, **kwargs): + attempt = 1 # TODO(dongge): A separate counter for each exception. + while True: + try: + return func(*args, **kwargs) + except Exception as e: + # Expected exceptions and their subclass. + num_attempts = next( + ( + attempts + for exc_type, attempts in exception_config.items() + if type(e) is exc_type + ), + 1, + ) # pylint: disable=unidiomatic-typecheck + + logging.error( + "Exception %s (%s) on function %s(args=%s, kwargs=%s), attempt " + "%d/%d", + type(e), + e, + func.__name__, + args, + kwargs, + attempt, + num_attempts, + ) + + if attempt >= num_attempts: + logging.error( + "Max attempts %d/%d reached for %s(args=%s, kwargs=%s) due to " + "%s (%s)", + attempt, + num_attempts, + func.__name__, + args, + kwargs, + type(e), + e, + ) + raise + + attempt += 1 + + delay_time = delay_fn(e, attempt) + logging.warning( + "Delay %d seconds before re-attempting (%d/%d) function call %s(" + "args=%s, kwargs=%s)", + delay_time, + attempt, + num_attempts, + func.__name__, + args, + kwargs, + ) + time.sleep(delay_time) + + return wrapper + + return decorator