Skip to content

Commit 328cafa

Browse files
committed
refactor lpc_cpu
1 parent 675c23a commit 328cafa

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

torchlpc/csrc/scan_cpu.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -103,24 +103,22 @@ void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out)
103103
const scalar_t *a_ptr = a_contiguous.const_data_ptr<scalar_t>();
104104
scalar_t *out_ptr = padded_out.mutable_data_ptr<scalar_t>();
105105

106-
// at::parallel_for(0, B, 1, [&](int64_t start, int64_t end)
107-
// {
108-
#pragma omp parallel for
109-
for (auto b = 0; b < B; b++)
110-
{
111-
auto out_offset = b * (T + order) + order;
112-
auto a_offset = b * T * order;
113-
for (int64_t t = 0; t < T; t++)
106+
at::parallel_for(0, B, 1, [&](int64_t start, int64_t end)
107+
{
108+
for (auto b = start; b < end; b++)
114109
{
115-
scalar_t y = out_ptr[out_offset + t];
116-
for (int64_t i = 0; i < order; i++)
110+
auto out_offset = out_ptr + b * (T + order) + order;
111+
auto a_offset = a_ptr + b * T * order;
112+
for (int64_t t = 0; t < T; t++)
117113
{
118-
y -= a_ptr[a_offset + t * order + i] *
119-
out_ptr[out_offset + t - i - 1];
114+
scalar_t y = out_offset[t];
115+
for (int64_t i = 0; i < order; i++)
116+
{
117+
y -= a_offset[t * order + i] * out_offset [t - i - 1];
118+
}
119+
out_offset[t] = y;
120120
}
121-
out_ptr[out_offset + t] = y;
122-
}
123-
};
121+
}; });
124122
}
125123

126124
at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,

0 commit comments

Comments
 (0)