Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 179 additions & 23 deletions ui_trt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import html
import os

Expand All @@ -13,21 +14,150 @@
from modules.ui_components import FormRow


def export_unet_to_onnx(filename, opset):
if not filename:
modelname = shared.sd_model.sd_checkpoint_info.model_name + ".onnx"
filename = os.path.join(paths_internal.models_path, "Unet-onnx", modelname)

export_onnx.export_current_unet_to_onnx(filename, opset)

return f'Saved as {filename}', ''


def get_trt_filename(filename, onnx_filename):
def export_unet_to_onnx(filename, opset, batch_run, batch_directory):
print(f'Starting Conversion to .onnx')

# Check if 'Unet-onnx' directory exists and create it if not
unet_onnx_path = os.path.join(paths_internal.models_path, "Unet-onnx")
os.makedirs(unet_onnx_path, exist_ok=True)

# Use default folder if batch_directory is empty
if not batch_directory:
batch_directory = os.path.join(paths_internal.models_path, "Stable-diffusion")
# Batch mode
if batch_run:
print(f"--Batch Models mode--")

onnx_files = os.listdir(os.path.join(paths_internal.models_path, "Unet-onnx"))
onnx_files_to_process = list()
# Take all files if destination folder is empty
if not onnx_files:
print(f"Unet-onnx is empty, adding all .safetensors and .ckpt files from {batch_directory}\n") # Debug line
onnx_files_to_process = [file for file in os.listdir(batch_directory) if file.endswith(".safetensors") or file.endswith(".ckpt")]
else:
for batch_file in os.listdir(batch_directory):
add_flag = True
for onnx_file in onnx_files:
if batch_file.split('.')[0] == onnx_file.split('.')[0]:
add_flag = False
if add_flag and (batch_file.endswith(".safetensors") or batch_file.endswith(".ckpt")):
onnx_files_to_process.append(batch_file)
print(f"Files to process:\n{onnx_files_to_process}\n")
# Exit if no files to process
if not onnx_files_to_process:
print(f'No files to convert...\nPlease uncheck the "Run Batch" checkbox or use a folder containing models.')
return f'No files to convert...\nPlease uncheck the "Run Batch" checkbox or use a folder containing models.', ''
# Process files
for i, file in enumerate(onnx_files_to_process):
print(f"Converting model file: {file}") # Debug line
modelname = os.path.splitext(file)[0] + ".onnx"
onnx_filename = os.path.join(unet_onnx_path, modelname)
print(f"Target ONNX filename: {onnx_filename}\n") # Debug line

export_onnx.export_current_unet_to_onnx(onnx_filename, opset)

# Ending message
print(f'Batch conversion completed for files in {batch_directory}')
return f'Batch conversion completed for files in {batch_directory}', ''
# Single mode
else:
print(f"--Single Model mode--")
if not filename:
modelname = shared.sd_model.sd_checkpoint_info.model_name + ".onnx"
onnx_filename = os.path.join(unet_onnx_path, modelname)
print(f"Target ONNX filename: {onnx_filename}\n") # Debug line

export_onnx.export_current_unet_to_onnx(onnx_filename, opset)

# Ending message
print(f'Done! Model saved as {onnx_filename}')
return f'Done! Model saved as {onnx_filename}', ''


def convert_onnx_to_trt(filename, onnx_filename, add_shape_to_filename, batch_run, batch_directory, *args):
assert not cmd_opts.disable_extension_access, "Won't run the command to create TensorRT file because extension access is disabled (use --enable-insecure-extension-access)"
print(f'Starting Conversion to .trt')

# Check if 'Unet-trt' directory exists and create it if not
unet_onnx_path = os.path.join(paths_internal.models_path, "Unet-trt")
os.makedirs(unet_onnx_path, exist_ok=True)

# Use default folder if batch_directory is empty
if not batch_directory:
batch_directory = os.path.join(paths_internal.models_path, "Unet-onnx")
# Batch mode
if batch_run:
print(f"--Batch Models mode--")

trt_files = os.listdir(os.path.join(paths_internal.models_path, "Unet-trt"))
trt_files_to_process = list()
# Take all files if destination folder is empty
if not trt_files:
print(f"Unet-trt is empty, adding all .onnx files from {batch_directory}\n") # Debug line
trt_files_to_process = [file for file in os.listdir(batch_directory) if file.endswith(".onnx")]
else:
for batch_file in os.listdir(batch_directory):
add_flag = True
for trt_file in trt_files:
size = len(trt_file.split('_'))
trt_file_shape = ""
for tokens in trt_file.split('_'):
if size > 1:
trt_file_shape += tokens
if size > 2:
trt_file_shape += '_'
size -= 1
if batch_file.split('.')[0] == trt_file.split('.')[0] or batch_file.split('.')[0] == trt_file_shape:
add_flag = False
if add_flag and batch_file.endswith(".onnx"):
trt_files_to_process.append(batch_file)
print(f"Files to process:\n{trt_files_to_process}\n")
# Exit if no files to process
if not trt_files_to_process:
print(f'No files to convert...\nPlease uncheck the "Run Batch" checkbox or use a folder containing models.')
return f'No files to convert...\nPlease uncheck the "Run Batch" checkbox or use a folder containing models.', ''
# Process files
for file in trt_files_to_process:
onnx_file = os.path.join(batch_directory, file)
print(f"Converting ONNX file: {onnx_file}") # Debug line
modelname = os.path.splitext(file)[0] + ".trt"
filename = os.path.join(paths_internal.models_path, "Unet-trt", modelname)

trt_filename = get_trt_filename(filename, onnx_file, batch_run, add_shape_to_filename, *args)
print(f"Target TRT filename: {trt_filename}\n") # Debug line
command = export_trt.get_trt_command(trt_filename, onnx_file, *args)

launch.run(command, live=True)

# Ending message
print(f'Batch conversion completed for files in {batch_directory}')
return f'Batch conversion completed for files in {batch_directory}', ''
# Single mode
else:
print(f"--Single Model mode--")
trt_filename = get_trt_filename(filename, onnx_filename, add_shape_to_filename, *args)
print(f"Target TRT filename: {trt_filename}\n") # Debug line
command = export_trt.get_trt_command(trt_filename, onnx_filename, *args)

launch.run(command, live=True)

# Ending message
print(f'Done! Model saved as {trt_filename}')
return f'Done! Model saved as {trt_filename}', ''


def get_trt_filename(filename, onnx_filename, batch_run=False, add_shape_to_filename=False, *args):
modelname = os.path.splitext(os.path.basename(onnx_filename))[0];
print("Shape args: ", args) # args: (1, 1, 75, 750, 512, 768, 512, 960, True, '')
#({0}min_bs, {1}max_bs, {2}min_token_count, {3}max_token_count, {4}min_width, {5}max_width, {6}min_height, {7}max_height, {8}use_fp16, {9}trt_extra_args)
if(add_shape_to_filename):
modelname += f'_{args[1]}x{args[5]}x{args[7]}' + ".trt"
else:
modelname += ".trt"
if batch_run:
return os.path.join(paths_internal.models_path, "Unet-trt", modelname)
if filename:
return filename

modelname = os.path.splitext(os.path.basename(onnx_filename))[0] + ".trt"
return os.path.join(paths_internal.models_path, "Unet-trt", modelname)


Expand Down Expand Up @@ -60,15 +190,21 @@ def get_trt_command(filename, onnx_filename, *args):
"""


def convert_onnx_to_trt(filename, onnx_filename, *args):
assert not cmd_opts.disable_extension_access, "won't run the command to create TensorRT file because extension access is dsabled (use --enable-insecure-extension-access)"
async def calculate_and_check_constraints(max_width, max_height, max_batch_size):
B = max_batch_size * 2
unknown = 4
H = max_height / 8
W = max_width / 8

filename = get_trt_filename(filename, onnx_filename)
command = export_trt.get_trt_command(filename, onnx_filename, *args)
calculated_value = int(B * unknown * H * W)

return f"{calculated_value} / 92160", ""

launch.run(command, live=True)

return f'Saved as {filename}', ''
def on_button_clicked(max_width, max_height, max_bs):
loop = asyncio.get_event_loop()
calculated_value, _ = loop.run_until_complete(calculate_and_check_constraints(max_width, max_height, max_bs))
return calculated_value, ""


def on_ui_tabs():
Expand All @@ -82,12 +218,16 @@ def on_ui_tabs():
onnx_filename = gr.Textbox(label='Filename', value="", elem_id="onnx_filename", info="Leave empty to use the same name as model and put results into models/Unet-onnx directory")
onnx_opset = gr.Number(label='ONNX opset version', precision=0, value=17, info="Leave this alone unless you know what you are doing")

batch_run_onnx = gr.Checkbox(label='Run Batch', value=False)
batch_directory_onnx = gr.Textbox(label='Directory', value="", info="Input directory containing models. Leave empty to use the default 'models/Stable-diffusion' folder as source.")

button_export_unet = gr.Button(value="Convert Unet to ONNX", variant='primary', elem_id="onnx_export_unet")

with gr.Tab(label="Convert ONNX to TensorRT"):
trt_source_filename = gr.Textbox(label='Onnx model filename', value="", elem_id="trt_source_filename")
trt_filename = gr.Textbox(label='Output filename', value="", elem_id="trt_filename", info="Leave empty to use the same name as onnx and put results into models/Unet-trt directory")

add_shape_to_filename = gr.Checkbox(label='Add Shape to end of filename', value=False)

with gr.Column(elem_id="trt_width"):
min_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Minimum width", value=512, elem_id="trt_min_width")
max_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Maximum width", value=512, elem_id="trt_max_width")
Expand All @@ -104,27 +244,43 @@ def on_ui_tabs():
min_token_count = gr.Slider(minimum=75, maximum=750, step=75, label="Minimum prompt token count", value=75, elem_id="trt_min_token_count")
max_token_count = gr.Slider(minimum=75, maximum=750, step=75, label="Maximum prompt token count", value=75, elem_id="trt_max_token_count")

button_check_constraints = gr.Button(value="Calculate values limit for conversion", variant='secondary', elem_id="calculate")

with gr.Column(elem_id="trt_calculated_value"):
calculated_value_label = gr.Label(elem_id="calculated_value", value="32768 / 92160", label="Current Value / Limit", show_label=True)

trt_extra_args = gr.Textbox(label='Extra arguments', value="", elem_id="trt_extra_args", info="Extra arguments for trtexec command in plain text form")

with FormRow(elem_classes="checkboxes-row", variant="compact"):
use_fp16 = gr.Checkbox(label='Use half floats', value=True, elem_id="trt_fp16")

batch_run_trt = gr.Checkbox(label='Run Batch', value=False)
batch_directory_trt = gr.Textbox(label='Directory', value="", info="Input directory containing models. Leave empty to use the default 'models/Unet-onnx' folder as source.")

button_export_trt = gr.Button(value="Convert ONNX to TensorRT", variant='primary', elem_id="trt_convert_from_onnx")
button_show_trt_command = gr.Button(value="Show command for conversion", variant='secondary', elem_id="trt_convert_from_onnx")

with gr.Column(variant='panel'):
trt_result = gr.Label(elem_id="trt_result", value="", show_label=False)
trt_info = gr.HTML(elem_id="trt_info", value="")
button_check_constraints
calculated_value_label

button_export_unet.click(
wrap_gradio_gpu_call(export_unet_to_onnx, extra_outputs=["Conversion failed"]),
inputs=[onnx_filename, onnx_opset],
inputs=[onnx_filename, onnx_opset, batch_run_onnx, batch_directory_onnx],
outputs=[trt_result, trt_info],
)

button_check_constraints.click(
on_button_clicked,
inputs=[max_width, max_height, max_bs],
outputs=[calculated_value_label, trt_info],
)

button_export_trt.click(
wrap_gradio_gpu_call(convert_onnx_to_trt, extra_outputs=[""]),
inputs=[trt_filename, trt_source_filename, min_bs, max_bs, min_token_count, max_token_count, min_width, max_width, min_height, max_height, use_fp16, trt_extra_args],
inputs=[trt_filename, trt_source_filename, add_shape_to_filename, batch_run_trt, batch_directory_trt, min_bs, max_bs, min_token_count, max_token_count, min_width, max_width, min_height, max_height, use_fp16, trt_extra_args],
outputs=[trt_result, trt_info],
)

Expand All @@ -135,4 +291,4 @@ def on_ui_tabs():
)

return [(trt_interface, "TensorRT", "tensorrt")]