diff --git a/image/resnet/auto_parallel/auto_parallel_demo.py b/image/resnet/auto_parallel/auto_parallel_demo.py index 429a99e..288430b 100644 --- a/image/resnet/auto_parallel/auto_parallel_demo.py +++ b/image/resnet/auto_parallel/auto_parallel_demo.py @@ -16,19 +16,17 @@ from titans.utils import barrier_context from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser -from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions, DataloaderOption from colossalai.auto_parallel.tensor_shard.solver.solver import Solver from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.tracer.tracer import ColoTracer DATA_ROOT = Path(os.environ.get('DATA', './data')) -BATCH_SIZE = 1024 -NUM_EPOCHS = 10 def main(): - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch(config='./config.py') logger = get_dist_logger() @@ -52,16 +50,16 @@ def main(): train_dataloader = get_dataloader( dataset=train_dataset, - add_sampler=False, + add_sampler=True, shuffle=True, - batch_size=BATCH_SIZE, + batch_size=gpc.config.BATCH_SIZE, pin_memory=True, ) test_dataloader = get_dataloader( dataset=test_dataset, add_sampler=False, - batch_size=BATCH_SIZE, + batch_size=gpc.config.BATCH_SIZE, pin_memory=True, ) @@ -73,13 +71,13 @@ def main(): # trace the model with meta data tracer = ColoTracer() model = resnet50(num_classes=10).cuda() - input_sample = {'x': torch.rand([1024, 3, 32, 32]).to('meta')} + input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() # prepare info for solver - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() cost_graph = CostGraph(strategies_constructor.leaf_strategies) @@ -106,9 +104,9 @@ def main(): optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) # lr_scheduler - lr_scheduler = CosineAnnealingLR(optimizer, total_steps=NUM_EPOCHS) + lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) - for epoch in range(NUM_EPOCHS): + for epoch in range(gpc.config.NUM_EPOCHS): gm.train() if gpc.get_global_rank() == 0: train_dl = tqdm(train_dataloader) @@ -121,6 +119,7 @@ def main(): output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict) train_loss = criterion(output, label) train_loss.backward(train_loss) + torch.cuda.synchronize() optimizer.step() lr_scheduler.step() diff --git a/image/resnet/auto_parallel/config.py b/image/resnet/auto_parallel/config.py new file mode 100644 index 0000000..feaef04 --- /dev/null +++ b/image/resnet/auto_parallel/config.py @@ -0,0 +1,2 @@ +BATCH_SIZE = 128 +NUM_EPOCHS = 10 \ No newline at end of file