Explorar el Código

SD Api Pics extension, v.1.1 (#596)

Φφ hace 2 años
padre
commit
ffd102e5c0

+ 78 - 0
extensions/sd_api_pictures/README.MD

@@ -0,0 +1,78 @@
+## Description:
+TL;DR: Lets the bot answer you with a picture!  
+
+Stable Diffusion API pictures for TextGen, v.1.1.0  
+An extension to [oobabooga's textgen-webui](https://github.com/oobabooga/text-generation-webui) allowing you to receive pics generated by [Automatic1111's SD-WebUI API](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
+
+<details>
+<summary>Interface overview</summary>
+
+![Interface](https://raw.githubusercontent.com/Brawlence/texgen-webui-SD_api_pics/main/illust/Interface.jpg)
+
+</details>
+
+Load it in the `--chat` mode with `--extension sd_api_pictures` alongside `send_pictures` (it's not really required, but completes the picture, *pun intended*).  
+
+The image generation is triggered either:  
+- manually through the 'Force the picture response' button while in `Manual` or `Immersive/Interactive` modes OR  
+- automatically in `Immersive/Interactive` mode if the words `'send|main|message|me'` are followed by `'image|pic|picture|photo|snap|snapshot|selfie|meme'` in the user's prompt  
+- always on in Picturebook/Adventure mode (if not currently suppressed by 'Suppress the picture response')  
+
+## Prerequisites
+
+One needs an available instance of Automatic1111's webui running with an `--api` flag. Ain't tested with a notebook / cloud hosted one but should be possible.   
+To run it locally in parallel on the same machine, specify custom `--listen-port` for either Auto1111's or ooba's webUIs.  
+
+## Features:
+- API detection (press enter in the API box)  
+- VRAM management (model shuffling)  
+- Three different operation modes (manual, interactive, always-on)  
+- persistent settings via settings.json
+
+The model input is modified only in the interactive mode; other two are unaffected. The output pic description is presented differently for Picture-book / Adventure mode.  
+
+Connection check (insert the Auto1111's address and press Enter):  
+![API-check](https://raw.githubusercontent.com/Brawlence/texgen-webui-SD_api_pics/main/illust/API-check.gif) 
+
+### Persistents settings
+
+Create or modify the `settings.json` in the `text-generation-webui` root directory to override the defaults
+present in script.py, ex:
+
+```json
+{
+    "sd_api_pictures-manage_VRAM": 1,
+    "sd_api_pictures-save_img": 1,
+    "sd_api_pictures-prompt_prefix": "(Masterpiece:1.1), detailed, intricate, colorful, (solo:1.1)",
+    "sd_api_pictures-sampler_name": "DPM++ 2M Karras"
+}
+```
+
+will automatically set the `Manage VRAM` & `Keep original images` checkboxes and change the texts in `Prompt Prefix` and `Sampler name` on load.
+
+---
+
+## Demonstrations:
+
+Those are examples of the version 1.0.0, but the core functionality is still the same
+
+<details>
+<summary>Conversation 1</summary>
+
+![EXA1](https://user-images.githubusercontent.com/42910943/224866564-939a3bcb-e7cf-4ac0-a33f-b3047b55054d.jpg)
+![EXA2](https://user-images.githubusercontent.com/42910943/224866566-38394054-1320-45cf-9515-afa76d9d7745.jpg)
+![EXA3](https://user-images.githubusercontent.com/42910943/224866568-10ea47b7-0bac-4269-9ec9-22c387a13b59.jpg)
+![EXA4](https://user-images.githubusercontent.com/42910943/224866569-326121ad-1ea1-4874-9f6b-4bca7930a263.jpg)
+
+
+</details>
+
+<details>
+<summary>Conversation 2</summary>
+
+![Hist1](https://user-images.githubusercontent.com/42910943/224865517-c6966b58-bc4d-4353-aab9-6eb97778d7bf.jpg)
+![Hist2](https://user-images.githubusercontent.com/42910943/224865527-b2fe7c2e-0da5-4c2e-b705-42e233b07084.jpg)
+![Hist3](https://user-images.githubusercontent.com/42910943/224865535-a38d94e7-8975-4a46-a655-1ae1de41f85d.jpg)
+
+</details>
+

+ 181 - 76
extensions/sd_api_pictures/script.py

@@ -1,34 +1,78 @@
 import base64
 import io
 import re
+import time
+from datetime import date
 from pathlib import Path
 
 import gradio as gr
+import modules.shared as shared
 import requests
 import torch
+from modules.models import reload_model, unload_model
 from PIL import Image
 
-from modules import chat, shared
-
 torch._C._jit_set_profiling_mode(False)
 
 # parameters which can be customized in settings.json of webui
 params = {
-    'enable_SD_api': False,
     'address': 'http://127.0.0.1:7860',
+    'mode': 0,  # modes of operation: 0 (Manual only), 1 (Immersive/Interactive - looks for words to trigger), 2 (Picturebook Adventure - Always on)
+    'manage_VRAM': False,
     'save_img': False,
-    'SD_model': 'NeverEndingDream',  # not really used right now
-    'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
+    'SD_model': 'NeverEndingDream',  # not used right now
+    'prompt_prefix': '(Masterpiece:1.1), detailed, intricate, colorful',
     'negative_prompt': '(worst quality, low quality:1.3)',
-    'side_length': 512,
-    'restore_faces': False
+    'width': 512,
+    'height': 512,
+    'restore_faces': False,
+    'seed': -1,
+    'sampler_name': 'DDIM',
+    'steps': 32,
+    'cfg_scale': 7
 }
 
+
+def give_VRAM_priority(actor):
+    global shared, params
+
+    if actor == 'SD':
+        unload_model()
+        print("Requesting Auto1111 to re-load last checkpoint used...")
+        response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
+        response.raise_for_status()
+
+    elif actor == 'LLM':
+        print("Requesting Auto1111 to vacate VRAM...")
+        response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
+        response.raise_for_status()
+        reload_model()
+
+    elif actor == 'set':
+        print("VRAM mangement activated -- requesting Auto1111 to vacate VRAM...")
+        response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
+        response.raise_for_status()
+
+    elif actor == 'reset':
+        print("VRAM mangement deactivated -- requesting Auto1111 to reload checkpoint")
+        response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
+        response.raise_for_status()
+
+    else:
+        raise RuntimeError(f'Managing VRAM: "{actor}" is not a known state!')
+
+    response.raise_for_status()
+    del response
+
+
+if params['manage_VRAM']:
+    give_VRAM_priority('set')
+
+samplers = ['DDIM', 'DPM++ 2M Karras']  # TODO: get the availible samplers with http://{address}}/sdapi/v1/samplers
 SD_models = ['NeverEndingDream']  # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
 
 streaming_state = shared.args.no_stream  # remember if chat streaming was enabled
 picture_response = False  # specifies if the next model response should appear as a picture
-pic_id = 0
 
 
 def remove_surrounded_chars(string):
@@ -36,7 +80,13 @@ def remove_surrounded_chars(string):
     # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
     return re.sub('\*[^\*]*?(\*|$)', '', string)
 
-# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
+
+def triggers_are_in(string):
+    string = remove_surrounded_chars(string)
+    # regex searches for send|main|message|me (at the end of the word) followed by
+    # a whole word of image|pic|picture|photo|snap|snapshot|selfie|meme(s),
+    # (?aims) are regex parser flags
+    return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string))
 
 
 def input_modifier(string):
@@ -44,75 +94,80 @@ def input_modifier(string):
     This function is applied to your text inputs before
     they are fed into the model.
     """
-    global params, picture_response
-    if not params['enable_SD_api']:
-        return string
 
-    commands = ['send', 'mail', 'me']
-    mediums = ['image', 'pic', 'picture', 'photo']
-    subjects = ['yourself', 'own']
-    lowstr = string.lower()
+    global params
 
-    # TODO: refactor out to separate handler and also replace detection with a regexp
-    if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums):  # trigger the generation if a command signature and a medium signature is found
-        picture_response = True
-        shared.args.no_stream = True                                                               # Disable streaming cause otherwise the SD-generated picture would return as a dud
-        shared.processing_message = "*Is sending a picture...*"
-        string = "Please provide a detailed description of your surroundings, how you look and the situation you're in and what you are doing right now"
-        if any(target in lowstr for target in subjects):                                           # the focus of the image should be on the sending character
-            string = "Please provide a detailed and vivid description of how you look and what you are wearing"
+    if not params['mode'] == 1:  # if not in immersive/interactive mode, do nothing
+        return string
+
+    if triggers_are_in(string):  # if we're in it, check for trigger words
+        toggle_generation(True)
+        string = string.lower()
+        if "of" in string:
+            subject = string.split('of', 1)[1]  # subdivide the string once by the first 'of' instance and get what's coming after it
+            string = "Please provide a detailed and vivid description of " + subject
+        else:
+            string = "Please provide a detailed description of your appearance, your surroundings and what you are doing right now"
 
     return string
 
 # Get and save the Stable Diffusion-generated picture
-
-
 def get_SD_pictures(description):
 
-    global params, pic_id
+    global params
+
+    if params['manage_VRAM']:
+        give_VRAM_priority('SD')
 
     payload = {
         "prompt": params['prompt_prefix'] + description,
-        "seed": -1,
-        "sampler_name": "DPM++ 2M Karras",
-        "steps": 32,
-        "cfg_scale": 7,
-        "width": params['side_length'],
-        "height": params['side_length'],
+        "seed": params['seed'],
+        "sampler_name": params['sampler_name'],
+        "steps": params['steps'],
+        "cfg_scale": params['cfg_scale'],
+        "width": params['width'],
+        "height": params['height'],
         "restore_faces": params['restore_faces'],
         "negative_prompt": params['negative_prompt']
     }
 
+    print(f'Prompting the image generator via the API on {params["address"]}...')
     response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
+    response.raise_for_status()
     r = response.json()
 
     visible_result = ""
     for img_str in r['images']:
         image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
         if params['save_img']:
-            output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
+            variadic = f'{date.today().strftime("%Y_%m_%d")}/{shared.character}_{int(time.time())}'
+            output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png')
+            output_file.parent.mkdir(parents=True, exist_ok=True)
             image.save(output_file.as_posix())
-            pic_id += 1
-        # lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
-        image.thumbnail((300, 300))
-        buffered = io.BytesIO()
-        image.save(buffered, format="JPEG")
-        buffered.seek(0)
-        image_bytes = buffered.getvalue()
-        img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
-        visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
+            visible_result = visible_result + f'<img src="/file/extensions/sd_api_pictures/outputs/{variadic}.png" alt="{description}" style="max-width: unset; max-height: unset;">\n'
+        else:
+            # lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
+            image.thumbnail((300, 300))
+            buffered = io.BytesIO()
+            image.save(buffered, format="JPEG")
+            buffered.seek(0)
+            image_bytes = buffered.getvalue()
+            img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
+            visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
+
+    if params['manage_VRAM']:
+        give_VRAM_priority('LLM')
 
     return visible_result
 
 # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
 # and replace it with 'text' for the purposes of logging?
-
-
 def output_modifier(string):
     """
     This function is applied to the model outputs.
     """
-    global pic_id, picture_response, streaming_state
+
+    global picture_response, params
 
     if not picture_response:
         return string
@@ -125,17 +180,18 @@ def output_modifier(string):
 
     if string == '':
         string = 'no viable description in reply, try regenerating'
+        return string
 
-    # I can't for the love of all that's holy get the name from shared.gradio['name1'], so for now it will be like this
-    text = f'*Description: "{string}"*'
-
-    image = get_SD_pictures(string)
+    text = ""
+    if (params['mode'] < 2):
+        toggle_generation(False)
+        text = f'*Sends a picture which portrays: “{string}”*'
+    else:
+        text = string
 
-    picture_response = False
+    string = get_SD_pictures(string) + "\n" + text
 
-    shared.processing_message = "*Is typing...*"
-    shared.args.no_stream = streaming_state
-    return image + "\n" + text
+    return string
 
 
 def bot_prefix_modifier(string):
@@ -148,42 +204,91 @@ def bot_prefix_modifier(string):
     return string
 
 
-def force_pic():
-    global picture_response
-    picture_response = True
+def toggle_generation(*args):
+    global picture_response, shared, streaming_state
+
+    if not args:
+        picture_response = not picture_response
+    else:
+        picture_response = args[0]
+
+    shared.args.no_stream = True if picture_response else streaming_state  # Disable streaming cause otherwise the SD-generated picture would return as a dud
+    shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*"
+
+
+def filter_address(address):
+    address = address.strip()
+    # address = re.sub('http(s)?:\/\/|\/$','',address) # remove starting http:// OR https:// OR trailing slash
+    address = re.sub('\/$', '', address)  # remove trailing /s
+    if not address.startswith('http'):
+        address = 'http://' + address
+    return address
+
+
+def SD_api_address_update(address):
+
+    global params
+
+    msg = "✔️ SD API is found on:"
+    address = filter_address(address)
+    params.update({"address": address})
+    try:
+        response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
+        response.raise_for_status()
+        # r = response.json()
+    except:
+        msg = "❌ No SD API endpoint on:"
+
+    return gr.Textbox.update(label=msg)
 
 
 def ui():
 
     # Gradio elements
-    with gr.Accordion("Stable Diffusion api integration", open=True):
+    # gr.Markdown('### Stable Diffusion API Pictures') # Currently the name of extension is shown as the title
+    with gr.Accordion("Parameters", open=True):
         with gr.Row():
-            with gr.Column():
-                enable = gr.Checkbox(value=params['enable_SD_api'], label='Activate SD Api integration')
-                save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
-            with gr.Column():
-                address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
+            address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Auto1111\'s WebUI address')
+            mode = gr.Dropdown(["Manual", "Immersive/Interactive", "Picturebook/Adventure"], value="Manual", label="Mode of operation", type="index")
+            with gr.Column(scale=1, min_width=300):
+                manage_VRAM = gr.Checkbox(value=params['manage_VRAM'], label='Manage VRAM')
+                save_img = gr.Checkbox(value=params['save_img'], label='Keep original images and use them in chat')
 
-        with gr.Row():
-            force_btn = gr.Button("Force the next response to be a picture")
-            generate_now_btn = gr.Button("Generate an image response to the input")
+            force_pic = gr.Button("Force the picture response")
+            suppr_pic = gr.Button("Suppress the picture response")
 
         with gr.Accordion("Generation parameters", open=False):
             prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
             with gr.Row():
-                negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
-                dimensions = gr.Slider(256, 702, value=params['side_length'], step=64, label='Image dimensions')
-                # model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
+                with gr.Column():
+                    negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
+                    sampler_name = gr.Textbox(placeholder=params['sampler_name'], value=params['sampler_name'], label='Sampler')
+                with gr.Column():
+                    width = gr.Slider(256, 768, value=params['width'], step=64, label='Width')
+                    height = gr.Slider(256, 768, value=params['height'], step=64, label='Height')
+            with gr.Row():
+                steps = gr.Number(label="Steps:", value=params['steps'])
+                seed = gr.Number(label="Seed:", value=params['seed'])
+                cfg_scale = gr.Number(label="CFG Scale:", value=params['cfg_scale'])
 
     # Event functions to update the parameters in the backend
-    enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
+    address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
+    mode.select(lambda x: params.update({"mode": x}), mode, None)
+    mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None)
+    manage_VRAM.change(lambda x: params.update({"manage_VRAM": x}), manage_VRAM, None)
+    manage_VRAM.change(lambda x: give_VRAM_priority('set' if x else 'reset'), inputs=manage_VRAM, outputs=None)
     save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
-    address.change(lambda x: params.update({"address": x}), address, None)
+
+    address.submit(fn=SD_api_address_update, inputs=address, outputs=address)
     prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None)
     negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None)
-    dimensions.change(lambda x: params.update({"side_length": x}), dimensions, None)
-    # model.change(lambda x: params.update({"SD_model": x}), model, None)
+    width.change(lambda x: params.update({"width": x}), width, None)
+    height.change(lambda x: params.update({"height": x}), height, None)
+
+    sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None)
+    steps.change(lambda x: params.update({"steps": x}), steps, None)
+    seed.change(lambda x: params.update({"seed": x}), seed, None)
+    cfg_scale.change(lambda x: params.update({"cfg_scale": x}), cfg_scale, None)
 
-    force_btn.click(force_pic)
-    generate_now_btn.click(force_pic)
-    generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
+    force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None)
+    suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None)

+ 1 - 8
modules/LoRA.py

@@ -4,14 +4,7 @@ import torch
 from peft import PeftModel
 
 import modules.shared as shared
-from modules.models import load_model
-from modules.text_generation import clear_torch_cache
-
-
-def reload_model():
-    shared.model = shared.tokenizer = None
-    clear_torch_cache()
-    shared.model, shared.tokenizer = load_model(shared.model_name)
+from modules.models import reload_model
 
 
 def add_lora_to_model(lora_name):

+ 19 - 2
modules/models.py

@@ -1,3 +1,4 @@
+import gc
 import json
 import os
 import re
@@ -16,11 +17,10 @@ import modules.shared as shared
 
 transformers.logging.set_verbosity_error()
 
-local_rank = None
-
 if shared.args.flexgen:
     from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
 
+local_rank = None
 if shared.args.deepspeed:
     import deepspeed
     from transformers.deepspeed import (HfDeepSpeedConfig,
@@ -182,6 +182,23 @@ def load_model(model_name):
     return model, tokenizer
 
 
+def clear_torch_cache():
+    gc.collect()
+    if not shared.args.cpu:
+        torch.cuda.empty_cache()
+
+
+def unload_model():
+    shared.model = shared.tokenizer = None
+    clear_torch_cache()
+
+
+def reload_model():
+    shared.model = shared.tokenizer = None
+    clear_torch_cache()
+    shared.model, shared.tokenizer = load_model(shared.model_name)
+
+
 def load_soft_prompt(name):
     if name == 'None':
         shared.soft_prompt = False

+ 1 - 8
modules/text_generation.py

@@ -1,4 +1,3 @@
-import gc
 import re
 import time
 import traceback
@@ -12,7 +11,7 @@ from modules.callbacks import (Iteratorize, Stream,
                                _SentinelTokenStoppingCriteria)
 from modules.extensions import apply_extensions
 from modules.html_generator import generate_4chan_html, generate_basic_html
-from modules.models import local_rank
+from modules.models import clear_torch_cache, local_rank
 
 
 def get_max_prompt_length(tokens):
@@ -101,12 +100,6 @@ def formatted_outputs(reply, model_name):
         return reply
 
 
-def clear_torch_cache():
-    gc.collect()
-    if not shared.args.cpu:
-        torch.cuda.empty_cache()
-
-
 def set_manual_seed(seed):
     if seed != -1:
         torch.manual_seed(seed)

+ 2 - 8
server.py

@@ -18,9 +18,8 @@ import modules.extensions as extensions_module
 from modules import api, chat, shared, training, ui
 from modules.html_generator import chat_html_wrapper
 from modules.LoRA import add_lora_to_model
-from modules.models import load_model, load_soft_prompt
-from modules.text_generation import (clear_torch_cache, generate_reply,
-                                     stop_everything_event)
+from modules.models import load_model, load_soft_prompt, unload_model
+from modules.text_generation import generate_reply, stop_everything_event
 
 # Loading custom settings
 settings_file = None
@@ -79,11 +78,6 @@ def get_available_loras():
     return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
 
 
-def unload_model():
-    shared.model = shared.tokenizer = None
-    clear_torch_cache()
-
-
 def load_model_wrapper(selected_model):
     if selected_model != shared.model_name:
         shared.model_name = selected_model