Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions generation_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from .niche_stable_diffusion import NicheStableDiffusion
from .niche_stable_diffusion_xl import NicheStableDiffusionXL
from .niche_go_journey import NicheGoJourney
from .niche_sticker_maker import NicheStickerMaker


__all__ = [
"NicheStableDiffusion",
"NicheStableDiffusionXL",
"NicheGoJourney",
"NicheStickerMaker",
]
5 changes: 5 additions & 0 deletions generation_models/configs/model_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ GoJourney:
params:
supporting_pipelines:
- "gojourney"
StickerMaker:
target: generation_models.NicheStickerMaker
params:
supporting_pipelines:
- "txt2img"
Gemma7b:
target: ""
repo_id: "google/gemma-7b-it"
Expand Down
1 change: 1 addition & 0 deletions generation_models/niche_go_journey.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def load_model(self, *args, **kwargs):
def __call__(self, *args, **kwargs):
return self.inference_function(*args, **kwargs)


def load_imagine(self, *args, **kwargs):
imagine_endpoint = "https://api.midjourneyapi.xyz/mj/v2/imagine"
fetch_endpoint = "https://api.midjourneyapi.xyz/mj/v2/fetch"
Expand Down
20 changes: 20 additions & 0 deletions generation_models/niche_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,23 @@ def process_conditional_image(self, **kwargs) -> Image.Image:
resolution = kwargs.get("resolution", 512)
conditional_image = resize_for_condition_image(conditional_image, resolution)
return conditional_image

if __name__ == "__main__":
params = {
"checkpoint_file": "checkpoints/RealisticVision.safetensors",
"download_url": "https://civitai.com/api/download/models/130072?type=Model&format=SafeTensor&size=pruned&fp=fp16",
"scheduler": "dpm++2m",
"supporting_pipelines": ['txt2img']
}
pipe = NicheStableDiffusion(
**params
)

input_dict = {
"pipeline_type": "txt2img",
"prompt": "a cat",
"num_inference_steps": 25,
}

image = pipe(**input_dict)
image.save("debug.webp")
2 changes: 2 additions & 0 deletions generation_models/niche_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,5 @@ def process_conditional_image(self, **kwargs) -> Image.Image:
resolution = kwargs.get("resolution", 768)
conditional_image = resize_for_condition_image(conditional_image, resolution)
return conditional_image


68 changes: 68 additions & 0 deletions generation_models/niche_sticker_maker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from .base_model import BaseModel
from generation_models.utils_api_comfyui import *
import uuid
from PIL import Image
import io
import base64


server_address = "82.67.70.191:40892"
client_id = "13c08530-8911-4e38-8489-7cded8eddd9d"

class NicheStickerMaker(BaseModel):
def __init__(self, *args, **kwargs):
self.server_address, self.client_id = kwargs.get("server_address"), kwargs.get("client_id")

self.inference_function = self.load_model(*args, **kwargs)

def __call__(self, *args, **kwargs):
return self.inference_function(*args, **kwargs)


def load_model(self, *args, **kwargs):
imagine_inference_function = self.load_image(*args, **kwargs)
return imagine_inference_function

def load_image(self, *args, **kwargs):
# workflow = load_workflow("generation_models/sticker_maker.json")

def inference_function(*args, **kwargs):
with open("generation_models/workflow-json/sticker_maker.json", "r") as file:
workflow_json = file.read()


workflow = json.loads(workflow_json)
workflow["2"]["inputs"]["positive"] = kwargs["prompt"]
workflow["4"]["inputs"]["seed"] = kwargs["seed"]


ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))

images = get_images(ws, workflow)
imgs = []
for node_id in images:
for image_data in images[node_id]:
image = Image.open(io.BytesIO(image_data))
imgs.append(image)

return imgs[0]

return inference_function

if __name__=="__main__":
params = {
"supporting_pipelines": ['txt2img']
}
pipe = NicheStickerMaker(
**params
)

input_dict = {
"pipeline_type": "txt2img",
"prompt": "a cat",
}

image = pipe(**input_dict)
image.save("debug.webp")

94 changes: 94 additions & 0 deletions generation_models/utils_api_comfyui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse

server_address = "82.67.70.191:40892"
client_id = "13c08530-8911-4e38-8489-7cded8eddd9d"

def queue_prompt(prompt):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())

def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
return response.read()

def get_history(prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
return json.loads(response.read())

def load_workflow(workflow_path):
try:
with open(workflow_path, 'r') as file:
workflow = json.load(file)
return workflow
except FileNotFoundError:
print(f"The file {workflow_path} was not found.")
return None
except json.JSONDecodeError:
print(f"The file {workflow_path} contains invalid JSON.")
return None

def get_images(ws, prompt):
prompt_id = queue_prompt(prompt)['prompt_id']
output_images = {}
while True:
out = ws.recv()

if isinstance(out, str):

message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
break #Execution is done
else:
continue #previews are binary data

history = get_history(prompt_id)[prompt_id]
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output

return output_images
if __name__ == "__main__":

with open("generation_models/workflow-json/sticker_maker.json", "r") as file:
workflow_json = file.read()


workflow = json.loads(workflow_json)
#set the text prompt for our positive CLIPTextEncode
workflow["2"]["inputs"]["positive"] = "a dog"
workflow["4"]["inputs"]["seed"] = 7

ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))

print(ws)
images = get_images(ws, workflow)


#Commented out code to display the output images:
i=0
for node_id in images:
for image_data in images[node_id]:
i+=1
from PIL import Image
import io
image = Image.open(io.BytesIO(image_data))
# image.show()
print("out")
image.save(f"Output{i}.webp")
Loading