Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ data:
- "${FORWARDER_PORT}"
- --num-workers
- "${FORWARDER_WORKER_COUNT}"
- --max-concurrency
- "${CONCURRENT_REQUESTS_PER_WORKER}"
- --set
- "forwarder.sync.predict_route=${PREDICT_ROUTE}"
- --set
Expand Down Expand Up @@ -172,6 +174,8 @@ data:
- "${FORWARDER_PORT}"
- --num-workers
- "${FORWARDER_WORKER_COUNT}"
- --max-concurrency
- "${CONCURRENT_REQUESTS_PER_WORKER}"
- --set
- "forwarder.sync.predict_route=${PREDICT_ROUTE}"
- --set
Expand Down Expand Up @@ -558,6 +562,8 @@ data:
- "${FORWARDER_PORT}"
- --num-workers
- "${FORWARDER_WORKER_COUNT}"
- --max-concurrency
- "${CONCURRENT_REQUESTS_PER_WORKER}"
- --set
- "forwarder.sync.predict_route=${PREDICT_ROUTE}"
- --set
Expand Down Expand Up @@ -604,6 +610,8 @@ data:
- "${FORWARDER_PORT}"
- --num-workers
- "${FORWARDER_WORKER_COUNT}"
- --max-concurrency
- "${CONCURRENT_REQUESTS_PER_WORKER}"
- --set
- "forwarder.sync.predict_route=${PREDICT_ROUTE}"
- --set
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def validate_deployment_resources(
def validate_concurrent_requests_per_worker(
concurrent_requests_per_worker: Optional[int],
endpoint_type: ModelEndpointType,
per_worker: Optional[int] = None,
):
if (
endpoint_type == ModelEndpointType.ASYNC
Expand All @@ -138,6 +139,18 @@ def validate_concurrent_requests_per_worker(
raise EndpointResourceInvalidRequestException(
f"Requested concurrent requests per worker {concurrent_requests_per_worker} too high"
)

# Validation for sync/streaming endpoints to prevent autoscaling issues
if (
endpoint_type in {ModelEndpointType.SYNC, ModelEndpointType.STREAMING}
and concurrent_requests_per_worker is not None
and per_worker is not None
and per_worker >= concurrent_requests_per_worker / 2
):
raise EndpointResourceInvalidRequestException(
f"For sync/streaming endpoints, per_worker ({per_worker}) must be less than "
f"concurrent_requests_per_worker/2 ({concurrent_requests_per_worker/2}) to prevent autoscaling issues"
)


@dataclass
Expand Down Expand Up @@ -299,7 +312,7 @@ async def execute(
if concurrent_requests_per_worker is None:
concurrent_requests_per_worker = min(request.per_worker, MAX_ASYNC_CONCURRENT_TASKS)
validate_concurrent_requests_per_worker(
concurrent_requests_per_worker, request.endpoint_type
concurrent_requests_per_worker, request.endpoint_type, request.per_worker
)

if request.labels is None:
Expand Down Expand Up @@ -488,7 +501,7 @@ async def execute(
)

validate_concurrent_requests_per_worker(
request.concurrent_requests_per_worker, endpoint_record.endpoint_type
request.concurrent_requests_per_worker, endpoint_record.endpoint_type, request.per_worker
)

if request.metadata is not None and CONVERTED_FROM_ARTIFACT_LIKE_KEY in request.metadata:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,15 @@ def get_streaming_forwarder_loader(

@lru_cache()
def get_concurrency_limiter() -> MultiprocessingConcurrencyLimiter:
config = get_config()
concurrency = int(config.get("max_concurrency", 100))
# Check if concurrency is set via command line (environment variable)
concurrency_from_env = os.environ.get("FORWARDER_MAX_CONCURRENCY")
if concurrency_from_env:
concurrency = int(concurrency_from_env)
else:
# Fall back to config file
config = get_config()
concurrency = int(config.get("max_concurrency", 100))

return MultiprocessingConcurrencyLimiter(
concurrency=concurrency, fail_on_concurrency_limit=True
)
Expand Down Expand Up @@ -241,12 +248,18 @@ def entrypoint(): # pragma: no cover
parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--set", type=str, action="append")
parser.add_argument("--graceful-timeout", type=int, default=600)
parser.add_argument("--max-concurrency", type=int, default=None,
help="Maximum concurrent requests per worker")

args, extra_args = parser.parse_known_args()

os.environ["CONFIG_FILE"] = args.config
if args.set is not None:
os.environ["CONFIG_OVERRIDES"] = ";".join(args.set)

# Set concurrency in environment for get_concurrency_limiter to use
if args.max_concurrency is not None:
os.environ["FORWARDER_MAX_CONCURRENCY"] = str(args.max_concurrency)

asyncio.run(
run_server(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,24 @@ def maybe_get_forwarder_container_from_deployment_config(deployment_config):
return None # Don't expect to get here since everything should have a forwarder


def get_concurrent_requests_per_worker_from_sync_deployment(deployment_config):
"""Extract concurrent_requests_per_worker from sync/streaming endpoint HTTP forwarder."""
forwarder_container = maybe_get_forwarder_container_from_deployment_config(deployment_config)
if forwarder_container is None or forwarder_container.name != "http-forwarder":
return None

command = forwarder_container.command
if command is None:
return None

# Look for --max-concurrency argument in the command
try:
concurrency_index = command.index("--max-concurrency")
return int(command[concurrency_index + 1])
except (ValueError, IndexError):
return None


def get_leader_container_from_lws_template(lws_template: Dict[str, Any]):
containers = lws_template["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"][
"containers"
Expand Down Expand Up @@ -1913,6 +1931,7 @@ def _get_async_autoscaling_params(
@staticmethod
def _get_sync_autoscaling_params(
hpa_config: V2beta2HorizontalPodAutoscaler,
concurrent_requests_per_worker: Optional[int] = None,
) -> HorizontalAutoscalingEndpointParams:
spec = hpa_config.spec
per_worker = get_per_worker_value_from_target_concurrency(
Expand All @@ -1922,12 +1941,13 @@ def _get_sync_autoscaling_params(
max_workers=spec.max_replicas,
min_workers=spec.min_replicas,
per_worker=per_worker,
concurrent_requests_per_worker=FAKE_SYNC_CONCURRENT_REQUESTS_PER_WORKER,
concurrent_requests_per_worker=concurrent_requests_per_worker or FAKE_SYNC_CONCURRENT_REQUESTS_PER_WORKER,
)

@staticmethod
def _get_sync_autoscaling_params_from_keda(
keda_config,
concurrent_requests_per_worker: Optional[int] = None,
) -> HorizontalAutoscalingEndpointParams:
spec = keda_config["spec"]
concurrency = 1
Expand All @@ -1941,7 +1961,7 @@ def _get_sync_autoscaling_params_from_keda(
max_workers=spec.get("maxReplicaCount"),
min_workers=spec.get("minReplicaCount"),
per_worker=concurrency,
concurrent_requests_per_worker=FAKE_SYNC_CONCURRENT_REQUESTS_PER_WORKER,
concurrent_requests_per_worker=concurrent_requests_per_worker or FAKE_SYNC_CONCURRENT_REQUESTS_PER_WORKER,
)

async def _get_resources(
Expand Down Expand Up @@ -2021,11 +2041,18 @@ async def _get_resources_from_deployment_type(
)
except ApiException:
keda_scaled_object_config = None
# Extract concurrent_requests_per_worker from the HTTP forwarder container
concurrent_requests_per_worker = get_concurrent_requests_per_worker_from_sync_deployment(
deployment_config
)

if hpa_config is not None:
horizontal_autoscaling_params = self._get_sync_autoscaling_params(hpa_config)
horizontal_autoscaling_params = self._get_sync_autoscaling_params(
hpa_config, concurrent_requests_per_worker
)
elif keda_scaled_object_config is not None:
horizontal_autoscaling_params = self._get_sync_autoscaling_params_from_keda(
keda_scaled_object_config
keda_scaled_object_config, concurrent_requests_per_worker
)
else:
raise EndpointResourceInfraException(
Expand Down Expand Up @@ -2245,11 +2272,19 @@ async def _get_all_resources(
# TODO I think this is correct but only barely, it introduces a coupling between
# an HPA (or keda SO) existing and an endpoint being a sync endpoint. The "more correct"
# thing to do is to query the db to get the endpoints, but it doesn't belong here
horizontal_autoscaling_params = self._get_sync_autoscaling_params(hpa_config)
concurrent_requests_per_worker = get_concurrent_requests_per_worker_from_sync_deployment(
deployment_config
)
horizontal_autoscaling_params = self._get_sync_autoscaling_params(
hpa_config, concurrent_requests_per_worker
)
elif keda_scaled_object_config:
# Also assume it's a sync endpoint
concurrent_requests_per_worker = get_concurrent_requests_per_worker_from_sync_deployment(
deployment_config
)
horizontal_autoscaling_params = self._get_sync_autoscaling_params_from_keda(
keda_scaled_object_config
keda_scaled_object_config, concurrent_requests_per_worker
)
else:
horizontal_autoscaling_params = self._get_async_autoscaling_params(
Expand Down