diff --git a/lib/THCUNN/PReLU.cu b/lib/THCUNN/PReLU.cu index 395e4a1e..0a286b76 100644 --- a/lib/THCUNN/PReLU.cu +++ b/lib/THCUNN/PReLU.cu @@ -68,7 +68,7 @@ struct PReLUAccGradParametersShared { __device__ __forceinline__ void operator()(T *gradInput, T *input, T *gradOutput) { - *gradInput = (*input) * (*gradOutput) * (*input <= 0); + *gradInput = *input <= 0 ? (*input) * (*gradOutput) : 0; } }; @@ -83,7 +83,7 @@ struct PReLUAccGradParameters __device__ __forceinline__ void operator()(T *gradInput, T *input, T *gradOutput) { - *gradInput = (*input) * (*gradOutput) * scale * (*input <= 0); + *gradInput = *input <= 0 ? (*input) * (*gradOutput) * scale : 0; } }; @@ -98,7 +98,7 @@ struct PReLUAccGradParameters1to1 __device__ __forceinline__ void operator()(T *gradWeight, T *input, T *gradOutput) { - *gradWeight += (*input) * (*gradOutput) * scale * (*input <= 0); + *gradWeight += *input <= 0 ? (*input) * (*gradOutput) * scale : 0; } };