-
Notifications
You must be signed in to change notification settings - Fork 0
swin ported. ./run.sh to reproduce #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
@@ -65,6 +64,7 @@ def build_dataset(is_train, data_path, args): | |||
dataset = datasets.ImageFolder(root, transform=transform) | |||
nb_classes = 1000 | |||
elif args.data_set == 'HUGGINGFACE': | |||
from datasets import load_dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved import here to reduce number of dependencies
@@ -228,6 +228,9 @@ def get_args_parser(): | |||
help='use scale-aware embeds') | |||
parser.add_argument('--grid-to-random-ratio', default=0.7, type=float, help='hybrid sampler grid to random ratio') | |||
|
|||
parser.add_argument('--model-type', default='deit', type=str, choices=["deit", "swin"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added special flag to decouple arguments of deit from "add on" models like swin (pyramid vit in future) which have vastly different config shapes.
elif args.model_type == "swin": | ||
from swin.config import get_config | ||
from swin.models import build_model | ||
class _Args: | ||
pass | ||
_args = _Args() | ||
cfg_path = "swin/swin_base_patch4_window7_224.yaml" | ||
setattr(_args, "cfg", cfg_path) | ||
setattr(_args, "opts", []) | ||
setattr(_args, "local_rank", 0) | ||
setattr(_args, "data_path", args.data_path) | ||
config = get_config(_args) | ||
model = build_model(config) | ||
else: | ||
assert False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded loading of swin vit of size B. We will not need other swin architectures, for simplicity implemented only this one.
run.sh
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
./run.sh achevec 83.5 accuracy on imagenet - reproduces original model
swin/models/swin_transformer.py
Outdated
assert keeps is None, "Swin transformer works only on merged patches with soft dropout. Set args --hard_dropout=0, --merge_patches=1 to proceed." | ||
x = self.patch_embed(x) | ||
if self.ape: | ||
if coords is not None: | ||
quantized = self.quantize_coords(coords) | ||
x = x + self.absolute_pos_embed[[0] ,quantized, :] | ||
else: | ||
x = x + self.absolute_pos_embed | ||
# ========================================== |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only modification from original file.
No description provided.