diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..1d66aaf --- /dev/null +++ b/Dockerfile @@ -0,0 +1,18 @@ +FROM tensorflow/tensorflow:latest-gpu + +ENV DEBIAN_FRONTEND noninteractive + +WORKDIR /app + +COPY requirements.txt /app + +RUN pip install --prefer-binary --no-cache-dir -q -r requirements.txt && \ + rm -rf ~/.cache + +COPY . /app/ + +VOLUME /root/.keras + +EXPOSE 7860 + +CMD ["python", "webui.py"] diff --git a/README.md b/README.md index 31af1c6..d791c10 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,12 @@ If you want to use a different name, use the `--output` flag. Check out the `img2img.py` file for more options, including the number of steps. +### Using the WebUI : + +```bash +python webui.py +``` + ## Example outputs The following outputs have been generated using this implementation: diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..1ce8dcb --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,22 @@ +version: '3' + +services: + stable-diffusion-tf: + image: stable-diffusion-tf + build: . + volumes: + - keras:/root/.keras + ports: + - "7860:7860" + command: python3 webui.py + restart: unless-stopped + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['0'] + capabilities: [gpu] + +volumes: + keras: \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1590124..891e542 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ Pillow==9.2.0 tqdm==4.64.1 ftfy==6.1.1 regex==2022.9.13 +gradio==3.3.1 tensorflow-addons==0.17.1 diff --git a/webui.py b/webui.py new file mode 100644 index 0000000..b496d6e --- /dev/null +++ b/webui.py @@ -0,0 +1,60 @@ +import gradio as gr +from stable_diffusion_tf.stable_diffusion import StableDiffusion + + +generator = StableDiffusion(img_height=512, img_width=512, jit_compile=False) + + +def infer(prompt, samples, steps, scale, seed): + return generator.generate( + prompt, + num_steps=steps, + unconditional_guidance_scale=scale, + temperature=1, + batch_size=samples, + seed=seed, + ) + + +block = gr.Blocks() + +with block: + with gr.Group(): + with gr.Box(): + with gr.Row().style(equal_height=True): + text = gr.Textbox( + label="Enter your prompt", + show_label=False, + max_lines=1, + placeholder="Enter your prompt", + ).style( + border=(True, False, True, True), + rounded=(True, False, False, True), + container=False, + ) + btn = gr.Button("Generate image").style( + margin=False, + rounded=(False, True, True, False), + ) + gallery = gr.Gallery( + label="Generated images", show_label=False, elem_id="gallery" + ).style(grid=[2], height="auto") + + advanced_button = gr.Button("Advanced options", elem_id="advanced-btn") + + with gr.Row(elem_id="advanced-options"): + samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1) + steps = gr.Slider(label="Steps", minimum=1, maximum=200, value=50, step=1) + scale = gr.Slider(label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1) + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=2147483647, + step=1, + randomize=True + ) + text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery) + btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery) + advanced_button.click(None, [], text) + +block.launch(server_name='0.0.0.0')