script.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import json
  2. import base64
  3. import requests
  4. import io
  5. from io import BytesIO
  6. from PIL import Image, PngImagePlugin
  7. from pathlib import Path
  8. import gradio as gr
  9. import torch
  10. import modules.chat as chat
  11. import modules.shared as shared
  12. torch._C._jit_set_profiling_mode(False)
  13. # parameters which can be customized in settings.json of webui
  14. params = {
  15. 'enable_SD_api': False,
  16. 'address': 'http://127.0.0.1:7860',
  17. 'save_img': False,
  18. 'SD_model': 'NeverEndingDream', # not really used right now
  19. 'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
  20. 'negative_prompt': '(worst quality, low quality:1.3)',
  21. 'side_length': 512,
  22. 'restore_faces': False
  23. }
  24. SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
  25. streaming_state = shared.args.no_stream # remember if chat streaming was enabled
  26. picture_response = False # specifies if the next model response should appear as a picture
  27. pic_id = 0
  28. def remove_surrounded_chars(string):
  29. new_string = ""
  30. in_star = False
  31. for char in string:
  32. if char == '*':
  33. in_star = not in_star
  34. elif not in_star:
  35. new_string += char
  36. return new_string
  37. # I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
  38. def input_modifier(string):
  39. """
  40. This function is applied to your text inputs before
  41. they are fed into the model.
  42. """
  43. global params, picture_response
  44. if not params['enable_SD_api']:
  45. return string
  46. commands = ['send', 'mail', 'me']
  47. mediums = ['image', 'pic', 'picture', 'photo']
  48. subjects = ['yourself', 'own']
  49. lowstr = string.lower()
  50. if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
  51. picture_response = True
  52. shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
  53. shared.processing_message = "*Is sending a picture...*"
  54. string = "Please provide a detailed description of your surroundings, how you look and the situation you're in and what you are doing right now"
  55. if any(target in lowstr for target in subjects): # the focus of the image should be on the sending character
  56. string = "Please provide a detailed and vivid description of how you look and what you are wearing"
  57. return string
  58. # Get and save the Stable Diffusion-generated picture
  59. def get_SD_pictures(description):
  60. global params, pic_id
  61. payload = {
  62. "prompt": params['prompt_prefix'] + description,
  63. "seed": -1,
  64. "sampler_name": "DPM++ 2M Karras",
  65. "steps": 32,
  66. "cfg_scale": 7,
  67. "width": params['side_length'],
  68. "height": params['side_length'],
  69. "restore_faces": params['restore_faces'],
  70. "negative_prompt": params['negative_prompt']
  71. }
  72. response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
  73. r = response.json()
  74. visible_result = ""
  75. for img_str in r['images']:
  76. image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",",1)[0])))
  77. if params['save_img']:
  78. output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
  79. image.save(output_file.as_posix())
  80. pic_id += 1
  81. # lower the resolution of received images for the chat, otherwise the history size gets out of control quickly with all the base64 values
  82. newsize = (300, 300)
  83. image = image.resize(newsize, Image.LANCZOS)
  84. buffered = io.BytesIO()
  85. image.save(buffered, format="JPEG")
  86. buffered.seek(0)
  87. image_bytes = buffered.getvalue()
  88. img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
  89. visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
  90. return visible_result
  91. # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
  92. # and replace it with 'text' for the purposes of logging?
  93. def output_modifier(string):
  94. """
  95. This function is applied to the model outputs.
  96. """
  97. global pic_id, picture_response, streaming_state
  98. if not picture_response:
  99. return string
  100. string = remove_surrounded_chars(string)
  101. string = string.replace('"', '')
  102. string = string.replace('“', '')
  103. string = string.replace('\n', ' ')
  104. string = string.strip()
  105. if string == '':
  106. string = 'no viable description in reply, try regenerating'
  107. # I can't for the love of all that's holy get the name from shared.gradio['name1'], so for now it will be like this
  108. text = f'*Description: "{string}"*'
  109. image = get_SD_pictures(string)
  110. picture_response = False
  111. shared.processing_message = "*Is typing...*"
  112. shared.args.no_stream = streaming_state
  113. return image + "\n" + text
  114. def bot_prefix_modifier(string):
  115. """
  116. This function is only applied in chat mode. It modifies
  117. the prefix text for the Bot and can be used to bias its
  118. behavior.
  119. """
  120. return string
  121. def force_pic():
  122. global picture_response
  123. picture_response = True
  124. def ui():
  125. # Gradio elements
  126. with gr.Accordion("Stable Diffusion api integration", open=True):
  127. with gr.Row():
  128. with gr.Column():
  129. enable = gr.Checkbox(value=params['enable_SD_api'], label='Activate SD Api integration')
  130. save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
  131. with gr.Column():
  132. address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
  133. with gr.Row():
  134. force_btn = gr.Button("Force the next response to be a picture")
  135. generate_now_btn = gr.Button("Generate an image response to the input")
  136. with gr.Accordion("Generation parameters", open=False):
  137. prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
  138. with gr.Row():
  139. negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
  140. dimensions = gr.Slider(256,702,value=params['side_length'],step=64,label='Image dimensions')
  141. # model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
  142. # Event functions to update the parameters in the backend
  143. enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
  144. save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
  145. address.change(lambda x: params.update({"address": x}), address, None)
  146. prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None)
  147. negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None)
  148. dimensions.change(lambda x: params.update({"side_length": x}), dimensions, None)
  149. # model.change(lambda x: params.update({"SD_model": x}), model, None)
  150. force_btn.click(force_pic)
  151. generate_now_btn.click(force_pic)
  152. generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)