Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model-engine/model_engine_server/api/v2/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ async def chat_completion(
)
else:
logger.info(
f"POST /v2/chat/completion ({('stream' if request.stream else 'sync')}) with {request} to endpoint {model_endpoint_name} for {auth}"
f"POST /v2/chat/completion ({('stream' if request.stream else 'sync')}) with request {request} to endpoint {model_endpoint_name} for {auth}"
)

if request.stream:
Expand Down
2 changes: 1 addition & 1 deletion model-engine/model_engine_server/api/v2/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ async def completion(
)
else:
logger.info(
f"POST /v2/completion ({('stream' if request.stream else 'sync')}) with {request} to endpoint {model_endpoint_name} for {auth}"
f"POST /v2/completion ({('stream' if request.stream else 'sync')}) with request {request} to endpoint {model_endpoint_name} for {auth}"
)

if request.stream:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@
}


NUM_DOWNSTREAM_REQUEST_RETRIES = 80 # has to be high enough so that the retries take the 5 minutes
DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes
NUM_DOWNSTREAM_REQUEST_RETRIES = 80 * 12 # has to be high enough so that the retries take the 5 minutes
DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 60 * 60 # 5 minutes

DEFAULT_BATCH_COMPLETIONS_NODES_PER_WORKER = 1

Expand Down Expand Up @@ -377,6 +377,87 @@ def check_docker_image_exists_for_image_tag(
tag=framework_image_tag,
)

async def create_sglang_multinode_bundle(
self,
user: User,
model_name: str,
framework_image_tag: str,
endpoint_unique_name: str,
num_shards: int,
nodes_per_worker: int,
quantize: Optional[Quantization],
checkpoint_path: Optional[str],
chat_template_override: Optional[str],
additional_args: Optional[SGLangEndpointAdditionalArgs] = None,
):
leader_command = [
"python3",
"/root/sglang-startup-script.py",
"--model",
"deepseek-ai/DeepSeek-R1-0528",
"--nnodes",
"2",
"--node-rank",
"0",
"--worker-port",
"5005",
"--leader-port",
"5002",
]

worker_command = [
"python3",
"/root/sglang-startup-script.py",
"--model",
"deepseek-ai/DeepSeek-R1-0528",
"--nnodes",
"2",
"--node-rank",
"1",
"--worker-port",
"5005",
"--leader-port",
"5002",
]

# NOTE: the most important env var SGLANG_HOST_IP is already established in the sglang startup script

common_sglang_envs = { # these are for debugging
"NCCL_SOCKET_IFNAME": "eth0",
"GLOO_SOCKET_IFNAME": "eth0",
}

# This is same as VLLM multinode bundle
create_model_bundle_v2_request = CreateModelBundleV2Request(
name=endpoint_unique_name,
schema_location="TBA",
flavor=StreamingEnhancedRunnableImageFlavor(
flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE,
repository=hmi_config.sglang_repository,
tag=framework_image_tag,
command=leader_command,
streaming_command=leader_command,
protocol="http",
readiness_initial_delay_seconds=10,
healthcheck_route="/health",
predict_route="/predict",
streaming_predict_route="/stream",
extra_routes=[OPENAI_CHAT_COMPLETION_PATH, OPENAI_COMPLETION_PATH],
env=common_sglang_envs,
worker_command=worker_command,
worker_env=common_sglang_envs,
),
metadata={},
)

return (
await self.create_model_bundle_use_case.execute(
user,
create_model_bundle_v2_request,
do_auth_check=False,
)
).model_bundle_id

async def execute(
self,
user: User,
Expand All @@ -400,7 +481,10 @@ async def execute(
self.check_docker_image_exists_for_image_tag(
framework_image_tag, INFERENCE_FRAMEWORK_REPOSITORY[framework]
)
if multinode and framework != LLMInferenceFramework.VLLM:
if multinode and framework not in [
LLMInferenceFramework.VLLM,
LLMInferenceFramework.SGLANG,
]:
raise ObjectHasInvalidValueException(
f"Multinode is not supported for framework {framework}."
)
Expand Down Expand Up @@ -481,16 +565,30 @@ async def execute(
if additional_args
else None
)
bundle_id = await self.create_sglang_bundle(
user,
model_name,
framework_image_tag,
endpoint_name,
num_shards,
checkpoint_path,
chat_template_override,
additional_args=additional_sglang_args,
)
if multinode:
bundle_id = await self.create_sglang_multinode_bundle(
user,
model_name,
framework_image_tag,
endpoint_name,
num_shards,
nodes_per_worker,
quantize,
checkpoint_path,
chat_template_override,
additional_args=additional_sglang_args,
)
else:
bundle_id = await self.create_sglang_bundle(
user,
model_name,
framework_image_tag,
endpoint_name,
num_shards,
checkpoint_path,
chat_template_override,
additional_args=additional_sglang_args,
)
case _:
assert_never(framework)
raise ObjectHasInvalidValueException(
Expand Down Expand Up @@ -1321,10 +1419,10 @@ async def execute(
request.inference_framework
)

if (
request.nodes_per_worker > 1
and not request.inference_framework == LLMInferenceFramework.VLLM
):
if request.nodes_per_worker > 1 and not request.inference_framework in [
LLMInferenceFramework.VLLM,
LLMInferenceFramework.SGLANG,
]:
raise ObjectHasInvalidValueException(
"Multinode endpoints are only supported for VLLM models."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ async def forward(self, json_payload: Any) -> Any:
logger.info(f"Accepted request, forwarding {json_payload_repr=}")

try:
async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient:
async with aiohttp.ClientSession(
json_serialize=_serialize_json, timeout=aiohttp.ClientTimeout(total=60 * 60)
) as aioclient:
response_raw = await aioclient.post(
self.predict_endpoint,
json=json_payload,
Expand Down Expand Up @@ -430,7 +432,9 @@ async def forward(self, json_payload: Any) -> AsyncGenerator[Any, None]: # prag

try:
response: aiohttp.ClientResponse
async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient:
async with aiohttp.ClientSession(
json_serialize=_serialize_json, timeout=aiohttp.ClientTimeout(total=60 * 60)
) as aioclient:
response = await aioclient.post(
self.predict_endpoint,
json=json_payload,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

logger = make_logger(logger_name())

LOG_SENSITIVE_DATA = False


def get_config():
overrides = os.getenv("CONFIG_OVERRIDES")
Expand Down Expand Up @@ -90,7 +92,10 @@ async def predict(
)
return response
except Exception:
logger.error(f"Failed to decode payload from: {request}")
if LOG_SENSITIVE_DATA:
logger.error(f"Failed to decode payload from: {request}")
else:
logger.error(f"Failed to decode payload")
raise


Expand All @@ -103,10 +108,16 @@ async def stream(
try:
payload = request.model_dump()
except Exception:
logger.error(f"Failed to decode payload from: {request}")
if LOG_SENSITIVE_DATA:
logger.error(f"Failed to decode payload from: {request}")
else:
logger.error(f"Failed to decode payload")
raise
else:
logger.debug(f"Received request: {payload}")
if LOG_SENSITIVE_DATA:
logger.debug(f"Received request: {request}")
else:
logger.debug(f"Received request")

responses = forwarder.forward(payload)
# We fetch the first response to check if upstream request was successful
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
FROM 692474966980.dkr.ecr.us-west-2.amazonaws.com/sglang:v0.4.1.post7-cu124
# FROM lmsysorg/sglang:v0.4.6.post5-cu124 -- this one didn't work
FROM lmsysorg/sglang:v0.4.5.post3-cu121

# These aren't all needed but good to have for debugging purposes
RUN apt-get -yq update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
Expand Down Expand Up @@ -35,7 +36,7 @@ RUN apt-get -yq update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
tk-dev \
libffi-dev \
liblzma-dev \
python-openssl \
python3-openssl \
moreutils \
libcurl4-openssl-dev \
libssl-dev \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def wait_for_dns(dns_name: str, max_retries: int = 20, sleep_seconds: int = 3):
sleeping sleep_seconds between attempts.
Raises RuntimeError if resolution fails repeatedly.
"""
for attempt in range(1, max_retries + 1):
for attempt in range(1, max_retries + 2):
try:
# Use AF_UNSPEC to allow both IPv4 and IPv6
socket.getaddrinfo(dns_name, None, socket.AF_UNSPEC)
Expand Down Expand Up @@ -107,7 +107,7 @@ def main(
"--tp",
str(tp),
"--host",
"::",
"0.0.0.0",
"--port",
str(worker_port),
"--dist-init-addr",
Expand Down