diff --git a/src/Ssh/Ssh/Common/SshBaseCmdlet.cs b/src/Ssh/Ssh/Common/SshBaseCmdlet.cs index 84c6901cbff8..ea909e0cda92 100644 --- a/src/Ssh/Ssh/Common/SshBaseCmdlet.cs +++ b/src/Ssh/Ssh/Common/SshBaseCmdlet.cs @@ -518,54 +518,6 @@ protected internal string GetClientApplicationPath(string command) return appInfo.Path; } - protected internal string GetClientSideProxy() - { - string proxyPath = null; - string oldProxyPattern = null; - string requestUrl = null; - - GetProxyUrlAndFilename(ref proxyPath, ref oldProxyPattern, ref requestUrl); - - if (!File.Exists(proxyPath)) - { - string proxyDir = Path.GetDirectoryName(proxyPath); - - if (!Directory.Exists(proxyDir)) - { - Directory.CreateDirectory(proxyDir); - } - else - { - var files = Directory.GetFiles(proxyDir, oldProxyPattern); - foreach (string file in files) - { - try - { - File.Delete(file); - } - catch (Exception exception) - { - WriteWarning(String.Format(Resources.FailedToDeleteOldProxy, file, exception.Message)); - } - } - } - - try - { - WebClient wc = new WebClient(); - wc.DownloadFile(new Uri(requestUrl), proxyPath); - } - catch (Exception exception) - { - string errorMessage = String.Format(Resources.FailedToDownloadProxy, requestUrl, exception.Message); - throw new AzPSApplicationException(errorMessage); - } - - ValidateSshProxy(proxyPath); - } - return proxyPath; - } - protected internal void DeleteFile(string fileName, string warningMessage = null) { if (File.Exists(fileName)) @@ -622,6 +574,35 @@ protected internal bool IsArc() return false; } + /// + /// Get the path of the required Ssh Proxy from the Az.Ssh.ArcProxy module + /// that should be installed, per pre-reqs. + /// + /// Path to Proxy Executable + /// + protected internal string GetInstalledProxyModulePath() + { + var results = InvokeCommand.InvokeScript( + script: "(Get-module -ListAvailable -Name Az.Ssh.ArcProxy).Path"); + + foreach (var result in results) + { + if (result?.BaseObject is string tempPath) + { + string proxyPath = GetProxyPath(tempPath); + + if (!File.Exists(proxyPath) || !ValidateSshProxy(proxyPath)) + { + continue; + } + + return proxyPath; + } + } + + throw new AzPSApplicationException("Unable to find a valid proxy"); + } + #endregion #region Private Methods @@ -807,10 +788,7 @@ private string CreateTempFolder() return dirname; } - private void GetProxyUrlAndFilename( - ref string proxyPath, - ref string oldProxyPattern, - ref string requestUrl) + private string GetProxyPath(string modulePath) { string os; string architecture; @@ -842,23 +820,19 @@ private void GetProxyUrlAndFilename( architecture = "386"; } - string proxyName = "sshProxy_" + os + "_" + architecture; - requestUrl = clientProxyStorageUrl + "/" + clientProxyRelease + "/" + proxyName + "_" + clientProxyVersion; - - string installPath = proxyName + "_" + clientProxyVersion.Replace('.', '_'); - oldProxyPattern = proxyName + "*"; + string proxyName = $"sshProxy_{os}_{architecture}_{clientProxyVersion.Replace('.', '_')}"; if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - requestUrl = requestUrl + ".exe"; - installPath = installPath + ".exe"; - oldProxyPattern = oldProxyPattern + ".exe"; + proxyName += ".exe"; } - proxyPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), Path.Combine(".clientsshproxy", installPath)); + string parentDirectory = Directory.GetParent(modulePath).FullName; + + return Path.Combine(parentDirectory, proxyName); } - private void ValidateSshProxy(string path) + private bool ValidateSshProxy(string path) { string hashString; @@ -892,13 +866,7 @@ private void ValidateSshProxy(string path) isValid = hashString.Equals(sshproxy_darwin_amd64_sha256_hash); break; } - - if (!isValid) - { - WriteWarning($"Validation of SSH Proxy {path} failed. Removing file from system."); - DeleteFile(path); - throw new AzPSApplicationException("Failed to download valid SSH Proxy. Unable to continue cmdlet execution."); - } + return isValid; } #endregion diff --git a/src/Ssh/Ssh/SshCommands/EnterAzVMCommand.cs b/src/Ssh/Ssh/SshCommands/EnterAzVMCommand.cs index 4e176167f1a9..1ebbc5997a4b 100644 --- a/src/Ssh/Ssh/SshCommands/EnterAzVMCommand.cs +++ b/src/Ssh/Ssh/SshCommands/EnterAzVMCommand.cs @@ -74,7 +74,7 @@ public override void ExecuteCmdlet() } if (IsArc()) { - proxyPath = GetClientSideProxy(); + proxyPath = GetInstalledProxyModulePath(); UpdateProgressBar(record, $"Dowloaded SSH Proxy, saved to {proxyPath}", 25); GetRelayInformation(); UpdateProgressBar(record, $"Retrieved Relay Information", 50); diff --git a/src/Ssh/Ssh/SshCommands/ExportAzSshConfig.cs b/src/Ssh/Ssh/SshCommands/ExportAzSshConfig.cs index 8dc2cee3dab9..018ab0a46b12 100644 --- a/src/Ssh/Ssh/SshCommands/ExportAzSshConfig.cs +++ b/src/Ssh/Ssh/SshCommands/ExportAzSshConfig.cs @@ -75,7 +75,7 @@ public override void ExecuteCmdlet() } if (IsArc()) { - proxyPath = GetClientSideProxy(); + proxyPath = GetInstalledProxyModulePath(); UpdateProgressBar(record, $"Downloaded proxy to {proxyPath}", 25); GetRelayInformation(); UpdateProgressBar(record, "Retrieved relay information", 50);