script.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import base64
  2. from io import BytesIO
  3. import gradio as gr
  4. import modules.chat as chat
  5. import modules.shared as shared
  6. import torch
  7. from transformers import BlipForConditionalGeneration, BlipProcessor
  8. # If 'state' is True, will hijack the next chat generation with
  9. # custom input text given by 'value' in the format [text, visible_text]
  10. input_hijack = {
  11. 'state': False,
  12. 'value': ["", ""]
  13. }
  14. processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  15. model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
  16. def caption_image(raw_image):
  17. inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
  18. out = model.generate(**inputs, max_new_tokens=100)
  19. return processor.decode(out[0], skip_special_tokens=True)
  20. def generate_chat_picture(picture, name1, name2):
  21. text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
  22. buffer = BytesIO()
  23. picture.save(buffer, format="JPEG")
  24. img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
  25. visible_text = f'<img src="data:image/jpeg;base64,{img_str}">'
  26. return text, visible_text
  27. def ui():
  28. picture_select = gr.Image(label='Send a picture', type='pil')
  29. function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
  30. # Prepare the hijack with custom inputs
  31. picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
  32. # Call the generation function
  33. picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
  34. # Clear the picture from the upload field
  35. picture_select.upload(lambda : None, [], [picture_select], show_progress=False)