From 3666ebb3371c1ff49a9d0a3ed83d5339f9aa6a4b Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 19 Apr 2023 17:51:35 -0700 Subject: [PATCH] Reland of "Python binding to set/get CUDA rng state offset" (#99565) Summary: Why? * To reduce the latency of hot path in https://github.com/pytorch/pytorch/pull/97377 Concern - I had to add `set_offset` in all instances of `GeneratorImpl`. I don't know if there is a better way. ~~~~ import torch torch.cuda.manual_seed(123) print(torch.cuda.get_rng_state()) torch.cuda.set_rng_state_offset(40) print(torch.cuda.get_rng_state()) tensor([123, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.uint8) tensor([123, 0, 0, 0, 0, 0, 0, 0, 40, 0, 0, 0, 0, 0, 0, 0], dtype=torch.uint8) ~~~~ Reland of https://github.com/pytorch/pytorch/pull/98965 (cherry picked from commit 8214fe07e8a200e0fe9ca4264bb6fca985c4911e) X-link: https://github.com/pytorch/pytorch/pull/99565 Reviewed By: anijain2305 Differential Revision: D45130271 Pulled By: malfet fbshipit-source-id: a31caebf1cecd6a6f9d4552d98a2b9a6eb7690e5 --- torchcsprng/csrc/kernels_commons.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchcsprng/csrc/kernels_commons.h b/torchcsprng/csrc/kernels_commons.h index c94167e..3e74d35 100644 --- a/torchcsprng/csrc/kernels_commons.h +++ b/torchcsprng/csrc/kernels_commons.h @@ -39,6 +39,8 @@ struct CSPRNGGeneratorImpl : public c10::GeneratorImpl { void set_state(const c10::TensorImpl& new_state) override { throw std::runtime_error("not implemented"); } c10::intrusive_ptr get_state() const override { throw std::runtime_error("not implemented"); } + void set_offset(uint64_t offset) override { throw std::runtime_error("not implemented"); } + uint64_t get_offset() const override { throw std::runtime_error("not implenented"); } bool use_rd_; std::random_device rd_; std::mt19937 mt_;