diff --git a/datashuttle/configs/config_class.py b/datashuttle/configs/config_class.py index 406db5c1d..fc0ec83ce 100644 --- a/datashuttle/configs/config_class.py +++ b/datashuttle/configs/config_class.py @@ -62,7 +62,6 @@ def __init__( self.logging_path: Path self.hostkeys_path: Path - self.ssh_key_path: Path self.project_metadata_path: Path def setup_after_load(self) -> None: @@ -293,8 +292,6 @@ def init_paths(self) -> None: self.project_name ) - self.ssh_key_path = datashuttle_path / f"{self.project_name}_ssh_key" - self.hostkeys_path = datashuttle_path / "hostkeys" self.logging_path = self.make_and_get_logging_path() diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 8df133bf1..0052653d5 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -27,7 +27,6 @@ TopLevelFolder, ) -import paramiko import yaml from datashuttle.configs import ( @@ -800,46 +799,26 @@ def setup_ssh_connection(self) -> None: "setup-ssh-connection-to-central-server", local_vars=locals() ) - verified = ssh.verify_ssh_central_host( + verified = ssh.verify_ssh_central_host_api( self.cfg["central_host_id"], self.cfg.hostkeys_path, log=True, ) if verified: - ssh.setup_ssh_key(self.cfg, log=True) - self._setup_rclone_central_ssh_config(log=True) + private_key_str = ssh.setup_ssh_key_api(self.cfg, log=True) + + self._setup_rclone_central_ssh_config(private_key_str, log=True) rclone.check_successful_connection_and_raise_error_on_fail( self.cfg ) - ds_logger.close_log_filehandler() - - @requires_ssh_configs - @check_is_not_local_project - def write_public_key(self, filepath: str) -> None: - """Save the public SSH key to a specified filepath. - - By default, only the SSH private key is stored in the - datashuttle configs folder. Use this function to save - the public key. - - Parameters - ---------- - filepath - Full filepath (including filename) to write the - public key to. - - """ - key: paramiko.RSAKey - key = paramiko.RSAKey.from_private_key_file( - self.cfg.ssh_key_path.as_posix() - ) + utils.log_and_message( + "SSH key pair setup successfully. SSH key saved to the RClone config file." + ) - with open(filepath, "w") as public: - public.write(key.get_base64()) - public.close() + ds_logger.close_log_filehandler() # ------------------------------------------------------------------------- # Google Drive @@ -1524,11 +1503,13 @@ def _make_project_metadata_if_does_not_exist(self) -> None: """ folders.create_folders(self.cfg.project_metadata_path, log=False) - def _setup_rclone_central_ssh_config(self, log: bool) -> None: + def _setup_rclone_central_ssh_config( + self, private_key_str: str, log: bool + ) -> None: rclone.setup_rclone_config_for_ssh( self.cfg, self.cfg.get_rclone_config_name("ssh"), - self.cfg.ssh_key_path, + private_key_str, log=log, ) diff --git a/datashuttle/tui/interface.py b/datashuttle/tui/interface.py index e9be18031..5db19930e 100644 --- a/datashuttle/tui/interface.py +++ b/datashuttle/tui/interface.py @@ -495,10 +495,14 @@ def setup_key_pair_and_rclone_config( ) -> InterfaceOutput: """Set up SSH key pair and associated rclone configuration.""" try: + rsa_key, private_key_str = ssh.generate_ssh_key_strings() + ssh.add_public_key_to_central_authorized_keys( - self.project.cfg, password, log=False + self.project.cfg, rsa_key, password, log=False + ) + self.project._setup_rclone_central_ssh_config( + private_key_str, log=False ) - self.project._setup_rclone_central_ssh_config(log=False) rclone.check_successful_connection_and_raise_error_on_fail( self.project.cfg diff --git a/datashuttle/tui/screens/setup_ssh.py b/datashuttle/tui/screens/setup_ssh.py index cd944e951..3ab5025e8 100644 --- a/datashuttle/tui/screens/setup_ssh.py +++ b/datashuttle/tui/screens/setup_ssh.py @@ -152,10 +152,7 @@ def use_password_to_setup_ssh_key_pairs(self) -> None: ) if success: - message = ( - f"Connection successful! SSH key " - f"saved to {self.interface.get_configs().ssh_key_path}" - ) + message = "Connection successful! SSH key saved to the RClone config file." self.query_one("#setup_ssh_ok_button").label = "Finish" self.query_one("#setup_ssh_cancel_button").disabled = True self.stage += 1 diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index b7c5981db..664ba44bb 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -3,8 +3,6 @@ from typing import TYPE_CHECKING, Dict, List, Literal, Optional if TYPE_CHECKING: - from pathlib import Path - from datashuttle.configs.config_class import Configs from datashuttle.utils.custom_types import TopLevelFolder @@ -176,7 +174,7 @@ def setup_rclone_config_for_local_filesystem( def setup_rclone_config_for_ssh( cfg: Configs, rclone_config_name: str, - ssh_key_path: Path, + private_key_str: str, log: bool = True, ) -> None: """Set the RClone remote config for ssh. @@ -194,25 +192,27 @@ def setup_rclone_config_for_ssh( canonical config name, generated by datashuttle.cfg.get_rclone_config_name() - ssh_key_path - path to the ssh key used for connecting to - ssh central filesystem + private_key_str + PEM encoded sssh private key to pass to RClone. log whether to log, if True logger must already be initialised. """ - call_rclone( + key_escaped = private_key_str.replace("\n", "\\n") + + command = ( f"config create " f"{rclone_config_name} " f"sftp " f"host {cfg['central_host_id']} " f"user {cfg['central_host_username']} " f"port {canonical_configs.get_default_ssh_port()} " - f"key_file {ssh_key_path.as_posix()}", - pipe_std=True, + f'-- key_pem "{key_escaped}"' ) + call_rclone(command, pipe_std=True) + if log: log_rclone_config_output() @@ -274,7 +274,7 @@ def setup_rclone_config_for_gdrive( f"{client_secret_key_value}" f"scope drive " f"root_folder_id {cfg['gdrive_root_folder_id']} " - f"{extra_args}", + f"{extra_args}" ) return process diff --git a/datashuttle/utils/ssh.py b/datashuttle/utils/ssh.py index 074ca731c..c3085d5b3 100644 --- a/datashuttle/utils/ssh.py +++ b/datashuttle/utils/ssh.py @@ -3,9 +3,11 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from pathlib import Path + from datashuttle.configs.config_class import Configs -from pathlib import Path +from io import StringIO from typing import Optional import paramiko @@ -14,53 +16,115 @@ from datashuttle.utils import utils # ----------------------------------------------------------------------------- -# Core Functions +# Setup SSH - API Wrappers # ----------------------------------------------------------------------------- -# These functions are called by both API and TUI. -# Unfortunately it is not possible for TUI to call API directly in the case of -# setting up SSH, because it requires user input to proceed. +# These functions wrap core SSH setup functions (above) for the API. See +# tui/screens/setup_ssh for the TUI equivalents. -def connect_client_core( - client: paramiko.SSHClient, - cfg: Configs, - password: Optional[str] = None, -) -> None: - """Connect to the client. +def verify_ssh_central_host_api( + central_host_id: str, hostkeys_path: Path, log: bool = True +) -> bool: + """Prompt the user to verify and cache the SSH server's host key. - A centralised function to connect to a paramiko client. + This function retrieves the SSH server's key and asks the user to + manually validate and accept it. Accepting the key caches it locally + to ensure secure future connections. Parameters ---------- - client - Paramiko client to connect to. + central_host_id + Hostname or IP address of the SSH server. + + hostkeys_path + Path to the local file where known host keys are stored. + + log + Whether to log the verification messages. + + Returns + ------- + bool + True if the host key was accepted and saved, False otherwise. + + """ + key = get_remote_server_key(central_host_id) + + message = ( + f"The host key is not cached for this server: " + f"{central_host_id}.\nYou have no guarantee " + f"that the server is the computer you think it is.\n" + f"The server's {key.get_name()} key fingerprint is: " + f"{key.get_base64()}\nIf you trust this host, to connect" + f" and cache the host key, press y: " + ) + input_ = utils.get_user_input(message) + + if input_ == "y": + save_hostkey_locally(key, central_host_id, hostkeys_path) + success = True + utils.print_message_to_user("Host accepted.") + else: + utils.print_message_to_user("Host not accepted. No connection made.") + success = False + + if log: + if success: + utils.log(f"{message}") + utils.log(f"Hostkeys saved at:{hostkeys_path.as_posix()}") + else: + utils.log("Host not accepted. No connection made.") + + return success + + +def setup_ssh_key_api( + cfg: Configs, + log: bool = True, +) -> str: + """Set up an SSH private / public key pair with central server. + + First, a private key is generated and saved in the .datashuttle config path. + Next a connection requiring input password made, and the public part of the key + added to ~/.ssh/authorized_keys. + Parameters + ---------- cfg - Datashuttle Configs. + datashuttle config UserDict - password - Password (if required) to establish the connection. + log + log if True, logger must already be initialised. """ - client.get_host_keys().load(cfg.hostkeys_path.as_posix()) - client.set_missing_host_key_policy(paramiko.RejectPolicy()) - - client.connect( - cfg["central_host_id"], - username=cfg["central_host_username"], - password=password, - key_filename=( - cfg.ssh_key_path.as_posix() - if isinstance(cfg.ssh_key_path, Path) - else None + rsa_key, private_key = generate_ssh_key_strings() + + server_password = utils.get_connection_secret_from_user( + connection_method_name="SSH", + key_name_full="password", + key_name_short="password", + key_info=( + "You are required to enter the password to your central host to add the public key. " + "You will not have to enter your password again." ), - look_for_keys=True, - port=canonical_configs.get_default_ssh_port(), + log_status=log, ) + add_public_key_to_central_authorized_keys(cfg, rsa_key, server_password) + + return private_key + + +# ----------------------------------------------------------------------------- +# Core Functions +# ----------------------------------------------------------------------------- +# These functions are called by both API and TUI. +# Unfortunately it is not possible for TUI to call API directly in the case of +# setting up SSH, because it requires user input to proceed. + def add_public_key_to_central_authorized_keys( - cfg: Configs, password: str, log=True + cfg: Configs, rsa_key: paramiko.RSAKey, server_password: str, log=True ) -> None: """Append the public part of key to central server ~/.ssh/authorized_keys. @@ -69,63 +133,88 @@ def add_public_key_to_central_authorized_keys( cfg Datashuttle Configs object. - password + rsa_key + The RSAKey key, the public part to add to `~/.ssh/authorized_keys.` + + server_password Password to the central server. log If `True`, log the client connection process. """ - generate_and_write_ssh_key(cfg.ssh_key_path) - - key = paramiko.RSAKey.from_private_key_file(cfg.ssh_key_path.as_posix()) - client: paramiko.SSHClient with paramiko.SSHClient() as client: - if log: - connect_client_with_logging(client, cfg, password=password) - else: - connect_client_core(client, cfg, password=password) + connect_client(client, cfg, password=server_password, log=log) client.exec_command("mkdir -p ~/.ssh/") client.exec_command( # double >> for concatenate - f'echo "{key.get_name()} {key.get_base64()}" ' + f'echo "{rsa_key.get_name()} {rsa_key.get_base64()}" ' f">> ~/.ssh/authorized_keys" ) client.exec_command("chmod 644 ~/.ssh/authorized_keys") client.exec_command("chmod 700 ~/.ssh/") -def generate_and_write_ssh_key(ssh_key_path: Path) -> None: - """Generate an RSA SSH key and save it to the specified file path. +def generate_ssh_key_strings(): + """Generate a private and public SSH key pair.""" + rsa_key = generate_ssh_key() - Parameters - ---------- - ssh_key_path - The full file path where the private SSH key will be saved. + private_key_io = StringIO() + rsa_key.write_private_key(private_key_io) - """ - key = paramiko.RSAKey.generate(4096) - key.write_private_key_file(ssh_key_path.as_posix()) + private_key_io.seek(0) + private_key_io.seek(0) + private_key_str = private_key_io.read() -def get_remote_server_key(central_host_id: str): - """Get the remove server host key for validation before connection. + return rsa_key, private_key_str - Parameters - ---------- - central_host_id - The hostname or IP address of the central host. - """ - transport: paramiko.Transport - with paramiko.Transport( - (central_host_id, canonical_configs.get_default_ssh_port()) - ) as transport: - transport.connect() - key = transport.get_remote_server_key() - return key +def generate_ssh_key() -> paramiko.RSAKey: + """Generate an RSA SSH key and save it to the specified file path.""" + return paramiko.RSAKey.generate(4096) + + +def connect_client( + client: paramiko.SSHClient, + cfg: Configs, + password: Optional[str] = None, + log=True, +) -> None: + """Connect client to central server using paramiko.""" + try: + client.get_host_keys().load(cfg.hostkeys_path.as_posix()) + client.set_missing_host_key_policy(paramiko.RejectPolicy()) + + client.connect( + cfg["central_host_id"], + username=cfg["central_host_username"], + password=password, + key_filename=None, + look_for_keys=True, + port=canonical_configs.get_default_ssh_port(), + ) + + utils.print_message_to_user( + f"Connection to {cfg['central_host_id']} made successfully." + ) + + except Exception as e: + raise_func = utils.log_and_raise_error if log else utils.raise_error + + raise_func( + f"Could not connect to server. Ensure that \n" + f"1) You have run setup_ssh_connection() \n" + f"2) You are on VPN network if required. \n" + f"3) The central_host_id: {cfg['central_host_id']} is" + f" correct.\n" + f"4) The central username:" + f" {cfg['central_host_username']}, and password are correct." + f"Original error: {e}", + ConnectionError, + ) def save_hostkey_locally(key, central_host_id, hostkeys_path) -> None: @@ -160,147 +249,20 @@ def save_hostkey_locally(key, central_host_id, hostkeys_path) -> None: client.get_host_keys().save(hostkeys_path.as_posix()) -# ----------------------------------------------------------------------------- -# Setup SSH - API Wrappers -# ----------------------------------------------------------------------------- -# These functions wrap core SSH setup functions (above) for the API. See -# tui/screens/setup_ssh for the TUI equivalents. - - -def setup_ssh_key( - cfg: Configs, - log: bool = True, -) -> None: - """Set up an SSH private / public key pair with central server. - - First, a private key is generated and saved in the .datashuttle config path. - Next a connection requiring input password made, and the public part of the key - added to ~/.ssh/authorized_keys. - - Parameters - ---------- - ssh_key_path - path to the ssh private key - - hostkeys_path - path to the ssh host key, once the user - has confirmed the key ID this is saved so verification - is not required each time. - - cfg - datashuttle config UserDict - - log - log if True, logger must already be initialised. - - """ - password = utils.get_connection_secret_from_user( - connection_method_name="SSH", - key_name_full="password", - key_name_short="password", - key_info=( - "You are required to enter the password to your central host to add the public key. " - "You will not have to enter your password again." - ), - log_status=log, - ) - - add_public_key_to_central_authorized_keys(cfg, password) - - success_message = ( - f"SSH key pair setup successfully. " - f"Private key at: {cfg.ssh_key_path.as_posix()}" - ) - - utils.print_message_to_user(success_message) - - if log: - utils.log(f"\n{success_message}") - - -def connect_client_with_logging( - client: paramiko.SSHClient, - cfg: Configs, - password: Optional[str] = None, - message_on_sucessful_connection: bool = True, -) -> None: - """Connect client to central server using paramiko. - - Accept either password or path to private key, but not both. - Paramiko does not support pathlib. - """ - try: - connect_client_core(client, cfg, password) - if message_on_sucessful_connection: - utils.print_message_to_user( - f"Connection to {cfg['central_host_id']} made successfully." - ) - - except Exception as e: - utils.log_and_raise_error( - f"Could not connect to server. Ensure that \n" - f"1) You have run setup_ssh_connection() \n" - f"2) You are on VPN network if required. \n" - f"3) The central_host_id: {cfg['central_host_id']} is" - f" correct.\n" - f"4) The central username:" - f" {cfg['central_host_username']}, and password are correct." - f"Original error: {e}", - ConnectionError, - ) - - -def verify_ssh_central_host( - central_host_id: str, hostkeys_path: Path, log: bool = True -) -> bool: - """Prompt the user to verify and cache the SSH server's host key. - - This function retrieves the SSH server's key and asks the user to - manually validate and accept it. Accepting the key caches it locally - to ensure secure future connections. +def get_remote_server_key(central_host_id: str): + """Get the remove server host key for validation before connection. Parameters ---------- central_host_id - Hostname or IP address of the SSH server. - - hostkeys_path - Path to the local file where known host keys are stored. - - log - Whether to log the verification messages. - - Returns - ------- - bool - True if the host key was accepted and saved, False otherwise. + The hostname or IP address of the central host. """ - key = get_remote_server_key(central_host_id) - - message = ( - f"The host key is not cached for this server: " - f"{central_host_id}.\nYou have no guarantee " - f"that the server is the computer you think it is.\n" - f"The server's {key.get_name()} key fingerprint is: " - f"{key.get_base64()}\nIf you trust this host, to connect" - f" and cache the host key, press y: " - ) - input_ = utils.get_user_input(message) - - if input_ == "y": - save_hostkey_locally(key, central_host_id, hostkeys_path) - success = True - utils.print_message_to_user("Host accepted.") - else: - utils.print_message_to_user("Host not accepted. No connection made.") - success = False - - if log: - if success: - utils.log(f"{message}") - utils.log(f"Hostkeys saved at:{hostkeys_path.as_posix()}") - else: - utils.log("Host not accepted. No connection made.") + transport: paramiko.Transport + with paramiko.Transport( + (central_host_id, canonical_configs.get_default_ssh_port()) + ) as transport: + transport.connect() + key = transport.get_remote_server_key() - return success + return key diff --git a/tests/test_utils.py b/tests/test_utils.py index 94d0da5ef..710998fba 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -40,11 +40,7 @@ def setup_project_default_configs( project.make_config_file(**default_configs) - rclone.setup_rclone_config_for_ssh( - project.cfg, - project.cfg.get_rclone_config_name("ssh"), - project.cfg.ssh_key_path, - ) + project._setup_rclone_central_local_filesystem_config() if local_path: os.makedirs(local_path, exist_ok=True) diff --git a/tests/tests_transfers/ssh/ssh_test_utils.py b/tests/tests_transfers/ssh/ssh_test_utils.py index c8e5057fa..b57f18fcc 100644 --- a/tests/tests_transfers/ssh/ssh_test_utils.py +++ b/tests/tests_transfers/ssh/ssh_test_utils.py @@ -47,24 +47,24 @@ def setup_ssh_connection(project, setup_ssh_key_pair=True): sys.stdin.isatty = lambda: True # Run setup - verified = ssh.verify_ssh_central_host( + verified = ssh.verify_ssh_central_host_api( project.cfg["central_host_id"], project.cfg.hostkeys_path, log=True ) if setup_ssh_key_pair: - ssh.setup_ssh_key(project.cfg, log=False) + private_key_str = ssh.setup_ssh_key_api(project.cfg, log=False) + + rclone.setup_rclone_config_for_ssh( + project.cfg, + project.cfg.get_rclone_config_name("ssh"), + private_key_str, + ) # Restore functions builtins.input = orig_builtin utils.get_connection_secret_from_user = orig_get_secret sys.stdin.isatty = orig_isatty - rclone.setup_rclone_config_for_ssh( - project.cfg, - project.cfg.get_rclone_config_name("ssh"), - project.cfg.ssh_key_path, - ) - return verified diff --git a/tests/tests_transfers/ssh/test_ssh_setup.py b/tests/tests_transfers/ssh/test_ssh_setup.py index 69e70e527..58c1459be 100644 --- a/tests/tests_transfers/ssh/test_ssh_setup.py +++ b/tests/tests_transfers/ssh/test_ssh_setup.py @@ -4,8 +4,6 @@ import pytest -from datashuttle.utils import ssh - from ... import test_utils from . import ssh_test_utils from .base_ssh import BaseSSHTransfer @@ -81,15 +79,3 @@ def test_verify_ssh_central_host_accept(self, capsys, project): assert ( f"[{project.cfg['central_host_id']}]:3306 ssh-ed25519 " in hostkey ) - - def test_generate_and_write_ssh_key(self, project): - """Check ssh key for passwordless connection is written - to file. - """ - path_to_save = project.cfg["local_path"] / "test" - ssh.generate_and_write_ssh_key(path_to_save) - - with open(path_to_save) as file: - first_line = file.readlines()[0] - - assert first_line == "-----BEGIN RSA PRIVATE KEY-----\n"