Skip to content

swin ported. ./run.sh to reproduce #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

swin ported. ./run.sh to reproduce #1

wants to merge 5 commits into from

Conversation

jano1906
Copy link
Collaborator

No description provided.

@jano1906 jano1906 requested a review from apardyl December 21, 2023 18:43
@@ -65,6 +64,7 @@ def build_dataset(is_train, data_path, args):
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
elif args.data_set == 'HUGGINGFACE':
from datasets import load_dataset
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved import here to reduce number of dependencies

@@ -228,6 +228,9 @@ def get_args_parser():
help='use scale-aware embeds')
parser.add_argument('--grid-to-random-ratio', default=0.7, type=float, help='hybrid sampler grid to random ratio')

parser.add_argument('--model-type', default='deit', type=str, choices=["deit", "swin"])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added special flag to decouple arguments of deit from "add on" models like swin (pyramid vit in future) which have vastly different config shapes.

Comment on lines +376 to +390
elif args.model_type == "swin":
from swin.config import get_config
from swin.models import build_model
class _Args:
pass
_args = _Args()
cfg_path = "swin/swin_base_patch4_window7_224.yaml"
setattr(_args, "cfg", cfg_path)
setattr(_args, "opts", [])
setattr(_args, "local_rank", 0)
setattr(_args, "data_path", args.data_path)
config = get_config(_args)
model = build_model(config)
else:
assert False
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded loading of swin vit of size B. We will not need other swin architectures, for simplicity implemented only this one.

run.sh Outdated
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

./run.sh achevec 83.5 accuracy on imagenet - reproduces original model

Comment on lines 595 to 603
assert keeps is None, "Swin transformer works only on merged patches with soft dropout. Set args --hard_dropout=0, --merge_patches=1 to proceed."
x = self.patch_embed(x)
if self.ape:
if coords is not None:
quantized = self.quantize_coords(coords)
x = x + self.absolute_pos_embed[[0] ,quantized, :]
else:
x = x + self.absolute_pos_embed
# ==========================================
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only modification from original file.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants