Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/changelog-fragments/756.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Added a ``pylibsshext.session.connect()`` parameter ``open_session_retries`` -- by :user:`justin-stephenson`.

Added a ``pylibsshext.session.connect()`` parameter ``timeout_usec`` to set SSH_OPTIONS_TIMEOUT_USEC -- by :user:`justin-stephenson`.
36 changes: 26 additions & 10 deletions src/pylibsshext/channel.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ from libc.string cimport memset

from pylibsshext.errors cimport LibsshChannelException
from pylibsshext.errors import LibsshChannelReadFailure
from pylibsshext.session cimport get_libssh_session
from pylibsshext.session cimport get_libssh_session, get_session_retries

from subprocess import CompletedProcess

Expand Down Expand Up @@ -63,12 +63,20 @@ cdef class Channel:

if self._libssh_channel is NULL:
raise MemoryError
rc = libssh.ssh_channel_open_session(self._libssh_channel)

if rc != libssh.SSH_OK:
libssh.ssh_channel_free(self._libssh_channel)
self._libssh_channel = NULL
raise LibsshChannelException("Failed to open_session: [%d]" % rc)
retry = get_session_retries(session)

for attempt in range(retry + 1):
rc = libssh.ssh_channel_open_session(self._libssh_channel)
if rc == libssh.SSH_OK:
break
if rc == libssh.SSH_AGAIN and attempt < retry:
continue
# either SSH_ERROR, or SSH_AGAIN with final attempt
if rc != libssh.SSH_OK:
libssh.ssh_channel_free(self._libssh_channel)
self._libssh_channel = NULL
raise LibsshChannelException("Failed to open_session: [%d]" % rc)

def __dealloc__(self):
if self._libssh_channel is not NULL:
Expand Down Expand Up @@ -164,10 +172,18 @@ cdef class Channel:
if channel is NULL:
raise MemoryError

rc = libssh.ssh_channel_open_session(channel)
if rc != libssh.SSH_OK:
libssh.ssh_channel_free(channel)
raise LibsshChannelException("Failed to open_session: [{0}]".format(rc))
retry = get_session_retries(self._session)

for attempt in range(retry + 1):
rc = libssh.ssh_channel_open_session(channel)
if rc == libssh.SSH_OK:
break
if rc == libssh.SSH_AGAIN and attempt < retry:
continue
# either SSH_ERROR, or SSH_AGAIN with final attempt
if rc != libssh.SSH_OK:
libssh.ssh_channel_free(channel)
raise LibsshChannelException("Failed to open_session: [{0}]".format(rc))

result = CompletedProcess(args=command, returncode=-1, stdout=b'', stderr=b'')

Expand Down
2 changes: 2 additions & 0 deletions src/pylibsshext/session.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ cdef class Session:
cdef _hash_py
cdef _fingerprint_py
cdef _keytype_py
cdef _retries
cdef _channel_callbacks

cdef libssh.ssh_session get_libssh_session(Session session)
cdef int get_session_retries(Session session)
18 changes: 17 additions & 1 deletion src/pylibsshext/session.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ OPTS_MAP = {
"user": libssh.SSH_OPTIONS_USER,
"port": libssh.SSH_OPTIONS_PORT,
"timeout": libssh.SSH_OPTIONS_TIMEOUT,
"timeout_usec": libssh.SSH_OPTIONS_TIMEOUT_USEC,
"knownhosts": libssh.SSH_OPTIONS_KNOWNHOSTS,
"proxycommand": libssh.SSH_OPTIONS_PROXYCOMMAND,
"key_exchange_algorithms": libssh.SSH_OPTIONS_KEY_EXCHANGE,
Expand Down Expand Up @@ -108,6 +109,7 @@ cdef class Session(object):
self._hash_py = None
self._fingerprint_py = None
self._keytype_py = None
self._retries = 0
# Due to delayed freeing of channels, some older libssh versions might expect
# the callbacks to be around even after we free the underlying channels so
# we should free them only when we terminate the session.
Expand Down Expand Up @@ -175,7 +177,7 @@ cdef class Session(object):
elif key == "port":
value_uint = value
libssh.ssh_options_set(self._libssh_session, key_m, &value_uint)
elif key == "timeout":
elif key == "timeout" or key == "timeout_usec":
value_long = value
libssh.ssh_options_set(self._libssh_session, key_m, &value_long)
else:
Expand Down Expand Up @@ -235,9 +237,17 @@ cdef class Session(object):
file should be validated. It defaults to True
:type host_key_checking: boolean

:param open_session_retries: The number of retries to attempt when libssh
channel function ssh_channel_open_session() returns SSH_AGAIN. It defaults
to 0, no retries attempted.
:type open_session_retries: integer

:param timeout: The timeout in seconds for the TCP connect
:type timeout: long integer

:param timeout_usec: The timeout in microseconds for the TCP connect
:type timeout_usec: long integer

:param port: The ssh server port to connect to
:type port: integer

Expand All @@ -261,6 +271,9 @@ cdef class Session(object):
libssh.ssh_disconnect(self._libssh_session)
raise

if kwargs.get('open_session_retries'):
self._retries = kwargs.get('open_session_retries')

# We need to userauth_none before we can query the available auth types
rc = libssh.ssh_userauth_none(self._libssh_session, NULL)
if rc == libssh.SSH_AUTH_SUCCESS:
Expand Down Expand Up @@ -553,3 +566,6 @@ cdef class Session(object):

cdef libssh.ssh_session get_libssh_session(Session session):
return session._libssh_session

cdef int get_session_retries(Session session):
return session._retries
2 changes: 2 additions & 0 deletions tests/_service_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def wait_for_svc_ready_state(

def ensure_ssh_session_connected( # noqa: WPS317
ssh_session, sshd_addr, ssh_clientkey_path, # noqa: WPS318
ssh_session_retries=0,
):
"""Attempt connecting to the SSH server until successful.

Expand All @@ -89,4 +90,5 @@ def ensure_ssh_session_connected( # noqa: WPS317
private_key=ssh_clientkey_path.read_bytes(),
host_key_checking=False,
look_for_keys=False,
open_session_retries=ssh_session_retries,
)
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,25 @@ def ssh_client_session(ssh_session_connect):
del ssh_session # noqa: WPS420


@pytest.fixture
def ssh_session_connect_retries(sshd_addr, ssh_clientkey_path):
"""
Authenticate existing session object against SSHD with a private SSH key.

This sets ssh_session_retries parameter to 10 and it returns a function
that takes session as parameter.

:returns: Function that will connect the session.
:rtype: Callback
"""
return partial(
ensure_ssh_session_connected,
sshd_addr=sshd_addr,
ssh_clientkey_path=ssh_clientkey_path,
ssh_session_retries=10,
)


@pytest.fixture
def ssh_session_connect(sshd_addr, ssh_clientkey_path):
"""
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/channel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import pytest

from pylibsshext.errors import LibsshChannelException
from pylibsshext.session import Session


Expand All @@ -33,6 +34,35 @@ def ssh_channel(ssh_client_session):
chan.close()


def test_open_session_timeout(ssh_session_connect):
"""Test opening a new channel with a low timeout value.

This generates an exception from ssh_channel_open_session()
returning SSH_AGAIN with the usec timeout and default
open_session_retries value of 0.
"""
ssh_session = Session()
ssh_session_connect(ssh_session)
ssh_session.set_ssh_options('timeout_usec', 10000)
error_msg = '^Failed to open_session'
with pytest.raises(LibsshChannelException, match=error_msg):
ssh_channel = ssh_session.new_channel()
ssh_channel.close()


def test_open_session_with_retries(ssh_session_connect_retries):
"""Test with a low timeout value and retries set.

This sets 'open_session_retries=10' and with the retries
ssh_channel_open_session() will succeed.
"""
ssh_session = Session()
ssh_session_connect_retries(ssh_session)
ssh_session.set_ssh_options('timeout_usec', 10000)
ssh_channel = ssh_session.new_channel()
ssh_channel.close()


def exec_second_command(ssh_channel):
"""Check the standard output of ``exec_command()`` as a string."""
u_cmd = ssh_channel.exec_command('echo -n Hello Again')
Expand Down
Loading