diff --git a/graph.py b/graph.py new file mode 100644 index 0000000000..5d47261674 --- /dev/null +++ b/graph.py @@ -0,0 +1,1304 @@ +# this version adds activation_manager context for offloading activations. + +import abc +import contextlib +import functools +import logging +import threading +from collections import defaultdict, deque +from collections.abc import Generator, Iterable, Iterator, MutableMapping, Sequence +from typing import ( + Any, + Callable, + cast, + Literal, + NamedTuple, + Optional, + TYPE_CHECKING, + Union, +) +from typing_extensions import TypeAlias +from weakref import WeakKeyDictionary, WeakValueDictionary + +import torch +from torch.autograd.variable import Variable +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils.hooks import RemovableHandle + + +if TYPE_CHECKING: + from torch._ops import OpOverload + + +__all__ = [ + "saved_tensors_hooks", + "save_on_cpu", + "disable_saved_tensors_hooks", + "register_multi_grad_hook", + "allow_mutation_on_saved_tensors", + "Node", + "GradientEdge", + "get_gradient_edge", + "increment_version", +] + + +log = logging.getLogger(__name__) + + +class Node(abc.ABC): + @abc.abstractmethod + def name(self) -> str: + r"""Return the name. + + Example:: + + >>> import torch + >>> a = torch.tensor([0., 0., 0.], requires_grad=True) + >>> b = a.clone() + >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) + >>> print(b.grad_fn.name()) + CloneBackward0 + """ + raise NotImplementedError + + @property + @abc.abstractmethod + def next_functions(self) -> tuple[tuple[Optional["Node"], int], ...]: + raise NotImplementedError + + @abc.abstractmethod + def metadata(self) -> dict: + r"""Return the metadata.""" + raise NotImplementedError + + @property + @abc.abstractmethod + def _input_metadata(self) -> list[Any]: + raise NotImplementedError + + @abc.abstractmethod + def _register_hook_dict(self, tensor: torch.Tensor) -> None: + raise NotImplementedError + + @abc.abstractmethod + def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle: + r"""Register a backward hook. + + The hook will be called every time a gradient with respect to the + Node is computed. The hook should have the following signature:: + + hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None + + + The hook should not modify its argument, but it can optionally return + a new gradient which will be used in place of :attr:`grad_inputs`. + + This function returns a handle with a method ``handle.remove()`` + that removes the hook from the module. + + .. note:: + See :ref:`backward-hooks-execution` for more information on how when this hook + is executed, and how its execution is ordered relative to other hooks. + + .. note:: + In the rare case where the hook is registered while the Node has already + begun execution, there is no longer any guarantee on :attr:`grad_outputs` + content (it might be as usual or empty depending on other factors). The + hook can still optionally return a new gradient to be used in place of + :attr:`grad_inputs` independent of :attr:`grad_outputs`. + + Example:: + + >>> import torch + >>> a = torch.tensor([0., 0., 0.], requires_grad=True) + >>> b = a.clone() + >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) + >>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,)) + >>> b.sum().backward(retain_graph=True) + >>> print(a.grad) + tensor([2., 2., 2.]) + >>> handle.remove() # Removes the hook + >>> a.grad = None + >>> b.sum().backward(retain_graph=True) + >>> print(a.grad) + tensor([1., 1., 1.]) + """ + raise NotImplementedError + + @abc.abstractmethod + def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle: + r"""Register a backward pre-hook. + + The hook will be called every time a gradient with respect to the + Node is computed. The hook should have the following signature:: + + hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None + + The hook should not modify its argument, but it can optionally return + a new gradient which will be used in place of :attr:`grad_outputs`. + + This function returns a handle with a method ``handle.remove()`` + that removes the hook from the module. + + .. note:: + See :ref:`backward-hooks-execution` for more information on how when this hook + is executed, and how its execution is ordered relative to other hooks. + + Example:: + + >>> a = torch.tensor([0., 0., 0.], requires_grad=True) + >>> b = a.clone() + >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) + >>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,)) + >>> b.sum().backward(retain_graph=True) + >>> print(a.grad) + tensor([2., 2., 2.]) + >>> handle.remove() + >>> a.grad = None + >>> b.sum().backward(retain_graph=True) + >>> print(a.grad) + tensor([1., 1., 1.]) + """ + raise NotImplementedError + + @classmethod + def __subclasshook__(cls, subclass: type) -> bool: + if cls is Node and ( + ( + subclass is not None + and subclass is getattr(torch._C._functions, subclass.__name__, None) + ) + or issubclass(subclass, torch.autograd.function.BackwardCFunction) + ): + return True + return NotImplemented + + +def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, "GradientEdge"]) -> Node: + if isinstance(t, GradientEdge): + return t.node + if t.requires_grad and t.grad_fn is None: + with torch.enable_grad(): + node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr] + else: + node = t.grad_fn + assert node is not None + return node + + +class GradientEdge(NamedTuple): + """Object representing a given gradient edge within the autograd graph. + + To get the gradient edge where a given Tensor gradient will be computed, + you can do ``edge = autograd.graph.get_gradient_edge(tensor)``. + """ + + node: Node + output_nr: int + + +def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge: + """Get the gradient edge for computing the gradient of the given Tensor. + + In particular, it is equivalent to call + ``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``. + """ + if not tensor.requires_grad: + raise RuntimeError( + "It is not possible to get the gradient edge for a Tensor " + "that does not require gradients", + ) + grad_fn = _get_grad_fn_or_grad_acc(tensor) + + # Note that output_nr default to 0 which is the right value + # for the AccumulateGrad node. + return GradientEdge(grad_fn, tensor.output_nr) + + +def increment_version(tensor: Union[torch.Tensor, Iterable[torch.Tensor]]) -> None: + """Update autograd metadata tracking whether the given Tensor was modified in place. + + This is to enable more accurate error checking within the autograd engine. + It is already done automatically by PyTorch functions and within custom Function + when mark_dirty() is called appropriately so you only need to call this explicitly + if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't + know about. For example a custom kernel that reads the Tensor data_ptr and modifies + the memory inplace based on this pointer. Can accept either a tensor, or a list of tensors. + + Note that incrementing the version counter multiple times for a single inplace operation + is not problematic. + + Note that if you pass in tensor constructed under torch.inference_mode(), + we will not bump its version counter (because your tensor does not have one). + """ + if isinstance(tensor, torch.Tensor): + tensor = (tensor,) + torch._C._increment_version(tensor) + + +class saved_tensors_hooks: + """Context-manager that sets a pair of pack / unpack hooks for saved tensors. + + Use this context-manager to define how intermediary results of an operation + should be packed before saving, and unpacked on retrieval. + + In that context, the ``pack_hook`` function will be called everytime an + operation saves a tensor for backward (this includes intermediary results + saved using + :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but + also those recorded by a PyTorch-defined operation). The output of + ``pack_hook`` is then stored in the computation graph instead of the + original tensor. + + The ``unpack_hook`` is called when the saved tensor needs to be accessed, + namely when executing :func:`torch.Tensor.backward()` or + :func:`torch.autograd.grad()`. It takes as argument the *packed* object + returned by ``pack_hook`` and should return a tensor which has the same + content as the original tensor (passed as input to the corresponding + ``pack_hook``). + + The hooks should have the following signatures: + + pack_hook(tensor: Tensor) -> Any + + unpack_hook(Any) -> Tensor + + where the return value of ``pack_hook`` is a valid input to ``unpack_hook``. + + In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms + of value, size, dtype and device. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> def pack_hook(x): + ... print("Packing", x) + ... return x.detach() + >>> + >>> def unpack_hook(x): + ... print("Unpacking", x) + ... return x + >>> + >>> a = torch.ones(5, requires_grad=True) + >>> b = torch.ones(5, requires_grad=True) * 2 + >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): + ... y = a * b + Packing tensor([1., 1., 1., 1., 1.], requires_grad=True) + Packing tensor([2., 2., 2., 2., 2.], grad_fn=) + >>> y.sum().backward() + Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True) + Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=) + + .. warning :: + Performing an inplace operation on the input to either hooks may lead + to undefined behavior. + + .. warning :: + Only one pair of hooks is allowed at a time. When recursively nesting this + context-manager, only the inner-most pair of hooks will be applied. + + .. warning :: + To avoid reference cycle, the return value of ``pack_hook`` cannot hold a + reference to the input tensor. For example, use `lambda x: x.detach()` + instead of `lambda x: x` as the pack hook. + """ + + def __init__( + self, + pack_hook: Callable[[torch.Tensor], Any], + unpack_hook: Callable[[Any], torch.Tensor], + ) -> None: + self.pack_hook = pack_hook + self.unpack_hook = unpack_hook + + def __enter__(self) -> None: + torch._C._autograd._push_saved_tensors_default_hooks( + self.pack_hook, self.unpack_hook + ) + + def __exit__(self, *args: object) -> None: + torch._C._autograd._pop_saved_tensors_default_hooks() + + +class save_on_cpu(saved_tensors_hooks): + """Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward. + + When performing operations within this context manager, intermediary + results saved in the graph during the forward pass will be moved to CPU, + then copied back to the original device when needed for the backward pass. + If the graph was already on CPU, no tensor copy is performed. + + Use this context-manager to trade compute for GPU memory usage (e.g. + when your model doesn't fit in GPU memory during training). + + Args: + pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory + during packing and copied to GPU asynchronously during unpacking. + Defaults to ``False``. + Also see :ref:`cuda-memory-pinning`. + + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) + >>> a = torch.randn(5, requires_grad=True, device="cuda") + >>> b = torch.randn(5, requires_grad=True, device="cuda") + >>> c = torch.randn(5, requires_grad=True, device="cuda") + >>> + >>> def f(a, b, c): + ... prod_1 = a * b # a and b are saved on GPU + ... with torch.autograd.graph.save_on_cpu(): + ... prod_2 = prod_1 * c # prod_1 and c are saved on CPU + ... y = prod_2 * a # prod_2 and a are saved on GPU + ... return y + >>> + >>> y = f(a, b, c) + >>> del a, b, c # for illustration only + >>> # the content of a, b, and prod_2 are still alive on GPU + >>> # the content of prod_1 and c only live on CPU + >>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward + >>> # all intermediary tensors are released (deleted) after the call to backward + """ + + def __init__(self, pin_memory: bool = False, device_type: str = "cuda") -> None: + device_module = getattr(torch, device_type, torch.cuda) + + def pack_to_cpu(tensor: torch.Tensor) -> tuple[torch.device, torch.Tensor]: + if not pin_memory: + return (tensor.device, tensor.cpu()) + packed = torch.empty( + tensor.size(), + dtype=tensor.dtype, + layout=tensor.layout, + pin_memory=(device_module.is_available() and not tensor.is_sparse), + ) + packed.copy_(tensor) + return (tensor.device, packed) + + def unpack_from_cpu(packed: tuple[torch.device, torch.Tensor]) -> torch.Tensor: + device, tensor = packed + return tensor.to(device, non_blocking=pin_memory) + + super().__init__(pack_to_cpu, unpack_from_cpu) + +# ================= activation manager ================= + +import time +import secrets + +def perf_timer(func): + def wrapper(*args, **kwargs): + start = time.perf_counter() + output = func(*args, **kwargs) + elapsed_time = time.perf_counter() - start + print(elapsed_time) + return output, elapsed_time + + return wrapper + +class manage_activations(saved_tensors_hooks): + """Context manager under which activation tensors created in the forward pass will be managed + """ + + def __init__(self, pin_memory: bool = True, device_type: str = "cuda") -> None: + device_module = getattr(torch, device_type, torch.cuda) + + self.caching: bool = True # now we're managing cpu cached memory blocks + self.min_tensor_size_bytes = 1024 # we don't want to bother with small tensors + self.tracker = {} # tensor_id = (new_tensor, dtype, if_modified, device) ---> track what saved/offloaded/compressed tensors, are where + self.mem_offload_cache = {} # cache of available memory blocks for tensors, keyed by (shape, dtype) + self.gb = 1024 * 1024 * 1024 # bytes in a gigabyte + self.ignore_types = [torch.complex64, torch.int64] # high precision and thus not good for quantization + self.is_first_forward = True + self.is_first_backward = True + # metrics + self.timing: bool = True + self.forward_start_time = 0 + self.backward_start_time = 0 + # tensor id counter + self.tensor_id_counter = 0 + self.active_tensors = 0 + self.pin_memory = pin_memory + self.cache_hits = 0 + self.cache_misses = 0 + + # platform util functions + def get_tensor_id()-> str: + # create a unique id for each tensor we are managing using an incrementing counter + tensor_id = str(self.tensor_id_counter) + self.tensor_id_counter += 1 + self.active_tensors += 1 + return tensor_id + + def get_cache_key(x: torch.Tensor)-> tuple[tuple, torch.dtype]: + # get the tensor shape and dtype as a tuple for cached memory re-use + return (tuple(x.size()), x.dtype) + + def get_num_bytes_tensor(x: torch.Tensor) -> int: + # get the number of bytes in a tensor, for memory management purposes + return x.element_size() * x.nelement() + + # -------- core pack / unpack work -------- + def pack_tensor(activation: torch.Tensor) -> str: + # activations are passed in during forward pass - from here we take over and return a unique id + if self.is_first_forward: + if self.timing: + if self.backward_start_time: + end_backward_time = time.perf_counter() + print(f"***** backward pass took {(end_backward_time - self.backward_start_time):.3f} seconds") + self.forward_start_time = time.perf_counter() + + print(f"total managed activations {len(self.tracker)=}") + print(f"cache stats: hits={self.cache_hits}, misses={self.cache_misses}") + print(f"cache size: {sum(len(v) for v in self.mem_offload_cache.values())} tensors") + + # Clear tracker but keep the memory cache + self.tracker.clear() + self.active_tensors = 0 + self.cache_hits = 0 + self.cache_misses = 0 + + print("***** first forward") + self.is_first_forward = False + self.is_first_backward = True + + # query for basic tensor info + activation_dtype = activation.dtype + num_bytes = get_num_bytes_tensor(activation) + sizes = activation.size() + tensor_id = get_tensor_id() + original_device = activation.device + cache_key = get_cache_key(activation) + + # skipping complex types, small tensors, and tensors with unsupported dtypes + if num_bytes < self.min_tensor_size_bytes or (activation_dtype in self.ignore_types): + print(f"skipping activation of {num_bytes}, size= {sizes}, {activation_dtype=}") + + # Store on CPU + if self.pin_memory: + cpu_tensor = torch.empty( + activation.size(), + dtype=activation.dtype, + layout=activation.layout, + pin_memory=True + ) + cpu_tensor.copy_(activation) + else: + cpu_tensor = activation.cpu().clone().detach() + + self.tracker[tensor_id] = (cpu_tensor, activation.dtype, False, original_device) # False = not modified + return tensor_id + else: + # main activation management code + print(f"Storing activation {sizes}, {num_bytes=}, {activation.dtype=} as {tensor_id}") + + # Try to reuse cached memory + cpu_tensor = None + if self.caching and cache_key in self.mem_offload_cache and len(self.mem_offload_cache[cache_key]) > 0: + # Reuse existing tensor from cache + cpu_tensor = self.mem_offload_cache[cache_key].pop() + cpu_tensor.copy_(activation) + self.cache_hits += 1 + print(f"Cache hit for {cache_key}") + else: + # Create new tensor + if self.pin_memory: + cpu_tensor = torch.empty( + activation.size(), + dtype=activation.dtype, + layout=activation.layout, + pin_memory=True + ) + cpu_tensor.copy_(activation) + else: + cpu_tensor = activation.cpu().clone().detach() + self.cache_misses += 1 + print(f"Cache miss for {cache_key}") + + self.tracker[tensor_id] = (cpu_tensor, activation_dtype, True, original_device) # True = (in future) modified + return tensor_id + + def unpack_tensor(unpack_tensor_id: str) -> torch.Tensor: + # backward pass - we are called with the tensor_id. + # We then use the tensor_id to retrieve the saved/offloaded/compressed tensor + # and return it in original state (or near original for quantized) + if self.is_first_backward: + self.is_first_backward = False + self.is_first_forward = True + if self.timing: + end_forward_time = time.perf_counter() + print(f"***** forward took {(end_forward_time - self.forward_start_time):.3f} seconds") + print(f"***** first backward, managing {len(self.tracker)} tensors") + self.backward_start_time = time.perf_counter() + + # retrieve the saved/offloaded/compressed tensor + assert unpack_tensor_id in self.tracker, f"untracked tensor, {unpack_tensor_id}" + cpu_tensor, dtype, modified, original_device = self.tracker[unpack_tensor_id] + print(f"Unpacking {unpack_tensor_id}, {cpu_tensor.size()}, {cpu_tensor.dtype=}, {modified=}") + + # Move tensor back to original device + gpu_tensor = cpu_tensor.to(original_device, non_blocking=self.pin_memory) + + # Add the CPU tensor back to the cache for reuse + if self.caching: + cache_key = get_cache_key(cpu_tensor) # Use the same function as in pack_tensor + if cache_key not in self.mem_offload_cache: + self.mem_offload_cache[cache_key] = [] + self.mem_offload_cache[cache_key].append(cpu_tensor) + print(f"Added tensor to cache with key {cache_key}") + else: + # If not caching, we don't need the CPU tensor anymore + del cpu_tensor + + # clear tensor from tracking + del self.tracker[unpack_tensor_id] + # decrement active tensor count + self.active_tensors -= 1 + print(f"Active tensors: {self.active_tensors}") + return gpu_tensor + + + super().__init__(pack_tensor, unpack_tensor) + + +class manage_activations(saved_tensors_hooks): + """Context manager under which activation tensors created in the forward pass will be managed + """ + + def __init__(self, pin_memory: bool = True, device_type: str = "cuda", offload_ratio: float = 1.0) -> None: + device_module = getattr(torch, device_type, torch.cuda) + + self.caching: bool = False # are we managing cpu cached memory blocks + self.min_tensor_size_bytes = 4096 # we don't want to bother with small tensors + self.tracker = {} # tensor_id = (new_tensor, dtype, if_modified, device) ---> track what saved/offloaded/compressed tensors, are where + self.mem_offload_cache = {} # cache of available memory blocks for tensors + self.gb = 1024 * 1024 * 1024 # bytes in a gigabyte + self.ignore_types = [torch.complex64, torch.int64] # high precision and thus not good for quantization + self.is_first_forward = True + self.is_first_backward = True + # metrics + self.timing: bool = True + self.forward_start_time = 0 + self.backward_start_time = 0 + # tensor id counter + self.tensor_id_counter = 0 + self.active_tensors = 0 + self.pin_memory = pin_memory + # offload ratio (0.0 = keep all on GPU, 1.0 = offload all to CPU) + self.offload_ratio = max(0.0, min(1.0, offload_ratio)) + self.offloaded_count = 0 + self.kept_on_gpu_count = 0 + + # platform util functions + def get_tensor_id()-> str: + # create a unique id for each tensor we are managing using an incrementing counter + tensor_id = str(self.tensor_id_counter) + self.tensor_id_counter += 1 + self.active_tensors += 1 + return tensor_id + + def get_tensor_size_id( x: torch.Tensor)-> tuple[int, tuple]: + # get the tensor shape and total bytes as a tuple for cached memory re-use + num_bytes = get_num_bytes_tensor(x) + return (num_bytes, x.size()) + + def get_num_bytes_tensor( x: torch.Tensor) -> int: + # get the number of bytes in a tensor, for memory management purposes + return x.element_size() * x.nelement() #x.element_size() * x._base_storage().nbytes() + + def get_bytes_per_dimension(x: torch.Tensor)-> tuple[int, ...]: + # this might be too slow but is a way to provide a full byte signature for a tensor + # and used to match available memory sizes for caching + # alternative = (total_bytes, tensor.shape) which does not account for strides + element_size = x.element_size() + shape = x.shape + stride = x.stride() + + bytes_per_dim = [] + for dim, (size, stride_val) in enumerate(zip(shape, stride)): + bytes_in_dim = size * stride_val * element_size + bytes_per_dim.append(bytes_in_dim) + + return tuple(bytes_per_dim) + + # -------- core pack / unpack work -------- + def pack_tensor(activation: torch.Tensor) -> str: + # activations are passed in during forward pass - from here we take over and return a unique id + if self.is_first_forward: + if self.timing: + if self.backward_start_time: + end_backward_time = time.perf_counter() + print(f"***** backward pass took {(end_backward_time - self.backward_start_time):.3f} seconds") + self.forward_start_time = time.perf_counter() + + print(f"total managed activations {len(self.tracker)=}") + print(f"offload ratio: {self.offload_ratio}, offloaded: {self.offloaded_count}, kept on GPU: {self.kept_on_gpu_count}") + #if not self.caching: + self.tracker.clear() + self.active_tensors = 0 + self.offloaded_count = 0 + self.kept_on_gpu_count = 0 + + print("***** first forward") + self.is_first_forward = False + self.is_first_backward = True + + # query for basic tensor info + activation_dtype = activation.dtype + num_bytes = get_num_bytes_tensor(activation) + sizes = activation.size() + tensor_id = get_tensor_id() + original_device = activation.device + + # skipping complex types, small tensors, and tensors with unsupported dtypes + if num_bytes < self.min_tensor_size_bytes or (activation_dtype in self.ignore_types): + print(f"skipping activation of {num_bytes}, size= {sizes}, {activation_dtype=}") + + # Store on CPU + if self.pin_memory: + cpu_tensor = torch.empty( + activation.size(), + dtype=activation.dtype, + layout=activation.layout, + pin_memory=True + ) + cpu_tensor.copy_(activation) + else: + cpu_tensor = activation.cpu().clone().detach() + + self.tracker[tensor_id] = (cpu_tensor, activation.dtype, False, original_device) # False = not modified + return tensor_id + else: + # main activation management code + print(f"Storing activation {sizes}, {num_bytes=}, {activation.dtype=} as {tensor_id}") + + # Decide whether to offload this tensor based on the ratio + should_offload = (self.offloaded_count / (self.offloaded_count + self.kept_on_gpu_count + 1) < self.offload_ratio) + + if should_offload: + # Store on CPU + if self.pin_memory: + cpu_tensor = torch.empty( + activation.size(), + dtype=activation.dtype, + layout=activation.layout, + pin_memory=True + ) + cpu_tensor.copy_(activation) + else: + cpu_tensor = activation.cpu().clone().detach() + + self.tracker[tensor_id] = (cpu_tensor, activation_dtype, True, original_device) + self.offloaded_count += 1 + print(f"Offloaded tensor {tensor_id} to CPU") + else: + # Keep on GPU + gpu_tensor = activation.clone().detach() + self.tracker[tensor_id] = (gpu_tensor, activation_dtype, True, original_device) + self.kept_on_gpu_count += 1 + print(f"Kept tensor {tensor_id} on GPU") + + return tensor_id + + def unpack_tensor(unpack_tensor_id: str) -> torch.Tensor: + # backward pass - we are called with the tensor_id. + # We then use the tensor_id to retrieve the saved/offloaded/compressed tensor + # and return it in original state (or near original for quantized) + if self.is_first_backward: + self.is_first_backward = False + self.is_first_forward = True + if self.timing: + end_forward_time = time.perf_counter() + print(f"***** forward took {(end_forward_time - self.forward_start_time):.3f} seconds") + print(f"***** first backward, managing {len(self.tracker)} tensors") + print(f"offloaded: {self.offloaded_count}, kept on GPU: {self.kept_on_gpu_count}") + self.backward_start_time = time.perf_counter() + + # retrieve the saved/offloaded/compressed tensor + assert unpack_tensor_id in self.tracker, f"untracked tensor, {unpack_tensor_id}" + tensor, dtype, modified, original_device = self.tracker[unpack_tensor_id] + print(f"Unpacking {unpack_tensor_id}, {tensor.size()}, {tensor.dtype=}, {modified=}") + + # If tensor is on CPU, move it back to original device + if tensor.device.type != original_device.type: + gpu_tensor = tensor.to(original_device, non_blocking=self.pin_memory) + else: + # Already on the right device + gpu_tensor = tensor + + # clear tensor from tracking + del self.tracker[unpack_tensor_id] + # decrement active tensor count + self.active_tensors -= 1 + print(f"Active tensors: {self.active_tensors}") + return gpu_tensor + + super().__init__(pack_tensor, unpack_tensor) + + +class manage_activations_old(saved_tensors_hooks): + """Context manager under which activation tensors created in the forward pass will be managed + """ + + def __init__(self, pin_memory: bool = False, device_type: str = "cuda") -> None: + device_module = getattr(torch, device_type, torch.cuda) + + self.caching: bool = False # are we managing cpu cached memory blocks + self.min_tensor_size_bytes = 1024 # we don't want to bother with small tensors + self.tracker = {} # tensor_id = (new_tensor, dtype, if_modified) ---> track what saved/offloaded/compressed tensors, are where + self.mem_offload_cache = {} # cache of available memory blocks for tensors + self.gb = 1024 * 1024 * 1024 # bytes in a gigabyte + self.ignore_types = [torch.complex64, torch.int64] # high precision and thus not good for quantization + self.is_first_forward = True + self.is_first_backward = True + # metrics + self.timing: bool = True + self.forward_start_time = 0 + self.backward_start_time = 0 + + # platform util functions + def get_tensor_id()-> str: + # create a unique id for each tensor we are managing + return secrets.token_urlsafe(nbytes=8) + + def get_tensor_size_id( x: torch.Tensor)-> Tuple[int]: + # get the tensor shape and total bytes as a tuple for cached memory re-use + num_bytes = self.get_num_bytes_tensor(x) + return tuple(num_bytes, x.size()) + + def get_num_bytes_tensor( x: torch.Tensor) -> int: + # get the number of bytes in a tensor, for memory management purposes + return x.element_size() * x.nelement() #x.element_size() * x._base_storage().nbytes() + + def get_bytes_per_dimension(x: torch.Tensor)-> Tuple[int]: + # this might be too slow but is a way to provide a full byte signature for a tensor + # and used to match available memory sizes for caching + # alternative = (total_bytes, tensor.shape) which does not account for strides + element_size = x.element_size() + shape = x.shape + stride = x.stride() + + bytes_per_dim = [] + for dim, (size, stride_val) in enumerate(zip(shape, stride)): + bytes_in_dim = size * stride_val * element_size + bytes_per_dim.append(bytes_in_dim) + + return tuple(bytes_per_dim) + + # -------- core pack / unpack work -------- + def pack_tensor(activation: torch.Tensor) -> str: + # activations are passed in during forward pass - from here we take over and return a unique id + if self.is_first_forward: + if self.timing: + if self.backward_start_time: + end_backward_time = time.perf_counter() + print(f"***** backward pass took {(end_backward_time - self.backward_start_time):.3f} seconds") + self.forward_start_time = time.perf_counter() + + print(f"total managed activations {len(self.tracker)=}") + #if not self.caching: + self.tracker.clear() + + print("***** first forward") + self.is_first_forward = False + self.is_first_backward = True + + # query for basic tensor info + activation_dtype = activation.dtype + num_bytes = get_num_bytes_tensor(activation) + sizes = activation.size() + tensor_id = get_tensor_id() + + # skipping complex types, small tensors, and tensors with unsupported dtypes + if num_bytes < self.min_tensor_size_bytes or (activation_dtype in self.ignore_types): + print(f"skipping activation of {num_bytes}, size= {sizes}, {activation_dtype=}") + + gpu_clone = activation.clone().detach() + self.tracker[tensor_id] = (gpu_clone, activation.dtype, False) # False = not modified + return tensor_id + else: + # main activation management code + print(f"Storing activation {sizes}, {num_bytes=}, {activation.dtype=} as {tensor_id}") + gpu_clone = activation.clone().detach() + self.tracker[tensor_id] = (gpu_clone, activation_dtype, True) # True = (in future) modified + return tensor_id + + def unpack_tensor(unpack_tensor_id: str) -> torch.Tensor: + # backward pass - we are called with the tensor_id. + # We then use the tensor_id to retrieve the saved/offloaded/compressed tensor + # and return it in original state (or near original for quantized) + if self.is_first_backward: + self.is_first_backward = False + self.is_first_forward = True + if self.timing: + end_forward_time = time.perf_counter() + print(f"***** forward took {(end_forward_time - self.forward_start_time):.3f} seconds") + print(f"***** first backward, managing {len(self.tracker)} tensors") + self.backward_start_time = time.perf_counter() + + # retrieve the saved/offloaded/compressed tensor + assert unpack_tensor_id in self.tracker, f"untracked tensor, {unpack_tensor_id}" + gpu_tensor, dtype, modified = self.tracker[unpack_tensor_id] + print(f"Unpacking {unpack_tensor_id}, {gpu_tensor.size()}, {gpu_tensor.dtype=}, {modified=}") + # clear tensor from tracking + del self.tracker[unpack_tensor_id] + return gpu_tensor + + super().__init__(pack_tensor, unpack_tensor) + + + + +# ================= backward hooks ================= +@contextlib.contextmanager +def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, None]: + """Context-manager that disables the saved tensors default hooks feature. + + Useful for if you are creating a feature that does not work with saved + tensors default hooks. + + Args: + error_message (str): When saved tensors default hooks are used when they + have been are disabled, a RuntimeError with this + error message gets raised. + + Example:: + + >>> # xdoctest: +SKIP(failing) + >>> message = "saved tensors default hooks are disabled" + >>> with torch.autograd.graph.disable_saved_tensors_hooks(message): + ... # Raises RuntimeError: saved tensors default hooks are disabled + ... with torch.autograd.graph.save_on_cpu(): + ... pass + """ + maybe_prev_message = None + try: + maybe_prev_message = ( + torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() + ) + torch._C._autograd._saved_tensors_hooks_disable(error_message) + yield + finally: + # See NOTE: [disabled_error_message invariant] + if maybe_prev_message is None: + torch._C._autograd._saved_tensors_hooks_enable() + else: + torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message) + + +class _MultiHandle(RemovableHandle): + handles: tuple[RemovableHandle, ...] + + def __init__(self, handles: tuple[RemovableHandle, ...]) -> None: + self.handles = handles + + def remove(self) -> None: + for handle in self.handles: + handle.remove() + + def __getstate__(self) -> tuple[RemovableHandle, ...]: + return self.handles + + def __setstate__(self, state: tuple[RemovableHandle, ...]) -> None: + self.handles = state + + +def register_multi_grad_hook( + tensors: Sequence[torch.Tensor], + fn: Union[ + Callable[[Sequence[Optional[torch.Tensor]]], None], + Callable[[torch.Tensor], None], + ], + *, + mode: Literal["all", "any"] = "all", +) -> RemovableHandle: + r"""Register a multi-grad backward hook. + + There are two supported modes: ``"all"`` and ``"any"``. + + Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in + :attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but + is not part of the graph, or if a tensor is not needed to compute the gradients + for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call, + this tensor will be ignored and the hook will not wait for its gradient to be + computed. + + After every non-ignored tensor's gradient has been computed, :attr:`fn` will be + called with those gradients. ``None`` will be passed for tensors that did not + have their gradients computed. + + Under the ``"any"`` mode, the hook will be called after the first gradient + with respect to a tensor in :attr:`tensors` has been computed. The hook + will be called with that gradient as its argument. + + The hook should not modify its arguments. + + This function returns a handle with a method ``handle.remove()`` that removes the hook. + + .. note:: + See :ref:`backward-hooks-execution` for more information on how when this hook + is executed, and how its execution is ordered relative to other hooks. + + Example:: + + >>> import torch + >>> + >>> a = torch.rand(2, 3, requires_grad=True) + >>> b = torch.rand(2, 3, requires_grad=True) + >>> c = a * b + >>> d = a * b + >>> + >>> def fn(grads): + ... print([g is not None for g in grads]) + ... + >>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn) + >>> + >>> c.sum().backward(retain_graph=True) + [True, True, True, False] + >>> c.sum().backward(inputs=(a,), retain_graph=True) + [True, False, True, False] + >>> + """ + supported_modes = ("all", "any") + lock = threading.Lock() + + if mode not in supported_modes: + raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}") + + if mode == "all": + count: dict[int, int] = {} + nb_calls = None + buffer: dict[int, list[Optional[torch.Tensor]]] = {} + + grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors)) + len_tensors = len(tensors) + + def get_inner_hook(idx: int) -> Callable[[torch.Tensor], None]: + def inner_hook(grad: torch.Tensor) -> None: + nonlocal count, nb_calls, buffer, fn + id = torch._C._current_graph_task_id() + assert ( + id != -1 + ), "expected this hook to be called inside a backward call" + count[id] = count.get(id, 0) + buffer[id] = buffer.get(id, [None] * len_tensors) + + with lock: + curr_count, count[id] = count[id], count[id] + 1 + + if curr_count == 0: + # On the first call, compute the actual nb_calls and buffer + nb_calls = sum( + map(torch._C._will_engine_execute_node, grad_fns) + ) + + buffer[id][idx] = grad + + assert nb_calls is not None + if curr_count == nb_calls - 1: + fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn) + fn(buffer[id]) + del count[id] + del buffer[id] + + return inner_hook + + handles = tuple( + t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors) + ) + elif mode == "any": + fn = cast(Callable[[torch.Tensor], None], fn) + ran_hook: dict[int, bool] = defaultdict(bool) + + @functools.wraps(fn) + def wrapped_fn(grad: torch.Tensor) -> None: + nonlocal ran_hook + id = torch._C._current_graph_task_id() + assert id != -1, "expected this hook to be called inside a backward call" + with lock: + prev, ran_hook[id] = ran_hook[id], True + if prev: + return + fn(grad) + + handles = tuple( + tensor.register_hook(wrapped_fn) + for tensor in tensors + if tensor.requires_grad + ) + + return _MultiHandle(handles) # type: ignore[possibly-undefined] + + +# NOTE [Allow mutation on tensors saved for backward] +# +# 1. Tensor gets saved for backward +# - remember the python object id and the version of the tensor +# - remember aliasing information (data_ptr of base + version) +# - save the original so we control its lifetime +# 2. Any time a tensor gets in-placed +# - for each tensor aliased to it: +# - check using its object id and version to see if it has been saved +# - if it has been saved, clone it +# - delete the reference to the original +# 3. during backward +# - if the clone exists, the tensor must've been modified in-place +_allow_mutation_on_saved_tensors_enabled: bool = False + + +_TID: TypeAlias = tuple[int, int, int] +_SID: TypeAlias = tuple[int, int] + + +def _get_tid(tensor: torch.Tensor) -> _TID: + # FIXME: This is almost definitely a bug. + if isinstance( + tensor, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + ): + data_ptr = 0 + else: + data_ptr = tensor.data_ptr() + return (id(tensor), data_ptr, tensor._version) + + +def _get_sid(tensor: torch.Tensor) -> _SID: + # FIXME: This is almost definitely a bug. + if isinstance( + tensor, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + ): + data_ptr = 0 + else: + data_ptr = tensor.data_ptr() + return (data_ptr, tensor._version) + + +class _Handle: + pass + + +class _swap_with_cloned(saved_tensors_hooks): + def __init__(self, ctx: "_AllowMutationOnSavedContext") -> None: + def pack_hook(tensor: torch.Tensor) -> _Handle: + tid = _get_tid(tensor) + sid = _get_sid(tensor) + # Tensors saved for backward have an entry in _tid_to_weakhandle + handle: Optional[_Handle] = None + + # Save aliasing information + ctx.sid_to_tid[sid].add(tid) + + # NB: The same tensor (of the same version) can be saved multiple times + if tid not in ctx.tid_to_weakhandle: + handle = _Handle() + ctx.tid_to_weakhandle[tid] = handle + ctx.original[handle] = tensor + else: + # Store an additional strong reference to the handle + handle = ctx.tid_to_weakhandle[tid] + return handle + + def unpack_hook(handle: _Handle) -> torch.Tensor: + error_msg = ( + "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" + "in which the graph was originally recorded." + ) + assert _allow_mutation_on_saved_tensors_enabled, error_msg + if handle in ctx.cloned: + res = ctx.cloned[handle] + else: + assert handle in ctx.original, error_msg + res = ctx.original[handle] + return res + + super().__init__(pack_hook, unpack_hook) + + +class _CloneArgBeforeMutateMode(TorchDispatchMode): + def __init__(self, ctx: "_AllowMutationOnSavedContext") -> None: + self.ctx = ctx + + def __torch_dispatch__( + self, + func: "OpOverload", + types: Iterable[type], + args: tuple[Any, ...] = (), + kwargs: Optional[dict[Any, Any]] = None, + ) -> Any: + kwargs = kwargs or {} + + def maybe_clone(t: torch.Tensor) -> None: + tid = _get_tid(t) + sid = _get_sid(t) + ctx = self.ctx + if sid in ctx.sid_to_tid: + for tid in ctx.sid_to_tid[sid]: + if tid not in ctx.tid_to_weakhandle: + # We know that if tid is in sid_to_tid, then it must also be in + # tid_to_weakhandle. However, it is possible for the tensor to be + # saved at one point, but cleared by backward before it is modified + # in-place. Consider the following example: + # + # >>> a = torch.randn(2, 3, requires_grad=True).clone() + # >>> out = (a**2).sum() + # >>> out.backward() + # >>> a.sin_() + continue + handle = ctx.tid_to_weakhandle[tid] + if handle in ctx.cloned: + # The same exact tensor has been cloned already + continue + ctx.cloned[handle] = ctx.original[handle].clone() + del ctx.original[handle] + + for idx, arg in enumerate(func._schema.arguments): + if arg.alias_info is not None and arg.alias_info.is_write: + if arg.is_out: + maybe_clone(kwargs["out"]) + elif isinstance(args[idx], list): + # Foreach case. (Possible optimization: if most of the + # tensors need to be cloned, use a for each clone?) + for t in args[idx]: + maybe_clone(t) + else: + maybe_clone(args[idx]) + + return func(*args, **kwargs) + + +class _AllowMutationOnSavedContext: + def __init__(self) -> None: + self.cloned: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary() + self.original: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary() + self.tid_to_weakhandle: MutableMapping[_TID, _Handle] = WeakValueDictionary() + self.sid_to_tid: dict[_SID, set[_TID]] = defaultdict(set) + + def clear(self) -> None: + self.cloned.clear() + self.original.clear() + self.tid_to_weakhandle.clear() + self.sid_to_tid.clear() + + +@contextlib.contextmanager +def allow_mutation_on_saved_tensors() -> ( + Generator[_AllowMutationOnSavedContext, None, None] +): + """Context manager under which mutating tensors saved for backward is allowed. + + Under this context manager, tensors saved for backward are cloned on mutation, + so the original version can still be used during backward. Normally, mutating a tensor + saved for backward will result in an error raised when it's used during backward. + + To ensure the correct behavior, both the forward and backward should be run under + the same context manager. + + Returns: + An _AllowMutationOnSavedContext object storing the state managed by this + context manager. This object can be useful for debugging purposes. The state + managed by the context manager is automatically cleared upon exiting. + + Example:: + + >>> import torch + >>> with torch.autograd.graph.allow_mutation_on_saved_tensors(): + ... # forward + ... a = torch.ones(2, 3, requires_grad=True) + ... b = a.clone() + ... out = (b**2).sum() + ... b.sin_() + ... # backward + ... out.sum().backward() + ... + tensor([[0.8415, 0.8415, 0.8415], + [0.8415, 0.8415, 0.8415]], grad_fn=) + """ + global _allow_mutation_on_saved_tensors_enabled + + ctx = _AllowMutationOnSavedContext() + + with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx): + try: + if _allow_mutation_on_saved_tensors_enabled: + raise RuntimeError( + "allow_mutation_on_saved_tensors contexts cannot be nested" + ) + _allow_mutation_on_saved_tensors_enabled = True + yield ctx + finally: + ctx.clear() + _allow_mutation_on_saved_tensors_enabled = False + + +def _register_logging_hooks_on_whole_graph( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], +) -> Callable[[], None]: + grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs)) + + def iter_graph(roots: list[Node]) -> Iterator[Node]: + if not roots: + return + seen: set[Node] = set() + q: deque[Node] = deque() + for node in roots: + if node is not None: + seen.add(node) + q.append(node) + + while q: + node = q.popleft() + for fn, _ in node.next_functions: + if fn in seen or fn is None: + continue + seen.add(fn) + q.append(fn) + + yield node + + def fmt(t: Optional[torch.Tensor]) -> str: + # Avoid circular import + from torch.utils._dtype_abbrs import dtype_abbrs + + if t is None: + return "None" + return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]" + + def prehook(grad_outputs: Sequence[Optional[torch.Tensor]]) -> None: + node = torch._C._current_autograd_node() + grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]" + log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}" + log.debug(log_str) + + handles = [node.register_prehook(prehook) for node in iter_graph(grad_fns)] + + def unregister_hooks() -> None: + for handle in handles: + handle.remove() + + return unregister_hooks + + +def _engine_run_backward( + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], + *args: Any, + **kwargs: Any, +) -> tuple[torch.Tensor, ...]: + attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG + if attach_logging_hooks: + unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + try: + return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass + t_outputs, *args, **kwargs + ) # Calls into the C++ engine to run the backward pass + finally: + if attach_logging_hooks: + unregister_hooks() # type: ignore[possibly-undefined] diff --git a/run_train.sh b/run_train.sh index fbed394ebb..619466f19c 100755 --- a/run_train.sh +++ b/run_train.sh @@ -12,7 +12,7 @@ set -ex # LOG_RANK=0,1 NGPU=4 ./run_train.sh NGPU=${NGPU:-"8"} export LOG_RANK=${LOG_RANK:-0} -CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} +CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_8b.toml"} overrides="" if [ $# -ne 0 ]; then diff --git a/torchtitan/models/llama3/model.py b/torchtitan/models/llama3/model.py index 20026a690b..b49234e665 100644 --- a/torchtitan/models/llama3/model.py +++ b/torchtitan/models/llama3/model.py @@ -17,6 +17,7 @@ from torchtitan.config_manager import JobConfig from torchtitan.models.attention import build_attention, init_attention_mask from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol +from torchtitan.offloading import activation_offload_with_overlap @dataclass @@ -362,6 +363,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ + #with activation_offload_with_overlap(self): h = x + self.attention(self.attention_norm(x), freqs_cis) out = h + self.feed_forward(self.ffn_norm(h)) return out diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 6cfe61bc70..f28ba83e76 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -53,7 +53,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] +mode = "none" # ["none", "selective", "full"] selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy [float8] diff --git a/torchtitan/offloading.py b/torchtitan/offloading.py new file mode 100644 index 0000000000..9bf2a76e8f --- /dev/null +++ b/torchtitan/offloading.py @@ -0,0 +1,174 @@ +import torch +import logging +from torch.nn import Module +from torch.autograd.graph import saved_tensors_hooks +from typing import NamedTuple +from collections import defaultdict + + +logger = logging.getLogger(__name__) + + +class PackInfo(NamedTuple): + # Record an event in the offload stream for the default stream to wait on + # before freeing the device tensor + d2h_event: torch.cuda.Event + # Keep a ref to the device tensor until the event has been waited on + device_tensor: torch.Tensor + + +class UnpackInfo(NamedTuple): + # Record an event during preallocation for the offload stream to wait on + # before copying to the device tensor + prealloc_event: torch.cuda.Event + # Preallocate the device tensor memory so it can be allocated in the + # default stream (instead of offload stream) to avoid fragmentation + device_tensor: torch.Tensor + + +# TODO: Remove these from global namespace and register on modules. Consider +# using module state as identifier instead of int ID. +# Used or overlapping H2D/D2H copy with compute +offload_stream: torch.cuda.Stream = torch.cuda.Stream() +# Used for module ordering +module_id_to_module: dict[int, Module] = {} +next_module_id = 0 +# Used in forward to keep device tensors alive through D2H copies +module_to_pack_infos: dict[Module, list[PackInfo]] = defaultdict(list) +# Appended to in forward and used in backward to know which CPU tensors will be +# copied H2D in backward to preallocate their device memory +module_to_cpu_tensors: dict[Module, list[torch.Tensor]] = defaultdict(list) +# Used in backward to preallocate device tensors in the default stream +cpu_tensor_to_unpack_info: dict[torch.Tensor, UnpackInfo] = {} + + +class activation_offload_with_overlap(saved_tensors_hooks): + """ + In forward, we overlap the current module's D2H copies with the next + module's forward compute. + + In backward, we overlap the current module's H2D copies with the previous + module's backward compute. + + In backward, since we need to allocate new device memory for the H2D + destinations, we can either do so in the offload stream or in the default + stream. Naively, we may do so in the offload stream, but this fragments the + memory pool since memory blocks are not shared across streams. As such, we + instead choose to do so in the default stream. This requires preallocation + and a CUDA event to ensure that the H2D copy does not start too early, + using the default stream memory before it should. + + """ + + def __init__(self, module: Module) -> None: + global next_module_id + + module_id = next_module_id + module_id_to_module[module_id] = module + next_module_id += 1 + + # logger.info(f"This is module {id(module):#x}, {module_id}.") + + def get_num_bytes_tensor( x: torch.Tensor ) -> int: + # get the number of bytes in a tensor, for memory management purposes + return x.element_size() * x.nelement() #x.element_size() * x._base_storage().nbytes() + + def pack_to_cpu(tensor: torch.Tensor) -> tuple[torch.device, torch.Tensor]: + if tensor.device.type == "cpu": + # logger.info(f"") + return (tensor.device, tensor) + + num_bytes = get_num_bytes_tensor(tensor) + sizes = tensor.size() + + device_tensor = tensor # rename for clarity + del tensor + + # TODO: Insert optional policy for deciding whether to offload. + # Migrate to be like non-reentrant activation checkpointing in the + # future to reuse the selective activation checkpointing logic. + if device_tensor.numel() < 1 * 1024 * 1024: + # logger.info(f"Ignoring activation tensor of {num_bytes} bytes, size = {sizes}, dtype = {device_tensor.dtype}") + return (device_tensor.device, device_tensor) + + current_stream = torch.cuda.current_stream() + + module_id_to_free = module_id - 1 + if module_id_to_free in module_id_to_module: + # Have the first of module i to free all of module i-1 + # logger.info(f"Trying to free {module_id_to_free}...") + module_to_free = module_id_to_module[module_id_to_free] + self.free_packed_device_tensors(module_to_free) + + offload_stream.wait_stream(current_stream) + with torch.cuda.stream(offload_stream): + # logger.info(f"Copying activation tensor of {num_bytes} bytes, size = {sizes}, dtype = {device_tensor.dtype} to CPU...") + cpu_tensor = device_tensor.to(torch.device("cpu"), non_blocking=True) + # logger.info(f"Record d2h event.") + d2h_event = offload_stream.record_event() + + module_to_cpu_tensors[module].append(cpu_tensor) + module_to_pack_infos[module].append(PackInfo(d2h_event, device_tensor)) + return (device_tensor.device, cpu_tensor) + + def unpack_from_cpu(packed) -> torch.Tensor: + device, tensor = packed + if tensor.device == device: + return tensor + assert tensor.device == torch.device("cpu"), f"{tensor.device}" + + cpu_tensor = tensor # rename for clarity + del tensor + + # Clear any existing refs from forward (this should only happen for + # the last module) + self.free_packed_device_tensors(module) + + current_stream = torch.cuda.current_stream() + module_id_to_prealloc = module_id - 1 + + if module_id_to_prealloc in module_id_to_module: + module_to_prealloc = module_id_to_module[module_id_to_prealloc] + if module_to_prealloc in module_to_cpu_tensors: + cpu_tensors = module_to_cpu_tensors[module_to_prealloc] + for _cpu_tensor in cpu_tensors: + cpu_tensor_to_unpack_info[_cpu_tensor] = UnpackInfo( + current_stream.record_event(), + torch.empty_like(_cpu_tensor, device=device), + ) + del module_to_cpu_tensors[module_to_prealloc] + + if cpu_tensor in cpu_tensor_to_unpack_info: # prefetched + event, device_tensor = cpu_tensor_to_unpack_info[cpu_tensor] + offload_stream.wait_event(event) + del cpu_tensor_to_unpack_info[cpu_tensor] + else: + device_tensor = torch.empty_like(cpu_tensor, device=device) + # Preallocate the rest of the 1st backward module + for _cpu_tensor in module_to_cpu_tensors[module]: + if _cpu_tensor is cpu_tensor: + continue + cpu_tensor_to_unpack_info[_cpu_tensor] = UnpackInfo( + current_stream.record_event(), + torch.empty_like(_cpu_tensor, device=device), + ) + del module_to_cpu_tensors[module] + offload_stream.wait_stream(current_stream) + + with torch.cuda.stream(offload_stream): + device_tensor.copy_(cpu_tensor, non_blocking=True) + current_stream.wait_stream(offload_stream) + + return device_tensor + + super().__init__(pack_to_cpu, unpack_from_cpu) + + def free_packed_device_tensors(self, module: torch.nn.Module): + if module in module_to_pack_infos: + # logger.info(f"Trying to free packed device tensors from module {id(module):#x}") + if infos := module_to_pack_infos[module]: + # Make sure that the default stream does not reuse any of + # the previous activation memory until the D2H finish + torch.cuda.current_stream().wait_event(infos[-1].d2h_event) + del module_to_pack_infos[module] + diff --git a/torchtitan/train.py b/torchtitan/train.py index d0228ae782..4088fe9d9a 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import importlib import os import time @@ -11,10 +12,10 @@ from typing import Any, Generator, Iterable, Optional import torch -from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module +from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.metrics import ( build_metrics_processor, @@ -359,13 +360,15 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): ) else: # Non-PP forward / backward - with self.train_context(optional_context_parallel_ctx): - assert len(model_parts) == 1 - pred = model_parts[0](inputs) - loss = self.loss_fn(pred, labels) - # need to free to before bwd to avoid peaking memory - del pred - loss.backward() + #with torch.autograd.graph.manage_activations(): + with contextlib.nullcontext(): + with self.train_context(optional_context_parallel_ctx): + assert len(model_parts) == 1 + pred = model_parts[0](inputs) + loss = self.loss_fn(pred, labels) + # need to free to before bwd to avoid peaking memory + del pred + loss.backward() dist_utils.clip_grad_norm_( [p for m in model_parts for p in m.parameters()],