-
Notifications
You must be signed in to change notification settings - Fork 34
Diffusion Plugin #421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Diffusion Plugin #421
Conversation
Codecov ReportAttention: Patch coverage is 📢 Thoughts on this report? Let us know! |
api.py
Outdated
recipes, | ||
users, | ||
) | ||
|
||
from transformerlab.routers.experiment import diffusion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not import it like this, look at the routers/experiment/experiment.py file inside and see how other routes under experiment are imported and follow that structure
@@ -452,6 +452,10 @@ def set_job_completion_status( | |||
self.add_to_job_data("completion_status", completion_status) | |||
self.add_to_job_data("completion_details", completion_details) | |||
|
|||
# Update the job status field if there's a failure | |||
if completion_status == "failed": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a duplicate change?
Maybe remove this as we do the same on the next line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the duplicate is removed.
self.params[key] = arg | ||
key = None | ||
|
||
def save_generated_images(self, metadata: dict = None, suffix: str = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what this is being used for? If its a dataset then it is valid but if you're trying to save a generated image with this then this might be wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this function and the get_output_file_path
function but kept diffusion.py since it might come in handy for future diffusion plugins
def get_output_file_path(self, suffix="", dir_only=False): | ||
self._ensure_args_parsed() | ||
|
||
workspace_dir = os.environ.get("_TFL_WORKSPACE_DIR", "./") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of this could you import WORKSPACE_DIR from transformerlab.plugin and use that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
solved (comment above)
print(f"[DIFFUSION] Metadata saved to: {metadata_file}") | ||
return metadata_file | ||
|
||
def get_output_file_path(self, suffix="", dir_only=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand why we add this here but if this is the exact same with only a difference of "diffusion" and "datasets", then it would be wiser to move this out to the base TLabPlugin class and add a parameter so GenTLab and DiffusionTLab can both use
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
solved (comment above)
@@ -720,7 +483,11 @@ def get_pipeline( | |||
# Load LoRA adaptor if provided - same code for local and HF Hub! | |||
if adaptor and adaptor.strip(): | |||
try: | |||
adaptor_dir = os.path.join(os.environ.get("_TFL_WORKSPACE_DIR"), "adaptors", secure_filename(model)) | |||
adaptor_dir = os.path.join( | |||
os.environ.get("_TFL_WORKSPACE_DIR"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be imported from transformerlab.plugin as an earlier comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -1097,7 +1160,22 @@ def run_pipe(): | |||
generation_start = time.time() | |||
log_print("Starting image generation...") | |||
|
|||
progress_queue = queue.Queue() | |||
|
|||
async def monitor_progress(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this used for? I thought we do progress update inside the callback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
|
||
def parse_args(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not be parsing any arg inside a plugin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
parser.add_argument("--plugin_dir", type=str, required=True) | ||
parser.add_argument("--job_id", type=str, required=True) | ||
parser.add_argument("--experiment_name", type=str, required=True) | ||
parser.add_argument("--run_name", type=str, default="diffused") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is stale code and you need to remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
config = job_config.get("config", {}) | ||
|
||
# Convert base64 images to files and update config | ||
base64_fields = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you we need to add this to config? Why cant the plugin directly fetch them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to save these images because if the input/mask image is passed as a huge base64 string via CLI args, it hits a wall where Linux has a system-imposed limit on command-line length (typically around 128 KB to 256 KB total across all args and env vars).
return {"controlnets": models} | ||
def main(): | ||
print("Starting image diffusion...", flush=True) | ||
asyncio.run(async_main()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to do this? And if you do need this, try to see how we have a async_job_wrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modified
@tlab_diffusion.async_job_wrapper(progress_start=0, progress_end=100) | ||
async def diffusion_generate_job(): | ||
job_config = { | ||
"plugin": tlab_diffusion.params.get("plugin", "image_diffusion"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably don't need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
|
||
@tlab_diffusion.async_job_wrapper(progress_start=0, progress_end=100) | ||
async def diffusion_generate_job(): | ||
job_config = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why aren't we directly using the tlab_diffusion.params instead of maintaining this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
"mask_image_path": "mask_image", | ||
} | ||
|
||
for path_key, base64_key in image_path_keys.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify here on previous comment, this makes the images back to base64 for processing by diffusers right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a note: we left this blank for now because all files are being installed by the original requirements
import asyncio | ||
import threading | ||
import gc | ||
from transformerlab.plugin import WORKSPACE_DIR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we standardize using this throughtout the file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
self._parser.add_argument("--diffusion_model", default="local", type=str, help="Diffusion model to use") | ||
self._parser.add_argument("--model", type=str, default="") | ||
self._parser.add_argument( | ||
"--diffusion_type", default="txt2img", type=str, help="Type of diffusion task (txt2img, img2img, etc.)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we use this param somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding it here because its probably not included in the diff, now that we include diffusion_worker in the plugin, we should not have a copy in the shared folder
except Exception as e: | ||
await db_jobs.job_update_status(job_id, "FAILED") | ||
print(f"[DIFFUSION] Job {job_id} execution error: {e}") | ||
return {"status": "error", "job_id": job_id, "message": "Diffusion job failed"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For error returning, there are some cases like when using a SDXL model and the height and width are not multiples of 8, this will error out but we want to show a correct error on why it errored out. The current implementation prints out the error but maybe we could send out the HTTPResponse for the error somehow and show that via notifications?
Additionally, we will also need to stream output from this job, as it is shown for other jobs, maybe we include a button where we show the progress to see the raw output of the generation which uses the same ViewOutputStreamingModal on the app side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either we do the HTTPResponse thing or we make use of how the errors are reported on /generate right now in the except block and we show the errors from the job within there as well if a job has failed status
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
router = APIRouter(prefix="/diffusion", tags=["diffusion"]) | ||
|
||
UNIVERSIAL_GENERATION_ID = str(uuid.uuid4()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know how we reached to using this but this should not be used here.
This is why we have the /generate_id route which generates a generation_id beforehand and you either use that or one provided with the /generate but the UNIVERSIAL_GENERATION_ID
should not be used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
…ab-api into add/change-diffusion-to-plugin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is almost ready to merge, just a couple of small changes to ensure we still support the plugins we created surrounding this still work
raise HTTPException(status_code=400, detail="num_images must be between 1 and 10") | ||
|
||
# Validate diffusion type | ||
if request.plugin == "image_diffusion": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A minor change here, once you check the name of request.plugin, you should also check whether it is installed or not, refer to the plugins.py router on how this is checked. But if the plugin is not installed, you return an error. This will fix the support issue we have with the dataset_imagegen plugin then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe after this in dataset_imagegen, we show the error from the request incase generation fails because of the plugin not being installed
diffusion_logger.info(message) | ||
|
||
|
||
class DiffusionOutputCapture: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not use this anymore so maybe we remove it?
This worked for me! |
No description provided.