script.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import base64
  2. from io import BytesIO
  3. import gradio as gr
  4. import modules.chat as chat
  5. import modules.shared as shared
  6. import torch
  7. from PIL import Image
  8. from transformers import BlipForConditionalGeneration, BlipProcessor
  9. # If 'state' is True, will hijack the next chat generation with
  10. # custom input text given by 'value' in the format [text, visible_text]
  11. input_hijack = {
  12. 'state': False,
  13. 'value': ["", ""]
  14. }
  15. processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  16. model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
  17. def caption_image(raw_image):
  18. inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
  19. out = model.generate(**inputs, max_new_tokens=100)
  20. return processor.decode(out[0], skip_special_tokens=True)
  21. def generate_chat_picture(picture, name1, name2):
  22. text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
  23. # lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
  24. picture.thumbnail((300, 300))
  25. buffer = BytesIO()
  26. picture.save(buffer, format="JPEG")
  27. img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
  28. visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
  29. return text, visible_text
  30. def ui():
  31. picture_select = gr.Image(label='Send a picture', type='pil')
  32. # Prepare the hijack with custom inputs
  33. picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
  34. # Call the generation function
  35. picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
  36. # Clear the picture from the upload field
  37. picture_select.upload(lambda : None, [], [picture_select], show_progress=False)