Kaynağa Gözat

Move the example dialogue to the chat history, and keep it hidden.

This greatly improves the performance of text generation, as
histories can be quite long. It also makes more sense to implement
it this way.
oobabooga 3 yıl önce
ebeveyn
işleme
990ee54ddd
3 değiştirilmiş dosya ile 39 ekleme ve 8 silme
  1. 2 2
      README.md
  2. 9 1
      html_generator.py
  3. 28 5
      server.py

+ 2 - 2
README.md

@@ -139,9 +139,9 @@ Optionally, you can use the following command-line flags:
 | `--load-in-8bit`  | Load the model with 8-bit precision.|
 | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
 | `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
-| `--disk-cache-dir DISK_CACHE_DIR` | Directory which you want the disk cache to load to. |
+| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
 | `--gpu-memory GPU_MEMORY` | Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number. |
-| `--cpu-memory CPU_MEMORY`    | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. |
+| `--cpu-memory CPU_MEMORY`    | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99 GiB.|
 | `--no-stream`   | Don't stream the text output in real time. This slightly improves the text generation performance.|
 | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
 | `--listen`   | Make the web UI reachable from your local network.|

+ 9 - 1
html_generator.py

@@ -6,6 +6,7 @@ This is a library for formatting GPT-4chan and chat outputs as nice HTML.
 
 import re
 from pathlib import Path
+import copy
 
 def generate_basic_html(s):
     s = '\n'.join([f'<p style="margin-bottom: 20px">{line}</p>' for line in s.split('\n')])
@@ -160,7 +161,7 @@ def generate_4chan_html(f):
 
     return output
 
-def generate_chat_html(history, name1, name2, character):
+def generate_chat_html(_history, name1, name2, character):
     css = """
     .chat {
       margin-left: auto;
@@ -233,6 +234,13 @@ def generate_chat_html(history, name1, name2, character):
             img = f'<img src="file/{i}">'
             break
 
+    history = copy.deepcopy(_history)
+    for i in range(len(history)):
+        if '<|BEGIN-VISIBLE-CHAT|>' in history[i][0]:
+            history[i][0] = history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '')
+            history = history[i:]
+            break
+
     for i,_row in enumerate(history[::-1]):
         row = _row.copy()
         row[0] = re.sub(r"[\\]*\*", r"*", row[0])

+ 28 - 5
server.py

@@ -26,9 +26,9 @@ parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate
 parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
 parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
 parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
-parser.add_argument('--disk-cache-dir', type=str, help='Directory which you want the disk cache to load to.')
+parser.add_argument('--disk-cache-dir', type=str, help='Directory to save the disk cache to. Defaults to "cache/".')
 parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.')
-parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number.')
+parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99 GiB.')
 parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This slightly improves the text generation performance.')
 parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
 parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
@@ -262,6 +262,7 @@ if args.chat or args.cai_chat:
             rows.pop(1)
 
         question = ''.join(rows)
+        question = question.replace('<|BEGIN-VISIBLE-CHAT|>', '')
         return question
 
     def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
@@ -336,6 +337,26 @@ if args.chat or args.cai_chat:
         global history
         history = json.loads(file.decode('utf-8'))['data']
 
+    def tokenize_example_dialogue(dialogue, name1, name2):
+        dialogue = re.sub('<START>', '', dialogue)
+        dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
+
+        idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
+        messages = []
+        for i in range(len(idx)-1):
+            messages.append(dialogue[idx[i]:idx[i+1]].strip())
+        history = []
+        entry = ['', '']
+        for i in messages:
+            if i.startswith(f'{name1}:'):
+                entry[0] = i[len(f'{name1}:'):].strip()
+            elif i.startswith(f'{name2}:'):
+                entry[1] = i[len(f'{name2}:'):].strip()
+                if not (len(entry[0]) == 0 and len(entry[1]) == 0):
+                    history.append(entry)
+                entry = ['', '']
+        return history
+
     def load_character(_character, name1, name2):
         global history, character
         context = ""
@@ -351,9 +372,11 @@ if args.chat or args.cai_chat:
                 context += f"Scenario: {data['world_scenario']}\n"
             context = f"{context.strip()}\n<START>\n"
             if 'example_dialogue' in data and data['example_dialogue'] != '':
-                context += f"{data['example_dialogue'].strip()}\n"
-            if 'char_greeting' in data:
-                history = [['', data['char_greeting']]]
+                history = tokenize_example_dialogue(data['example_dialogue'], name1, name2)
+            if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
+                history += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
+            else:
+                history += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
         else:
             character = None
             context = settings['context_pygmalion']