Skip to content

Multi-branched model training through random seed manipulation at points of rapid model performance change

Notifications You must be signed in to change notification settings

Darakhsh1999/Torch-Checkpointer

Repository files navigation

Torch-Checkpointer

Multi branch model training

Systematically branches out torch models through re-training epochs with changed random seed where the absolute performance delta in successive epochs is the highest. The TorchCheckpointer implements a binary tree in the CheckpointTree module. Each node in the tree points to a saved model carrying information about epoch, difference to previous epoch and model test accuracy. The tree is parametrized through a tree_height which dictates the number of trained epochs and tree_branches that determines how many branchings we perform. During runtime both the model weights and optimizer state are saved with the base10 encoding of the binary position of the node, where going left in a node is represented as (1) and going to the right node is represented with a (0). For example, the binary node position (read from left to right) $11001_{2} = 25_{10}$ corresponds to root(1)->left(1)->right(0)->right(0)->left(1).


2D representation of the TorchCheckpointer where the main branch (branch = 1) creates alternative models through random seed perturbation. The branches are displayed in chronological order where the next branch is chosen as the node where the absolute difference to its subsequent node is the highest in the current tree.

image1

3D representation of the same tree with the z-value corresponding to model accuracy.

gif

About

Multi-branched model training through random seed manipulation at points of rapid model performance change

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages