diff --git a/torch_xla/csrc/runtime/env_vars.cpp b/torch_xla/csrc/runtime/env_vars.cpp index 31c066d6f0d..74add068255 100644 --- a/torch_xla/csrc/runtime/env_vars.cpp +++ b/torch_xla/csrc/runtime/env_vars.cpp @@ -1,36 +1 @@ #include "torch_xla/csrc/runtime/env_vars.h" - -namespace torch_xla { -namespace runtime { -namespace env { - -const char* const kEnvNumTpu = "TPU_NUM_DEVICES"; -const char* const kEnvNumGpu = "GPU_NUM_DEVICES"; -const char* const kEnvNumCpu = "CPU_NUM_DEVICES"; -const char* const kEnvTpuvmMode = "TPUVM_MODE"; -const char* const kEnvPjRtDevice = "PJRT_DEVICE"; -const char* const kEnvPjRtTpuMaxInflightComputations = - "PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS"; -const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT"; -const char* const kEnvPjrtAsyncGpuClient = "PJRT_GPU_ASYNC_CLIENT"; -const char* const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH"; -const char* const kEnvInferredTpuLibraryPath = "PTXLA_TPU_LIBRARY_PATH"; -const char* const kEnvXpuLibraryPath = "XPU_LIBRARY_PATH"; -const char* const kEnvNeuronLibraryPath = "NEURON_LIBRARY_PATH"; -const char* const kEnvPjrtDistServiceAddr = "PJRT_DIST_SERVICE_ADDR"; -const char* const kEnvPjRtLocalProcessCount = "PJRT_LOCAL_PROCESS_COUNT"; -const char* const kEnvPjRtLocalRank = "PJRT_LOCAL_PROCESS_RANK"; -const char* const kEnvPjrtAllocatorCudaAsync = "PJRT_ALLOCATOR_CUDA_ASYNC"; -const char* const kEnvPjrtAllocatorPreallocate = "PJRT_ALLOCATOR_PREALLOCATE"; -const char* const kEnvPjrtAllocatorFraction = "PJRT_ALLOCATOR_FRACTION"; -const char* const kEnvPjrtDynamicPlugins = "PJRT_DYNAMIC_PLUGINS"; -const char* const kEnvDistSvcHeartbeatIntervalInSec = - "DIST_SERVICE_HEARTBEAT_INTERVAL_IN_SEC"; -const char* const kEnvDistSvcMaxMissingHeartbeats = - "DIST_SERVICE_MAX_MISSING_HEARTBEATS"; -const char* const kEnvDistSvcShutdownTimeoutInMin = - "DIST_SERVICE_SHUTDOWN_TIMEOUT_IN_MIN"; - -} // namespace env -} // namespace runtime -} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/env_vars.h b/torch_xla/csrc/runtime/env_vars.h index a777846b6da..827c4822d49 100644 --- a/torch_xla/csrc/runtime/env_vars.h +++ b/torch_xla/csrc/runtime/env_vars.h @@ -1,3 +1,5 @@ +// Names of environment variables. + #ifndef XLA_CLIENT_ENV_VARS_H_ #define XLA_CLIENT_ENV_VARS_H_ @@ -5,38 +7,36 @@ namespace torch_xla { namespace runtime { namespace env { -extern const char* const kEnvNumTpu; -extern const char* const kEnvNumGpu; -extern const char* const kEnvNumCpu; -extern const char* const kEnvLocalWorker; -extern const char* const kEnvTpuConfig; -extern const char* const kEnvDeviceMap; -extern const char* const kEnvWorkers; -extern const char* const kEnvMeshService; -extern const char* const kEnvWorldSize; -extern const char* const kEnvMpDevice; -extern const char* const kEnvHostOrdinal; -extern const char* const kEnvShardOrdinal; -extern const char* const kEnvStartService; -extern const char* const kEnvTpuvmMode; -extern const char* const kEnvPjRtDevice; -extern const char* const kEnvPjRtTpuMaxInflightComputations; -extern const char* const kEnvPjrtAsyncCpuClient; -extern const char* const kEnvPjrtAsyncGpuClient; -extern const char* const kEnvTpuLibraryPath; -extern const char* const kEnvInferredTpuLibraryPath; -extern const char* const kEnvXpuLibraryPath; -extern const char* const kEnvNeuronLibraryPath; -extern const char* const kEnvPjrtDistServiceAddr; -extern const char* const kEnvPjRtLocalProcessCount; -extern const char* const kEnvPjRtLocalRank; -extern const char* const kEnvPjrtAllocatorCudaAsync; -extern const char* const kEnvPjrtAllocatorPreallocate; -extern const char* const kEnvPjrtAllocatorFraction; -extern const char* const kEnvPjrtDynamicPlugins; -extern const char* const kEnvDistSvcHeartbeatIntervalInSec; -extern const char* const kEnvDistSvcMaxMissingHeartbeats; -extern const char* const kEnvDistSvcShutdownTimeoutInMin; +inline constexpr char kEnvLocalWorker[] = "LOCAL_WORKER"; +inline constexpr char kEnvTpuConfig[] = "TPU_CONFIG"; +inline constexpr char kEnvNumTpu[] = "TPU_NUM_DEVICES"; +inline constexpr char kEnvNumGpu[] = "GPU_NUM_DEVICES"; +inline constexpr char kEnvNumCpu[] = "CPU_NUM_DEVICES"; +inline constexpr char kEnvTpuvmMode[] = "TPUVM_MODE"; +inline constexpr char kEnvPjRtDevice[] = "PJRT_DEVICE"; +inline constexpr char kEnvPjRtTpuMaxInflightComputations[] = + "PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS"; +inline constexpr char kEnvPjrtAsyncCpuClient[] = "PJRT_CPU_ASYNC_CLIENT"; +inline constexpr char kEnvPjrtAsyncGpuClient[] = "PJRT_GPU_ASYNC_CLIENT"; +inline constexpr char kEnvTpuLibraryPath[] = "TPU_LIBRARY_PATH"; +inline constexpr char kEnvInferredTpuLibraryPath[] = "PTXLA_TPU_LIBRARY_PATH"; +inline constexpr char kEnvXpuLibraryPath[] = "XPU_LIBRARY_PATH"; +inline constexpr char kEnvNeuronLibraryPath[] = "NEURON_LIBRARY_PATH"; +inline constexpr char kEnvPjrtDistServiceAddr[] = "PJRT_DIST_SERVICE_ADDR"; +inline constexpr char kEnvPjRtLocalProcessCount[] = "PJRT_LOCAL_PROCESS_COUNT"; +inline constexpr char kEnvPjRtLocalRank[] = "PJRT_LOCAL_PROCESS_RANK"; +inline constexpr char kEnvPjrtAllocatorCudaAsync[] = + "PJRT_ALLOCATOR_CUDA_ASYNC"; +inline constexpr char kEnvPjrtAllocatorPreallocate[] = + "PJRT_ALLOCATOR_PREALLOCATE"; +inline constexpr char kEnvPjrtAllocatorFraction[] = "PJRT_ALLOCATOR_FRACTION"; +inline constexpr char kEnvPjrtDynamicPlugins[] = "PJRT_DYNAMIC_PLUGINS"; +inline constexpr char kEnvDistSvcHeartbeatIntervalInSec[] = + "DIST_SERVICE_HEARTBEAT_INTERVAL_IN_SEC"; +inline constexpr char kEnvDistSvcMaxMissingHeartbeats[] = + "DIST_SERVICE_MAX_MISSING_HEARTBEATS"; +inline constexpr char kEnvDistSvcShutdownTimeoutInMin[] = + "DIST_SERVICE_SHUTDOWN_TIMEOUT_IN_MIN"; } // namespace env } // namespace runtime diff --git a/torch_xla/csrc/runtime/runtime.cpp b/torch_xla/csrc/runtime/runtime.cpp index c40d3bfb1d0..24c11db4e1a 100644 --- a/torch_xla/csrc/runtime/runtime.cpp +++ b/torch_xla/csrc/runtime/runtime.cpp @@ -1,7 +1,8 @@ +#include "torch_xla/csrc/runtime/runtime.h" + #include #include "absl/log/absl_check.h" -#include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/ifrt_computation_client.h" @@ -10,12 +11,17 @@ namespace torch_xla::runtime { -std::atomic g_computation_client_initialized(false); +static std::atomic g_computation_client_initialized(false); // Creates a new instance of a `ComputationClient` (e.g. -// `PjRtComputationClient`), and initializes the computation client +// `PjRtComputationClient`), and initializes the computation client. +// Can only be called when g_computation_client_initialized is false. static absl::StatusOr InitializeComputationClient() { + ABSL_CHECK(!g_computation_client_initialized) + << "InitializeComputationClient() can only be called once."; + g_computation_client_initialized = true; + if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) { tsl::testing::InstallStacktraceHandler(); } @@ -25,27 +31,24 @@ InitializeComputationClient() { // // static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false); const bool use_ifrt = false; - if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { - auto* client = - (use_ifrt) - ? static_cast(new IfrtComputationClient()) - : static_cast(new PjRtComputationClient()); - g_computation_client_initialized = true; - return client; - } else { + if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") == "") { return absl::FailedPreconditionError("$PJRT_DEVICE is not set."); } + + if (use_ifrt) { + return new IfrtComputationClient(); + } + return new PjRtComputationClient(); } -absl::StatusOr GetComputationClient() { +const absl::StatusOr& GetComputationClient() { // Reference to singleton Status-wrapped ComputationClient instance. // // Since we only allow a single initialization, as soon as this function is // called, we store the initialization result in this trivially destructible // reference. - static auto& maybe_client = - *new absl::StatusOr( - InitializeComputationClient()); + static const auto& maybe_client = + *new absl::StatusOr(InitializeComputationClient()); return maybe_client; } diff --git a/torch_xla/csrc/runtime/runtime.h b/torch_xla/csrc/runtime/runtime.h index 0db40cc2d58..f6af26cb66f 100644 --- a/torch_xla/csrc/runtime/runtime.h +++ b/torch_xla/csrc/runtime/runtime.h @@ -1,12 +1,14 @@ #ifndef XLA_CLIENT_RUNTIME_H_ #define XLA_CLIENT_RUNTIME_H_ +#include "absl/base/attributes.h" +#include "absl/status/statusor.h" #include "torch_xla/csrc/runtime/computation_client.h" namespace torch_xla::runtime { // Returns the ComputationClient singleton. -absl::StatusOr GetComputationClient(); +const absl::StatusOr& GetComputationClient(); ABSL_DEPRECATED( "Use status::GetComputationClient(), instead. " @@ -15,12 +17,13 @@ ABSL_DEPRECATED( "safer.") ComputationClient* absl_nonnull GetComputationClientOrDie(); -// Returns the ComputationClient singleton, if successfully initialized. -// Returns a nullptr, if the ComputationClient wasn't initialized yet, or -// if there was an error on initialization. +// Returns the ComputationClient singleton if it was successfully initialized. +// Returns a nullptr if the ComputationClient wasn't initialized yet. +// Throws an exception if the ComputationClient was initialized but the +// initialization failed. ComputationClient* GetComputationClientIfInitialized(); -// Run the XRT local service, this will block the caller unitl the server +// Runs the XRT local service, this will block the caller unitl the server // being stopped. void RunLocalService(uint64_t service_port); diff --git a/torch_xla/csrc/runtime/sys_util.cpp b/torch_xla/csrc/runtime/sys_util.cpp index 32f4d8916bc..fa23d68a13f 100644 --- a/torch_xla/csrc/runtime/sys_util.cpp +++ b/torch_xla/csrc/runtime/sys_util.cpp @@ -10,12 +10,12 @@ namespace torch_xla { namespace runtime { namespace sys_util { -std::string GetEnvString(const char* name, const std::string& defval) { - const char* env = std::getenv(name); +std::string GetEnvString(const char* const name, const std::string& defval) { + const char* const env = std::getenv(name); return env != nullptr ? env : defval; } -std::string GetEnvOrdinalPath(const char* name, const std::string& defval, +std::string GetEnvOrdinalPath(const char* const name, const std::string& defval, const int64_t ordinal) { std::string path = GetEnvString(name, defval); if (!path.empty()) { @@ -26,23 +26,23 @@ std::string GetEnvOrdinalPath(const char* name, const std::string& defval, return path; } -std::string GetEnvOrdinalPath(const char* name, const std::string& defval, - const char* ordinal_env) { +std::string GetEnvOrdinalPath(const char* const name, const std::string& defval, + const char* const ordinal_env) { return GetEnvOrdinalPath(name, defval, GetEnvInt(ordinal_env, -1)); } -int64_t GetEnvInt(const char* name, int64_t defval) { - const char* env = std::getenv(name); +int64_t GetEnvInt(const char* const name, const int64_t defval) { + const char* const env = std::getenv(name); return env != nullptr ? std::atol(env) : defval; } -double GetEnvDouble(const char* name, double defval) { - const char* env = std::getenv(name); +double GetEnvDouble(const char* const name, const double defval) { + const char* const env = std::getenv(name); return env != nullptr ? std::atof(env) : defval; } -bool GetEnvBool(const char* name, bool defval) { - const char* env = std::getenv(name); +bool GetEnvBool(const char* const name, const bool defval) { + const char* const env = std::getenv(name); if (env == nullptr) { return defval; } @@ -56,7 +56,7 @@ bool GetEnvBool(const char* name, bool defval) { } int64_t NowNs() { - auto now = std::chrono::high_resolution_clock::now(); + const auto now = std::chrono::high_resolution_clock::now(); return std::chrono::duration_cast( now.time_since_epoch()) .count();