Skip to content

Commit bfe137c

Browse files
committed
Fix bugs found from more testing
1 parent e49e2f4 commit bfe137c

File tree

5 files changed

+38
-32
lines changed

5 files changed

+38
-32
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,13 @@ $ python main.py
5454
| Training batch size | 128 | --batch-size 128 |
5555
| Testing batch size | 128 | --test-batch-size 128 |
5656
| Loss threshold | 0.001 | --loss-threshold 0.001 |
57-
| Log interval | 1 | --log-interval 1 |
57+
| Log interval | 10 | --log-interval 10 |
58+
| Disables CUDA training | false | --no-cuda |
5859
| Num. of convolutional channel | 256 | --num-conv-channel 256 |
5960
| Num. of primary unit | 8 | --num-primary-unit 8 |
6061
| Primary unit size | 1152 | --primary-unit-size 1152 |
6162
| Output unit size | 16 | --output-unit-size 16 |
63+
| Num. routing iteration | 3 | --num-routing 3 |
6264

6365
## Results
6466
Coming soon!

capsule_layer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@ class CapsuleLayer(nn.Module):
1717
"""
1818
The core implementation of the idea of capsules
1919
"""
20-
def __init__(self, in_unit, in_channel, num_unit, unit_size, use_routing, cuda):
20+
21+
def __init__(self, in_unit, in_channel, num_unit, unit_size, use_routing,
22+
num_routing, cuda_enabled):
2123
super(CapsuleLayer, self).__init__()
2224

2325
self.in_unit = in_unit
2426
self.in_channel = in_channel
2527
self.num_unit = num_unit
2628
self.use_routing = use_routing
27-
self.cuda = cuda
29+
self.num_routing = num_routing
30+
self.cuda_enabled = cuda_enabled
2831

2932
if self.use_routing:
3033
"""
@@ -50,7 +53,8 @@ def create_conv_unit(idx):
5053
self.add_module("conv_unit" + str(idx), unit)
5154
return unit
5255

53-
self.conv_units = [create_conv_unit(u) for u in range(self.num_unit)]
56+
self.conv_units = [create_conv_unit(
57+
u) for u in range(self.num_unit)]
5458

5559
@staticmethod
5660
def squash(sj):
@@ -88,12 +92,12 @@ def routing(self, x):
8892
# All the routing logits (b_ij in the paper) are initialized to zero.
8993
b_ij = Variable(torch.zeros(
9094
1, self.in_channel, self.num_unit, 1))
91-
if self.cuda:
95+
if self.cuda_enabled:
9296
b_ij = b_ij.cuda()
9397

9498
# From the paper in the "Capsules on MNIST" section,
9599
# the sample MNIST test reconstructions of a CapsNet with 3 routing iterations.
96-
num_iterations = 3
100+
num_iterations = self.num_routing
97101

98102
for iteration in range(num_iterations):
99103
# Routing algorithm

main.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313

1414
from __future__ import print_function
1515
import argparse
16-
import sys
17-
import time
1816

1917
import torch
2018
import torch.optim as optim
@@ -57,8 +55,7 @@ def train(model, data_loader, optimizer, epoch):
5755
optimizer.step()
5856

5957
if batch_idx % args.log_interval == 0:
60-
mesg = '{}\tEpoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
61-
time.ctime(),
58+
mesg = 'Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
6259
epoch,
6360
batch_idx * len(data),
6461
len(data_loader.dataset),
@@ -87,7 +84,7 @@ def test(model, data_loader):
8784
for data, target in data_loader:
8885
target_indices = target
8986
target_one_hot = utils.one_hot_encode(
90-
target_indices, length=model.digits.num_units)
87+
target_indices, length=model.digits.num_unit)
9188

9289
data, target = Variable(data, volatile=True), Variable(target_one_hot)
9390

@@ -133,12 +130,12 @@ def main():
133130
default=128, help='testing batch size. default=128')
134131
parser.add_argument('--loss-threshold', type=float, default=0.0001,
135132
help='stop training if loss goes below this threshold. default=0.0001')
136-
parser.add_argument("--log-interval", type=int, default=1,
137-
help='number of images after which the training loss is logged, default is 1')
138-
parser.add_argument('--cuda', action='store_true',
139-
help='set it to 1 for running on GPU, 0 for CPU')
133+
parser.add_argument('--log-interval', type=int, default=10,
134+
help='how many batches to wait before logging training status, default=10')
135+
parser.add_argument('--no-cuda', action='store_true', default=False,
136+
help='disables CUDA training, default=false')
140137
parser.add_argument('--threads', type=int, default=4,
141-
help='number of threads for data loader to use')
138+
help='number of threads for data loader to use, default=4')
142139
parser.add_argument('--seed', type=int, default=42,
143140
help='random seed for training. default=42')
144141
parser.add_argument('--num-conv-channel', type=int, default=256,
@@ -149,20 +146,18 @@ def main():
149146
default=1152, help='primary unit size. default=1152')
150147
parser.add_argument('--output-unit-size', type=int,
151148
default=16, help='output unit size. default=16')
149+
parser.add_argument('--num-routing', type=int,
150+
default=3, help='number of routing iteration. default=3')
152151

153152
args = parser.parse_args()
154153

155154
print(args)
156155

157156
# Check GPU or CUDA is available
158-
cuda = args.cuda
159-
if cuda and not torch.cuda.is_available():
160-
print(
161-
"ERROR: No GPU/cuda is not available. Try running on CPU or run without --cuda")
162-
sys.exit(1)
157+
args.cuda = not args.no_cuda and torch.cuda.is_available()
163158

164159
torch.manual_seed(args.seed)
165-
if cuda:
160+
if args.cuda:
166161
torch.cuda.manual_seed(args.seed)
167162

168163
# Load data
@@ -174,10 +169,11 @@ def main():
174169
num_primary_unit=args.num_primary_unit,
175170
primary_unit_size=args.primary_unit_size,
176171
output_unit_size=args.output_unit_size,
177-
cuda=args.cuda)
172+
num_routing=args.num_routing,
173+
cuda_enabled=args.cuda)
178174

179-
if cuda:
180-
model = model.cuda()
175+
if args.cuda:
176+
model.cuda()
181177

182178
optimizer = optim.Adam(model.parameters(), lr=args.lr)
183179

model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,16 @@ class Net(nn.Module):
1919
"""
2020
A simple CapsNet with 3 layers
2121
"""
22-
def __init__(self, num_conv_channel, num_primary_unit, primary_unit_size, output_unit_size, cuda):
22+
23+
def __init__(self, num_conv_channel, num_primary_unit, primary_unit_size,
24+
output_unit_size, num_routing, cuda_enabled):
2325
"""
2426
In the constructor we instantiate one ConvLayer module and two CapsuleLayer modules
2527
and assign them as member variables.
2628
"""
2729
super(Net, self).__init__()
2830

29-
self.cuda = cuda
31+
self.cuda_enabled = cuda_enabled
3032

3133
self.conv1 = ConvLayer(in_channel=1,
3234
out_channel=num_conv_channel,
@@ -38,15 +40,17 @@ def __init__(self, num_conv_channel, num_primary_unit, primary_unit_size, output
3840
num_unit=num_primary_unit,
3941
unit_size=primary_unit_size,
4042
use_routing=False,
41-
cuda=cuda)
43+
num_routing=num_routing,
44+
cuda_enabled=cuda_enabled)
4245

4346
# DigitCaps
4447
self.digits = CapsuleLayer(in_unit=num_primary_unit,
4548
in_channel=primary_unit_size,
4649
num_unit=10,
4750
unit_size=output_unit_size,
4851
use_routing=True,
49-
cuda=cuda)
52+
num_routing=num_routing,
53+
cuda_enabled=cuda_enabled)
5054

5155
def forward(self, x):
5256
"""
@@ -74,7 +78,7 @@ def margin_loss(self, input, target, size_average=True):
7478

7579
# Calculate left and right max() terms.
7680
zero = Variable(torch.zeros(1))
77-
if self.cuda:
81+
if self.cuda_enabled:
7882
zero = zero.cuda()
7983
m_plus = 0.9
8084
m_minus = 0.1

utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ def load_mnist(args):
4646

4747
print('===> Loading training datasets')
4848
training_set = datasets.MNIST(
49-
'../data', train=True, download=True, transform=data_transform)
49+
'./data', train=True, download=True, transform=data_transform)
5050
training_data_loader = DataLoader(
5151
training_set, batch_size=args.batch_size, shuffle=True, **kwargs)
5252

5353
print('===> Loading testing datasets')
5454
testing_set = datasets.MNIST(
55-
'../data', train=False, download=True, transform=data_transform)
55+
'./data', train=False, download=True, transform=data_transform)
5656
testing_data_loader = DataLoader(
5757
testing_set, batch_size=args.test_batch_size, shuffle=True, **kwargs)
5858

0 commit comments

Comments
 (0)