diff --git a/eval_model.py b/eval_model.py index ba3842b..6f10ac9 100644 --- a/eval_model.py +++ b/eval_model.py @@ -15,17 +15,17 @@ def get_args_parser(): parser = argparse.ArgumentParser('Singleto3D', add_help=False) parser.add_argument('--arch', default='resnet18', type=str) - parser.add_argument('--max_iter', default=10000, type=str) - parser.add_argument('--vis_freq', default=1000, type=str) - parser.add_argument('--batch_size', default=1, type=str) - parser.add_argument('--num_workers', default=0, type=str) + parser.add_argument('--max_iter', default=10000, type=int) + parser.add_argument('--vis_freq', default=1000, type=int) + parser.add_argument('--batch_size', default=1, type=int) + parser.add_argument('--num_workers', default=0, type=int) parser.add_argument('--type', default='vox', choices=['vox', 'point', 'mesh'], type=str) parser.add_argument('--n_points', default=5000, type=int) parser.add_argument('--w_chamfer', default=1.0, type=float) - parser.add_argument('--w_smooth', default=0.1, type=float) - parser.add_argument('--load_checkpoint', action='store_true') - parser.add_argument('--device', default='cuda', type=str) - parser.add_argument('--load_feat', action='store_true') + parser.add_argument('--w_smooth', default=0.1, type=float) + parser.add_argument('--load_checkpoint', action='store_true') + parser.add_argument('--device', default='cuda', type=str) + parser.add_argument('--load_feat', action='store_true') return parser def preprocess(feed_dict, args): diff --git a/fit_data.py b/fit_data.py index 9d8ad44..6f241a0 100644 --- a/fit_data.py +++ b/fit_data.py @@ -31,7 +31,7 @@ def fit_mesh(mesh_src, mesh_tgt, args): start_iter = 0 start_time = time.time() - deform_vertices_src = torch.zeros(mesh_src.verts_packed().shape, requires_grad=True, device='cuda') + deform_vertices_src = torch.zeros(mesh_src.verts_packed().shape, requires_grad=True, device=args.device) optimizer = torch.optim.Adam([deform_vertices_src], lr = args.lr) print("Starting training !") for step in range(start_iter, args.max_iter): @@ -56,8 +56,9 @@ def fit_mesh(mesh_src, mesh_tgt, args): loss_vis = loss.cpu().item() - print("[%4d/%4d]; ttime: %.0f (%.2f); loss: %.3f" % (step, args.max_iter, total_time, iter_time, loss_vis)) - + if (step % args.log_freq) == 0: + print("[%4d/%4d]; ttime: %.0f (%.2f); loss: %.3f" % (step, args.max_iter, total_time, iter_time, loss_vis)) + mesh_src.offset_verts_(deform_vertices_src) print('Done!') @@ -81,8 +82,9 @@ def fit_pointcloud(pointclouds_src, pointclouds_tgt, args): loss_vis = loss.cpu().item() - print("[%4d/%4d]; ttime: %.0f (%.2f); loss: %.3f" % (step, args.max_iter, total_time, iter_time, loss_vis)) - + if (step % args.log_freq) == 0: + print("[%4d/%4d]; ttime: %.0f (%.2f); loss: %.3f" % (step, args.max_iter, total_time, iter_time, loss_vis)) + print('Done!') @@ -104,8 +106,9 @@ def fit_voxel(voxels_src, voxels_tgt, args): loss_vis = loss.cpu().item() - print("[%4d/%4d]; ttime: %.0f (%.2f); loss: %.3f" % (step, args.max_iter, total_time, iter_time, loss_vis)) - + if (step % args.log_freq) == 0: + print("[%4d/%4d]; ttime: %.0f (%.2f); loss: %.3f" % (step, args.max_iter, total_time, iter_time, loss_vis)) + print('Done!') diff --git a/train_model.py b/train_model.py index cc86c4e..bb1b7b6 100644 --- a/train_model.py +++ b/train_model.py @@ -13,19 +13,19 @@ def get_args_parser(): parser = argparse.ArgumentParser('Singleto3D', add_help=False) # Model parameters parser.add_argument('--arch', default='resnet18', type=str) - parser.add_argument('--lr', default=4e-4, type=str) - parser.add_argument('--max_iter', default=10000, type=str) - parser.add_argument('--log_freq', default=1000, type=str) - parser.add_argument('--batch_size', default=2, type=str) - parser.add_argument('--num_workers', default=0, type=str) + parser.add_argument('--lr', default=4e-4, type=float) + parser.add_argument('--max_iter', default=10000, type=int) + parser.add_argument('--log_freq', default=1000, type=int) + parser.add_argument('--batch_size', default=2, type=int) + parser.add_argument('--num_workers', default=0, type=int) parser.add_argument('--type', default='vox', choices=['vox', 'point', 'mesh'], type=str) parser.add_argument('--n_points', default=5000, type=int) parser.add_argument('--w_chamfer', default=1.0, type=float) parser.add_argument('--w_smooth', default=0.1, type=float) - parser.add_argument('--save_freq', default=10000, type=int) - parser.add_argument('--device', default='cuda', type=str) - parser.add_argument('--load_feat', action='store_true') - parser.add_argument('--load_checkpoint', action='store_true') + parser.add_argument('--save_freq', default=10000, type=int) + parser.add_argument('--device', default='cuda', type=str) + parser.add_argument('--load_feat', action='store_true') + parser.add_argument('--load_checkpoint', action='store_true') return parser def preprocess(feed_dict,args): @@ -126,7 +126,8 @@ def train_model(args): 'optimizer_state_dict': optimizer.state_dict() }, f'checkpoint_{args.type}.pth') - print("[%4d/%4d]; ttime: %.0f (%.2f, %.2f); loss: %.3f" % (step, args.max_iter, total_time, read_time, iter_time, loss_vis)) + if (step % args.log_freq) == 0: + print("[%4d/%4d]; ttime: %.0f (%.2f, %.2f); loss: %.3f" % (step, args.max_iter, total_time, read_time, iter_time, loss_vis)) print('Done!')