Skip to content

Commit 5ae7958

Browse files
committed
WIP
1 parent 2246912 commit 5ae7958

File tree

6 files changed

+72
-33
lines changed

6 files changed

+72
-33
lines changed

codegen/xla_native_functions.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ full_codegen:
3535
- cholesky
3636
- clamp
3737
- clamp.Tensor
38+
- clamp_max
3839
- clamp_max.Tensor
40+
- clamp_min
3941
- clamp_min.Tensor
4042
- _conj_copy
4143
- cos
@@ -184,8 +186,6 @@ supported:
184186
- cat
185187
- celu
186188
- celu_
187-
- clamp_max
188-
- clamp_min
189189
- clone
190190
- constant_pad_nd
191191
- convolution_backward_overrideable

test/cpp/BUILD

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -130,22 +130,37 @@ ptxla_cc_test(
130130
],
131131
)
132132

133+
ptxla_cc_test(
134+
name = "test_aten_xla_tensor_5",
135+
size = "enormous",
136+
srcs = ["test_aten_xla_tensor_5.cpp"],
137+
deps = [
138+
":cpp_test_util",
139+
":torch_xla_test",
140+
"//torch_xla/csrc/runtime:metrics",
141+
"//torch_xla/csrc:tensor",
142+
"//torch_xla/csrc:aten_cuda_functions",
143+
"@com_google_googletest//:gtest_main",
144+
"@xla//xla:permutation_util",
145+
],
146+
)
147+
133148
# This tets is very large so it's split into shards.
134149
# To make it run fast, please add new shards when needed.
135-
[
136-
ptxla_cc_test(
137-
name = test[:-len(".cpp")],
138-
size = "enormous",
139-
srcs = [test],
140-
deps = [
141-
":cpp_test_util",
142-
":torch_xla_test",
143-
"//torch_xla/csrc/runtime:metrics",
144-
"//torch_xla/csrc:tensor",
145-
"//torch_xla/csrc:aten_cuda_functions",
146-
"@com_google_googletest//:gtest_main",
147-
"@xla//xla:permutation_util",
148-
],
149-
)
150-
for test in glob(["test_aten_xla_tensor*cpp"])
151-
]
150+
# [
151+
# ptxla_cc_test(
152+
# name = test[:-len(".cpp")],
153+
# size = "enormous",
154+
# srcs = [test],
155+
# deps = [
156+
# ":cpp_test_util",
157+
# ":torch_xla_test",
158+
# "//torch_xla/csrc/runtime:metrics",
159+
# "//torch_xla/csrc:tensor",
160+
# "//torch_xla/csrc:aten_cuda_functions",
161+
# "@com_google_googletest//:gtest_main",
162+
# "@xla//xla:permutation_util",
163+
# ],
164+
# )
165+
# for test in glob(["test_aten_xla_tensor*cpp"])
166+
# ]

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,20 +1312,6 @@ at::Tensor& XLANativeFunctions::celu_(at::Tensor& self,
13121312
return self;
13131313
}
13141314

1315-
at::Tensor XLANativeFunctions::clamp_max(const at::Tensor& self,
1316-
const at::Scalar& max) {
1317-
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1318-
return bridge::AtenFromXlaTensor(
1319-
tensor_methods::clamp(bridge::GetXlaTensor(self), std::nullopt, max));
1320-
}
1321-
1322-
at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self,
1323-
const at::Scalar& min) {
1324-
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1325-
return bridge::AtenFromXlaTensor(
1326-
tensor_methods::clamp(bridge::GetXlaTensor(self), min, std::nullopt));
1327-
}
1328-
13291315
at::Tensor XLANativeFunctions::clone(
13301316
const at::Tensor& self,
13311317
std::optional<at::MemoryFormat> /* memory_format */) {

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,12 +401,24 @@ torch_xla::XlaOpVector ClampTensor::Lower(LoweringContext* loctx) const {
401401
return ReturnOp(res, loctx);
402402
}
403403

404+
torch_xla::XlaOpVector ClampMax::Lower(LoweringContext* loctx) const {
405+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
406+
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
407+
return ReturnOp(xla::Min(xla_input, xla_other), loctx);
408+
}
409+
404410
torch_xla::XlaOpVector ClampMaxTensor::Lower(LoweringContext* loctx) const {
405411
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
406412
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
407413
return ReturnOp(xla::Min(xla_input, xla_other), loctx);
408414
}
409415

416+
torch_xla::XlaOpVector ClampMin::Lower(LoweringContext* loctx) const {
417+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
418+
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
419+
return ReturnOp(xla::Max(xla_input, xla_other), loctx);
420+
}
421+
410422
torch_xla::XlaOpVector ClampMinTensor::Lower(LoweringContext* loctx) const {
411423
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
412424
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,16 @@ xla::Shape ClampTensorOutputShape(
454454
return InferOutputShape(shapes, lower_for_shape_fn);
455455
}
456456

457+
xla::Shape ClampMaxOutputShape(const torch::lazy::Value& input,
458+
const torch::lazy::Value& target) {
459+
auto lower_for_shape_fn =
460+
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
461+
return xla::Min(operands[0], operands[1]);
462+
};
463+
return InferOutputShape({GetXlaShape(input), GetXlaShape(other)},
464+
lower_for_shape_fn);
465+
}
466+
457467
xla::Shape ClampMaxTensorOutputShape(const torch::lazy::Value& input,
458468
const torch::lazy::Value& other) {
459469
auto lower_for_shape_fn =
@@ -464,6 +474,16 @@ xla::Shape ClampMaxTensorOutputShape(const torch::lazy::Value& input,
464474
lower_for_shape_fn);
465475
}
466476

477+
xla::Shape ClampMinOutputShape(const torch::lazy::Value& input,
478+
const torch::lazy::Value& other) {
479+
auto lower_for_shape_fn =
480+
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
481+
return xla::Max(operands[0], operands[1]);
482+
};
483+
return InferOutputShape({GetXlaShape(input), GetXlaShape(other)},
484+
lower_for_shape_fn);
485+
}
486+
467487
xla::Shape ClampMinTensorOutputShape(const torch::lazy::Value& input,
468488
const torch::lazy::Value& other) {
469489
auto lower_for_shape_fn =

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,15 @@ xla::Shape ClampTensorOutputShape(const torch::lazy::Value& input,
112112
const std::optional<torch::lazy::Value>& min,
113113
const std::optional<torch::lazy::Value>& max);
114114

115+
xla::Shape ClampMaxOutputShape(const torch::lazy::Value& input,
116+
const torch::lazy::Value& target);
117+
115118
xla::Shape ClampMaxTensorOutputShape(const torch::lazy::Value& input,
116119
const torch::lazy::Value& target);
117120

121+
xla::Shape ClampMinOutputShape(const torch::lazy::Value& input,
122+
const torch::lazy::Value& target);
123+
118124
xla::Shape ClampMinTensorOutputShape(const torch::lazy::Value& input,
119125
const torch::lazy::Value& target);
120126

0 commit comments

Comments
 (0)