script.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import base64
  2. from io import BytesIO
  3. import gradio as gr
  4. import modules.chat as chat
  5. import modules.shared as shared
  6. from modules.bot_picture import caption_image
  7. params = {
  8. }
  9. # If 'state' is True, will hijack the next chat generation with
  10. # custom input text
  11. input_hijack = {
  12. 'state': False,
  13. 'value': ["", ""]
  14. }
  15. def generate_chat_picture(picture, name1, name2):
  16. text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
  17. buffer = BytesIO()
  18. picture.save(buffer, format="JPEG")
  19. img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
  20. visible_text = f'<img src="data:image/jpeg;base64,{img_str}">'
  21. return text, visible_text
  22. def input_modifier(string):
  23. """
  24. This function is applied to your text inputs before
  25. they are fed into the model.
  26. """
  27. return string
  28. def output_modifier(string):
  29. """
  30. This function is applied to the model outputs.
  31. """
  32. return string
  33. def bot_prefix_modifier(string):
  34. """
  35. This function is only applied in chat mode. It modifies
  36. the prefix text for the Bot and can be used to bias its
  37. behavior.
  38. """
  39. return string
  40. def ui():
  41. picture_select = gr.Image(label='Send a picture', type='pil')
  42. function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
  43. picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
  44. picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
  45. picture_select.upload(lambda : None, [], [picture_select], show_progress=False)