Skip to content

Diffusion Plugin #421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open

Diffusion Plugin #421

wants to merge 41 commits into from

Conversation

Sourenm
Copy link
Contributor

@Sourenm Sourenm commented Jul 8, 2025

No description provided.

Copy link

codecov bot commented Jul 8, 2025

Codecov Report

Attention: Patch coverage is 22.66667% with 928 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
transformerlab/plugins/image_diffusion/main.py 1.58% 620 Missing ⚠️
transformerlab/routers/experiment/diffusion.py 50.87% 226 Missing and 26 partials ⚠️
transformerlab/shared/shared.py 0.00% 56 Missing ⚠️

📢 Thoughts on this report? Let us know!

@Sourenm Sourenm marked this pull request as ready for review July 11, 2025 17:31
api.py Outdated
recipes,
users,
)

from transformerlab.routers.experiment import diffusion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not import it like this, look at the routers/experiment/experiment.py file inside and see how other routes under experiment are imported and follow that structure

@@ -452,6 +452,10 @@ def set_job_completion_status(
self.add_to_job_data("completion_status", completion_status)
self.add_to_job_data("completion_details", completion_details)

# Update the job status field if there's a failure
if completion_status == "failed":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a duplicate change?
Maybe remove this as we do the same on the next line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the duplicate is removed.

self.params[key] = arg
key = None

def save_generated_images(self, metadata: dict = None, suffix: str = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what this is being used for? If its a dataset then it is valid but if you're trying to save a generated image with this then this might be wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this function and the get_output_file_path function but kept diffusion.py since it might come in handy for future diffusion plugins

def get_output_file_path(self, suffix="", dir_only=False):
self._ensure_args_parsed()

workspace_dir = os.environ.get("_TFL_WORKSPACE_DIR", "./")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this could you import WORKSPACE_DIR from transformerlab.plugin and use that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solved (comment above)

print(f"[DIFFUSION] Metadata saved to: {metadata_file}")
return metadata_file

def get_output_file_path(self, suffix="", dir_only=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand why we add this here but if this is the exact same with only a difference of "diffusion" and "datasets", then it would be wiser to move this out to the base TLabPlugin class and add a parameter so GenTLab and DiffusionTLab can both use

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solved (comment above)

@@ -720,7 +483,11 @@ def get_pipeline(
# Load LoRA adaptor if provided - same code for local and HF Hub!
if adaptor and adaptor.strip():
try:
adaptor_dir = os.path.join(os.environ.get("_TFL_WORKSPACE_DIR"), "adaptors", secure_filename(model))
adaptor_dir = os.path.join(
os.environ.get("_TFL_WORKSPACE_DIR"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be imported from transformerlab.plugin as an earlier comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -1097,7 +1160,22 @@ def run_pipe():
generation_start = time.time()
log_print("Starting image generation...")

progress_queue = queue.Queue()

async def monitor_progress():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this used for? I thought we do progress update inside the callback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


def parse_args():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not be parsing any arg inside a plugin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

parser.add_argument("--plugin_dir", type=str, required=True)
parser.add_argument("--job_id", type=str, required=True)
parser.add_argument("--experiment_name", type=str, required=True)
parser.add_argument("--run_name", type=str, default="diffused")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is stale code and you need to remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

config = job_config.get("config", {})

# Convert base64 images to files and update config
base64_fields = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you we need to add this to config? Why cant the plugin directly fetch them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to save these images because if the input/mask image is passed as a huge base64 string via CLI args, it hits a wall where Linux has a system-imposed limit on command-line length (typically around 128 KB to 256 KB total across all args and env vars).

return {"controlnets": models}
def main():
print("Starting image diffusion...", flush=True)
asyncio.run(async_main())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to do this? And if you do need this, try to see how we have a async_job_wrapper

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified

@tlab_diffusion.async_job_wrapper(progress_start=0, progress_end=100)
async def diffusion_generate_job():
job_config = {
"plugin": tlab_diffusion.params.get("plugin", "image_diffusion"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


@tlab_diffusion.async_job_wrapper(progress_start=0, progress_end=100)
async def diffusion_generate_job():
job_config = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't we directly using the tlab_diffusion.params instead of maintaining this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"mask_image_path": "mask_image",
}

for path_key, base64_key in image_path_keys.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify here on previous comment, this makes the images back to base64 for processing by diffusers right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's correct

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note: we left this blank for now because all files are being installed by the original requirements

import asyncio
import threading
import gc
from transformerlab.plugin import WORKSPACE_DIR
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we standardize using this throughtout the file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self._parser.add_argument("--diffusion_model", default="local", type=str, help="Diffusion model to use")
self._parser.add_argument("--model", type=str, default="")
self._parser.add_argument(
"--diffusion_type", default="txt2img", type=str, help="Type of diffusion task (txt2img, img2img, etc.)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we use this param somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding it here because its probably not included in the diff, now that we include diffusion_worker in the plugin, we should not have a copy in the shared folder

except Exception as e:
await db_jobs.job_update_status(job_id, "FAILED")
print(f"[DIFFUSION] Job {job_id} execution error: {e}")
return {"status": "error", "job_id": job_id, "message": "Diffusion job failed"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For error returning, there are some cases like when using a SDXL model and the height and width are not multiples of 8, this will error out but we want to show a correct error on why it errored out. The current implementation prints out the error but maybe we could send out the HTTPResponse for the error somehow and show that via notifications?

Additionally, we will also need to stream output from this job, as it is shown for other jobs, maybe we include a button where we show the progress to see the raw output of the generation which uses the same ViewOutputStreamingModal on the app side?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either we do the HTTPResponse thing or we make use of how the errors are reported on /generate right now in the except block and we show the errors from the job within there as well if a job has failed status

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


router = APIRouter(prefix="/diffusion", tags=["diffusion"])

UNIVERSIAL_GENERATION_ID = str(uuid.uuid4())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how we reached to using this but this should not be used here.
This is why we have the /generate_id route which generates a generation_id beforehand and you either use that or one provided with the /generate but the UNIVERSIAL_GENERATION_ID should not be used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Copy link
Contributor

@deep1401 deep1401 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is almost ready to merge, just a couple of small changes to ensure we still support the plugins we created surrounding this still work

raise HTTPException(status_code=400, detail="num_images must be between 1 and 10")

# Validate diffusion type
if request.plugin == "image_diffusion":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A minor change here, once you check the name of request.plugin, you should also check whether it is installed or not, refer to the plugins.py router on how this is checked. But if the plugin is not installed, you return an error. This will fix the support issue we have with the dataset_imagegen plugin then

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe after this in dataset_imagegen, we show the error from the request incase generation fails because of the plugin not being installed

diffusion_logger.info(message)


class DiffusionOutputCapture:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not use this anymore so maybe we remove it?

@deep1401 deep1401 self-requested a review July 21, 2025 14:40
@dadmobile
Copy link
Member

This worked for me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants