diff --git a/src/Ssh/Ssh/Common/SshBaseCmdlet.cs b/src/Ssh/Ssh/Common/SshBaseCmdlet.cs
index 84c6901cbff8..ea909e0cda92 100644
--- a/src/Ssh/Ssh/Common/SshBaseCmdlet.cs
+++ b/src/Ssh/Ssh/Common/SshBaseCmdlet.cs
@@ -518,54 +518,6 @@ protected internal string GetClientApplicationPath(string command)
return appInfo.Path;
}
- protected internal string GetClientSideProxy()
- {
- string proxyPath = null;
- string oldProxyPattern = null;
- string requestUrl = null;
-
- GetProxyUrlAndFilename(ref proxyPath, ref oldProxyPattern, ref requestUrl);
-
- if (!File.Exists(proxyPath))
- {
- string proxyDir = Path.GetDirectoryName(proxyPath);
-
- if (!Directory.Exists(proxyDir))
- {
- Directory.CreateDirectory(proxyDir);
- }
- else
- {
- var files = Directory.GetFiles(proxyDir, oldProxyPattern);
- foreach (string file in files)
- {
- try
- {
- File.Delete(file);
- }
- catch (Exception exception)
- {
- WriteWarning(String.Format(Resources.FailedToDeleteOldProxy, file, exception.Message));
- }
- }
- }
-
- try
- {
- WebClient wc = new WebClient();
- wc.DownloadFile(new Uri(requestUrl), proxyPath);
- }
- catch (Exception exception)
- {
- string errorMessage = String.Format(Resources.FailedToDownloadProxy, requestUrl, exception.Message);
- throw new AzPSApplicationException(errorMessage);
- }
-
- ValidateSshProxy(proxyPath);
- }
- return proxyPath;
- }
-
protected internal void DeleteFile(string fileName, string warningMessage = null)
{
if (File.Exists(fileName))
@@ -622,6 +574,35 @@ protected internal bool IsArc()
return false;
}
+ ///
+ /// Get the path of the required Ssh Proxy from the Az.Ssh.ArcProxy module
+ /// that should be installed, per pre-reqs.
+ ///
+ /// Path to Proxy Executable
+ ///
+ protected internal string GetInstalledProxyModulePath()
+ {
+ var results = InvokeCommand.InvokeScript(
+ script: "(Get-module -ListAvailable -Name Az.Ssh.ArcProxy).Path");
+
+ foreach (var result in results)
+ {
+ if (result?.BaseObject is string tempPath)
+ {
+ string proxyPath = GetProxyPath(tempPath);
+
+ if (!File.Exists(proxyPath) || !ValidateSshProxy(proxyPath))
+ {
+ continue;
+ }
+
+ return proxyPath;
+ }
+ }
+
+ throw new AzPSApplicationException("Unable to find a valid proxy");
+ }
+
#endregion
#region Private Methods
@@ -807,10 +788,7 @@ private string CreateTempFolder()
return dirname;
}
- private void GetProxyUrlAndFilename(
- ref string proxyPath,
- ref string oldProxyPattern,
- ref string requestUrl)
+ private string GetProxyPath(string modulePath)
{
string os;
string architecture;
@@ -842,23 +820,19 @@ private void GetProxyUrlAndFilename(
architecture = "386";
}
- string proxyName = "sshProxy_" + os + "_" + architecture;
- requestUrl = clientProxyStorageUrl + "/" + clientProxyRelease + "/" + proxyName + "_" + clientProxyVersion;
-
- string installPath = proxyName + "_" + clientProxyVersion.Replace('.', '_');
- oldProxyPattern = proxyName + "*";
+ string proxyName = $"sshProxy_{os}_{architecture}_{clientProxyVersion.Replace('.', '_')}";
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
- requestUrl = requestUrl + ".exe";
- installPath = installPath + ".exe";
- oldProxyPattern = oldProxyPattern + ".exe";
+ proxyName += ".exe";
}
- proxyPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), Path.Combine(".clientsshproxy", installPath));
+ string parentDirectory = Directory.GetParent(modulePath).FullName;
+
+ return Path.Combine(parentDirectory, proxyName);
}
- private void ValidateSshProxy(string path)
+ private bool ValidateSshProxy(string path)
{
string hashString;
@@ -892,13 +866,7 @@ private void ValidateSshProxy(string path)
isValid = hashString.Equals(sshproxy_darwin_amd64_sha256_hash);
break;
}
-
- if (!isValid)
- {
- WriteWarning($"Validation of SSH Proxy {path} failed. Removing file from system.");
- DeleteFile(path);
- throw new AzPSApplicationException("Failed to download valid SSH Proxy. Unable to continue cmdlet execution.");
- }
+ return isValid;
}
#endregion
diff --git a/src/Ssh/Ssh/SshCommands/EnterAzVMCommand.cs b/src/Ssh/Ssh/SshCommands/EnterAzVMCommand.cs
index 4e176167f1a9..1ebbc5997a4b 100644
--- a/src/Ssh/Ssh/SshCommands/EnterAzVMCommand.cs
+++ b/src/Ssh/Ssh/SshCommands/EnterAzVMCommand.cs
@@ -74,7 +74,7 @@ public override void ExecuteCmdlet()
}
if (IsArc())
{
- proxyPath = GetClientSideProxy();
+ proxyPath = GetInstalledProxyModulePath();
UpdateProgressBar(record, $"Dowloaded SSH Proxy, saved to {proxyPath}", 25);
GetRelayInformation();
UpdateProgressBar(record, $"Retrieved Relay Information", 50);
diff --git a/src/Ssh/Ssh/SshCommands/ExportAzSshConfig.cs b/src/Ssh/Ssh/SshCommands/ExportAzSshConfig.cs
index 8dc2cee3dab9..018ab0a46b12 100644
--- a/src/Ssh/Ssh/SshCommands/ExportAzSshConfig.cs
+++ b/src/Ssh/Ssh/SshCommands/ExportAzSshConfig.cs
@@ -75,7 +75,7 @@ public override void ExecuteCmdlet()
}
if (IsArc())
{
- proxyPath = GetClientSideProxy();
+ proxyPath = GetInstalledProxyModulePath();
UpdateProgressBar(record, $"Downloaded proxy to {proxyPath}", 25);
GetRelayInformation();
UpdateProgressBar(record, "Retrieved relay information", 50);