-
Notifications
You must be signed in to change notification settings - Fork 0
swin ported. ./run.sh to reproduce #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
8d397ce
bc9e800
70cd52d
64984f8
52158f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
DEBUG=False |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -228,6 +228,9 @@ def get_args_parser(): | |
help='use scale-aware embeds') | ||
parser.add_argument('--grid-to-random-ratio', default=0.7, type=float, help='hybrid sampler grid to random ratio') | ||
|
||
parser.add_argument('--model-type', default='deit', type=str, choices=["deit", "swin"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added special flag to decouple arguments of deit from "add on" models like swin (pyramid vit in future) which have vastly different config shapes. |
||
|
||
|
||
return parser | ||
|
||
|
||
|
@@ -356,19 +359,35 @@ def log_to_wandb(log_dict, step): | |
mixup_fn = None | ||
|
||
print(f"Creating model: {args.model}") | ||
model = create_model( | ||
args.model, | ||
pretrained=False, | ||
num_classes=args.nb_classes, | ||
drop_rate=args.drop, | ||
drop_path_rate=args.drop_path, | ||
drop_block_rate=None, | ||
img_size=args.input_size, | ||
|
||
Patch_layer=PatchEmbedHybrid, | ||
use_learned_pos_embed = args.use_learned_pos_embed, | ||
quantize_pos_embed = args.quantize_pos_embed | ||
) | ||
if args.model_type == "deit": | ||
model = create_model( | ||
args.model, | ||
pretrained=False, | ||
num_classes=args.nb_classes, | ||
drop_rate=args.drop, | ||
drop_path_rate=args.drop_path, | ||
drop_block_rate=None, | ||
img_size=args.input_size, | ||
|
||
Patch_layer=PatchEmbedHybrid, | ||
use_learned_pos_embed = args.use_learned_pos_embed, | ||
quantize_pos_embed = args.quantize_pos_embed | ||
) | ||
elif args.model_type == "swin": | ||
from swin.config import get_config | ||
from swin.models import build_model | ||
class _Args: | ||
pass | ||
_args = _Args() | ||
cfg_path = "swin/swin_base_patch4_window7_224.yaml" | ||
setattr(_args, "cfg", cfg_path) | ||
setattr(_args, "opts", []) | ||
setattr(_args, "local_rank", 0) | ||
setattr(_args, "data_path", args.data_path) | ||
config = get_config(_args) | ||
model = build_model(config) | ||
else: | ||
assert False | ||
Comment on lines
+376
to
+390
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hardcoded loading of swin vit of size B. We will not need other swin architectures, for simplicity implemented only this one. |
||
|
||
if args.finetune: | ||
if args.finetune.startswith('https'): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#!/bin/bash | ||
|
||
model_name=STD_SWINB | ||
model_path=models/swin_base_patch4_window7_224.pth | ||
input_size=224 | ||
patch_size=16 | ||
patch_shake=0 | ||
patch_dropout=0 | ||
patch_select=0 | ||
patch_select_mode=drop | ||
patch_zoom=0 | ||
|
||
|
||
#model_name=$1 | ||
#model_path=$2 | ||
#input_size=$3 | ||
#patch_size=$4 | ||
#patch_shake=$5 | ||
#patch_dropout=$6 | ||
#patch_select=$7 | ||
#patch_select_mode=$8 | ||
#patch_zoom=$9 | ||
|
||
run_name="${model_name}_GRID${input_size}x${patch_size}_SHAKE${patch_shake}_DROP${patch_dropout}_SELECT${patch_select}_MODE${patch_select_mode}_ZOOM${patch_zoom}" | ||
echo $run_name | ||
if [[ $model_name == *"STD"* ]]; then | ||
quantize_pos_embed=1 | ||
use_learned_pos_embed=1 | ||
else | ||
quantize_pos_embed=0 | ||
use_learned_pos_embed=0 | ||
fi | ||
|
||
python main.py \ | ||
--model-type swin \ | ||
--no-log-to-wandb \ | ||
--hard_dropout 0 \ | ||
--merge_patches 1 \ | ||
--model _ \ | ||
--data-path ~/datasets/imagenet/ \ | ||
--batch 16 --input-size $input_size --eval-crop-ratio 1.0 --seed 0 \ | ||
--grid-patch-size $patch_size \ | ||
--num_workers 10 \ | ||
--eval --resume=$model_path \ | ||
--patch_shake=$patch_shake \ | ||
--patch_dropout=$patch_dropout \ | ||
--patch_select=$patch_select \ | ||
--patch_select_mode=$patch_select_mode \ | ||
--quantize_pos_embed=$quantize_pos_embed \ | ||
--use_learned_pos_embed=$use_learned_pos_embed \ | ||
--patch_zoom=$patch_zoom \ | ||
--modify_patches | ||
|
||
#--log-to-wandb --wandb_run_name=$run_name --wandb_model_name=$model_name \ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,20 @@ | ||
#!/bin/bash | ||
|
||
source scripts/sweep_grid_scale.sh | ||
set -e | ||
|
||
#source scripts/sweep_zoom_exp.sh | ||
#source scripts/sweep_dropout_exp.sh | ||
#source scripts/sweep_shake_exp.sh | ||
#source scripts/sweep_grid_scale.sh | ||
|
||
source scripts/sweep_zoom_exp.sh | ||
source scripts/sweep_dropout_exp.sh | ||
source scripts/sweep_shake_exp.sh | ||
|
||
source scripts/sweep_zoom_shake_exp.sh | ||
source scripts/sweep_zoom_dropout_exp.sh | ||
#source scripts/sweep_dropout_shake_exp.sh | ||
source scripts/sweep_dropout_shake_exp.sh | ||
#source random | ||
|
||
#source edge counting | ||
#source scripts/sweep_central_sampling_exp.sh | ||
|
||
#source 30/70, 70/30 vits | ||
#source scripts/sweep_dropout_vs_rescale_exp.sh | ||
source scripts/sweep_dropout_vs_rescale_exp.sh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved import here to reduce number of dependencies