Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 78 additions & 4 deletions script.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

torch._C._jit_set_profiling_mode(False)


sd_models = [] # list of models reported by SD-server (fetched in fetch_models_on_sd_server)


# parameters which can be customized in settings.json of webui
params = {
'address': 'http://127.0.0.1:7860',
Expand All @@ -34,7 +38,7 @@
'seed': -1,
'sampler_name': 'DDIM',
'steps': 32,
'cfg_scale': 7
'cfg_scale': 7
}


Expand Down Expand Up @@ -74,7 +78,6 @@ def give_VRAM_priority(actor):
give_VRAM_priority('set')

samplers = ['DDIM', 'DPM++ 2M Karras'] # TODO: get the availible samplers with http://{address}}/sdapi/v1/samplers
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select

picture_response = False # specifies if the next model response should appear as a picture

Expand Down Expand Up @@ -130,7 +133,7 @@ def get_SD_pictures(description):
give_VRAM_priority('SD')

payload = {
"prompt": params['prompt_prefix'] + description,
"prompt": params['prompt_prefix'] + ", " + description,
"seed": params['seed'],
"sampler_name": params['sampler_name'],
"enable_hr": params['enable_hr'],
Expand Down Expand Up @@ -163,7 +166,7 @@ def get_SD_pictures(description):
with open(output_file.as_posix(), 'wb') as f:
f.write(img_data)

visible_result = visible_result + f'<img src="/file/extensions/sd_api_pictures/outputs/{variadic}.png" alt="{description}" style="max-width: unset; max-height: unset;">\n'
visible_result = visible_result + f'[<a target="_blank" href="/file/extensions/sd_api_pictures/outputs/{variadic}.png">Attachment</a>]\n'
else:
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
Expand All @@ -180,6 +183,35 @@ def get_SD_pictures(description):

return visible_result

# function filters out common conversational words from string
# NOTE: the local array substring_to_remove should be customizable
def filter_out_conversational_words(string):
# convert the string to lowercase for case-insensitive matching (sd models ignore capitalization anyway, as far as I know)
string = string.lower()

# define the list of substrings to remove
# NB. there are a lot of words that might be filtered depending on character, chatbot, and user's purpose but I have kept this list to contain mostly words that seem to 'confuse' the SD models I know
substrings_to_remove = [" i'm ", " i'd ", " a ", " an ", " i ", " me ", " my ", " mine ", " you ", " your ", " they ", "they", "'re ", "their", " at ", " the ", " that's ", "this", " who ", " and ", " but ", " all ", " it's", " i've ", " it ", " in ", " to ", " there ", " there's ", " these ", " those " "where's ", " from ", " is ", " am ", " are ", " was ", " were ", " my ", " me ", " you " " will ", " be ", " can ", " could ", " has ", " or ", " that ", " photos", " pictures" , " of ", "okay", " ok ", " here", " go ", " done ", "danbooru", " wtf", " put ", " what ", " why ", " would ", "should ", " good ", " one ", " oh ", " yeah ", " now ", " tag ", " tags ", " tagged ", " tagged as ", " description ", " describe ", " also", "without", " while ", " goes ", "anyways", "because", " still ", " going ", " so ", " then ", " these ", " else ", " might ", "http", " let ", " try ", " let's ", "see ", " name ", " hello ", " do ", " where ", " represents ", " got ", " about ", " how ", " much ", " well ", " um ", " umm "]
# define the list of special character substrings to remove
trailing_characters_to_remove = [ " - ","--",".", ", ,",",,"," , ",",,", "!", "?", ";", ":", ",,", "&", "(", ")", "<", ">", "/", "\\"]

# loop through each substring in the list and remove it from the string
for substring in substrings_to_remove:
string = string.replace(substring, ", ")
string = string.replace(" " + substring.strip()+",", ", ")
string = string.replace(" " + substring.strip()+".", ", ")

# removing resulting trailing characters
for substring in trailing_characters_to_remove:
string = string.replace(substring, "")

string = string.replace(" , ", " ")
string = string.replace(" ", " ")

# return the filtered string
return string


# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging?
def output_modifier(string):
Expand Down Expand Up @@ -260,6 +292,43 @@ def SD_api_address_update(address):

return gr.Textbox.update(label=msg)

# Gets the list of available SD models on the SD-server.
# Saves the titles of the response in the models lists.
def fetch_models_on_sd_server():

response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
response.raise_for_status()
response_json = response.json()

#transer each title from the payload to the models (list of strings)
for item in response_json:
sd_models.append(item['title'])


def fetch_current_model_on_sd_server():
response = requests.get(url=f'{params["address"]}/sdapi/v1/options')
response.raise_for_status()
response_json = response.json()

if response_json["sd_model_checkpoint"]:
params['SD_model'] = response_json["sd_model_checkpoint"]


# Loads model on SD-server
def load_sd_model_remote(name):

payload = {
"sd_model_checkpoint" : name
}

params['SD_model'] = name

response = requests.post(url=f'{params["address"]}/sdapi/v1/options', json=payload)

# Initialization of list of sd-models and fetching the loaded model.
fetch_models_on_sd_server()
fetch_current_model_on_sd_server()


def ui():

Expand Down Expand Up @@ -287,6 +356,9 @@ def ui():
with gr.Column():
sampler_name = gr.Textbox(placeholder=params['sampler_name'], value=params['sampler_name'], label='Sampling method', elem_id="sampler_box")
steps = gr.Slider(1, 150, value=params['steps'], step=1, label="Sampling steps")
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(sd_models,value=params['SD_model'],label="Stable Diffusion Model",type="value")
with gr.Row():
seed = gr.Number(label="Seed", value=params['seed'], elem_id="seed_box")
cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box")
Expand Down Expand Up @@ -318,6 +390,8 @@ def ui():
enable_hr.change(lambda x: params.update({"enable_hr": x}), enable_hr, None)
enable_hr.change(lambda x: hr_options.update(visible=params["enable_hr"]), enable_hr, hr_options)

model_dropdown.select(lambda x: load_sd_model_remote(x), model_dropdown, outputs=None)

sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None)
steps.change(lambda x: params.update({"steps": x}), steps, None)
seed.change(lambda x: params.update({"seed": x}), seed, None)
Expand Down