diff --git a/malsim/mal_simulator.py b/malsim/mal_simulator.py index c331fc67..983ca796 100644 --- a/malsim/mal_simulator.py +++ b/malsim/mal_simulator.py @@ -7,6 +7,8 @@ from types import MappingProxyType from typing import Any, Optional +from .visualization.malsim_gui_client import MalSimGUIClient + from maltoolbox import neo4j_configs from maltoolbox.ingestors import neo4j from maltoolbox.attackgraph import (AttackGraph, AttackGraphNode, @@ -189,6 +191,9 @@ def __init__( if prune_unviable_unnecessary: apriori.prune_unviable_and_unnecessary_nodes(attack_graph) + # Initialize the REST API client + self.rest_api_client = MalSimGUIClient() + # Keep a backup attack graph to use when resetting self.attack_graph_backup = copy.deepcopy(attack_graph) @@ -220,6 +225,9 @@ def reset( # Reset agents self._reset_agents() + # Upload initial state to the REST API + self.rest_api_client.upload_initial_state(self.attack_graph) + return self.agent_states def _create_attacker_state( @@ -641,6 +649,11 @@ def step( logger.info("Removing agent %s", agent_state.name) self._alive_agents.remove(agent_state.name) + self.rest_api_client.upload_performed_nodes( + list(step_all_compromised_nodes | step_enabled_defenses), + self.cur_iter + ) + self.cur_iter += 1 return self.agent_states diff --git a/malsim/visualization/malsim_gui_client.py b/malsim/visualization/malsim_gui_client.py new file mode 100644 index 00000000..94e77fa6 --- /dev/null +++ b/malsim/visualization/malsim_gui_client.py @@ -0,0 +1,155 @@ +"""TYR Monitor REST API client""" + +from __future__ import annotations +import logging +from typing import Optional +import json + +import requests +from maltoolbox.attackgraph import AttackGraph, AttackGraphNode +from maltoolbox.model import Model + +logger = logging.getLogger(__name__) +DEFAULT_TIMEOUT = 30 + +class MalSimGUIClient(): + """A client that can talk to the TYR Monitor REST API""" + + def __init__( + self, + host="localhost", + port=8888, + password="letmein", + ): + self.protocol = "http" + self.host = host + self.port = port + self.password = password + + def _create_url(self, endpoint): + return f"{self.protocol}://{self.host}:{self.port}/{endpoint}" + + def _send_request(self, method, endpoint, json_content=None): + """Send a request to the REST API""" + url = self._create_url(endpoint) + if method == 'GET': + res = requests.get(url, timeout=DEFAULT_TIMEOUT) + elif method == 'POST': + res = requests.post( + url, json=json_content, timeout=DEFAULT_TIMEOUT + ) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + res.raise_for_status() + return res + + @classmethod + def from_config(cls, config, record=False) -> Optional[MalsimGUIClient]: + """Load REST API client from config dict + + Returns the client, or None if no REST API details set in config file + """ + + if 'tyr-monitor-restapi' not in config: + logger.warning( + "No REST API connection details specified in config file") + return None + + tyr_restapi_config = config['tyr-monitor-restapi'] + return cls( + tyr_restapi_config.get('host'), + tyr_restapi_config.get('port'), + tyr_restapi_config.get('password') + ) + + def upload_model(self, model: Model) -> None: + """Uploads a serialized model to the POST endpoint of the API""" + self._send_request( + 'POST', + 'model', + json_content=model._to_dict() + ) + + def upload_attack_graph(self, attack_graph: AttackGraph) -> None: + """Uploads a serialized graph to the POST endpoint of the API""" + self._send_request( + 'POST', + 'attack_graph', + json_content=attack_graph._to_dict() + ) + + def upload_performed_nodes( + self, new_performed_nodes: list[AttackGraphNode], iteration: int + ) -> None: + """Uploads newly performed nodes to API""" + self._send_request( + 'POST', + 'performed_nodes', + json_content=[ + {'node_id': n.id, 'iteration': iteration} + for n in new_performed_nodes + ] + ) + + def upload_latest_attack_steps( + self, latest_steps: dict[int, list[dict]]) -> None: + """Upload dict of latest attack step ids mapped to alert logs""" + self._send_request( + 'POST', + 'latest_attack_steps', + json_content=latest_steps + ) + + def upload_defender_suggestions( + self, suggestions: dict[str, dict[int, dict]]) -> None: + """ + Uploads dict of suggestions that maps agent name + to a dict mapping node id to suggestion info. + """ + self._send_request( + 'POST', + 'defender_suggestions', + json_content=suggestions + ) + + def get_defender_action(self) -> dict: + """Get selected defender action from API""" + return self._send_request( + 'GET', + 'defender_action' + ).json() + + def clear_defender_action(self) -> None: + """Post None to selected action from API""" + self._send_request( + 'POST', + 'defender_action', + json_content={'iteration': -1, 'node_id': None} + ) + + def get_reward_value(self) -> dict: + """Get reward value for the current iteration from API""" + return self._send_request( + 'GET', + 'reward_value' + ).json() + + def set_reward_value(self, iteration: int, reward: float) -> None: + """Post reward value for iteration to API""" + self._send_request( + 'POST', + 'reward_value', + json_content={'iteration': iteration, 'reward': reward} + ) + + def reset(self) -> None: + """Reset the API contents""" + self._send_request('POST', 'reset') + + def upload_initial_state( + self, attack_graph: AttackGraph + ): + """Reset the REST API and upload model and graph""" + self.reset() + self.upload_model(attack_graph.model) + self.upload_attack_graph(attack_graph) diff --git a/pyproject.toml b/pyproject.toml index a9482dce..7b815e81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,11 @@ requires-python = ">=3.10" dependencies = [ "py2neo>=2021.2.3", "mal-toolbox~=0.3.0", - "PyYAML>=6.0.1" + "PyYAML>=6.0.1", + "fastapi", + "pydantic", + "uvicorn", + "requests" ] license = {text = "Apache Software License"} keywords = ["mal"]