script.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import base64
  2. import io
  3. import re
  4. from pathlib import Path
  5. import gradio as gr
  6. import requests
  7. import torch
  8. from PIL import Image
  9. from modules import chat, shared
  10. torch._C._jit_set_profiling_mode(False)
  11. # parameters which can be customized in settings.json of webui
  12. params = {
  13. 'enable_SD_api': False,
  14. 'address': 'http://127.0.0.1:7860',
  15. 'save_img': False,
  16. 'SD_model': 'NeverEndingDream', # not really used right now
  17. 'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
  18. 'negative_prompt': '(worst quality, low quality:1.3)',
  19. 'side_length': 512,
  20. 'restore_faces': False
  21. }
  22. SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
  23. streaming_state = shared.args.no_stream # remember if chat streaming was enabled
  24. picture_response = False # specifies if the next model response should appear as a picture
  25. pic_id = 0
  26. def remove_surrounded_chars(string):
  27. # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
  28. # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
  29. return re.sub('\*[^\*]*?(\*|$)', '', string)
  30. # I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
  31. def input_modifier(string):
  32. """
  33. This function is applied to your text inputs before
  34. they are fed into the model.
  35. """
  36. global params, picture_response
  37. if not params['enable_SD_api']:
  38. return string
  39. commands = ['send', 'mail', 'me']
  40. mediums = ['image', 'pic', 'picture', 'photo']
  41. subjects = ['yourself', 'own']
  42. lowstr = string.lower()
  43. # TODO: refactor out to separate handler and also replace detection with a regexp
  44. if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
  45. picture_response = True
  46. shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
  47. shared.processing_message = "*Is sending a picture...*"
  48. string = "Please provide a detailed description of your surroundings, how you look and the situation you're in and what you are doing right now"
  49. if any(target in lowstr for target in subjects): # the focus of the image should be on the sending character
  50. string = "Please provide a detailed and vivid description of how you look and what you are wearing"
  51. return string
  52. # Get and save the Stable Diffusion-generated picture
  53. def get_SD_pictures(description):
  54. global params, pic_id
  55. payload = {
  56. "prompt": params['prompt_prefix'] + description,
  57. "seed": -1,
  58. "sampler_name": "DPM++ 2M Karras",
  59. "steps": 32,
  60. "cfg_scale": 7,
  61. "width": params['side_length'],
  62. "height": params['side_length'],
  63. "restore_faces": params['restore_faces'],
  64. "negative_prompt": params['negative_prompt']
  65. }
  66. response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
  67. r = response.json()
  68. visible_result = ""
  69. for img_str in r['images']:
  70. image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
  71. if params['save_img']:
  72. output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
  73. image.save(output_file.as_posix())
  74. pic_id += 1
  75. # lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
  76. image.thumbnail((300, 300))
  77. buffered = io.BytesIO()
  78. image.save(buffered, format="JPEG")
  79. buffered.seek(0)
  80. image_bytes = buffered.getvalue()
  81. img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
  82. visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
  83. return visible_result
  84. # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
  85. # and replace it with 'text' for the purposes of logging?
  86. def output_modifier(string):
  87. """
  88. This function is applied to the model outputs.
  89. """
  90. global pic_id, picture_response, streaming_state
  91. if not picture_response:
  92. return string
  93. string = remove_surrounded_chars(string)
  94. string = string.replace('"', '')
  95. string = string.replace('“', '')
  96. string = string.replace('\n', ' ')
  97. string = string.strip()
  98. if string == '':
  99. string = 'no viable description in reply, try regenerating'
  100. # I can't for the love of all that's holy get the name from shared.gradio['name1'], so for now it will be like this
  101. text = f'*Description: "{string}"*'
  102. image = get_SD_pictures(string)
  103. picture_response = False
  104. shared.processing_message = "*Is typing...*"
  105. shared.args.no_stream = streaming_state
  106. return image + "\n" + text
  107. def bot_prefix_modifier(string):
  108. """
  109. This function is only applied in chat mode. It modifies
  110. the prefix text for the Bot and can be used to bias its
  111. behavior.
  112. """
  113. return string
  114. def force_pic():
  115. global picture_response
  116. picture_response = True
  117. def ui():
  118. # Gradio elements
  119. with gr.Accordion("Stable Diffusion api integration", open=True):
  120. with gr.Row():
  121. with gr.Column():
  122. enable = gr.Checkbox(value=params['enable_SD_api'], label='Activate SD Api integration')
  123. save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
  124. with gr.Column():
  125. address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
  126. with gr.Row():
  127. force_btn = gr.Button("Force the next response to be a picture")
  128. generate_now_btn = gr.Button("Generate an image response to the input")
  129. with gr.Accordion("Generation parameters", open=False):
  130. prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
  131. with gr.Row():
  132. negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
  133. dimensions = gr.Slider(256, 702, value=params['side_length'], step=64, label='Image dimensions')
  134. # model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
  135. # Event functions to update the parameters in the backend
  136. enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
  137. save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
  138. address.change(lambda x: params.update({"address": x}), address, None)
  139. prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None)
  140. negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None)
  141. dimensions.change(lambda x: params.update({"side_length": x}), dimensions, None)
  142. # model.change(lambda x: params.update({"SD_model": x}), model, None)
  143. force_btn.click(force_pic)
  144. generate_now_btn.click(force_pic)
  145. generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)