Преглед изворни кода

Add support for extensions

This is experimental.
oobabooga пре 3 година
родитељ
комит
6b5dcd46c5
3 измењених фајлова са 101 додато и 53 уклоњено
  1. 1 0
      README.md
  2. 14 0
      extensions/example/script.py
  3. 86 53
      server.py

+ 1 - 0
README.md

@@ -133,6 +133,7 @@ Optionally, you can use the following command-line flags:
 | `--cpu-memory CPU_MEMORY`    | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
 | `--no-stream`   | Don't stream the text output in real time. This improves the text generation performance.|
 | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
+| `--extensions EXTENSIONS` | The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this". |
 | `--listen`   | Make the web UI reachable from your local network.|
 | `--share`   | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
 | `--verbose`   | Print the prompts to the terminal. |

+ 14 - 0
extensions/example/script.py

@@ -0,0 +1,14 @@
+def input_modifier(string):
+    """
+    This function is applied to your text inputs before
+    they are fed into the model.
+    """ 
+
+    return string.replace(' ', '#')
+
+def output_modifier(string):
+    """
+    This function is applied to the model outputs.
+    """
+
+    return string.replace(' ', '_')

+ 86 - 53
server.py

@@ -5,6 +5,7 @@ import glob
 import torch
 import argparse
 import json
+import sys
 from sys import exit
 from pathlib import Path
 import gradio as gr
@@ -32,6 +33,7 @@ parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to
 parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
 parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
 parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
+parser.add_argument('--extensions', type=str, help='The list of extensions to load. If you want to load more than one extension, write the names separated by commas and between quotation marks, "like,this".')
 parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
 parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
 parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
@@ -165,6 +167,9 @@ def formatted_outputs(reply, model_name):
 def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None, stopping_string=None):
     global model, tokenizer, model_name, loaded_preset, preset
 
+    original_question = question
+    if not (args.chat or args.cai_chat):
+        question = apply_extensions(question, "input")
     if args.verbose:
         print(f"\n\n{question}\n--------------------\n")
 
@@ -203,20 +208,36 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
         reply = decode(output[0])
         t1 = time.time()
         print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0):.2f} it/s)")
+        if not (args.chat or args.cai_chat):
+            reply = original_question + apply_extensions(reply[len(question):], "output")
         yield formatted_outputs(reply, model_name)
 
     # Generate the reply 1 token at a time
     else:
-        yield formatted_outputs(question, model_name)
+        yield formatted_outputs(original_question, model_name)
         preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8')
         for i in tqdm(range(tokens//8+1)):
             output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
             reply = decode(output[0])
+            if not (args.chat or args.cai_chat):
+                reply = original_question + apply_extensions(reply[len(question):], "output")
             yield formatted_outputs(reply, model_name)
             input_ids = output
             if output[0][-1] == n:
                 break
 
+def apply_extensions(text, typ):
+    global available_extensions, extension_state
+    for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
+        if extension_state[ext][0] == True:
+            ext_string = f"extensions.{ext}.script"
+            exec(f"import {ext_string}")
+            if typ == "input":
+                text = eval(f"{ext_string}.input_modifier(text)")
+            else:
+                text = eval(f"{ext_string}.output_modifier(text)")
+    return text
+
 def get_available_models():
     return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
 
@@ -226,9 +247,19 @@ def get_available_presets():
 def get_available_characters():
     return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
 
+def get_available_extensions():
+    return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
+
 available_models = get_available_models()
 available_presets = get_available_presets()
 available_characters = get_available_characters()
+available_extensions = get_available_extensions()
+extension_state = {}
+if args.extensions is not None:
+    for i,ext in enumerate(args.extensions.split(',')):
+        if ext in available_extensions:
+            print(f'The extension "{ext}" is enabled.')
+            extension_state[ext] = [True, i]
 
 # Choosing the default model
 if args.model is not None:
@@ -256,7 +287,7 @@ description = f"\n\n# Text generation lab\nGenerate text using Large Language Mo
 css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
 
 if args.chat or args.cai_chat:
-    history = []
+    history = {'internal': [], 'visible': []}
     character = None
 
     # This gets the new line characters right.
@@ -270,13 +301,13 @@ if args.chat or args.cai_chat:
         text = clean_chat_message(text)
 
         rows = [f"{context.strip()}\n"]
-        i = len(history)-1
+        i = len(history['internal'])-1
         count = 0
         while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens:
-            rows.insert(1, f"{name2}: {history[i][1].strip()}\n")
+            rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
             count += 1
-            if not (history[i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
-                rows.insert(1, f"{name1}: {history[i][0].strip()}\n")
+            if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
+                rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n")
                 count += 1
             i -= 1
             if history_size != 0 and count >= history_size:
@@ -291,18 +322,12 @@ if args.chat or args.cai_chat:
         question = ''.join(rows)
         return question
 
-    def remove_example_dialogue_from_history(history):
-        _history = copy.deepcopy(history)
-        for i in range(len(_history)):
-            if '<|BEGIN-VISIBLE-CHAT|>' in _history[i][0]:
-                _history[i][0] = _history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '')
-                _history = _history[i:]
-                break
-        return _history
-
     def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
+        original_text = text
+        text = apply_extensions(text, "input")
         question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
-        history.append(['', ''])
+        history['internal'].append(['', ''])
+        history['visible'].append(['', ''])
         eos_token = '\n' if check else None
         for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
             next_character_found = False
@@ -312,7 +337,6 @@ if args.chat or args.cai_chat:
             idx = idx[len(previous_idx)-1]
 
             reply = reply[idx + len(f"\n{name2}:"):]
-
             if check:
                 reply = reply.split('\n')[0].strip()
             else:
@@ -322,7 +346,8 @@ if args.chat or args.cai_chat:
                     next_character_found = True
                 reply = clean_chat_message(reply)
 
-            history[-1] = [text, reply]
+            history['internal'][-1] = [text, reply]
+            history['visible'][-1] = [original_text, apply_extensions(reply, "output")]
             if next_character_found:
                 break
 
@@ -335,16 +360,17 @@ if args.chat or args.cai_chat:
                     next_character_substring_found = True
 
             if not next_character_substring_found:
-                yield remove_example_dialogue_from_history(history)
+                yield history['visible']
 
-        yield remove_example_dialogue_from_history(history)
+        yield history['visible']
 
     def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
-        for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
-            yield generate_chat_html(history, name1, name2, character)
+        for _history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
+            yield generate_chat_html(_history, name1, name2, character)
 
     def regenerate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
-        last = history.pop()
+        last = history['internal'].pop()
+        history['visible'].pop()
         text = last[0]
         if args.cai_chat:
             for i in cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
@@ -354,12 +380,15 @@ if args.chat or args.cai_chat:
                 yield i
 
     def remove_last_message(name1, name2):
-        last = history.pop()
-        _history = remove_example_dialogue_from_history(history)
+        if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
+            last = history['visible'].pop()
+            history['internal'].pop()
+        else:
+            last = ['', '']
         if args.cai_chat:
-            return generate_chat_html(_history, name1, name2, character), last[0]
+            return generate_chat_html(history['visible'], name1, name2, character), last[0]
         else:
-            return _history, last[0]
+            return history['visible'], last[0]
 
     def clear_html():
         return generate_chat_html([], "", "", character)
@@ -367,28 +396,31 @@ if args.chat or args.cai_chat:
     def clear_chat_log(_character, name1, name2):
         global history
         if _character != 'None':
-            load_character(_character, name1, name2)
+            for i in range(len(history['internal'])):
+                if '<|BEGIN-VISIBLE-CHAT|>' in history['internal'][i][0]:
+                    history['visible'] = [['', history['internal'][i][1]]]
+                    history['internal'] = history['internal'][:i+1]
+                    break
         else:
-            history = []
-        _history = remove_example_dialogue_from_history(history)
+            history['internal'] = []
+            history['visible'] = []
         if args.cai_chat:
-            return generate_chat_html(_history, name1, name2, character)
+            return generate_chat_html(history['visible'], name1, name2, character)
         else:
-            return _history
+            return history['visible'] 
 
     def redraw_html(name1, name2):
         global history
-        _history = remove_example_dialogue_from_history(history)
-        return generate_chat_html(_history, name1, name2, character)
+        return generate_chat_html(history['visible'], name1, name2, character)
 
     def tokenize_dialogue(dialogue, name1, name2):
-        history = []
+        _history = []
 
         dialogue = re.sub('<START>', '', dialogue)
         dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
         idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
         if len(idx) == 0:
-            return history
+            return _history
 
         messages = []
         for i in range(len(idx)-1):
@@ -402,16 +434,16 @@ if args.chat or args.cai_chat:
             elif i.startswith(f'{name2}:'):
                 entry[1] = i[len(f'{name2}:'):].strip()
                 if not (len(entry[0]) == 0 and len(entry[1]) == 0):
-                    history.append(entry)
+                    _history.append(entry)
                 entry = ['', '']
 
-        return history
+        return _history
 
     def save_history():
         if not Path('logs').exists():
             Path('logs').mkdir()
         with open(Path('logs/conversation.json'), 'w') as f:
-            f.write(json.dumps({'data': history}, indent=2))
+            f.write(json.dumps({'data': history['internal']}, indent=2))
         return Path('logs/conversation.json')
 
     def upload_history(file, name1, name2):
@@ -420,21 +452,22 @@ if args.chat or args.cai_chat:
         try:
             j = json.loads(file)
             if 'data' in j:
-                history = j['data']
+                history['internal'] = j['data']
             # Compatibility with Pygmalion AI's official web UI
             elif 'chat' in j:
-                history = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
+                history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
                 if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
-                    history = [['<|BEGIN-VISIBLE-CHAT|>', history[0]]] + [[history[i], history[i+1]] for i in range(1, len(history)-1, 2)]
+                    history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', history['internal'][0]]] + [[history['internal'][i], history['internal'][i+1]] for i in range(1, len(history['internal'])-1, 2)]
                 else:
-                    history = [[history[i], history[i+1]] for i in range(0, len(history)-1, 2)]
+                    history['internal'] = [[history['internal'][i], history['internal'][i+1]] for i in range(0, len(history['internal'])-1, 2)]
         except:
-            history = tokenize_dialogue(file, name1, name2)
+            history['internal'] = tokenize_dialogue(file, name1, name2)
 
     def load_character(_character, name1, name2):
         global history, character
         context = ""
-        history = []
+        history['internal'] = []
+        history['visible'] = []
         if _character != 'None':
             character = _character
             with open(Path(f'characters/{_character}.json'), 'r') as f:
@@ -446,24 +479,24 @@ if args.chat or args.cai_chat:
                 context += f"Scenario: {data['world_scenario']}\n"
             context = f"{context.strip()}\n<START>\n"
             if 'example_dialogue' in data and data['example_dialogue'] != '':
-                history = tokenize_dialogue(data['example_dialogue'], name1, name2)
+                history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2)
             if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
-                history += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
+                history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
+                history['visible'] += [['', data['char_greeting']]]
             else:
-                history += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
+                history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
+                history['visible'] += [['', "Hello there!"]]
         else:
             character = None
             context = settings['context_pygmalion']
             name2 = settings['name2_pygmalion']
 
-        _history = remove_example_dialogue_from_history(history)
         if args.cai_chat:
-            return name2, context, generate_chat_html(_history, name1, name2, character)
+            return name2, context, generate_chat_html(history['visible'], name1, name2, character)
         else:
-            return name2, context, _history
+            return name2, context, history['visible']
 
     def upload_character(file, name1, name2):
-        global history
         file = file.decode('utf-8')
         data = json.loads(file)
         outfile_name = data["char_name"]
@@ -543,7 +576,7 @@ if args.chat or args.cai_chat:
         if args.cai_chat:
             upload.upload(redraw_html, [name1, name2], [display1])
         else:
-            upload.upload(lambda : remove_example_dialogue_from_history(history), [], [display1])
+            upload.upload(lambda : history['visible'], [], [display1])
 
 elif args.notebook:
     with gr.Blocks(css=css, analytics_enabled=False) as interface: