Systematically branches out torch models through re-training epochs with changed random seed where the absolute performance delta in successive epochs is the highest. The TorchCheckpointer
implements a binary tree in the CheckpointTree
module. Each node in the tree points to a saved model carrying information about epoch, difference to previous epoch and model test accuracy. The tree is parametrized through a tree_height
which dictates the number of trained epochs and tree_branches
that determines how many branchings we perform.
During runtime both the model weights and optimizer state are saved with the base10 encoding of the binary position of the node, where going left in a node is represented as (1) and going to the right node is represented with a (0). For example, the binary node position (read from left to right)
2D representation of the TorchCheckpointer
where the main branch (branch = 1) creates alternative models through random seed perturbation. The branches are displayed in chronological order where the next branch is chosen as the node where the absolute difference to its subsequent node is the highest in the current tree.
3D representation of the same tree with the z-value corresponding to model accuracy.