diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml index ebc9ca45..a333ff59 100644 --- a/charts/model-engine/templates/service_template_config_map.yaml +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -126,6 +126,8 @@ data: - "${FORWARDER_PORT}" - --num-workers - "${FORWARDER_WORKER_COUNT}" + - --max-concurrency + - "${CONCURRENT_REQUESTS_PER_WORKER}" - --set - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set @@ -172,6 +174,8 @@ data: - "${FORWARDER_PORT}" - --num-workers - "${FORWARDER_WORKER_COUNT}" + - --max-concurrency + - "${CONCURRENT_REQUESTS_PER_WORKER}" - --set - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set @@ -558,6 +562,8 @@ data: - "${FORWARDER_PORT}" - --num-workers - "${FORWARDER_WORKER_COUNT}" + - --max-concurrency + - "${CONCURRENT_REQUESTS_PER_WORKER}" - --set - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set @@ -604,6 +610,8 @@ data: - "${FORWARDER_PORT}" - --num-workers - "${FORWARDER_WORKER_COUNT}" + - --max-concurrency + - "${CONCURRENT_REQUESTS_PER_WORKER}" - --set - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set diff --git a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index ea846643..d321249c 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -129,6 +129,7 @@ def validate_deployment_resources( def validate_concurrent_requests_per_worker( concurrent_requests_per_worker: Optional[int], endpoint_type: ModelEndpointType, + per_worker: Optional[int] = None, ): if ( endpoint_type == ModelEndpointType.ASYNC @@ -138,6 +139,18 @@ def validate_concurrent_requests_per_worker( raise EndpointResourceInvalidRequestException( f"Requested concurrent requests per worker {concurrent_requests_per_worker} too high" ) + + # Validation for sync/streaming endpoints to prevent autoscaling issues + if ( + endpoint_type in {ModelEndpointType.SYNC, ModelEndpointType.STREAMING} + and concurrent_requests_per_worker is not None + and per_worker is not None + and per_worker >= concurrent_requests_per_worker / 2 + ): + raise EndpointResourceInvalidRequestException( + f"For sync/streaming endpoints, per_worker ({per_worker}) must be less than " + f"concurrent_requests_per_worker/2 ({concurrent_requests_per_worker/2}) to prevent autoscaling issues" + ) @dataclass @@ -299,7 +312,7 @@ async def execute( if concurrent_requests_per_worker is None: concurrent_requests_per_worker = min(request.per_worker, MAX_ASYNC_CONCURRENT_TASKS) validate_concurrent_requests_per_worker( - concurrent_requests_per_worker, request.endpoint_type + concurrent_requests_per_worker, request.endpoint_type, request.per_worker ) if request.labels is None: @@ -488,7 +501,7 @@ async def execute( ) validate_concurrent_requests_per_worker( - request.concurrent_requests_per_worker, endpoint_record.endpoint_type + request.concurrent_requests_per_worker, endpoint_record.endpoint_type, request.per_worker ) if request.metadata is not None and CONVERTED_FROM_ARTIFACT_LIKE_KEY in request.metadata: diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py index 89fcb3fb..e0ffbaa3 100644 --- a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -58,8 +58,15 @@ def get_streaming_forwarder_loader( @lru_cache() def get_concurrency_limiter() -> MultiprocessingConcurrencyLimiter: - config = get_config() - concurrency = int(config.get("max_concurrency", 100)) + # Check if concurrency is set via command line (environment variable) + concurrency_from_env = os.environ.get("FORWARDER_MAX_CONCURRENCY") + if concurrency_from_env: + concurrency = int(concurrency_from_env) + else: + # Fall back to config file + config = get_config() + concurrency = int(config.get("max_concurrency", 100)) + return MultiprocessingConcurrencyLimiter( concurrency=concurrency, fail_on_concurrency_limit=True ) @@ -241,12 +248,18 @@ def entrypoint(): # pragma: no cover parser.add_argument("--port", type=int, default=5000) parser.add_argument("--set", type=str, action="append") parser.add_argument("--graceful-timeout", type=int, default=600) + parser.add_argument("--max-concurrency", type=int, default=None, + help="Maximum concurrent requests per worker") args, extra_args = parser.parse_known_args() os.environ["CONFIG_FILE"] = args.config if args.set is not None: os.environ["CONFIG_OVERRIDES"] = ";".join(args.set) + + # Set concurrency in environment for get_concurrency_limiter to use + if args.max_concurrency is not None: + os.environ["FORWARDER_MAX_CONCURRENCY"] = str(args.max_concurrency) asyncio.run( run_server( diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 45ab0d73..8e1e7a36 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -261,6 +261,24 @@ def maybe_get_forwarder_container_from_deployment_config(deployment_config): return None # Don't expect to get here since everything should have a forwarder +def get_concurrent_requests_per_worker_from_sync_deployment(deployment_config): + """Extract concurrent_requests_per_worker from sync/streaming endpoint HTTP forwarder.""" + forwarder_container = maybe_get_forwarder_container_from_deployment_config(deployment_config) + if forwarder_container is None or forwarder_container.name != "http-forwarder": + return None + + command = forwarder_container.command + if command is None: + return None + + # Look for --max-concurrency argument in the command + try: + concurrency_index = command.index("--max-concurrency") + return int(command[concurrency_index + 1]) + except (ValueError, IndexError): + return None + + def get_leader_container_from_lws_template(lws_template: Dict[str, Any]): containers = lws_template["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"][ "containers" @@ -1913,6 +1931,7 @@ def _get_async_autoscaling_params( @staticmethod def _get_sync_autoscaling_params( hpa_config: V2beta2HorizontalPodAutoscaler, + concurrent_requests_per_worker: Optional[int] = None, ) -> HorizontalAutoscalingEndpointParams: spec = hpa_config.spec per_worker = get_per_worker_value_from_target_concurrency( @@ -1922,12 +1941,13 @@ def _get_sync_autoscaling_params( max_workers=spec.max_replicas, min_workers=spec.min_replicas, per_worker=per_worker, - concurrent_requests_per_worker=FAKE_SYNC_CONCURRENT_REQUESTS_PER_WORKER, + concurrent_requests_per_worker=concurrent_requests_per_worker or FAKE_SYNC_CONCURRENT_REQUESTS_PER_WORKER, ) @staticmethod def _get_sync_autoscaling_params_from_keda( keda_config, + concurrent_requests_per_worker: Optional[int] = None, ) -> HorizontalAutoscalingEndpointParams: spec = keda_config["spec"] concurrency = 1 @@ -1941,7 +1961,7 @@ def _get_sync_autoscaling_params_from_keda( max_workers=spec.get("maxReplicaCount"), min_workers=spec.get("minReplicaCount"), per_worker=concurrency, - concurrent_requests_per_worker=FAKE_SYNC_CONCURRENT_REQUESTS_PER_WORKER, + concurrent_requests_per_worker=concurrent_requests_per_worker or FAKE_SYNC_CONCURRENT_REQUESTS_PER_WORKER, ) async def _get_resources( @@ -2021,11 +2041,18 @@ async def _get_resources_from_deployment_type( ) except ApiException: keda_scaled_object_config = None + # Extract concurrent_requests_per_worker from the HTTP forwarder container + concurrent_requests_per_worker = get_concurrent_requests_per_worker_from_sync_deployment( + deployment_config + ) + if hpa_config is not None: - horizontal_autoscaling_params = self._get_sync_autoscaling_params(hpa_config) + horizontal_autoscaling_params = self._get_sync_autoscaling_params( + hpa_config, concurrent_requests_per_worker + ) elif keda_scaled_object_config is not None: horizontal_autoscaling_params = self._get_sync_autoscaling_params_from_keda( - keda_scaled_object_config + keda_scaled_object_config, concurrent_requests_per_worker ) else: raise EndpointResourceInfraException( @@ -2245,11 +2272,19 @@ async def _get_all_resources( # TODO I think this is correct but only barely, it introduces a coupling between # an HPA (or keda SO) existing and an endpoint being a sync endpoint. The "more correct" # thing to do is to query the db to get the endpoints, but it doesn't belong here - horizontal_autoscaling_params = self._get_sync_autoscaling_params(hpa_config) + concurrent_requests_per_worker = get_concurrent_requests_per_worker_from_sync_deployment( + deployment_config + ) + horizontal_autoscaling_params = self._get_sync_autoscaling_params( + hpa_config, concurrent_requests_per_worker + ) elif keda_scaled_object_config: # Also assume it's a sync endpoint + concurrent_requests_per_worker = get_concurrent_requests_per_worker_from_sync_deployment( + deployment_config + ) horizontal_autoscaling_params = self._get_sync_autoscaling_params_from_keda( - keda_scaled_object_config + keda_scaled_object_config, concurrent_requests_per_worker ) else: horizontal_autoscaling_params = self._get_async_autoscaling_params(