Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,13 @@ def send_pending_requests_v2(self):
total_bytes = self._send_bytes(self._send_buffer)
self._send_buffer = self._send_buffer[total_bytes:]

# If all data was sent, we need to get the new data from the protocol now, otherwise
# this function would return True, indicating that there are no more pending
# requests. This could cause the calling thread to wait indefinitely as it won't
# know that there is still buffered data to send.
if not self._send_buffer:
self._send_buffer = self._protocol.send_bytes()

if self._sensors:
self._sensors.bytes_sent.record(total_bytes)
# Return True iff send buffer is empty
Expand Down
52 changes: 52 additions & 0 deletions test/test_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

from kafka.conn import BrokerConnection, ConnectionStates
from kafka.future import Future
from kafka.conn import BrokerConnection, ConnectionStates, SSLWantWriteError
from kafka.metrics.metrics import Metrics
from kafka.metrics.stats.sensor import Sensor
from kafka.protocol.api import RequestHeader
from kafka.protocol.group import HeartbeatResponse
from kafka.protocol.metadata import MetadataRequest
Expand Down Expand Up @@ -43,6 +46,15 @@ def _socket(mocker):
mocker.patch('socket.socket', return_value=socket)
return socket

def metrics(mocker):
metrics = mocker.MagicMock(Metrics)
metrics.mocked_sensors = {}
def sensor(name, **kwargs):
if name not in metrics.mocked_sensors:
metrics.mocked_sensors[name] = mocker.MagicMock(Sensor)
return metrics.mocked_sensors[name]
metrics.sensor.side_effect = sensor
return metrics

@pytest.fixture
def conn(_socket, dns_lookup, mocker):
Expand Down Expand Up @@ -228,6 +240,46 @@ def test_send_response(_socket, conn):
assert len(conn.in_flight_requests) == 1


def test_send_async_request_while_other_request_is_already_in_buffer(_socket, conn, metrics):
conn.connect()
assert conn.state is ConnectionStates.CONNECTED
assert 'node-0.bytes-sent' in metrics.mocked_sensors
bytes_sent_sensor = metrics.mocked_sensors['node-0.bytes-sent']

req1 = MetadataRequest[0](topics='foo')
header1 = RequestHeader(req1, client_id=conn.config['client_id'])
payload_bytes1 = len(header1.encode()) + len(req1.encode())
req2 = MetadataRequest[0]([])
header2 = RequestHeader(req2, client_id=conn.config['client_id'])
payload_bytes2 = len(header2.encode()) + len(req2.encode())

# The first call to the socket will raise a transient SSL exception. This will make the first
# request to be kept in the internal buffer to be sent in the next call of
# send_pending_requests_v2.
_socket.send.side_effect = [SSLWantWriteError, 4 + payload_bytes1, 4 + payload_bytes2]

conn.send(req1, blocking=False)
# This won't send any bytes because of the SSL exception and the request bytes will be kept in
# the buffer.
assert conn.send_pending_requests_v2() is False
assert bytes_sent_sensor.record.call_args_list[0].args == (0,)

conn.send(req2, blocking=False)
# This will send the remaining bytes in the buffer from the first request, but should notice
# that the second request was queued, therefore it should return False.
bytes_sent_sensor.record.reset_mock()
assert conn.send_pending_requests_v2() is False
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where the test would fail before the fix.

bytes_sent_sensor.record.assert_called_once_with(4 + payload_bytes1)

bytes_sent_sensor.record.reset_mock()
assert conn.send_pending_requests_v2() is True
bytes_sent_sensor.record.assert_called_once_with(4 + payload_bytes2)

bytes_sent_sensor.record.reset_mock()
assert conn.send_pending_requests_v2() is True
bytes_sent_sensor.record.assert_called_once_with(0)


def test_send_error(_socket, conn):
conn.connect()
assert conn.state is ConnectionStates.CONNECTED
Expand Down