Alexander Hristov Hristov 2 лет назад
Родитель
Сommit
63c5a139a2
10 измененных файлов с 159 добавлено и 82 удалено
  1. 1 0
      .github/FUNDING.yml
  2. 11 12
      README.md
  3. 85 11
      extensions/silero_tts/script.py
  4. 4 4
      modules/RWKV.py
  5. 25 31
      modules/chat.py
  6. 3 4
      modules/quantized_LLaMA.py
  7. 2 2
      modules/shared.py
  8. 20 10
      modules/text_generation.py
  9. 4 4
      requirements.txt
  10. 4 4
      server.py

+ 1 - 0
.github/FUNDING.yml

@@ -0,0 +1 @@
+ko_fi: oobabooga

+ 11 - 12
README.md

@@ -27,7 +27,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
 * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
 * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
 * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
 * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
 * Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
 * Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
-* [Supports the LLaMA model](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
+* [Supports the LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
 * [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
 * [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
 * Supports softprompts.
 * Supports softprompts.
 * [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
 * [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
@@ -60,11 +60,13 @@ pip3 install torch torchvision torchaudio --extra-index-url https://download.pyt
 conda install pytorch torchvision torchaudio git -c pytorch
 conda install pytorch torchvision torchaudio git -c pytorch
 ```
 ```
 
 
+See also: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
+
 ## Installation option 2: one-click installers
 ## Installation option 2: one-click installers
 
 
-[oobabooga-windows.zip](https://github.com/oobabooga/text-generation-webui/releases/download/installers/oobabooga-windows.zip)
+[oobabooga-windows.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-windows.zip)
 
 
-[oobabooga-linux.zip](https://github.com/oobabooga/text-generation-webui/releases/download/installers/oobabooga-linux.zip)
+[oobabooga-linux.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-linux.zip)
 
 
 Just download the zip above, extract it, and double click on "install". The web UI and all its dependencies will be installed in the same folder.
 Just download the zip above, extract it, and double click on "install". The web UI and all its dependencies will be installed in the same folder.
 
 
@@ -139,7 +141,7 @@ Optionally, you can use the following command-line flags:
 | `--cpu`       | Use the CPU to generate text.|
 | `--cpu`       | Use the CPU to generate text.|
 | `--load-in-8bit`  | Load the model with 8-bit precision.|
 | `--load-in-8bit`  | Load the model with 8-bit precision.|
 | `--load-in-4bit`  | Load the model with 4-bit precision. Currently only works with LLaMA.|
 | `--load-in-4bit`  | Load the model with 4-bit precision. Currently only works with LLaMA.|
-| `--gptq-bits`  |  Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA. |
+| `--gptq-bits GPTQ_BITS`  |  Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA. |
 | `--bf16`  | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
 | `--bf16`  | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
 | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
 | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
 | `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
 | `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
@@ -155,12 +157,13 @@ Optionally, you can use the following command-line flags:
 | `--local_rank LOCAL_RANK`    | DeepSpeed: Optional argument for distributed setups. |
 | `--local_rank LOCAL_RANK`    | DeepSpeed: Optional argument for distributed setups. |
 |  `--rwkv-strategy RWKV_STRATEGY`         |    RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
 |  `--rwkv-strategy RWKV_STRATEGY`         |    RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
 |  `--rwkv-cuda-on`                        |   RWKV: Compile the CUDA kernel for better performance. |
 |  `--rwkv-cuda-on`                        |   RWKV: Compile the CUDA kernel for better performance. |
-| `--no-stream`   | Don't stream the text output in real time. This improves the text generation performance.|
+| `--no-stream`   | Don't stream the text output in real time. |
 | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
 | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
 |  `--extensions EXTENSIONS [EXTENSIONS ...]` |  The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
 |  `--extensions EXTENSIONS [EXTENSIONS ...]` |  The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
 | `--listen`   | Make the web UI reachable from your local network.|
 | `--listen`   | Make the web UI reachable from your local network.|
 |  `--listen-port LISTEN_PORT` | The listening port that the server will use. |
 |  `--listen-port LISTEN_PORT` | The listening port that the server will use. |
 | `--share`   | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
 | `--share`   | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
+| `--auto-launch` | Open the web UI in the default browser upon launch. |
 | `--verbose`   | Print the prompts to the terminal. |
 | `--verbose`   | Print the prompts to the terminal. |
 
 
 Out of memory errors? [Check this guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).
 Out of memory errors? [Check this guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).
@@ -179,14 +182,10 @@ Check the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/System-
 
 
 Pull requests, suggestions, and issue reports are welcome.
 Pull requests, suggestions, and issue reports are welcome.
 
 
-Before reporting a bug, make sure that you have created a conda environment and installed the dependencies exactly as in the *Installation* section above.
-
-These issues are known:
-
-* 8-bit doesn't work properly on Windows or older GPUs.
-* DeepSpeed doesn't work properly on Windows.
+Before reporting a bug, make sure that you have:
 
 
-For these two, please try commenting on an existing issue instead of creating a new one.
+1. Created a conda environment and installed the dependencies exactly as in the *Installation* section above.
+2. [Searched](https://github.com/oobabooga/text-generation-webui/issues) to see if an issue already exists for the issue you encountered.
 
 
 ## Credits
 ## Credits
 
 

+ 85 - 11
extensions/silero_tts/script.py

@@ -1,8 +1,12 @@
+import time
 from pathlib import Path
 from pathlib import Path
 
 
 import gradio as gr
 import gradio as gr
 import torch
 import torch
 
 
+import modules.chat as chat
+import modules.shared as shared
+
 torch._C._jit_set_profiling_mode(False)
 torch._C._jit_set_profiling_mode(False)
 
 
 params = {
 params = {
@@ -12,10 +16,28 @@ params = {
     'model_id': 'v3_en',
     'model_id': 'v3_en',
     'sample_rate': 48000,
     'sample_rate': 48000,
     'device': 'cpu',
     'device': 'cpu',
+    'show_text': False,
+    'autoplay': True,
+    'voice_pitch': 'medium',
+    'voice_speed': 'medium',
 }
 }
+
 current_params = params.copy()
 current_params = params.copy()
 voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
 voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
-wav_idx = 0
+voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
+voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
+
+# Used for making text xml compatible, needed for voice pitch and speed control
+table = str.maketrans({
+    "<": "&lt;",
+    ">": "&gt;",
+    "&": "&amp;",
+    "'": "&apos;",
+    '"': "&quot;",
+})
+
+def xmlesc(txt):
+    return txt.translate(table)
 
 
 def load_model():
 def load_model():
     model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
     model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
@@ -33,12 +55,32 @@ def remove_surrounded_chars(string):
             new_string += char
             new_string += char
     return new_string
     return new_string
 
 
+def remove_tts_from_history(name1, name2):
+    for i, entry in enumerate(shared.history['internal']):
+        shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
+    return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+
+def toggle_text_in_history(name1, name2):
+    for i, entry in enumerate(shared.history['visible']):
+        visible_reply = entry[1]
+        if visible_reply.startswith('<audio'):
+            if params['show_text']:
+                reply = shared.history['internal'][i][1]
+                shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
+            else:
+                shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
+    return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+
 def input_modifier(string):
 def input_modifier(string):
     """
     """
     This function is applied to your text inputs before
     This function is applied to your text inputs before
     they are fed into the model.
     they are fed into the model.
     """
     """
 
 
+    # Remove autoplay from the last reply
+    if (shared.args.chat or shared.args.cai_chat) and len(shared.history['internal']) > 0:
+        shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')]
+
     return string
     return string
 
 
 def output_modifier(string):
 def output_modifier(string):
@@ -46,7 +88,7 @@ def output_modifier(string):
     This function is applied to the model outputs.
     This function is applied to the model outputs.
     """
     """
 
 
-    global wav_idx, model, current_params
+    global model, current_params
 
 
     for i in params:
     for i in params:
         if params[i] != current_params[i]:
         if params[i] != current_params[i]:
@@ -57,6 +99,7 @@ def output_modifier(string):
     if params['activate'] == False:
     if params['activate'] == False:
         return string
         return string
 
 
+    original_string = string
     string = remove_surrounded_chars(string)
     string = remove_surrounded_chars(string)
     string = string.replace('"', '')
     string = string.replace('"', '')
     string = string.replace('“', '')
     string = string.replace('“', '')
@@ -64,13 +107,17 @@ def output_modifier(string):
     string = string.strip()
     string = string.strip()
 
 
     if string == '':
     if string == '':
-        string = 'empty reply, try regenerating'
-
-    output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav')
-    model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
-
-    string = f'<audio src="file/{output_file.as_posix()}" controls></audio>'
-    wav_idx += 1
+        string = '*Empty reply, try regenerating*'
+    else:
+        output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav')
+        prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
+        silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
+        model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
+
+        autoplay = 'autoplay' if params['autoplay'] else ''
+        string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
+        if params['show_text']:
+            string += f'\n\n{original_string}'
 
 
     return string
     return string
 
 
@@ -85,9 +132,36 @@ def bot_prefix_modifier(string):
 
 
 def ui():
 def ui():
     # Gradio elements
     # Gradio elements
-    activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
-    voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
+    with gr.Accordion("Silero TTS"):
+        with gr.Row():
+            activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
+            autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
+        show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
+        voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
+        with gr.Row():
+            v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
+            v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
+        with gr.Row():
+            convert = gr.Button('Permanently replace audios with the message texts')
+            convert_cancel = gr.Button('Cancel', visible=False)
+            convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
+
+    # Convert history with confirmation
+    convert_arr = [convert_confirm, convert, convert_cancel]
+    convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
+    convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
+    convert_confirm.click(remove_tts_from_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
+    convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+    convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
+
+    # Toggle message text in history
+    show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
+    show_text.change(toggle_text_in_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
+    show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
 
 
     # Event functions to update the parameters in the backend
     # Event functions to update the parameters in the backend
     activate.change(lambda x: params.update({"activate": x}), activate, None)
     activate.change(lambda x: params.update({"activate": x}), activate, None)
+    autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
     voice.change(lambda x: params.update({"speaker": x}), voice, None)
     voice.change(lambda x: params.update({"speaker": x}), voice, None)
+    v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
+    v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)

+ 4 - 4
modules/RWKV.py

@@ -25,10 +25,10 @@ class RWKVModel:
         tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
         tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
 
 
         if shared.args.rwkv_strategy is None:
         if shared.args.rwkv_strategy is None:
-            model = RWKV(model=os.path.abspath(path), strategy=f'{device} {dtype}')
+            model = RWKV(model=str(path), strategy=f'{device} {dtype}')
         else:
         else:
-            model = RWKV(model=os.path.abspath(path), strategy=shared.args.rwkv_strategy)
-        pipeline = PIPELINE(model, os.path.abspath(tokenizer_path))
+            model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
+        pipeline = PIPELINE(model, str(tokenizer_path))
 
 
         result = self()
         result = self()
         result.pipeline = pipeline
         result.pipeline = pipeline
@@ -61,7 +61,7 @@ class RWKVTokenizer:
     @classmethod
     @classmethod
     def from_pretrained(self, path):
     def from_pretrained(self, path):
         tokenizer_path = path / "20B_tokenizer.json"
         tokenizer_path = path / "20B_tokenizer.json"
-        tokenizer = Tokenizer.from_file(os.path.abspath(tokenizer_path))
+        tokenizer = Tokenizer.from_file(str(tokenizer_path))
 
 
         result = self()
         result = self()
         result.tokenizer = tokenizer
         result.tokenizer = tokenizer

+ 25 - 31
modules/chat.py

@@ -22,6 +22,12 @@ def clean_chat_message(text):
     text = text.strip()
     text = text.strip()
     return text
     return text
 
 
+def generate_chat_output(history, name1, name2, character):
+    if shared.args.cai_chat:
+        return generate_chat_html(history, name1, name2, character)
+    else:
+        return history
+
 def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False):
 def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False):
     user_input = clean_chat_message(user_input)
     user_input = clean_chat_message(user_input)
     rows = [f"{context.strip()}\n"]
     rows = [f"{context.strip()}\n"]
@@ -53,7 +59,6 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
 
 
 def extract_message_from_reply(question, reply, name1, name2, check, impersonate=False):
 def extract_message_from_reply(question, reply, name1, name2, check, impersonate=False):
     next_character_found = False
     next_character_found = False
-    substring_found = False
 
 
     asker = name1 if not impersonate else name2
     asker = name1 if not impersonate else name2
     replier = name2 if not impersonate else name1
     replier = name2 if not impersonate else name1
@@ -79,15 +84,15 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
             next_character_found = True
             next_character_found = True
         reply = clean_chat_message(reply)
         reply = clean_chat_message(reply)
 
 
-        # Detect if something like "\nYo" is generated just before
-        # "\nYou:" is completed
-        tmp = f"\n{asker}:"
-        for j in range(1, len(tmp)):
-            if reply[-j:] == tmp[:j]:
+        # If something like "\nYo" is generated just before "\nYou:"
+        # is completed, trim it
+        next_turn = f"\n{asker}:"
+        for j in range(len(next_turn)-1, 0, -1):
+            if reply[-j:] == next_turn[:j]:
                 reply = reply[:-j]
                 reply = reply[:-j]
-                substring_found = True
+                break
 
 
-    return reply, next_character_found, substring_found
+    return reply, next_character_found
 
 
 def stop_everything_event():
 def stop_everything_event():
     shared.stop_everything = True
     shared.stop_everything = True
@@ -122,7 +127,6 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
         prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
         prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
 
 
     if not regenerate:
     if not regenerate:
-        # Display user input and "*is typing...*" imediately
         yield shared.history['visible']+[[visible_text, '*Is typing...*']]
         yield shared.history['visible']+[[visible_text, '*Is typing...*']]
 
 
     # Generate
     # Generate
@@ -131,7 +135,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
         for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
         for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
 
 
             # Extracting the reply
             # Extracting the reply
-            reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check)
+            reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check)
             visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
             visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
             visible_reply = apply_extensions(visible_reply, "output")
             visible_reply = apply_extensions(visible_reply, "output")
             if shared.args.chat:
             if shared.args.chat:
@@ -148,7 +152,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
 
 
             shared.history['internal'][-1] = [text, reply]
             shared.history['internal'][-1] = [text, reply]
             shared.history['visible'][-1] = [visible_text, visible_reply]
             shared.history['visible'][-1] = [visible_text, visible_reply]
-            if not substring_found and not shared.args.no_stream:
+            if not shared.args.no_stream:
                 yield shared.history['visible']
                 yield shared.history['visible']
             if next_character_found:
             if next_character_found:
                 break
                 break
@@ -163,15 +167,12 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
 
 
     prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
     prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
 
 
-    # Display "*is typing...*" imediately
-    yield '*Is typing...*'
-
     reply = ''
     reply = ''
+    yield '*Is typing...*'
     for i in range(chat_generation_attempts):
     for i in range(chat_generation_attempts):
         for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
         for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
-            reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
-            if not substring_found:
-                yield reply
+            reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
+            yield reply
             if next_character_found:
             if next_character_found:
                 break
                 break
         yield reply
         yield reply
@@ -182,21 +183,18 @@ def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
 
 
 def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
 def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
     if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
     if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
-        if shared.args.cai_chat:
-            yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
-        else:
-            yield shared.history['visible']
+        yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
     else:
     else:
         last_visible = shared.history['visible'].pop()
         last_visible = shared.history['visible'].pop()
         last_internal = shared.history['internal'].pop()
         last_internal = shared.history['internal'].pop()
 
 
+        yield generate_chat_output(shared.history['visible']+[[last_visible[0], '*Is typing...*']], name1, name2, shared.character)
         for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
         for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
             if shared.args.cai_chat:
             if shared.args.cai_chat:
                 shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
                 shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
-                yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
             else:
             else:
                 shared.history['visible'][-1] = (last_visible[0], _history[-1][1])
                 shared.history['visible'][-1] = (last_visible[0], _history[-1][1])
-                yield shared.history['visible']
+            yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
 
 
 def remove_last_message(name1, name2):
 def remove_last_message(name1, name2):
     if len(shared.history['visible']) > 0 and not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
     if len(shared.history['visible']) > 0 and not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
@@ -204,6 +202,7 @@ def remove_last_message(name1, name2):
         shared.history['internal'].pop()
         shared.history['internal'].pop()
     else:
     else:
         last = ['', '']
         last = ['', '']
+
     if shared.args.cai_chat:
     if shared.args.cai_chat:
         return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
         return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
     else:
     else:
@@ -223,10 +222,7 @@ def replace_last_reply(text, name1, name2):
             shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
             shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
         shared.history['internal'][-1][1] = apply_extensions(text, "input")
         shared.history['internal'][-1][1] = apply_extensions(text, "input")
 
 
-    if shared.args.cai_chat:
-        return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
-    else:
-        return shared.history['visible']
+    return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
 
 
 def clear_html():
 def clear_html():
     return generate_chat_html([], "", "", shared.character)
     return generate_chat_html([], "", "", shared.character)
@@ -246,10 +242,8 @@ def clear_chat_log(name1, name2):
     else:
     else:
         shared.history['internal'] = []
         shared.history['internal'] = []
         shared.history['visible'] = []
         shared.history['visible'] = []
-    if shared.args.cai_chat:
-        return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
-    else:
-        return shared.history['visible']
+
+    return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
 
 
 def redraw_html(name1, name2):
 def redraw_html(name1, name2):
     return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
     return generate_chat_html(shared.history['visible'], name1, name2, shared.character)

+ 3 - 4
modules/quantized_LLaMA.py

@@ -1,4 +1,3 @@
-import os
 import sys
 import sys
 from pathlib import Path
 from pathlib import Path
 
 
@@ -7,7 +6,7 @@ import torch
 
 
 import modules.shared as shared
 import modules.shared as shared
 
 
-sys.path.insert(0, os.path.abspath(Path("repositories/GPTQ-for-LLaMa")))
+sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
 from llama import load_quant
 from llama import load_quant
 
 
 
 
@@ -41,9 +40,9 @@ def load_quantized_LLaMA(model_name):
         print(f"Could not find {pt_model}, exiting...")
         print(f"Could not find {pt_model}, exiting...")
         exit()
         exit()
 
 
-    model = load_quant(path_to_model, os.path.abspath(pt_path), bits)
+    model = load_quant(str(path_to_model), str(pt_path), bits)
 
 
-    # Multi-GPU setup
+    # Multiple GPUs or GPU+CPU
     if shared.args.gpu_memory:
     if shared.args.gpu_memory:
         max_memory = {}
         max_memory = {}
         for i in range(len(shared.args.gpu_memory)):
         for i in range(len(shared.args.gpu_memory)):

+ 2 - 2
modules/shared.py

@@ -85,12 +85,12 @@ parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory t
 parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
 parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
 parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".')
 parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".')
 parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.')
 parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.')
-parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
+parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
 parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
 parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
 parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
 parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
 parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
 parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
 parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
 parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
 parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
 parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
+parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
 parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
 parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
-parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch')
 args = parser.parse_args()
 args = parser.parse_args()

+ 20 - 10
modules/text_generation.py

@@ -37,9 +37,13 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
             return input_ids.cuda()
             return input_ids.cuda()
 
 
 def decode(output_ids):
 def decode(output_ids):
-    reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
-    reply = reply.replace(r'<|endoftext|>', '')
-    return reply
+    # Open Assistant relies on special tokens like <|endoftext|>
+    if re.match('oasst-*', shared.model_name.lower()):
+        return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
+    else:
+        reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
+        reply = reply.replace(r'<|endoftext|>', '')
+        return reply
 
 
 def generate_softprompt_input_tensors(input_ids):
 def generate_softprompt_input_tensors(input_ids):
     inputs_embeds = shared.model.transformer.wte(input_ids)
     inputs_embeds = shared.model.transformer.wte(input_ids)
@@ -119,7 +123,9 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
     original_input_ids = input_ids
     original_input_ids = input_ids
     output = input_ids[0]
     output = input_ids[0]
     cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
     cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
-    n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
+    eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
+    if eos_token is not None:
+        eos_token_ids.append(int(encode(eos_token)[0][-1]))
     stopping_criteria_list = transformers.StoppingCriteriaList()
     stopping_criteria_list = transformers.StoppingCriteriaList()
     if stopping_string is not None:
     if stopping_string is not None:
         # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
         # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
@@ -129,7 +135,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
     if not shared.args.flexgen:
     if not shared.args.flexgen:
         generate_params = [
         generate_params = [
             f"max_new_tokens=max_new_tokens",
             f"max_new_tokens=max_new_tokens",
-            f"eos_token_id={n}",
+            f"eos_token_id={eos_token_ids}",
             f"stopping_criteria=stopping_criteria_list",
             f"stopping_criteria=stopping_criteria_list",
             f"do_sample={do_sample}",
             f"do_sample={do_sample}",
             f"temperature={temperature}",
             f"temperature={temperature}",
@@ -149,7 +155,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
             f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
             f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
             f"do_sample={do_sample}",
             f"do_sample={do_sample}",
             f"temperature={temperature}",
             f"temperature={temperature}",
-            f"stop={n}",
+            f"stop={eos_token_ids[-1]}",
         ]
         ]
     if shared.args.deepspeed:
     if shared.args.deepspeed:
         generate_params.append("synced_gpus=True")
         generate_params.append("synced_gpus=True")
@@ -196,10 +202,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
 
 
                     if not (shared.args.chat or shared.args.cai_chat):
                     if not (shared.args.chat or shared.args.cai_chat):
                         reply = original_question + apply_extensions(reply[len(question):], "output")
                         reply = original_question + apply_extensions(reply[len(question):], "output")
-                    yield formatted_outputs(reply, shared.model_name)
 
 
-                    if output[-1] == n:
+                    if output[-1] in eos_token_ids:
                         break
                         break
+                    yield formatted_outputs(reply, shared.model_name)
+
+                yield formatted_outputs(reply, shared.model_name)
 
 
         # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
         # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
         else:
         else:
@@ -213,15 +221,17 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
 
 
                 if not (shared.args.chat or shared.args.cai_chat):
                 if not (shared.args.chat or shared.args.cai_chat):
                     reply = original_question + apply_extensions(reply[len(question):], "output")
                     reply = original_question + apply_extensions(reply[len(question):], "output")
-                yield formatted_outputs(reply, shared.model_name)
 
 
-                if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
+                if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
                     break
                     break
+                yield formatted_outputs(reply, shared.model_name)
 
 
                 input_ids = np.reshape(output, (1, output.shape[0]))
                 input_ids = np.reshape(output, (1, output.shape[0]))
                 if shared.soft_prompt:
                 if shared.soft_prompt:
                     inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
                     inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
 
 
+            yield formatted_outputs(reply, shared.model_name)
+
     finally:
     finally:
         t1 = time.time()
         t1 = time.time()
         print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
         print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")

+ 4 - 4
requirements.txt

@@ -1,12 +1,12 @@
-accelerate==0.16.0
+accelerate==0.17.0
 bitsandbytes==0.37.0
 bitsandbytes==0.37.0
 flexgen==0.1.7
 flexgen==0.1.7
 gradio==3.18.0
 gradio==3.18.0
 numpy
 numpy
 requests
 requests
-rwkv==0.1.0
-safetensors==0.2.8
+rwkv==0.3.1
+safetensors==0.3.0
 sentencepiece
 sentencepiece
 tqdm
 tqdm
 markdown
 markdown
-git+https://github.com/zphang/transformers@llama_push
+git+https://github.com/zphang/transformers.git@68d640f7c368bcaaaecfc678f11908ebbd3d6176

+ 4 - 4
server.py

@@ -269,10 +269,10 @@ if shared.args.chat or shared.args.cai_chat:
 
 
         function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
         function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
 
 
-        gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False, api_name='textgen'))
-        gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False))
-        gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False))
-        gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False))
+        gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen'))
+        gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
+        gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
+        gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
         shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
         shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
 
 
         shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
         shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)