diff --git a/text2image.py b/text2image.py index 0e79fcc..a2163c3 100644 --- a/text2image.py +++ b/text2image.py @@ -46,6 +46,10 @@ "--steps", type=int, default=50, help="number of ddim sampling steps" ) +parser.add_argument( + "--batch", type=int, default=1, help="number of images to generate" +) + parser.add_argument( "--seed", type=int, @@ -71,8 +75,20 @@ num_steps=args.steps, unconditional_guidance_scale=args.scale, temperature=1, - batch_size=1, + batch_size=args.batch, seed=args.seed, ) -Image.fromarray(img[0]).save(args.output) -print(f"saved at {args.output}") + +if(args.batch <= 1): + Image.fromarray(img[0]).save(args.output) + print(f"saved at {args.output}") +else: + split_filename = args.output.split(".") + filename = ''.join(split_filename[0:-1]) + extension = split_filename[-1] + + for i in range(args.batch): + generated_filename = f"{filename}-{i+1}.{extension}" + Image.fromarray(img[i]).save(generated_filename) + print(f"saved at {generated_filename}") +