Skip to content

Style improvements. #9410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions torch_xla/csrc/runtime/env_vars.cpp
Original file line number Diff line number Diff line change
@@ -1,36 +1 @@
#include "torch_xla/csrc/runtime/env_vars.h"

namespace torch_xla {
namespace runtime {
namespace env {

const char* const kEnvNumTpu = "TPU_NUM_DEVICES";
const char* const kEnvNumGpu = "GPU_NUM_DEVICES";
const char* const kEnvNumCpu = "CPU_NUM_DEVICES";
const char* const kEnvTpuvmMode = "TPUVM_MODE";
const char* const kEnvPjRtDevice = "PJRT_DEVICE";
const char* const kEnvPjRtTpuMaxInflightComputations =
"PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS";
const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT";
const char* const kEnvPjrtAsyncGpuClient = "PJRT_GPU_ASYNC_CLIENT";
const char* const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH";
const char* const kEnvInferredTpuLibraryPath = "PTXLA_TPU_LIBRARY_PATH";
const char* const kEnvXpuLibraryPath = "XPU_LIBRARY_PATH";
const char* const kEnvNeuronLibraryPath = "NEURON_LIBRARY_PATH";
const char* const kEnvPjrtDistServiceAddr = "PJRT_DIST_SERVICE_ADDR";
const char* const kEnvPjRtLocalProcessCount = "PJRT_LOCAL_PROCESS_COUNT";
const char* const kEnvPjRtLocalRank = "PJRT_LOCAL_PROCESS_RANK";
const char* const kEnvPjrtAllocatorCudaAsync = "PJRT_ALLOCATOR_CUDA_ASYNC";
const char* const kEnvPjrtAllocatorPreallocate = "PJRT_ALLOCATOR_PREALLOCATE";
const char* const kEnvPjrtAllocatorFraction = "PJRT_ALLOCATOR_FRACTION";
const char* const kEnvPjrtDynamicPlugins = "PJRT_DYNAMIC_PLUGINS";
const char* const kEnvDistSvcHeartbeatIntervalInSec =
"DIST_SERVICE_HEARTBEAT_INTERVAL_IN_SEC";
const char* const kEnvDistSvcMaxMissingHeartbeats =
"DIST_SERVICE_MAX_MISSING_HEARTBEATS";
const char* const kEnvDistSvcShutdownTimeoutInMin =
"DIST_SERVICE_SHUTDOWN_TIMEOUT_IN_MIN";

} // namespace env
} // namespace runtime
} // namespace torch_xla
64 changes: 32 additions & 32 deletions torch_xla/csrc/runtime/env_vars.h
Original file line number Diff line number Diff line change
@@ -1,42 +1,42 @@
// Names of environment variables.

#ifndef XLA_CLIENT_ENV_VARS_H_
#define XLA_CLIENT_ENV_VARS_H_

namespace torch_xla {
namespace runtime {
namespace env {

extern const char* const kEnvNumTpu;
extern const char* const kEnvNumGpu;
extern const char* const kEnvNumCpu;
extern const char* const kEnvLocalWorker;
extern const char* const kEnvTpuConfig;
extern const char* const kEnvDeviceMap;
extern const char* const kEnvWorkers;
extern const char* const kEnvMeshService;
extern const char* const kEnvWorldSize;
extern const char* const kEnvMpDevice;
extern const char* const kEnvHostOrdinal;
extern const char* const kEnvShardOrdinal;
extern const char* const kEnvStartService;
extern const char* const kEnvTpuvmMode;
extern const char* const kEnvPjRtDevice;
extern const char* const kEnvPjRtTpuMaxInflightComputations;
extern const char* const kEnvPjrtAsyncCpuClient;
extern const char* const kEnvPjrtAsyncGpuClient;
extern const char* const kEnvTpuLibraryPath;
extern const char* const kEnvInferredTpuLibraryPath;
extern const char* const kEnvXpuLibraryPath;
extern const char* const kEnvNeuronLibraryPath;
extern const char* const kEnvPjrtDistServiceAddr;
extern const char* const kEnvPjRtLocalProcessCount;
extern const char* const kEnvPjRtLocalRank;
extern const char* const kEnvPjrtAllocatorCudaAsync;
extern const char* const kEnvPjrtAllocatorPreallocate;
extern const char* const kEnvPjrtAllocatorFraction;
extern const char* const kEnvPjrtDynamicPlugins;
extern const char* const kEnvDistSvcHeartbeatIntervalInSec;
extern const char* const kEnvDistSvcMaxMissingHeartbeats;
extern const char* const kEnvDistSvcShutdownTimeoutInMin;
inline constexpr char kEnvLocalWorker[] = "LOCAL_WORKER";
inline constexpr char kEnvTpuConfig[] = "TPU_CONFIG";
inline constexpr char kEnvNumTpu[] = "TPU_NUM_DEVICES";
inline constexpr char kEnvNumGpu[] = "GPU_NUM_DEVICES";
inline constexpr char kEnvNumCpu[] = "CPU_NUM_DEVICES";
inline constexpr char kEnvTpuvmMode[] = "TPUVM_MODE";
inline constexpr char kEnvPjRtDevice[] = "PJRT_DEVICE";
inline constexpr char kEnvPjRtTpuMaxInflightComputations[] =
"PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS";
inline constexpr char kEnvPjrtAsyncCpuClient[] = "PJRT_CPU_ASYNC_CLIENT";
inline constexpr char kEnvPjrtAsyncGpuClient[] = "PJRT_GPU_ASYNC_CLIENT";
inline constexpr char kEnvTpuLibraryPath[] = "TPU_LIBRARY_PATH";
inline constexpr char kEnvInferredTpuLibraryPath[] = "PTXLA_TPU_LIBRARY_PATH";
inline constexpr char kEnvXpuLibraryPath[] = "XPU_LIBRARY_PATH";
inline constexpr char kEnvNeuronLibraryPath[] = "NEURON_LIBRARY_PATH";
inline constexpr char kEnvPjrtDistServiceAddr[] = "PJRT_DIST_SERVICE_ADDR";
inline constexpr char kEnvPjRtLocalProcessCount[] = "PJRT_LOCAL_PROCESS_COUNT";
inline constexpr char kEnvPjRtLocalRank[] = "PJRT_LOCAL_PROCESS_RANK";
inline constexpr char kEnvPjrtAllocatorCudaAsync[] =
"PJRT_ALLOCATOR_CUDA_ASYNC";
inline constexpr char kEnvPjrtAllocatorPreallocate[] =
"PJRT_ALLOCATOR_PREALLOCATE";
inline constexpr char kEnvPjrtAllocatorFraction[] = "PJRT_ALLOCATOR_FRACTION";
inline constexpr char kEnvPjrtDynamicPlugins[] = "PJRT_DYNAMIC_PLUGINS";
inline constexpr char kEnvDistSvcHeartbeatIntervalInSec[] =
"DIST_SERVICE_HEARTBEAT_INTERVAL_IN_SEC";
inline constexpr char kEnvDistSvcMaxMissingHeartbeats[] =
"DIST_SERVICE_MAX_MISSING_HEARTBEATS";
inline constexpr char kEnvDistSvcShutdownTimeoutInMin[] =
"DIST_SERVICE_SHUTDOWN_TIMEOUT_IN_MIN";

} // namespace env
} // namespace runtime
Expand Down
33 changes: 18 additions & 15 deletions torch_xla/csrc/runtime/runtime.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#include "torch_xla/csrc/runtime/runtime.h"

#include <torch/csrc/lazy/backend/backend_device.h>

#include "absl/log/absl_check.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/ifrt_computation_client.h"
Expand All @@ -10,12 +11,17 @@

namespace torch_xla::runtime {

std::atomic<bool> g_computation_client_initialized(false);
static std::atomic<bool> g_computation_client_initialized(false);

// Creates a new instance of a `ComputationClient` (e.g.
// `PjRtComputationClient`), and initializes the computation client
// `PjRtComputationClient`), and initializes the computation client.
// Can only be called when g_computation_client_initialized is false.
static absl::StatusOr<ComputationClient * absl_nonnull>
InitializeComputationClient() {
ABSL_CHECK(!g_computation_client_initialized)
<< "InitializeComputationClient() can only be called once.";
g_computation_client_initialized = true;

if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) {
tsl::testing::InstallStacktraceHandler();
}
Expand All @@ -25,27 +31,24 @@ InitializeComputationClient() {
//
// static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false);
const bool use_ifrt = false;
if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") {
auto* client =
(use_ifrt)
? static_cast<ComputationClient*>(new IfrtComputationClient())
: static_cast<ComputationClient*>(new PjRtComputationClient());
g_computation_client_initialized = true;
return client;
} else {
if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") == "") {
return absl::FailedPreconditionError("$PJRT_DEVICE is not set.");
}

if (use_ifrt) {
return new IfrtComputationClient();
}
return new PjRtComputationClient();
}

absl::StatusOr<ComputationClient * absl_nonnull> GetComputationClient() {
const absl::StatusOr<ComputationClient * absl_nonnull>& GetComputationClient() {
// Reference to singleton Status-wrapped ComputationClient instance.
//
// Since we only allow a single initialization, as soon as this function is
// called, we store the initialization result in this trivially destructible
// reference.
static auto& maybe_client =
*new absl::StatusOr<ComputationClient * absl_nonnull>(
InitializeComputationClient());
static const auto& maybe_client =
*new absl::StatusOr<ComputationClient*>(InitializeComputationClient());
return maybe_client;
}

Expand Down
13 changes: 8 additions & 5 deletions torch_xla/csrc/runtime/runtime.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#ifndef XLA_CLIENT_RUNTIME_H_
#define XLA_CLIENT_RUNTIME_H_

#include "absl/base/attributes.h"
#include "absl/status/statusor.h"
#include "torch_xla/csrc/runtime/computation_client.h"

namespace torch_xla::runtime {

// Returns the ComputationClient singleton.
absl::StatusOr<ComputationClient * absl_nonnull> GetComputationClient();
const absl::StatusOr<ComputationClient * absl_nonnull>& GetComputationClient();

ABSL_DEPRECATED(
"Use status::GetComputationClient(), instead. "
Expand All @@ -15,12 +17,13 @@ ABSL_DEPRECATED(
"safer.")
ComputationClient* absl_nonnull GetComputationClientOrDie();

// Returns the ComputationClient singleton, if successfully initialized.
// Returns a nullptr, if the ComputationClient wasn't initialized yet, or
// if there was an error on initialization.
// Returns the ComputationClient singleton if it was successfully initialized.
// Returns a nullptr if the ComputationClient wasn't initialized yet.
// Throws an exception if the ComputationClient was initialized but the
// initialization failed.
ComputationClient* GetComputationClientIfInitialized();

// Run the XRT local service, this will block the caller unitl the server
// Runs the XRT local service, this will block the caller unitl the server
// being stopped.
void RunLocalService(uint64_t service_port);

Expand Down
24 changes: 12 additions & 12 deletions torch_xla/csrc/runtime/sys_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ namespace torch_xla {
namespace runtime {
namespace sys_util {

std::string GetEnvString(const char* name, const std::string& defval) {
const char* env = std::getenv(name);
std::string GetEnvString(const char* const name, const std::string& defval) {
const char* const env = std::getenv(name);
return env != nullptr ? env : defval;
}

std::string GetEnvOrdinalPath(const char* name, const std::string& defval,
std::string GetEnvOrdinalPath(const char* const name, const std::string& defval,
const int64_t ordinal) {
std::string path = GetEnvString(name, defval);
if (!path.empty()) {
Expand All @@ -26,23 +26,23 @@ std::string GetEnvOrdinalPath(const char* name, const std::string& defval,
return path;
}

std::string GetEnvOrdinalPath(const char* name, const std::string& defval,
const char* ordinal_env) {
std::string GetEnvOrdinalPath(const char* const name, const std::string& defval,
const char* const ordinal_env) {
return GetEnvOrdinalPath(name, defval, GetEnvInt(ordinal_env, -1));
}

int64_t GetEnvInt(const char* name, int64_t defval) {
const char* env = std::getenv(name);
int64_t GetEnvInt(const char* const name, const int64_t defval) {
const char* const env = std::getenv(name);
return env != nullptr ? std::atol(env) : defval;
}

double GetEnvDouble(const char* name, double defval) {
const char* env = std::getenv(name);
double GetEnvDouble(const char* const name, const double defval) {
const char* const env = std::getenv(name);
return env != nullptr ? std::atof(env) : defval;
}

bool GetEnvBool(const char* name, bool defval) {
const char* env = std::getenv(name);
bool GetEnvBool(const char* const name, const bool defval) {
const char* const env = std::getenv(name);
if (env == nullptr) {
return defval;
}
Expand All @@ -56,7 +56,7 @@ bool GetEnvBool(const char* name, bool defval) {
}

int64_t NowNs() {
auto now = std::chrono::high_resolution_clock::now();
const auto now = std::chrono::high_resolution_clock::now();
return std::chrono::duration_cast<std::chrono::nanoseconds>(
now.time_since_epoch())
.count();
Expand Down
Loading