diff --git a/gym_http_server.py b/gym_http_server.py index 5c2384d..2b36a89 100644 --- a/gym_http_server.py +++ b/gym_http_server.py @@ -59,6 +59,8 @@ def list_all(self): def reset(self, instance_id): env = self._lookup_env(instance_id) obs = env.reset() + if (isinstance(obs, tuple)): + obs = [obs] return env.observation_space.to_jsonable(obs) def step(self, instance_id, action, render): @@ -70,6 +72,9 @@ def step(self, instance_id, action, render): if render: env.render() [observation, reward, done, info] = env.step(nice_action) + + if isinstance(observation, tuple): + observation = [observation] obs_jsonable = env.observation_space.to_jsonable(observation) return [obs_jsonable, reward, done, info]