script.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import base64
  2. import io
  3. import re
  4. import time
  5. from datetime import date
  6. from pathlib import Path
  7. import gradio as gr
  8. import modules.shared as shared
  9. import requests
  10. import torch
  11. from modules.models import reload_model, unload_model
  12. from PIL import Image
  13. torch._C._jit_set_profiling_mode(False)
  14. # parameters which can be customized in settings.json of webui
  15. params = {
  16. 'address': 'http://127.0.0.1:7860',
  17. 'mode': 0, # modes of operation: 0 (Manual only), 1 (Immersive/Interactive - looks for words to trigger), 2 (Picturebook Adventure - Always on)
  18. 'manage_VRAM': False,
  19. 'save_img': False,
  20. 'SD_model': 'NeverEndingDream', # not used right now
  21. 'prompt_prefix': '(Masterpiece:1.1), detailed, intricate, colorful',
  22. 'negative_prompt': '(worst quality, low quality:1.3)',
  23. 'width': 512,
  24. 'height': 512,
  25. 'restore_faces': False,
  26. 'seed': -1,
  27. 'sampler_name': 'DDIM',
  28. 'steps': 32,
  29. 'cfg_scale': 7
  30. }
  31. def give_VRAM_priority(actor):
  32. global shared, params
  33. if actor == 'SD':
  34. unload_model()
  35. print("Requesting Auto1111 to re-load last checkpoint used...")
  36. response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
  37. response.raise_for_status()
  38. elif actor == 'LLM':
  39. print("Requesting Auto1111 to vacate VRAM...")
  40. response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
  41. response.raise_for_status()
  42. reload_model()
  43. elif actor == 'set':
  44. print("VRAM mangement activated -- requesting Auto1111 to vacate VRAM...")
  45. response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
  46. response.raise_for_status()
  47. elif actor == 'reset':
  48. print("VRAM mangement deactivated -- requesting Auto1111 to reload checkpoint")
  49. response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
  50. response.raise_for_status()
  51. else:
  52. raise RuntimeError(f'Managing VRAM: "{actor}" is not a known state!')
  53. response.raise_for_status()
  54. del response
  55. if params['manage_VRAM']:
  56. give_VRAM_priority('set')
  57. samplers = ['DDIM', 'DPM++ 2M Karras'] # TODO: get the availible samplers with http://{address}}/sdapi/v1/samplers
  58. SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
  59. streaming_state = shared.args.no_stream # remember if chat streaming was enabled
  60. picture_response = False # specifies if the next model response should appear as a picture
  61. def remove_surrounded_chars(string):
  62. # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
  63. # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
  64. return re.sub('\*[^\*]*?(\*|$)', '', string)
  65. def triggers_are_in(string):
  66. string = remove_surrounded_chars(string)
  67. # regex searches for send|main|message|me (at the end of the word) followed by
  68. # a whole word of image|pic|picture|photo|snap|snapshot|selfie|meme(s),
  69. # (?aims) are regex parser flags
  70. return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string))
  71. def input_modifier(string):
  72. """
  73. This function is applied to your text inputs before
  74. they are fed into the model.
  75. """
  76. global params
  77. if not params['mode'] == 1: # if not in immersive/interactive mode, do nothing
  78. return string
  79. if triggers_are_in(string): # if we're in it, check for trigger words
  80. toggle_generation(True)
  81. string = string.lower()
  82. if "of" in string:
  83. subject = string.split('of', 1)[1] # subdivide the string once by the first 'of' instance and get what's coming after it
  84. string = "Please provide a detailed and vivid description of " + subject
  85. else:
  86. string = "Please provide a detailed description of your appearance, your surroundings and what you are doing right now"
  87. return string
  88. # Get and save the Stable Diffusion-generated picture
  89. def get_SD_pictures(description):
  90. global params
  91. if params['manage_VRAM']:
  92. give_VRAM_priority('SD')
  93. payload = {
  94. "prompt": params['prompt_prefix'] + description,
  95. "seed": params['seed'],
  96. "sampler_name": params['sampler_name'],
  97. "steps": params['steps'],
  98. "cfg_scale": params['cfg_scale'],
  99. "width": params['width'],
  100. "height": params['height'],
  101. "restore_faces": params['restore_faces'],
  102. "negative_prompt": params['negative_prompt']
  103. }
  104. print(f'Prompting the image generator via the API on {params["address"]}...')
  105. response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
  106. response.raise_for_status()
  107. r = response.json()
  108. visible_result = ""
  109. for img_str in r['images']:
  110. image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
  111. if params['save_img']:
  112. variadic = f'{date.today().strftime("%Y_%m_%d")}/{shared.character}_{int(time.time())}'
  113. output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png')
  114. output_file.parent.mkdir(parents=True, exist_ok=True)
  115. image.save(output_file.as_posix())
  116. visible_result = visible_result + f'<img src="/file/extensions/sd_api_pictures/outputs/{variadic}.png" alt="{description}" style="max-width: unset; max-height: unset;">\n'
  117. else:
  118. # lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
  119. image.thumbnail((300, 300))
  120. buffered = io.BytesIO()
  121. image.save(buffered, format="JPEG")
  122. buffered.seek(0)
  123. image_bytes = buffered.getvalue()
  124. img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
  125. visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
  126. if params['manage_VRAM']:
  127. give_VRAM_priority('LLM')
  128. return visible_result
  129. # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
  130. # and replace it with 'text' for the purposes of logging?
  131. def output_modifier(string):
  132. """
  133. This function is applied to the model outputs.
  134. """
  135. global picture_response, params
  136. if not picture_response:
  137. return string
  138. string = remove_surrounded_chars(string)
  139. string = string.replace('"', '')
  140. string = string.replace('“', '')
  141. string = string.replace('\n', ' ')
  142. string = string.strip()
  143. if string == '':
  144. string = 'no viable description in reply, try regenerating'
  145. return string
  146. text = ""
  147. if (params['mode'] < 2):
  148. toggle_generation(False)
  149. text = f'*Sends a picture which portrays: “{string}”*'
  150. else:
  151. text = string
  152. string = get_SD_pictures(string) + "\n" + text
  153. return string
  154. def bot_prefix_modifier(string):
  155. """
  156. This function is only applied in chat mode. It modifies
  157. the prefix text for the Bot and can be used to bias its
  158. behavior.
  159. """
  160. return string
  161. def toggle_generation(*args):
  162. global picture_response, shared, streaming_state
  163. if not args:
  164. picture_response = not picture_response
  165. else:
  166. picture_response = args[0]
  167. shared.args.no_stream = True if picture_response else streaming_state # Disable streaming cause otherwise the SD-generated picture would return as a dud
  168. shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*"
  169. def filter_address(address):
  170. address = address.strip()
  171. # address = re.sub('http(s)?:\/\/|\/$','',address) # remove starting http:// OR https:// OR trailing slash
  172. address = re.sub('\/$', '', address) # remove trailing /s
  173. if not address.startswith('http'):
  174. address = 'http://' + address
  175. return address
  176. def SD_api_address_update(address):
  177. global params
  178. msg = "✔️ SD API is found on:"
  179. address = filter_address(address)
  180. params.update({"address": address})
  181. try:
  182. response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
  183. response.raise_for_status()
  184. # r = response.json()
  185. except:
  186. msg = "❌ No SD API endpoint on:"
  187. return gr.Textbox.update(label=msg)
  188. def ui():
  189. # Gradio elements
  190. # gr.Markdown('### Stable Diffusion API Pictures') # Currently the name of extension is shown as the title
  191. with gr.Accordion("Parameters", open=True):
  192. with gr.Row():
  193. address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Auto1111\'s WebUI address')
  194. mode = gr.Dropdown(["Manual", "Immersive/Interactive", "Picturebook/Adventure"], value="Manual", label="Mode of operation", type="index")
  195. with gr.Column(scale=1, min_width=300):
  196. manage_VRAM = gr.Checkbox(value=params['manage_VRAM'], label='Manage VRAM')
  197. save_img = gr.Checkbox(value=params['save_img'], label='Keep original images and use them in chat')
  198. force_pic = gr.Button("Force the picture response")
  199. suppr_pic = gr.Button("Suppress the picture response")
  200. with gr.Accordion("Generation parameters", open=False):
  201. prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
  202. with gr.Row():
  203. with gr.Column():
  204. negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
  205. sampler_name = gr.Textbox(placeholder=params['sampler_name'], value=params['sampler_name'], label='Sampler')
  206. with gr.Column():
  207. width = gr.Slider(256, 768, value=params['width'], step=64, label='Width')
  208. height = gr.Slider(256, 768, value=params['height'], step=64, label='Height')
  209. with gr.Row():
  210. steps = gr.Number(label="Steps:", value=params['steps'])
  211. seed = gr.Number(label="Seed:", value=params['seed'])
  212. cfg_scale = gr.Number(label="CFG Scale:", value=params['cfg_scale'])
  213. # Event functions to update the parameters in the backend
  214. address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
  215. mode.select(lambda x: params.update({"mode": x}), mode, None)
  216. mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None)
  217. manage_VRAM.change(lambda x: params.update({"manage_VRAM": x}), manage_VRAM, None)
  218. manage_VRAM.change(lambda x: give_VRAM_priority('set' if x else 'reset'), inputs=manage_VRAM, outputs=None)
  219. save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
  220. address.submit(fn=SD_api_address_update, inputs=address, outputs=address)
  221. prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None)
  222. negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None)
  223. width.change(lambda x: params.update({"width": x}), width, None)
  224. height.change(lambda x: params.update({"height": x}), height, None)
  225. sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None)
  226. steps.change(lambda x: params.update({"steps": x}), steps, None)
  227. seed.change(lambda x: params.update({"seed": x}), seed, None)
  228. cfg_scale.change(lambda x: params.update({"cfg_scale": x}), cfg_scale, None)
  229. force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None)
  230. suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None)