Ver Fonte

Merge pull request #489 from Brawlence/ext-fixes

Extensions performance & memory optimisations
oobabooga há 2 anos atrás
pai
commit
d5fc1bead7

+ 6 - 11
extensions/elevenlabs_tts/script.py

@@ -1,11 +1,11 @@
+import re
 from pathlib import Path
 
 import gradio as gr
+import modules.shared as shared
 from elevenlabslib import ElevenLabsUser
 from elevenlabslib.helpers import save_bytes_to_path
 
-import modules.shared as shared
-
 params = {
     'activate': True,
     'api_key': '12345',
@@ -52,14 +52,9 @@ def refresh_voices():
         return
 
 def remove_surrounded_chars(string):
-    new_string = ""
-    in_star = False
-    for char in string:
-        if char == '*':
-            in_star = not in_star
-        elif not in_star:
-            new_string += char
-    return new_string
+    # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
+    # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
+    return re.sub('\*[^\*]*?(\*|$)','',string)
 
 def input_modifier(string):
     """
@@ -115,4 +110,4 @@ def ui():
     voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
     api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
     connect.click(check_valid_api, [], connection_status)
-    connect.click(refresh_voices, [], voice)
+    connect.click(refresh_voices, [], voice)

+ 11 - 15
extensions/sd_api_pictures/script.py

@@ -1,15 +1,15 @@
 import base64
 import io
+import re
 from pathlib import Path
 
 import gradio as gr
+import modules.chat as chat
+import modules.shared as shared
 import requests
 import torch
 from PIL import Image
 
-import modules.chat as chat
-import modules.shared as shared
-
 torch._C._jit_set_profiling_mode(False)
 
 # parameters which can be customized in settings.json of webui  
@@ -31,14 +31,9 @@ picture_response = False # specifies if the next model response should appear as
 pic_id = 0
 
 def remove_surrounded_chars(string):
-    new_string = ""
-    in_star = False
-    for char in string:
-        if char == '*':
-            in_star = not in_star
-        elif not in_star:
-            new_string += char
-    return new_string
+    # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
+    # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
+    return re.sub('\*[^\*]*?(\*|$)','',string)
 
 # I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
 def input_modifier(string):
@@ -54,6 +49,8 @@ def input_modifier(string):
     mediums = ['image', 'pic', 'picture', 'photo']
     subjects = ['yourself', 'own']
     lowstr = string.lower()
+
+    # TODO: refactor out to separate handler and also replace detection with a regexp
     if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
         picture_response = True
         shared.args.no_stream = True                                                               # Disable streaming cause otherwise the SD-generated picture would return as a dud
@@ -91,9 +88,8 @@ def get_SD_pictures(description):
             output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
             image.save(output_file.as_posix())
             pic_id += 1
-        # lower the resolution of received images for the chat, otherwise the history size gets out of control quickly with all the base64 values
-        newsize = (300, 300)
-        image = image.resize(newsize, Image.LANCZOS)
+        # lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
+        image.thumbnail((300, 300))
         buffered = io.BytesIO()
         image.save(buffered, format="JPEG")
         buffered.seek(0)
@@ -180,4 +176,4 @@ def ui():
 
     force_btn.click(force_pic)
     generate_now_btn.click(force_pic)
-    generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
+    generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)

+ 6 - 4
extensions/send_pictures/script.py

@@ -2,11 +2,11 @@ import base64
 from io import BytesIO
 
 import gradio as gr
-import torch
-from transformers import BlipForConditionalGeneration, BlipProcessor
-
 import modules.chat as chat
 import modules.shared as shared
+import torch
+from PIL import Image
+from transformers import BlipForConditionalGeneration, BlipProcessor
 
 # If 'state' is True, will hijack the next chat generation with
 # custom input text given by 'value' in the format [text, visible_text]
@@ -25,10 +25,12 @@ def caption_image(raw_image):
 
 def generate_chat_picture(picture, name1, name2):
     text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
+    # lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
+    picture.thumbnail((300, 300))
     buffer = BytesIO()
     picture.save(buffer, format="JPEG")
     img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
-    visible_text = f'<img src="data:image/jpeg;base64,{img_str}">'
+    visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
     return text, visible_text
 
 def ui():

+ 6 - 11
extensions/silero_tts/script.py

@@ -1,11 +1,11 @@
+import re
 import time
 from pathlib import Path
 
 import gradio as gr
-import torch
-
 import modules.chat as chat
 import modules.shared as shared
+import torch
 
 torch._C._jit_set_profiling_mode(False)
 
@@ -46,14 +46,9 @@ def load_model():
 model = load_model()
 
 def remove_surrounded_chars(string):
-    new_string = ""
-    in_star = False
-    for char in string:
-        if char == '*':
-            in_star = not in_star
-        elif not in_star:
-            new_string += char
-    return new_string
+    # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
+    # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
+    return re.sub('\*[^\*]*?(\*|$)','',string)
 
 def remove_tts_from_history(name1, name2):
     for i, entry in enumerate(shared.history['internal']):
@@ -166,4 +161,4 @@ def ui():
     autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
     voice.change(lambda x: params.update({"speaker": x}), voice, None)
     v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
-    v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
+    v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)