Просмотр исходного кода

Merge branch 'oobabooga:main' into stt-extension

Elias Vincent Simon 2 лет назад
Родитель
Сommit
3b4145966d

+ 8 - 9
README.md

@@ -1,6 +1,6 @@
 # Text generation web UI
 # Text generation web UI
 
 
-A gradio web UI for running Large Language Models like GPT-J 6B, OPT, GALACTICA, GPT-Neo, and Pygmalion.
+A gradio web UI for running Large Language Models like GPT-J 6B, OPT, GALACTICA, LLaMA, and Pygmalion.
 
 
 Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
 Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
 
 
@@ -27,6 +27,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
 * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
 * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
 * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
 * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
 * Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
 * Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
+* [Supports the LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
 * [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
 * [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
 * Supports softprompts.
 * Supports softprompts.
 * [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
 * [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
@@ -53,7 +54,7 @@ The third line assumes that you have an NVIDIA GPU.
 pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
 pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
 ```
 ```
   	  
   	  
-* If you are running in CPU mode, replace the third command with this one:
+* If you are running it in CPU mode, replace the third command with this one:
 
 
 ```
 ```
 conda install pytorch torchvision torchaudio git -c pytorch
 conda install pytorch torchvision torchaudio git -c pytorch
@@ -137,6 +138,8 @@ Optionally, you can use the following command-line flags:
 | `--cai-chat`  | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
 | `--cai-chat`  | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
 | `--cpu`       | Use the CPU to generate text.|
 | `--cpu`       | Use the CPU to generate text.|
 | `--load-in-8bit`  | Load the model with 8-bit precision.|
 | `--load-in-8bit`  | Load the model with 8-bit precision.|
+| `--load-in-4bit`  | Load the model with 4-bit precision. Currently only works with LLaMA.|
+| `--gptq-bits GPTQ_BITS`  |  Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA. |
 | `--bf16`  | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
 | `--bf16`  | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
 | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
 | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
 | `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
 | `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
@@ -176,14 +179,10 @@ Check the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/System-
 
 
 Pull requests, suggestions, and issue reports are welcome.
 Pull requests, suggestions, and issue reports are welcome.
 
 
-Before reporting a bug, make sure that you have created a conda environment and installed the dependencies exactly as in the *Installation* section above.
+Before reporting a bug, make sure that you have:
 
 
-These issues are known:
-
-* 8-bit doesn't work properly on Windows or older GPUs.
-* DeepSpeed doesn't work properly on Windows.
-
-For these two, please try commenting on an existing issue instead of creating a new one.
+1. Created a conda environment and installed the dependencies exactly as in the *Installation* section above.
+2. [Searched](https://github.com/oobabooga/text-generation-webui/issues) to see if an issue already exists for the issue you encountered.
 
 
 ## Credits
 ## Credits
 
 

+ 14 - 6
download-model.py

@@ -5,7 +5,9 @@ Example:
 python download-model.py facebook/opt-1.3b
 python download-model.py facebook/opt-1.3b
 
 
 '''
 '''
+
 import argparse
 import argparse
+import base64
 import json
 import json
 import multiprocessing
 import multiprocessing
 import re
 import re
@@ -93,23 +95,28 @@ facebook/opt-1.3b
 def get_download_links_from_huggingface(model, branch):
 def get_download_links_from_huggingface(model, branch):
     base = "https://huggingface.co"
     base = "https://huggingface.co"
     page = f"/api/models/{model}/tree/{branch}?cursor="
     page = f"/api/models/{model}/tree/{branch}?cursor="
+    cursor = b""
 
 
     links = []
     links = []
     classifications = []
     classifications = []
     has_pytorch = False
     has_pytorch = False
     has_safetensors = False
     has_safetensors = False
-    while page is not None:
-        content = requests.get(f"{base}{page}").content
+    while True:
+        content = requests.get(f"{base}{page}{cursor.decode()}").content
+
         dict = json.loads(content)
         dict = json.loads(content)
+        if len(dict) == 0:
+            break
 
 
         for i in range(len(dict)):
         for i in range(len(dict)):
             fname = dict[i]['path']
             fname = dict[i]['path']
 
 
             is_pytorch = re.match("pytorch_model.*\.bin", fname)
             is_pytorch = re.match("pytorch_model.*\.bin", fname)
             is_safetensors = re.match("model.*\.safetensors", fname)
             is_safetensors = re.match("model.*\.safetensors", fname)
-            is_text = re.match(".*\.(txt|json)", fname)
+            is_tokenizer = re.match("tokenizer.*\.model", fname)
+            is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
 
 
-            if is_text or is_safetensors or is_pytorch:
+            if any((is_pytorch, is_safetensors, is_text, is_tokenizer)):
                 if is_text:
                 if is_text:
                     links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
                     links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
                     classifications.append('text')
                     classifications.append('text')
@@ -123,8 +130,9 @@ def get_download_links_from_huggingface(model, branch):
                         has_pytorch = True
                         has_pytorch = True
                         classifications.append('pytorch')
                         classifications.append('pytorch')
 
 
-        #page = dict['nextUrl']
-        page = None
+        cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
+        cursor = base64.b64encode(cursor)
+        cursor = cursor.replace(b'=', b'%3D')
 
 
     # If both pytorch and safetensors are available, download safetensors only
     # If both pytorch and safetensors are available, download safetensors only
     if has_pytorch and has_safetensors:
     if has_pytorch and has_safetensors:

+ 18 - 0
extensions/llama_prompts/script.py

@@ -0,0 +1,18 @@
+import gradio as gr
+import modules.shared as shared
+import pandas as pd
+
+df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
+
+def get_prompt_by_name(name):
+    if name == 'None':
+        return ''
+    else:
+        return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
+
+def ui():
+    if not shared.args.chat or shared.args.cai_chat:
+        choices = ['None'] + list(df['Prompt name'])
+
+        prompts_menu = gr.Dropdown(value=choices[0], choices=choices, label='Prompt')
+        prompts_menu.change(get_prompt_by_name, prompts_menu, shared.gradio['textbox'])

+ 121 - 9
extensions/silero_tts/script.py

@@ -1,21 +1,45 @@
+import re
+import time
 from pathlib import Path
 from pathlib import Path
 
 
 import gradio as gr
 import gradio as gr
 import torch
 import torch
 
 
+import modules.chat as chat
+import modules.shared as shared
+
 torch._C._jit_set_profiling_mode(False)
 torch._C._jit_set_profiling_mode(False)
 
 
 params = {
 params = {
     'activate': True,
     'activate': True,
-    'speaker': 'en_56',
+    'speaker': 'en_5',
     'language': 'en',
     'language': 'en',
     'model_id': 'v3_en',
     'model_id': 'v3_en',
     'sample_rate': 48000,
     'sample_rate': 48000,
     'device': 'cpu',
     'device': 'cpu',
+    'show_text': False,
+    'autoplay': True,
+    'voice_pitch': 'medium',
+    'voice_speed': 'medium',
 }
 }
+
 current_params = params.copy()
 current_params = params.copy()
 voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
 voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
-wav_idx = 0
+voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
+voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
+last_msg_id = 0
+
+# Used for making text xml compatible, needed for voice pitch and speed control
+table = str.maketrans({
+    "<": "&lt;",
+    ">": "&gt;",
+    "&": "&amp;",
+    "'": "&apos;",
+    '"': "&quot;",
+})
+
+def xmlesc(txt):
+    return txt.translate(table)
 
 
 def load_model():
 def load_model():
     model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
     model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
@@ -33,12 +57,59 @@ def remove_surrounded_chars(string):
             new_string += char
             new_string += char
     return new_string
     return new_string
 
 
+def remove_tts_from_history():
+    suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
+    for i, entry in enumerate(shared.history['internal']):
+        reply = entry[1]
+        reply = re.sub("(<USER>|<user>|{{user}})", shared.settings[f'name1{suffix}'], reply)
+        if shared.args.chat:
+            reply = reply.replace('\n', '<br>')
+        shared.history['visible'][i][1] = reply
+
+    if shared.args.cai_chat:
+        return chat.generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name1{suffix}'], shared.character)
+    else:
+        return shared.history['visible']
+
+def toggle_text_in_history():
+    suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
+    audio_str='\n\n' # The '\n\n' used after </audio>
+    if shared.args.chat:
+         audio_str='<br><br>'
+
+    if params['show_text']==True:
+        #for i, entry in enumerate(shared.history['internal']):
+        for i, entry in enumerate(shared.history['visible']):
+            vis_reply = entry[1]
+            if vis_reply.startswith('<audio'):
+                reply = shared.history['internal'][i][1]
+                reply = re.sub("(<USER>|<user>|{{user}})", shared.settings[f'name1{suffix}'], reply)
+                if shared.args.chat:
+                    reply = reply.replace('\n', '<br>')
+                shared.history['visible'][i][1] = vis_reply.split(audio_str,1)[0]+audio_str+reply
+    else:
+        for i, entry in enumerate(shared.history['visible']):
+            vis_reply = entry[1]
+            if vis_reply.startswith('<audio'):
+                shared.history['visible'][i][1] = vis_reply.split(audio_str,1)[0]+audio_str
+
+    if shared.args.cai_chat:
+        return chat.generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name1{suffix}'], shared.character)
+    else:
+        return shared.history['visible']
+
 def input_modifier(string):
 def input_modifier(string):
     """
     """
     This function is applied to your text inputs before
     This function is applied to your text inputs before
     they are fed into the model.
     they are fed into the model.
     """
     """
 
 
+    # Remove autoplay from previous chat history
+    if (shared.args.chat or shared.args.cai_chat)and len(shared.history['internal'])>0:
+        [visible_text, visible_reply] = shared.history['visible'][-1]
+        vis_rep_clean = visible_reply.replace('controls autoplay>','controls>')
+        shared.history['visible'][-1] = [visible_text, vis_rep_clean]
+
     return string
     return string
 
 
 def output_modifier(string):
 def output_modifier(string):
@@ -46,7 +117,7 @@ def output_modifier(string):
     This function is applied to the model outputs.
     This function is applied to the model outputs.
     """
     """
 
 
-    global wav_idx, model, current_params
+    global model, current_params
 
 
     for i in params:
     for i in params:
         if params[i] != current_params[i]:
         if params[i] != current_params[i]:
@@ -57,20 +128,34 @@ def output_modifier(string):
     if params['activate'] == False:
     if params['activate'] == False:
         return string
         return string
 
 
+    orig_string = string
     string = remove_surrounded_chars(string)
     string = remove_surrounded_chars(string)
     string = string.replace('"', '')
     string = string.replace('"', '')
     string = string.replace('“', '')
     string = string.replace('“', '')
     string = string.replace('\n', ' ')
     string = string.replace('\n', ' ')
     string = string.strip()
     string = string.strip()
 
 
+    silent_string = False # Used to prevent unnecessary audio file generation
     if string == '':
     if string == '':
         string = 'empty reply, try regenerating'
         string = 'empty reply, try regenerating'
+        silent_string = True
+
+    pitch = params['voice_pitch']
+    speed = params['voice_speed']
+    prosody=f'<prosody rate="{speed}" pitch="{pitch}">'
+    string = '<speak>'+prosody+xmlesc(string)+'</prosody></speak>'
 
 
-    output_file = Path(f'extensions/silero_tts/outputs/{wav_idx:06d}.wav')
-    model.save_wav(text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
+    if not shared.still_streaming and not silent_string:
+        output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav')
+        model.save_wav(ssml_text=string, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
+        autoplay_str = ' autoplay' if params['autoplay'] else ''
+        string = f'<audio src="file/{output_file.as_posix()}" controls{autoplay_str}></audio>\n\n'
+    else:
+        # Placeholder so text doesn't shift around so much
+        string = '<audio controls></audio>\n\n'
 
 
-    string = f'<audio src="file/{output_file.as_posix()}" controls></audio>'
-    wav_idx += 1
+    if params['show_text']:
+        string += orig_string
 
 
     return string
     return string
 
 
@@ -85,9 +170,36 @@ def bot_prefix_modifier(string):
 
 
 def ui():
 def ui():
     # Gradio elements
     # Gradio elements
-    activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
-    voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
+    with gr.Accordion("Silero TTS"):
+        with gr.Row():
+            activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
+            autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
+        show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
+        voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
+        with gr.Row():
+            v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
+            v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
+        with gr.Row():
+            convert = gr.Button('Permanently replace chat history audio with message text')
+            convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
+            convert_cancel = gr.Button('Cancel', visible=False)
+
+    # Convert history with confirmation
+    convert_arr = [convert_confirm, convert, convert_cancel]
+    convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
+    convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
+    convert_confirm.click(remove_tts_from_history, [], shared.gradio['display'])
+    convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+    convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
+
+    # Toggle message text in history
+    show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
+    show_text.change(toggle_text_in_history, [], shared.gradio['display'])
+    show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
 
 
     # Event functions to update the parameters in the backend
     # Event functions to update the parameters in the backend
     activate.change(lambda x: params.update({"activate": x}), activate, None)
     activate.change(lambda x: params.update({"activate": x}), activate, None)
+    autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
     voice.change(lambda x: params.update({"speaker": x}), voice, None)
     voice.change(lambda x: params.update({"speaker": x}), voice, None)
+    v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
+    v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)

+ 6 - 42
modules/RWKV.py

@@ -1,12 +1,11 @@
 import os
 import os
 from pathlib import Path
 from pathlib import Path
-from queue import Queue
-from threading import Thread
 
 
 import numpy as np
 import numpy as np
 from tokenizers import Tokenizer
 from tokenizers import Tokenizer
 
 
 import modules.shared as shared
 import modules.shared as shared
+from modules.callbacks import Iteratorize
 
 
 np.set_printoptions(precision=4, suppress=True, linewidth=200)
 np.set_printoptions(precision=4, suppress=True, linewidth=200)
 
 
@@ -49,11 +48,11 @@ class RWKVModel:
         return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
         return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
 
 
     def generate_with_streaming(self, **kwargs):
     def generate_with_streaming(self, **kwargs):
-        iterable = Iteratorize(self.generate, kwargs, callback=None)
-        reply = kwargs['context']
-        for token in iterable:
-            reply += token
-            yield reply
+        with Iteratorize(self.generate, kwargs, callback=None) as generator:
+            reply = kwargs['context']
+            for token in generator:
+                reply += token
+                yield reply
 
 
 class RWKVTokenizer:
 class RWKVTokenizer:
     def __init__(self):
     def __init__(self):
@@ -73,38 +72,3 @@ class RWKVTokenizer:
 
 
     def decode(self, ids):
     def decode(self, ids):
         return self.tokenizer.decode(ids)
         return self.tokenizer.decode(ids)
-
-class Iteratorize:
-
-    """
-    Transforms a function that takes a callback
-    into a lazy iterator (generator).
-    """
-
-    def __init__(self, func, kwargs={}, callback=None):
-        self.mfunc=func
-        self.c_callback=callback
-        self.q = Queue(maxsize=1)
-        self.sentinel = object()
-        self.kwargs = kwargs
-
-        def _callback(val):
-            self.q.put(val)
-
-        def gentask():
-            ret = self.mfunc(callback=_callback, **self.kwargs)
-            self.q.put(self.sentinel)
-            if self.c_callback:
-                self.c_callback(ret)
-
-        Thread(target=gentask).start()
-
-    def __iter__(self):
-        return self
-
-    def __next__(self):
-        obj = self.q.get(True,None)
-        if obj is self.sentinel:
-            raise StopIteration
-        else:
-            return obj

+ 98 - 0
modules/callbacks.py

@@ -0,0 +1,98 @@
+import gc
+from queue import Queue
+from threading import Thread
+
+import torch
+import transformers
+
+import modules.shared as shared
+
+# Copied from https://github.com/PygmalionAI/gradio-ui/
+class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
+
+    def __init__(self, sentinel_token_ids: torch.LongTensor,
+                 starting_idx: int):
+        transformers.StoppingCriteria.__init__(self)
+        self.sentinel_token_ids = sentinel_token_ids
+        self.starting_idx = starting_idx
+
+    def __call__(self, input_ids: torch.LongTensor,
+                 _scores: torch.FloatTensor) -> bool:
+        for sample in input_ids:
+            trimmed_sample = sample[self.starting_idx:]
+            # Can't unfold, output is still too tiny. Skip.
+            if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
+                continue
+
+            for window in trimmed_sample.unfold(
+                    0, self.sentinel_token_ids.shape[-1], 1):
+                if torch.all(torch.eq(self.sentinel_token_ids, window)):
+                    return True
+        return False
+
+class Stream(transformers.StoppingCriteria):
+    def __init__(self, callback_func=None):
+        self.callback_func = callback_func
+
+    def __call__(self, input_ids, scores) -> bool:
+        if self.callback_func is not None:
+            self.callback_func(input_ids[0])
+        return False
+
+class Iteratorize:
+
+    """
+    Transforms a function that takes a callback
+    into a lazy iterator (generator).
+    """
+
+    def __init__(self, func, kwargs={}, callback=None):
+        self.mfunc=func
+        self.c_callback=callback
+        self.q = Queue()
+        self.sentinel = object()
+        self.kwargs = kwargs
+        self.stop_now = False
+
+        def _callback(val):
+            if self.stop_now:
+                raise ValueError
+            self.q.put(val)
+
+        def gentask():
+            try:
+                ret = self.mfunc(callback=_callback, **self.kwargs)
+            except ValueError:
+                pass
+            clear_torch_cache()
+            self.q.put(self.sentinel)
+            if self.c_callback:
+                self.c_callback(ret)
+
+        self.thread = Thread(target=gentask)
+        self.thread.start()
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        obj = self.q.get(True,None)
+        if obj is self.sentinel:
+            raise StopIteration
+        else:
+            return obj
+
+    def __del__(self):
+        clear_torch_cache()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.stop_now = True
+        clear_torch_cache()
+
+def clear_torch_cache():
+    gc.collect()
+    if not shared.args.cpu:
+        torch.cuda.empty_cache()

+ 13 - 5
modules/chat.py

@@ -84,6 +84,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
         tmp = f"\n{asker}:"
         tmp = f"\n{asker}:"
         for j in range(1, len(tmp)):
         for j in range(1, len(tmp)):
             if reply[-j:] == tmp[:j]:
             if reply[-j:] == tmp[:j]:
+                reply = reply[:-j]
                 substring_found = True
                 substring_found = True
 
 
     return reply, next_character_found, substring_found
     return reply, next_character_found, substring_found
@@ -91,7 +92,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
 def stop_everything_event():
 def stop_everything_event():
     shared.stop_everything = True
     shared.stop_everything = True
 
 
-def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
+def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
     shared.stop_everything = False
     shared.stop_everything = False
     just_started = True
     just_started = True
     eos_token = '\n' if check else None
     eos_token = '\n' if check else None
@@ -120,6 +121,10 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
     else:
     else:
         prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
         prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
 
 
+    if not regenerate:
+        # Display user input and "*is typing...*" imediately
+        yield shared.history['visible']+[[visible_text, '*Is typing...*']]
+
     # Generate
     # Generate
     reply = ''
     reply = ''
     for i in range(chat_generation_attempts):
     for i in range(chat_generation_attempts):
@@ -158,6 +163,9 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
 
 
     prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
     prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
 
 
+    # Display "*is typing...*" imediately
+    yield '*Is typing...*'
+
     reply = ''
     reply = ''
     for i in range(chat_generation_attempts):
     for i in range(chat_generation_attempts):
         for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
         for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
@@ -182,7 +190,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
         last_visible = shared.history['visible'].pop()
         last_visible = shared.history['visible'].pop()
         last_internal = shared.history['internal'].pop()
         last_internal = shared.history['internal'].pop()
 
 
-        for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
+        for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
             if shared.args.cai_chat:
             if shared.args.cai_chat:
                 shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
                 shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
                 yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
                 yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
@@ -291,7 +299,7 @@ def save_history(timestamp=True):
         fname = f"{prefix}persistent.json"
         fname = f"{prefix}persistent.json"
     if not Path('logs').exists():
     if not Path('logs').exists():
         Path('logs').mkdir()
         Path('logs').mkdir()
-    with open(Path(f'logs/{fname}'), 'w') as f:
+    with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
         f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
         f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
     return Path(f'logs/{fname}')
     return Path(f'logs/{fname}')
 
 
@@ -332,7 +340,7 @@ def load_character(_character, name1, name2):
     shared.history['visible'] = []
     shared.history['visible'] = []
     if _character != 'None':
     if _character != 'None':
         shared.character = _character
         shared.character = _character
-        data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read())
+        data = json.loads(open(Path(f'characters/{_character}.json'), 'r', encoding='utf-8').read())
         name2 = data['char_name']
         name2 = data['char_name']
         if 'char_persona' in data and data['char_persona'] != '':
         if 'char_persona' in data and data['char_persona'] != '':
             context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
             context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
@@ -372,7 +380,7 @@ def upload_character(json_file, img, tavern=False):
         i += 1
         i += 1
     if tavern:
     if tavern:
         outfile_name = f'TavernAI-{outfile_name}'
         outfile_name = f'TavernAI-{outfile_name}'
-    with open(Path(f'characters/{outfile_name}.json'), 'w') as f:
+    with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f:
         f.write(json_file)
         f.write(json_file)
     if img is not None:
     if img is not None:
         img = Image.open(io.BytesIO(img))
         img = Image.open(io.BytesIO(img))

+ 8 - 1
modules/models.py

@@ -1,5 +1,6 @@
 import json
 import json
 import os
 import os
+import sys
 import time
 import time
 import zipfile
 import zipfile
 from pathlib import Path
 from pathlib import Path
@@ -41,7 +42,7 @@ def load_model(model_name):
     shared.is_RWKV = model_name.lower().startswith('rwkv-')
     shared.is_RWKV = model_name.lower().startswith('rwkv-')
 
 
     # Default settings
     # Default settings
-    if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV):
+    if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.gptq_bits > 0, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
         if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
         if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
             model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
             model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
         else:
         else:
@@ -86,6 +87,12 @@ def load_model(model_name):
 
 
         return model, tokenizer
         return model, tokenizer
 
 
+    # 4-bit LLaMA
+    elif shared.args.gptq_bits > 0 or shared.args.load_in_4bit:
+        from modules.quantized_LLaMA import load_quantized_LLaMA
+
+        model = load_quantized_LLaMA(model_name)
+
     # Custom
     # Custom
     else:
     else:
         command = "AutoModelForCausalLM.from_pretrained"
         command = "AutoModelForCausalLM.from_pretrained"

+ 60 - 0
modules/quantized_LLaMA.py

@@ -0,0 +1,60 @@
+import os
+import sys
+from pathlib import Path
+
+import accelerate
+import torch
+
+import modules.shared as shared
+
+sys.path.insert(0, os.path.abspath(Path("repositories/GPTQ-for-LLaMa")))
+from llama import load_quant
+
+
+# 4-bit LLaMA
+def load_quantized_LLaMA(model_name):
+    if shared.args.load_in_4bit:
+        bits = 4
+    else:
+        bits = shared.args.gptq_bits
+
+    path_to_model = Path(f'models/{model_name}')
+    pt_model = ''
+    if path_to_model.name.lower().startswith('llama-7b'):
+        pt_model = f'llama-7b-{bits}bit.pt'
+    elif path_to_model.name.lower().startswith('llama-13b'):
+        pt_model = f'llama-13b-{bits}bit.pt'
+    elif path_to_model.name.lower().startswith('llama-30b'):
+        pt_model = f'llama-30b-{bits}bit.pt'
+    elif path_to_model.name.lower().startswith('llama-65b'):
+        pt_model = f'llama-65b-{bits}bit.pt'
+    else:
+        pt_model = f'{model_name}-{bits}bit.pt'
+
+    # Try to find the .pt both in models/ and in the subfolder
+    pt_path = None
+    for path in [Path(p) for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
+        if path.exists():
+            pt_path = path
+
+    if not pt_path:
+        print(f"Could not find {pt_model}, exiting...")
+        exit()
+
+    model = load_quant(path_to_model, os.path.abspath(pt_path), bits)
+
+    # Multi-GPU setup
+    if shared.args.gpu_memory:
+        max_memory = {}
+        for i in range(len(shared.args.gpu_memory)):
+            max_memory[i] = f"{shared.args.gpu_memory[i]}GiB"
+        max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB"
+
+        device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LLaMADecoderLayer"])
+        model = accelerate.dispatch_model(model, device_map=device_map)
+
+    # Single GPU
+    else:
+        model = model.to(torch.device('cuda:0'))
+
+    return model

+ 6 - 2
modules/shared.py

@@ -11,6 +11,7 @@ is_RWKV = False
 history = {'internal': [], 'visible': []}
 history = {'internal': [], 'visible': []}
 character = 'None'
 character = 'None'
 stop_everything = False
 stop_everything = False
+still_streaming = False
 
 
 # UI elements (buttons, sliders, HTML, etc)
 # UI elements (buttons, sliders, HTML, etc)
 gradio = {}
 gradio = {}
@@ -42,12 +43,12 @@ settings = {
         'default': 'NovelAI-Sphinx Moth',
         'default': 'NovelAI-Sphinx Moth',
         'pygmalion-*': 'Pygmalion',
         'pygmalion-*': 'Pygmalion',
         'RWKV-*': 'Naive',
         'RWKV-*': 'Naive',
-        '(rosey|chip|joi)_.*_instruct.*': 'Instruct Joi (Contrastive Search)'
     },
     },
     'prompts': {
     'prompts': {
         'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
         'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
         '^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
         '^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
-        '(rosey|chip|joi)_.*_instruct.*': 'User: \n'
+        '(rosey|chip|joi)_.*_instruct.*': 'User: \n',
+        'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
     }
     }
 }
 }
 
 
@@ -68,6 +69,8 @@ parser.add_argument('--chat', action='store_true', help='Launch the web UI in ch
 parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
 parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
 parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
 parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
 parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
 parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
+parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision. Currently only works with LLaMA.')
+parser.add_argument('--gptq-bits', type=int, default=0, help='Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA.')
 parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
 parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
 parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
 parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
 parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
 parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
@@ -90,4 +93,5 @@ parser.add_argument('--listen', action='store_true', help='Make the web UI reach
 parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
 parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
 parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
 parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
 parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
 parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
+parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch')
 args = parser.parse_args()
 args = parser.parse_args()

+ 0 - 32
modules/stopping_criteria.py

@@ -1,32 +0,0 @@
-'''
-This code was copied from
-
-https://github.com/PygmalionAI/gradio-ui/
-
-'''
-
-import torch
-import transformers
-
-
-class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
-
-    def __init__(self, sentinel_token_ids: torch.LongTensor,
-                 starting_idx: int):
-        transformers.StoppingCriteria.__init__(self)
-        self.sentinel_token_ids = sentinel_token_ids
-        self.starting_idx = starting_idx
-
-    def __call__(self, input_ids: torch.LongTensor,
-                 _scores: torch.FloatTensor) -> bool:
-        for sample in input_ids:
-            trimmed_sample = sample[self.starting_idx:]
-            # Can't unfold, output is still too tiny. Skip.
-            if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
-                continue
-
-            for window in trimmed_sample.unfold(
-                    0, self.sentinel_token_ids.shape[-1], 1):
-                if torch.all(torch.eq(self.sentinel_token_ids, window)):
-                    return True
-        return False

+ 92 - 62
modules/text_generation.py

@@ -5,13 +5,13 @@ import time
 import numpy as np
 import numpy as np
 import torch
 import torch
 import transformers
 import transformers
-from tqdm import tqdm
 
 
 import modules.shared as shared
 import modules.shared as shared
+from modules.callbacks import (Iteratorize, Stream,
+                               _SentinelTokenStoppingCriteria)
 from modules.extensions import apply_extensions
 from modules.extensions import apply_extensions
 from modules.html_generator import generate_4chan_html, generate_basic_html
 from modules.html_generator import generate_4chan_html, generate_basic_html
 from modules.models import local_rank
 from modules.models import local_rank
-from modules.stopping_criteria import _SentinelTokenStoppingCriteria
 
 
 
 
 def get_max_prompt_length(tokens):
 def get_max_prompt_length(tokens):
@@ -92,19 +92,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
     # These models are not part of Hugging Face, so we handle them
     # These models are not part of Hugging Face, so we handle them
     # separately and terminate the function call earlier
     # separately and terminate the function call earlier
     if shared.is_RWKV:
     if shared.is_RWKV:
-        if shared.args.no_stream:
-            reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
-            yield formatted_outputs(reply, shared.model_name)
-        else:
-            yield formatted_outputs(question, shared.model_name)
-            # RWKV has proper streaming, which is very nice.
-            # No need to generate 8 tokens at a time.
-            for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
+        try:
+            if shared.args.no_stream:
+                reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
                 yield formatted_outputs(reply, shared.model_name)
                 yield formatted_outputs(reply, shared.model_name)
-
-        t1 = time.time()
-        print(f"Output generated in {(t1-t0):.2f} seconds.")
-        return
+            else:
+                yield formatted_outputs(question, shared.model_name)
+                # RWKV has proper streaming, which is very nice.
+                # No need to generate 8 tokens at a time.
+                for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
+                    yield formatted_outputs(reply, shared.model_name)
+        finally:
+            t1 = time.time()
+            output = encode(reply)[0]
+            input_ids = encode(question)
+            print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
+            return
 
 
     original_question = question
     original_question = question
     if not (shared.args.chat or shared.args.cai_chat):
     if not (shared.args.chat or shared.args.cai_chat):
@@ -113,24 +116,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
         print(f"\n\n{question}\n--------------------\n")
         print(f"\n\n{question}\n--------------------\n")
 
 
     input_ids = encode(question, max_new_tokens)
     input_ids = encode(question, max_new_tokens)
+    original_input_ids = input_ids
+    output = input_ids[0]
     cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
     cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
-    n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
+    eos_token_ids = [shared.tokenizer.eos_token_id]
+    if eos_token is not None:
+        eos_token_ids.append(int(encode(eos_token)[0][-1]))
+    stopping_criteria_list = transformers.StoppingCriteriaList()
     if stopping_string is not None:
     if stopping_string is not None:
-        # The stopping_criteria code below was copied from
-        # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
+        # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
         t = encode(stopping_string, 0, add_special_tokens=False)
         t = encode(stopping_string, 0, add_special_tokens=False)
-        stopping_criteria_list = transformers.StoppingCriteriaList([
-            _SentinelTokenStoppingCriteria(
-                sentinel_token_ids=t,
-                starting_idx=len(input_ids[0])
-            )
-        ])
-    else:
-        stopping_criteria_list = None
+        stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
 
 
     if not shared.args.flexgen:
     if not shared.args.flexgen:
         generate_params = [
         generate_params = [
-            f"eos_token_id={n}",
+            f"max_new_tokens=max_new_tokens",
+            f"eos_token_id={eos_token_ids}",
             f"stopping_criteria=stopping_criteria_list",
             f"stopping_criteria=stopping_criteria_list",
             f"do_sample={do_sample}",
             f"do_sample={do_sample}",
             f"temperature={temperature}",
             f"temperature={temperature}",
@@ -147,44 +148,23 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
         ]
         ]
     else:
     else:
         generate_params = [
         generate_params = [
+            f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
             f"do_sample={do_sample}",
             f"do_sample={do_sample}",
             f"temperature={temperature}",
             f"temperature={temperature}",
-            f"stop={n}",
+            f"stop={eos_token_ids[-1]}",
         ]
         ]
     if shared.args.deepspeed:
     if shared.args.deepspeed:
         generate_params.append("synced_gpus=True")
         generate_params.append("synced_gpus=True")
-    if shared.args.no_stream:
-        generate_params.append("max_new_tokens=max_new_tokens")
-    else:
-        generate_params.append("max_new_tokens=8")
     if shared.soft_prompt:
     if shared.soft_prompt:
         inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
         inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
         generate_params.insert(0, "inputs_embeds=inputs_embeds")
         generate_params.insert(0, "inputs_embeds=inputs_embeds")
-        generate_params.insert(0, "filler_input_ids")
+        generate_params.insert(0, "inputs=filler_input_ids")
     else:
     else:
-        generate_params.insert(0, "input_ids")
-
-    # Generate the entire reply at once
-    if shared.args.no_stream:
-        with torch.no_grad():
-            output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
-        if shared.soft_prompt:
-            output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
-
-        reply = decode(output)
-        if not (shared.args.chat or shared.args.cai_chat):
-            reply = original_question + apply_extensions(reply[len(question):], "output")
-
-        t1 = time.time()
-        print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
-        yield formatted_outputs(reply, shared.model_name)
-
-    # Generate the reply 8 tokens at a time
-    else:
-        yield formatted_outputs(original_question, shared.model_name)
-        for i in tqdm(range(max_new_tokens//8+1)):
-            clear_torch_cache()
+        generate_params.insert(0, "inputs=input_ids")
 
 
+    try:
+        # Generate the entire reply at once.
+        if shared.args.no_stream:
             with torch.no_grad():
             with torch.no_grad():
                 output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
                 output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
             if shared.soft_prompt:
             if shared.soft_prompt:
@@ -193,16 +173,66 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
             reply = decode(output)
             reply = decode(output)
             if not (shared.args.chat or shared.args.cai_chat):
             if not (shared.args.chat or shared.args.cai_chat):
                 reply = original_question + apply_extensions(reply[len(question):], "output")
                 reply = original_question + apply_extensions(reply[len(question):], "output")
+
             yield formatted_outputs(reply, shared.model_name)
             yield formatted_outputs(reply, shared.model_name)
 
 
-            if not shared.args.flexgen:
-                if output[-1] == n:
-                    break
-                input_ids = torch.reshape(output, (1, output.shape[0]))
-            else:
-                if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
+        # Stream the reply 1 token at a time.
+        # This is based on the trick of using 'stopping_criteria' to create an iterator.
+        elif not shared.args.flexgen:
+
+            def generate_with_callback(callback=None, **kwargs):
+                kwargs['stopping_criteria'].append(Stream(callback_func=callback))
+                clear_torch_cache()
+                with torch.no_grad():
+                    shared.model.generate(**kwargs)
+
+            def generate_with_streaming(**kwargs):
+                return Iteratorize(generate_with_callback, kwargs, callback=None)
+
+            shared.still_streaming = True
+            yield formatted_outputs(original_question, shared.model_name)
+            with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator:
+                for output in generator:
+                    if shared.soft_prompt:
+                        output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+                    reply = decode(output)
+
+                    if not (shared.args.chat or shared.args.cai_chat):
+                        reply = original_question + apply_extensions(reply[len(question):], "output")
+
+                    if output[-1] in eos_token_ids:
+                        break
+                    yield formatted_outputs(reply, shared.model_name)
+
+                shared.still_streaming = False
+                yield formatted_outputs(reply, shared.model_name)
+
+        # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
+        else:
+            shared.still_streaming = True
+            for i in range(max_new_tokens//8+1):
+                clear_torch_cache()
+                with torch.no_grad():
+                    output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
+                if shared.soft_prompt:
+                    output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+                reply = decode(output)
+
+                if not (shared.args.chat or shared.args.cai_chat):
+                    reply = original_question + apply_extensions(reply[len(question):], "output")
+
+                if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
                     break
                     break
+                yield formatted_outputs(reply, shared.model_name)
+
                 input_ids = np.reshape(output, (1, output.shape[0]))
                 input_ids = np.reshape(output, (1, output.shape[0]))
+                if shared.soft_prompt:
+                    inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
 
 
-            if shared.soft_prompt:
-                inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
+            shared.still_streaming = False
+            yield formatted_outputs(reply, shared.model_name)
+
+    finally:
+        t1 = time.time()
+        print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
+        return

+ 6 - 4
requirements.txt

@@ -1,9 +1,11 @@
-accelerate==0.16.0
+accelerate==0.17.0
 bitsandbytes==0.37.0
 bitsandbytes==0.37.0
 flexgen==0.1.7
 flexgen==0.1.7
 gradio==3.18.0
 gradio==3.18.0
 numpy
 numpy
-rwkv==0.1.0
-safetensors==0.2.8
+requests
+rwkv==0.3.1
+safetensors==0.3.0
 sentencepiece
 sentencepiece
-git+https://github.com/oobabooga/transformers@llama_push
+tqdm
+git+https://github.com/zphang/transformers@llama_push

+ 8 - 10
server.py

@@ -18,9 +18,6 @@ from modules.html_generator import generate_chat_html
 from modules.models import load_model, load_soft_prompt
 from modules.models import load_model, load_soft_prompt
 from modules.text_generation import generate_reply
 from modules.text_generation import generate_reply
 
 
-if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
-    print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n')
-    
 # Loading custom settings
 # Loading custom settings
 settings_file = None
 settings_file = None
 if shared.args.settings is not None and Path(shared.args.settings).exists():
 if shared.args.settings is not None and Path(shared.args.settings).exists():
@@ -37,7 +34,7 @@ def get_available_models():
     if shared.args.flexgen:
     if shared.args.flexgen:
         return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
         return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
     else:
     else:
-        return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower)
+        return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt'))], key=str.lower)
 
 
 def get_available_presets():
 def get_available_presets():
     return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
     return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
@@ -272,10 +269,10 @@ if shared.args.chat or shared.args.cai_chat:
 
 
         function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
         function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
 
 
-        gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen'))
-        gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
-        gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
-        gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
+        gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False, api_name='textgen'))
+        gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False))
+        gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False))
+        gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False))
         shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
         shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
 
 
         shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
         shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
@@ -309,6 +306,7 @@ if shared.args.chat or shared.args.cai_chat:
         reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
         reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
         shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
         shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
         shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
         shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
+        shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
 
 
         shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
         shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
         shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
         shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
@@ -372,9 +370,9 @@ else:
 
 
 shared.gradio['interface'].queue()
 shared.gradio['interface'].queue()
 if shared.args.listen:
 if shared.args.listen:
-    shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port)
+    shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
 else:
 else:
-    shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port)
+    shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
 
 
 # I think that I will need this later
 # I think that I will need this later
 while True:
 while True:

+ 2 - 1
settings-template.json

@@ -29,6 +29,7 @@
     "prompts": {
     "prompts": {
         "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
         "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
         "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
         "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
-        "(rosey|chip|joi)_.*_instruct.*": "User: \n"
+        "(rosey|chip|joi)_.*_instruct.*": "User: \n",
+        "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
     }
     }
 }
 }