script.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import base64
  2. from io import BytesIO
  3. import gradio as gr
  4. import torch
  5. from transformers import BlipForConditionalGeneration, BlipProcessor
  6. from modules import chat, shared
  7. # If 'state' is True, will hijack the next chat generation with
  8. # custom input text given by 'value' in the format [text, visible_text]
  9. input_hijack = {
  10. 'state': False,
  11. 'value': ["", ""]
  12. }
  13. processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  14. model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
  15. def caption_image(raw_image):
  16. inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
  17. out = model.generate(**inputs, max_new_tokens=100)
  18. return processor.decode(out[0], skip_special_tokens=True)
  19. def generate_chat_picture(picture, name1, name2):
  20. text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
  21. # lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
  22. picture.thumbnail((300, 300))
  23. buffer = BytesIO()
  24. picture.save(buffer, format="JPEG")
  25. img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
  26. visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
  27. return text, visible_text
  28. def ui():
  29. picture_select = gr.Image(label='Send a picture', type='pil')
  30. # Prepare the hijack with custom inputs
  31. picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
  32. # Call the generation function
  33. picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
  34. # Clear the picture from the upload field
  35. picture_select.upload(lambda: None, [], [picture_select], show_progress=False)