From 95bf220beb51e9f486d386da533bab82798e795e Mon Sep 17 00:00:00 2001 From: Masaru Kimura Date: Fri, 30 May 2025 15:19:26 +0900 Subject: [PATCH 1/3] Fix torch.jit.ScriptModule.zero_grad. TorchSharp 0.105.0 doesn't have torch.jit.ScriptModule.zero_grad and falls back into torch.nn.Module.zero_grad incorrectly, then terminates silently. Most probably, because JITModule is not compatible to NNModule in LibTorchSharp. And as reported in https://github.com/pytorch/pytorch/issues/27144, libtorch also doesn't have torch::jit::Module::zero_grad. As a workaround, manually loop over the parameters and zero them out like optimizer does. --- src/Native/LibTorchSharp/THSJIT.cpp | 17 +++++++++++++++++ src/Native/LibTorchSharp/THSJIT.h | 1 + src/TorchSharp/JIT/ScriptModule.cs | 17 +++++++++++++++++ src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs | 3 +++ 4 files changed, 38 insertions(+) diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp index a0a4a5d0c..56b161895 100644 --- a/src/Native/LibTorchSharp/THSJIT.cpp +++ b/src/Native/LibTorchSharp/THSJIT.cpp @@ -68,6 +68,23 @@ int THSJIT_Module_is_training(JITModule module) return (*module)->is_training(); } +void THSJIT_Module_zero_grad(const JITModule module, bool set_to_none) +{ + // According to https://github.com/pytorch/pytorch/issues/27144, + // torch::jit::Module has no zero_grad(). + // As a workaround, manually loop over the parameters and zero them out like optimizer does; + // https://github.com/pytorch/pytorch/blob/v2.5.1/torch/csrc/api/src/optim/optimizer.cpp#L123 + for (auto& p : (*module)->parameters()) { + if (p.mutable_grad().defined()) { + p.mutable_grad().detach_(); + if (set_to_none) + p.mutable_grad().reset(); + else + p.mutable_grad().zero_(); + } + } +} + void THSJIT_Module_train(JITModule module, bool on) { (*module)->train(on); diff --git a/src/Native/LibTorchSharp/THSJIT.h b/src/Native/LibTorchSharp/THSJIT.h index 81e6d51ad..25d7cea32 100644 --- a/src/Native/LibTorchSharp/THSJIT.h +++ b/src/Native/LibTorchSharp/THSJIT.h @@ -44,6 +44,7 @@ EXPORT_API(void) THSJIT_Module_invoke(const JITModule module, const char* name, EXPORT_API(void) THSJIT_CompilationUnit_Invoke(const JITCompilationUnit module, const char* method, const TensorOrScalar* tensorPtrs, const int length, TensorOrScalar* (*allocator)(int32_t idx, size_t length), int8_t* typeCode, int32_t idx); EXPORT_API(int) THSJIT_Module_is_training(JITModule module); +EXPORT_API(void) THSJIT_Module_zero_grad(const JITModule module, bool set_to_none); EXPORT_API(void) THSJIT_Module_train(JITModule module, bool on); EXPORT_API(void) THSJIT_Module_eval(JITModule module); diff --git a/src/TorchSharp/JIT/ScriptModule.cs b/src/TorchSharp/JIT/ScriptModule.cs index 14e5d4773..7166febeb 100644 --- a/src/TorchSharp/JIT/ScriptModule.cs +++ b/src/TorchSharp/JIT/ScriptModule.cs @@ -143,6 +143,23 @@ public override bool training { } } + public override void zero_grad(bool set_to_none = true) + { + THSJIT_Module_zero_grad(handle, set_to_none); + CheckForErrors(); + + foreach (var (_, p) in named_parameters()) { + using var grad = p.grad; + if (grad is not null) { + if (set_to_none) { + p.grad = null; + } else { + grad.zero_(); + } + } + } + } + protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) { if (device.type != DeviceType.CUDA) { device = new Device(device.type, -1); }; diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs index 074fcc247..4cdc25e82 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs @@ -57,6 +57,9 @@ internal static partial class NativeMethods [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSJIT_Module_is_training(torch.nn.Module.HType module); + [DllImport("LibTorchSharp")] + internal static extern void THSJIT_Module_zero_grad(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.U1)] bool set_to_none); + [DllImport("LibTorchSharp")] internal static extern void THSJIT_Module_to_device(torch.nn.Module.HType module, long deviceType, long deviceIndex); From 03d5cf5d5e73f02b469cbffce3c2d96ec3cadbbb Mon Sep 17 00:00:00 2001 From: Masaru Kimura Date: Fri, 30 May 2025 17:24:08 +0900 Subject: [PATCH 2/3] Try to fix build failure on ubuntu & mac. --- src/Native/LibTorchSharp/THSJIT.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp index 56b161895..8a40aa75b 100644 --- a/src/Native/LibTorchSharp/THSJIT.cpp +++ b/src/Native/LibTorchSharp/THSJIT.cpp @@ -74,7 +74,7 @@ void THSJIT_Module_zero_grad(const JITModule module, bool set_to_none) // torch::jit::Module has no zero_grad(). // As a workaround, manually loop over the parameters and zero them out like optimizer does; // https://github.com/pytorch/pytorch/blob/v2.5.1/torch/csrc/api/src/optim/optimizer.cpp#L123 - for (auto& p : (*module)->parameters()) { + for (const auto& p : (*module)->parameters()) { if (p.mutable_grad().defined()) { p.mutable_grad().detach_(); if (set_to_none) From 05028a35e4bb0fff4d98f6cb201bce205995cf8f Mon Sep 17 00:00:00 2001 From: Masaru Kimura Date: Thu, 3 Jul 2025 09:53:07 +0900 Subject: [PATCH 3/3] Update RELEASENOTES.md. --- RELEASENOTES.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 3e9c01a83..5c4ddba66 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -1,6 +1,12 @@ ## TorchSharp Release Notes Releases, starting with 9/2/2021, are listed with the most recent release at the top. +# NuGet Version 0.105.2 + +__API Changes__: + +Fix torch.jit.ScriptModule.zero_grad.
+ # NuGet Version 0.105.1 __Bug Fixes__: