Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions malsim/mal_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from types import MappingProxyType
from typing import Any, Optional

from .visualization.malsim_gui_client import MalSimGUIClient

from maltoolbox import neo4j_configs
from maltoolbox.ingestors import neo4j
from maltoolbox.attackgraph import (AttackGraph, AttackGraphNode,
Expand Down Expand Up @@ -189,6 +191,9 @@ def __init__(
if prune_unviable_unnecessary:
apriori.prune_unviable_and_unnecessary_nodes(attack_graph)

# Initialize the REST API client
self.rest_api_client = MalSimGUIClient()

# Keep a backup attack graph to use when resetting
self.attack_graph_backup = copy.deepcopy(attack_graph)

Expand Down Expand Up @@ -220,6 +225,9 @@ def reset(
# Reset agents
self._reset_agents()

# Upload initial state to the REST API
self.rest_api_client.upload_initial_state(self.attack_graph)

return self.agent_states

def _create_attacker_state(
Expand Down Expand Up @@ -641,6 +649,11 @@ def step(
logger.info("Removing agent %s", agent_state.name)
self._alive_agents.remove(agent_state.name)

self.rest_api_client.upload_performed_nodes(
list(step_all_compromised_nodes | step_enabled_defenses),
self.cur_iter
)

self.cur_iter += 1
return self.agent_states

Expand Down
155 changes: 155 additions & 0 deletions malsim/visualization/malsim_gui_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""TYR Monitor REST API client"""

from __future__ import annotations
import logging
from typing import Optional
import json

import requests
from maltoolbox.attackgraph import AttackGraph, AttackGraphNode
from maltoolbox.model import Model

logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 30

class MalSimGUIClient():
"""A client that can talk to the TYR Monitor REST API"""

def __init__(
self,
host="localhost",
port=8888,
password="letmein",
):
self.protocol = "http"
self.host = host
self.port = port
self.password = password

def _create_url(self, endpoint):
return f"{self.protocol}://{self.host}:{self.port}/{endpoint}"

def _send_request(self, method, endpoint, json_content=None):
"""Send a request to the REST API"""
url = self._create_url(endpoint)
if method == 'GET':
res = requests.get(url, timeout=DEFAULT_TIMEOUT)
elif method == 'POST':
res = requests.post(
url, json=json_content, timeout=DEFAULT_TIMEOUT
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
res.raise_for_status()
return res

@classmethod
def from_config(cls, config, record=False) -> Optional[MalsimGUIClient]:
"""Load REST API client from config dict

Returns the client, or None if no REST API details set in config file
"""

if 'tyr-monitor-restapi' not in config:
logger.warning(
"No REST API connection details specified in config file")
return None

tyr_restapi_config = config['tyr-monitor-restapi']
return cls(
tyr_restapi_config.get('host'),
tyr_restapi_config.get('port'),
tyr_restapi_config.get('password')
)

def upload_model(self, model: Model) -> None:
"""Uploads a serialized model to the POST endpoint of the API"""
self._send_request(
'POST',
'model',
json_content=model._to_dict()
)

def upload_attack_graph(self, attack_graph: AttackGraph) -> None:
"""Uploads a serialized graph to the POST endpoint of the API"""
self._send_request(
'POST',
'attack_graph',
json_content=attack_graph._to_dict()
)

def upload_performed_nodes(
self, new_performed_nodes: list[AttackGraphNode], iteration: int
) -> None:
"""Uploads newly performed nodes to API"""
self._send_request(
'POST',
'performed_nodes',
json_content=[
{'node_id': n.id, 'iteration': iteration}
for n in new_performed_nodes
]
)

def upload_latest_attack_steps(
self, latest_steps: dict[int, list[dict]]) -> None:
"""Upload dict of latest attack step ids mapped to alert logs"""
self._send_request(
'POST',
'latest_attack_steps',
json_content=latest_steps
)

def upload_defender_suggestions(
self, suggestions: dict[str, dict[int, dict]]) -> None:
"""
Uploads dict of suggestions that maps agent name
to a dict mapping node id to suggestion info.
"""
self._send_request(
'POST',
'defender_suggestions',
json_content=suggestions
)

def get_defender_action(self) -> dict:
"""Get selected defender action from API"""
return self._send_request(
'GET',
'defender_action'
).json()

def clear_defender_action(self) -> None:
"""Post None to selected action from API"""
self._send_request(
'POST',
'defender_action',
json_content={'iteration': -1, 'node_id': None}
)

def get_reward_value(self) -> dict:
"""Get reward value for the current iteration from API"""
return self._send_request(
'GET',
'reward_value'
).json()

def set_reward_value(self, iteration: int, reward: float) -> None:
"""Post reward value for iteration to API"""
self._send_request(
'POST',
'reward_value',
json_content={'iteration': iteration, 'reward': reward}
)

def reset(self) -> None:
"""Reset the API contents"""
self._send_request('POST', 'reset')

def upload_initial_state(
self, attack_graph: AttackGraph
):
"""Reset the REST API and upload model and graph"""
self.reset()
self.upload_model(attack_graph.model)
self.upload_attack_graph(attack_graph)
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ requires-python = ">=3.10"
dependencies = [
"py2neo>=2021.2.3",
"mal-toolbox~=0.3.0",
"PyYAML>=6.0.1"
"PyYAML>=6.0.1",
"fastapi",
"pydantic",
"uvicorn",
"requests"
]
license = {text = "Apache Software License"}
keywords = ["mal"]
Expand Down
Loading