Skip to content

Fix when base_url is used in combination with a gateway client #1550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions jupyter_server/gateway/gateway_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from traitlets.config import LoggingConfigurable, SingletonConfigurable

from jupyter_server import DEFAULT_EVENTS_SCHEMA_PATH, JUPYTER_SERVER_EVENTS_URI
from jupyter_server.utils import url_path_join

ERROR_STATUS = "error"
SUCCESS_STATUS = "success"
Expand Down Expand Up @@ -170,6 +171,18 @@ def _ws_url_validate(self, proposal):
raise TraitError(message)
return value

base_url = Unicode(
default_value="/",
config=True,
help="""The gateway API base_url for fixing default kernel endpoints""",
)

@observe("base_url")
def _base_url(self, change):
self.kernels_endpoint = self._kernels_endpoint_default()
self.kernelspecs_endpoint = self._kernelspecs_endpoint_default()
self.kernelspecs_resource_endpoint = self._kernelspecs_resource_endpoint_default()

kernels_endpoint_default_value = "/api/kernels"
kernels_endpoint_env = "JUPYTER_GATEWAY_KERNELS_ENDPOINT"
kernels_endpoint = Unicode(
Expand All @@ -180,7 +193,10 @@ def _ws_url_validate(self, proposal):

@default("kernels_endpoint")
def _kernels_endpoint_default(self):
return os.environ.get(self.kernels_endpoint_env, self.kernels_endpoint_default_value)
return os.environ.get(
self.kernels_endpoint_env,
url_path_join(self.base_url, self.kernels_endpoint_default_value),
)

kernelspecs_endpoint_default_value = "/api/kernelspecs"
kernelspecs_endpoint_env = "JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT"
Expand All @@ -193,7 +209,8 @@ def _kernels_endpoint_default(self):
@default("kernelspecs_endpoint")
def _kernelspecs_endpoint_default(self):
return os.environ.get(
self.kernelspecs_endpoint_env, self.kernelspecs_endpoint_default_value
self.kernelspecs_endpoint_env,
url_path_join(self.base_url, self.kernelspecs_endpoint_default_value),
)

kernelspecs_resource_endpoint_default_value = "/kernelspecs"
Expand All @@ -209,7 +226,7 @@ def _kernelspecs_endpoint_default(self):
def _kernelspecs_resource_endpoint_default(self):
return os.environ.get(
self.kernelspecs_resource_endpoint_env,
self.kernelspecs_resource_endpoint_default_value,
url_path_join(self.base_url, self.kernelspecs_resource_endpoint_default_value),
)

connect_timeout_default_value = 40.0
Expand Down
5 changes: 4 additions & 1 deletion jupyter_server/gateway/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _default_shared_context(self):
def __init__(self, **kwargs):
"""Initialize a gateway mapping kernel manager."""
super().__init__(**kwargs)
GatewayClient.instance().base_url = self.parent.base_url
self.kernels_url = url_path_join(
GatewayClient.instance().url or "", GatewayClient.instance().kernels_endpoint or ""
)
Expand Down Expand Up @@ -218,6 +219,8 @@ class GatewayKernelSpecManager(KernelSpecManager):
def __init__(self, **kwargs):
"""Initialize a gateway kernel spec manager."""
super().__init__(**kwargs)
GatewayClient.instance().base_url = self.parent.base_url

base_endpoint = url_path_join(
GatewayClient.instance().url or "", GatewayClient.instance().kernelspecs_endpoint
)
Expand Down Expand Up @@ -248,7 +251,7 @@ def _replace_path_kernelspec_resources(self, kernel_specs):
resources = kernelspecs[kernel_name]["resources"]
for resource_name in resources:
original_path = resources[resource_name]
split_eg_base_url = str.rsplit(original_path, sep="/kernelspecs/", maxsplit=1)
split_eg_base_url = str.rsplit(original_path, sep="/kernelspecs", maxsplit=1)
if len(split_eg_base_url) > 1:
new_path = url_path_join(
self.parent.base_url, "kernelspecs", split_eg_base_url[1]
Expand Down
28 changes: 22 additions & 6 deletions tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@
from jupyter_server.gateway.gateway_client import GatewayTokenRenewerBase, NoOpTokenRenewer
from jupyter_server.gateway.managers import ChannelQueue, GatewayClient, GatewayKernelManager
from jupyter_server.services.kernels.websocket import KernelWebsocketHandler
from jupyter_server.utils import url_path_join

from .utils import expected_http_error

pytest_plugins = ["jupyter_events.pytest_plugin"]


def generate_kernelspec(name):
def generate_kernelspec(name, kernelspecs_endpoint):
argv_stanza = ["python", "-m", "ipykernel_launcher", "-f", "{connection_file}"]
spec_stanza = {
"spec": {
Expand All @@ -52,6 +53,7 @@ def generate_kernelspec(name):
"resources": {
"logo-64x64": f"f/kernelspecs/{name}/logo-64x64.png",
"url": "https://example.com/example-url",
"kernelspec": kernelspecs_endpoint,
},
}
return kernelspec_stanza
Expand All @@ -61,8 +63,11 @@ def generate_kernelspec(name):
kernelspecs: dict = {
"default": "kspec_foo",
"kernelspecs": {
"kspec_foo": generate_kernelspec("kspec_foo"),
"kspec_bar": generate_kernelspec("kspec_bar"),
"kspec_foo": generate_kernelspec("kspec_foo", "/foo/kernelspecs"),
"kspec_bar": generate_kernelspec("kspec_bar", "/bar/kernelspecs/"),
"kspec_baz": generate_kernelspec(
"kspec_baz", GatewayClient.kernelspecs_endpoint_default_value
),
},
}

Expand Down Expand Up @@ -437,11 +442,20 @@ async def test_gateway_get_kernelspecs(init_gateway, jp_fetch, jp_serverapp):
assert r.code == 200
content = json.loads(r.body.decode("utf-8"))
kspecs = content.get("kernelspecs")
assert len(kspecs) == 2
assert len(kspecs) == 3
assert kspecs.get("kspec_bar").get("name") == "kspec_bar"
assert (
kspecs.get("kspec_bar").get("resources")["logo-64x64"].startswith(jp_serverapp.base_url)
)
assert (
kspecs.get("kspec_bar").get("resources")["kernelspec"].startswith(jp_serverapp.base_url)
)
assert (
kspecs.get("kspec_foo").get("resources")["kernelspec"].startswith(jp_serverapp.base_url)
)
assert (
kspecs.get("kspec_baz").get("resources")["kernelspec"].startswith(jp_serverapp.base_url)
)


async def test_gateway_get_named_kernelspec(init_gateway, jp_fetch):
Expand Down Expand Up @@ -750,8 +764,10 @@ async def test_websocket_connection_with_session_id(init_gateway, jp_serverapp,
handler.connection = conn
await conn.connect()
assert conn.session_id != None
expected_ws_url = (
f"{mock_gateway_ws_url}/api/kernels/{kernel_id}/channels?session_id={conn.session_id}"
expected_ws_url = url_path_join(
mock_gateway_ws_url,
jp_serverapp.base_url,
f"/api/kernels/{kernel_id}/channels?session_id={conn.session_id}",
)
assert (
expected_ws_url in caplog.text
Expand Down
Loading