Просмотр исходного кода

Merge branch 'main' into Brawlence-main

oobabooga 3 лет назад
Родитель
Сommit
eab8de0d4a

+ 4 - 0
.gitignore

@@ -4,12 +4,15 @@ extensions/silero_tts/outputs/*
 extensions/elevenlabs_tts/outputs/*
 extensions/elevenlabs_tts/outputs/*
 extensions/sd_api_pictures/outputs/*
 extensions/sd_api_pictures/outputs/*
 logs/*
 logs/*
+loras/*
 models/*
 models/*
 softprompts/*
 softprompts/*
 torch-dumps/*
 torch-dumps/*
 *pycache*
 *pycache*
 */*pycache*
 */*pycache*
 */*/pycache*
 */*/pycache*
+venv/
+.venv/
 
 
 settings.json
 settings.json
 img_bot*
 img_bot*
@@ -17,6 +20,7 @@ img_me*
 
 
 !characters/Example.json
 !characters/Example.json
 !characters/Example.png
 !characters/Example.png
+!loras/place-your-loras-here.txt
 !models/place-your-models-here.txt
 !models/place-your-models-here.txt
 !softprompts/place-your-softprompts-here.txt
 !softprompts/place-your-softprompts-here.txt
 !torch-dumps/place-your-pt-models-here.txt
 !torch-dumps/place-your-pt-models-here.txt

+ 88 - 56
README.md

@@ -19,52 +19,76 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
 * Generate Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX support.
 * Generate Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX support.
 * Support for [Pygmalion](https://huggingface.co/models?search=pygmalionai/pygmalion) and custom characters in JSON or TavernAI Character Card formats ([FAQ](https://github.com/oobabooga/text-generation-webui/wiki/Pygmalion-chat-model-FAQ)).
 * Support for [Pygmalion](https://huggingface.co/models?search=pygmalionai/pygmalion) and custom characters in JSON or TavernAI Character Card formats ([FAQ](https://github.com/oobabooga/text-generation-webui/wiki/Pygmalion-chat-model-FAQ)).
 * Advanced chat features (send images, get audio responses with TTS).
 * Advanced chat features (send images, get audio responses with TTS).
-* Stream the text output in real time.
+* Stream the text output in real time very efficiently.
 * Load parameter presets from text files.
 * Load parameter presets from text files.
-* Load large models in 8-bit mode (see [here](https://github.com/oobabooga/text-generation-webui/issues/147#issuecomment-1456040134), [here](https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652) and [here](https://www.reddit.com/r/PygmalionAI/comments/1115gom/running_pygmalion_6b_with_8gb_of_vram/) if you are on Windows).
+* Load large models in 8-bit mode.
 * Split large models across your GPU(s), CPU, and disk.
 * Split large models across your GPU(s), CPU, and disk.
 * CPU mode.
 * CPU mode.
 * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
 * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
 * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
 * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
 * Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
 * Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
-* [Supports the LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
-* [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
+* [LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
+* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
+* [Supports LoRAs](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs).
 * Supports softprompts.
 * Supports softprompts.
 * [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
 * [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
 * [Works on Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab).
 * [Works on Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab).
 
 
-## Installation option 1: conda
+## Installation
 
 
-Open a terminal and copy and paste these commands one at a time ([install conda](https://docs.conda.io/en/latest/miniconda.html) first if you don't have it already):
+The recommended installation methods are the following:
+
+* Linux and MacOS: using conda natively.
+* Windows: using conda on WSL ([WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/Windows-Subsystem-for-Linux-(Ubuntu)-Installation-Guide)).
+
+Conda can be downloaded here: https://docs.conda.io/en/latest/miniconda.html
+
+On Linux or WSL, it can be automatically installed with these two commands:
 
 
 ```
 ```
-conda create -n textgen
-conda activate textgen
-conda install torchvision torchaudio pytorch-cuda=11.7 git -c pytorch -c nvidia
-git clone https://github.com/oobabooga/text-generation-webui
-cd text-generation-webui
-pip install -r requirements.txt
+curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh"
+bash Miniconda3.sh
 ```
 ```
 
 
-The third line assumes that you have an NVIDIA GPU. 
+Source: https://educe-ubc.github.io/conda.html
 
 
-* If you have an AMD GPU, replace the third command with this one:
+#### 1. Create a new conda environment
 
 
 ```
 ```
-pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
+conda create -n textgen python=3.10.9
+conda activate textgen
 ```
 ```
-  	  
-* If you are running it in CPU mode, replace the third command with this one:
+
+#### 2. Install Pytorch
+
+| System | GPU | Command |
+|--------|---------|---------|
+| Linux/WSL | NVIDIA | `pip3 install torch torchvision torchaudio` |
+| Linux | AMD | `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2` |
+| MacOS + MPS (untested) | Any | `pip3 install torch torchvision torchaudio` |
+
+The up to date commands can be found here: https://pytorch.org/get-started/locally/. 
+
+MacOS users, refer to the comments here: https://github.com/oobabooga/text-generation-webui/pull/393
+
+
+#### 3. Install the web UI
 
 
 ```
 ```
-conda install pytorch torchvision torchaudio git -c pytorch
+git clone https://github.com/oobabooga/text-generation-webui
+cd text-generation-webui
+pip install -r requirements.txt
 ```
 ```
 
 
 > **Note**
 > **Note**
-> 1. If you are on Windows, it may be easier to run the commands above in a WSL environment. The performance may also be better.
-> 2. For a more detailed, user-contributed guide, see: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
+> 
+> For bitsandbytes and `--load-in-8bit` to work on Linux/WSL, this dirty fix is currently necessary: https://github.com/oobabooga/text-generation-webui/issues/400#issuecomment-1474876859
+
+### Alternative: native Windows installation
 
 
-## Installation option 2: one-click installers
+As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
+
+### Alternative: one-click installers
 
 
 [oobabooga-windows.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-windows.zip)
 [oobabooga-windows.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-windows.zip)
 
 
@@ -75,19 +99,25 @@ Just download the zip above, extract it, and double click on "install". The web
 * To download a model, double click on "download-model"
 * To download a model, double click on "download-model"
 * To start the web UI, double click on "start-webui" 
 * To start the web UI, double click on "start-webui" 
 
 
-## Downloading models
+Source codes: https://github.com/oobabooga/one-click-installers
+
+This method lags behind the newest developments and does not support 8-bit mode on Windows without additional set up: https://github.com/oobabooga/text-generation-webui/issues/147#issuecomment-1456040134, https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652
 
 
-Models should be placed under `models/model-name`. For instance, `models/gpt-j-6B` for [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main).
+### Alternative: Docker
 
 
-#### Hugging Face
+https://github.com/oobabooga/text-generation-webui/issues/174, https://github.com/oobabooga/text-generation-webui/issues/87
+
+## Downloading models
+
+Models should be placed inside the `models` folder.
 
 
 [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads) is the main place to download models. These are some noteworthy examples:
 [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads) is the main place to download models. These are some noteworthy examples:
 
 
-* [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main)
-* [GPT-Neo](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=eleutherai+%2F+gpt-neo)
 * [Pythia](https://huggingface.co/models?search=eleutherai/pythia)
 * [Pythia](https://huggingface.co/models?search=eleutherai/pythia)
 * [OPT](https://huggingface.co/models?search=facebook/opt)
 * [OPT](https://huggingface.co/models?search=facebook/opt)
 * [GALACTICA](https://huggingface.co/models?search=facebook/galactica)
 * [GALACTICA](https://huggingface.co/models?search=facebook/galactica)
+* [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main)
+* [GPT-Neo](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=eleutherai+%2F+gpt-neo)
 * [\*-Erebus](https://huggingface.co/models?search=erebus) (NSFW)
 * [\*-Erebus](https://huggingface.co/models?search=erebus) (NSFW)
 * [Pygmalion](https://huggingface.co/models?search=pygmalion) (NSFW)
 * [Pygmalion](https://huggingface.co/models?search=pygmalion) (NSFW)
 
 
@@ -101,7 +131,7 @@ For instance:
 
 
 If you want to download a model manually, note that all you need are the json, txt, and pytorch\*.bin (or model*.safetensors) files. The remaining files are not necessary.
 If you want to download a model manually, note that all you need are the json, txt, and pytorch\*.bin (or model*.safetensors) files. The remaining files are not necessary.
 
 
-#### GPT-4chan
+### GPT-4chan
 
 
 [GPT-4chan](https://huggingface.co/ykilcher/gpt-4chan) has been shut down from Hugging Face, so you need to download it elsewhere. You have two options:
 [GPT-4chan](https://huggingface.co/ykilcher/gpt-4chan) has been shut down from Hugging Face, so you need to download it elsewhere. You have two options:
 
 
@@ -123,6 +153,7 @@ python download-model.py EleutherAI/gpt-j-6B --text-only
 ## Starting the web UI
 ## Starting the web UI
 
 
     conda activate textgen
     conda activate textgen
+    cd text-generation-webui
     python server.py
     python server.py
 
 
 Then browse to 
 Then browse to 
@@ -133,41 +164,42 @@ Then browse to
 
 
 Optionally, you can use the following command-line flags:
 Optionally, you can use the following command-line flags:
 
 
-| Flag        | Description |
-|-------------|-------------|
-| `-h`, `--help`  | show this help message and exit |
-| `--model MODEL`    | Name of the model to load by default. |
-| `--notebook`  | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
-| `--chat`      | Launch the web UI in chat mode.|
-| `--cai-chat`  | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
-| `--cpu`       | Use the CPU to generate text.|
-| `--load-in-8bit`  | Load the model with 8-bit precision.|
-| `--load-in-4bit`  | DEPRECATED: use `--gptq-bits 4` instead. |
-| `--gptq-bits GPTQ_BITS`  |  Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. |
-| `--gptq-model-type MODEL_TYPE`  |  Model type of pre-quantized model. Currently only LLaMa and OPT are supported. |
-| `--bf16`  | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
+| Flag             | Description |
+|------------------|-------------|
+| `-h`, `--help`   | show this help message and exit |
+| `--model MODEL`  | Name of the model to load by default. |
+| `--lora LORA`    | Name of the LoRA to apply to the model by default. |
+| `--notebook`     | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
+| `--chat`         | Launch the web UI in chat mode.|
+| `--cai-chat`     | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
+| `--cpu`          | Use the CPU to generate text.|
+| `--load-in-8bit` | Load the model with 8-bit precision.|
+| `--load-in-4bit` | DEPRECATED: use `--gptq-bits 4` instead. |
+| `--gptq-bits GPTQ_BITS` |  Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. |
+| `--gptq-model-type MODEL_TYPE` |  Model type of pre-quantized model. Currently only LLaMa and OPT are supported. |
+| `--bf16`         | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
 | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
 | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
-| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
+| `--disk`         | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
 | `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
 | `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
 |  `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` |  Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. |
 |  `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` |  Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. |
-| `--cpu-memory CPU_MEMORY`    | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
-| `--flexgen`                   |         Enable the use of FlexGen offloading. |
-|  `--percent PERCENT [PERCENT ...]`    |  FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
-|  `--compress-weight`                  |  FlexGen: Whether to compress weight (default: False).|
-|  `--pin-weight [PIN_WEIGHT]`          |       FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). |
+| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
+| `--flexgen`      |         Enable the use of FlexGen offloading. |
+|  `--percent PERCENT [PERCENT ...]` |  FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
+|  `--compress-weight` |  FlexGen: Whether to compress weight (default: False).|
+|  `--pin-weight [PIN_WEIGHT]` |       FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). |
 | `--deepspeed`    | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. |
 | `--deepspeed`    | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. |
-| `--nvme-offload-dir NVME_OFFLOAD_DIR`    | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. |
-| `--local_rank LOCAL_RANK`    | DeepSpeed: Optional argument for distributed setups. |
-|  `--rwkv-strategy RWKV_STRATEGY`         |    RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
-|  `--rwkv-cuda-on`                        |   RWKV: Compile the CUDA kernel for better performance. |
-| `--no-stream`   | Don't stream the text output in real time. |
+| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. |
+| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
+|  `--rwkv-strategy RWKV_STRATEGY` |    RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
+|  `--rwkv-cuda-on` |   RWKV: Compile the CUDA kernel for better performance. |
+| `--no-stream`    | Don't stream the text output in real time. |
 | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
 | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
 |  `--extensions EXTENSIONS [EXTENSIONS ...]` |  The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
 |  `--extensions EXTENSIONS [EXTENSIONS ...]` |  The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
-| `--listen`   | Make the web UI reachable from your local network.|
+| `--listen`       | Make the web UI reachable from your local network.|
 |  `--listen-port LISTEN_PORT` | The listening port that the server will use. |
 |  `--listen-port LISTEN_PORT` | The listening port that the server will use. |
-| `--share`   | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
-| `--auto-launch` | Open the web UI in the default browser upon launch. |
-| `--verbose`   | Print the prompts to the terminal. |
+| `--share`        | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
+| `--auto-launch`  | Open the web UI in the default browser upon launch. |
+| `--verbose`      | Print the prompts to the terminal. |
 
 
 Out of memory errors? [Check this guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).
 Out of memory errors? [Check this guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).
 
 
@@ -192,7 +224,7 @@ Before reporting a bug, make sure that you have:
 
 
 ## Credits
 ## Credits
 
 
-- Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui
+- Gradio dropdown menu refresh button, code for reloading the interface: https://github.com/AUTOMATIC1111/stable-diffusion-webui
 - Verbose preset: Anonymous 4chan user.
 - Verbose preset: Anonymous 4chan user.
 - NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
 - NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
 - Pygmalion preset, code for early stopping in chat mode, code for some of the sliders, --chat mode colors: https://github.com/PygmalionAI/gradio-ui/
 - Pygmalion preset, code for early stopping in chat mode, code for some of the sliders, --chat mode colors: https://github.com/PygmalionAI/gradio-ui/

+ 4 - 2
api-example-stream.py

@@ -26,6 +26,7 @@ async def run(context):
         'top_p': 0.9,
         'top_p': 0.9,
         'typical_p': 1,
         'typical_p': 1,
         'repetition_penalty': 1.05,
         'repetition_penalty': 1.05,
+        'encoder_repetition_penalty': 1.0,
         'top_k': 0,
         'top_k': 0,
         'min_length': 0,
         'min_length': 0,
         'no_repeat_ngram_size': 0,
         'no_repeat_ngram_size': 0,
@@ -43,14 +44,14 @@ async def run(context):
                 case "send_hash":
                 case "send_hash":
                     await websocket.send(json.dumps({
                     await websocket.send(json.dumps({
                         "session_hash": session,
                         "session_hash": session,
-                        "fn_index": 7
+                        "fn_index": 9
                     }))
                     }))
                 case "estimation":
                 case "estimation":
                     pass
                     pass
                 case "send_data":
                 case "send_data":
                     await websocket.send(json.dumps({
                     await websocket.send(json.dumps({
                         "session_hash": session,
                         "session_hash": session,
-                        "fn_index": 7,
+                        "fn_index": 9,
                         "data": [
                         "data": [
                             context,
                             context,
                             params['max_new_tokens'],
                             params['max_new_tokens'],
@@ -59,6 +60,7 @@ async def run(context):
                             params['top_p'],
                             params['top_p'],
                             params['typical_p'],
                             params['typical_p'],
                             params['repetition_penalty'],
                             params['repetition_penalty'],
+                            params['encoder_repetition_penalty'],
                             params['top_k'],
                             params['top_k'],
                             params['min_length'],
                             params['min_length'],
                             params['no_repeat_ngram_size'],
                             params['no_repeat_ngram_size'],

+ 2 - 0
api-example.py

@@ -24,6 +24,7 @@ params = {
     'top_p': 0.9,
     'top_p': 0.9,
     'typical_p': 1,
     'typical_p': 1,
     'repetition_penalty': 1.05,
     'repetition_penalty': 1.05,
+    'encoder_repetition_penalty': 1.0,
     'top_k': 0,
     'top_k': 0,
     'min_length': 0,
     'min_length': 0,
     'no_repeat_ngram_size': 0,
     'no_repeat_ngram_size': 0,
@@ -45,6 +46,7 @@ response = requests.post(f"http://{server}:7860/run/textgen", json={
         params['top_p'],
         params['top_p'],
         params['typical_p'],
         params['typical_p'],
         params['repetition_penalty'],
         params['repetition_penalty'],
+        params['encoder_repetition_penalty'],
         params['top_k'],
         params['top_k'],
         params['min_length'],
         params['min_length'],
         params['no_repeat_ngram_size'],
         params['no_repeat_ngram_size'],

+ 25 - 0
css/chat.css

@@ -0,0 +1,25 @@
+.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx {
+    height: 66.67vh
+}
+
+.gradio-container {
+    margin-left: auto !important;
+    margin-right: auto !important;
+}
+
+.w-screen {
+    width: unset
+}
+
+div.svelte-362y77>*, div.svelte-362y77>.form>* {
+    flex-wrap: nowrap
+}
+
+/* fixes the API documentation in chat mode */
+.api-docs.svelte-1iguv9h.svelte-1iguv9h.svelte-1iguv9h {
+    display: grid;
+}
+
+.pending.svelte-1ed2p3z {
+    opacity: 1;
+}

+ 4 - 0
css/chat.js

@@ -0,0 +1,4 @@
+document.getElementById("main").childNodes[0].style = "max-width: 800px; margin-left: auto; margin-right: auto";
+document.getElementById("extensions").style.setProperty("max-width", "800px");
+document.getElementById("extensions").style.setProperty("margin-left", "auto");
+document.getElementById("extensions").style.setProperty("margin-right", "auto");

+ 103 - 0
css/html_4chan_style.css

@@ -0,0 +1,103 @@
+#parent #container {
+    background-color: #eef2ff;
+    padding: 17px;
+}
+#parent #container .reply {
+    background-color: rgb(214, 218, 240);
+    border-bottom-color: rgb(183, 197, 217);
+    border-bottom-style: solid;
+    border-bottom-width: 1px;
+    border-image-outset: 0;
+    border-image-repeat: stretch;
+    border-image-slice: 100%;
+    border-image-source: none;
+    border-image-width: 1;
+    border-left-color: rgb(0, 0, 0);
+    border-left-style: none;
+    border-left-width: 0px;
+    border-right-color: rgb(183, 197, 217);
+    border-right-style: solid;
+    border-right-width: 1px;
+    border-top-color: rgb(0, 0, 0);
+    border-top-style: none;
+    border-top-width: 0px;
+    color: rgb(0, 0, 0);
+    display: table;
+    font-family: arial, helvetica, sans-serif;
+    font-size: 13.3333px;
+    margin-bottom: 4px;
+    margin-left: 0px;
+    margin-right: 0px;
+    margin-top: 4px;
+    overflow-x: hidden;
+    overflow-y: hidden;
+    padding-bottom: 4px;
+    padding-left: 2px;
+    padding-right: 2px;
+    padding-top: 4px;
+}
+
+#parent #container .number {
+    color: rgb(0, 0, 0);
+    font-family: arial, helvetica, sans-serif;
+    font-size: 13.3333px;
+    width: 342.65px;
+    margin-right: 7px;
+}
+
+#parent #container .op {
+    color: rgb(0, 0, 0);
+    font-family: arial, helvetica, sans-serif;
+    font-size: 13.3333px;
+    margin-bottom: 8px;
+    margin-left: 0px;
+    margin-right: 0px;
+    margin-top: 4px;
+    overflow-x: hidden;
+    overflow-y: hidden;
+}
+
+#parent #container .op blockquote {
+    margin-left: 0px !important;
+}
+
+#parent #container .name {
+    color: rgb(17, 119, 67);
+    font-family: arial, helvetica, sans-serif;
+    font-size: 13.3333px;
+    font-weight: 700;
+    margin-left: 7px;
+}
+
+#parent #container .quote {
+    color: rgb(221, 0, 0);
+    font-family: arial, helvetica, sans-serif;
+    font-size: 13.3333px;
+    text-decoration-color: rgb(221, 0, 0);
+    text-decoration-line: underline;
+    text-decoration-style: solid;
+    text-decoration-thickness: auto;
+}
+
+#parent #container .greentext {
+    color: rgb(120, 153, 34);
+    font-family: arial, helvetica, sans-serif;
+    font-size: 13.3333px;
+}
+
+#parent #container blockquote {
+    margin: 0px !important;
+    margin-block-start: 1em;
+    margin-block-end: 1em;
+    margin-inline-start: 40px;
+    margin-inline-end: 40px;
+    margin-top: 13.33px !important;
+    margin-bottom: 13.33px !important;
+    margin-left: 40px !important;
+    margin-right: 40px !important;
+}
+
+#parent #container .message {
+    color: black;
+    border: none;
+}

+ 73 - 0
css/html_cai_style.css

@@ -0,0 +1,73 @@
+.chat {
+    margin-left: auto;
+    margin-right: auto;
+    max-width: 800px;
+    height: 66.67vh;
+    overflow-y: auto;
+    padding-right: 20px;
+    display: flex;
+    flex-direction: column-reverse;
+}
+
+.message {
+    display: grid;
+    grid-template-columns: 60px 1fr;
+    padding-bottom: 25px;
+    font-size: 15px;
+    font-family: Helvetica, Arial, sans-serif;
+    line-height: 1.428571429;
+}
+
+.circle-you {
+    width: 50px;
+    height: 50px;
+    background-color: rgb(238, 78, 59);
+    border-radius: 50%;
+}
+
+.circle-bot {
+    width: 50px;
+    height: 50px;
+    background-color: rgb(59, 78, 244);
+    border-radius: 50%;
+}
+
+.circle-bot img,
+.circle-you img {
+    border-radius: 50%;
+    width: 100%;
+    height: 100%;
+    object-fit: cover;
+}
+
+.text {}
+
+.text p {
+    margin-top: 5px;
+}
+
+.username {
+    font-weight: bold;
+}
+
+.message-body {}
+
+.message-body img {
+    max-width: 300px;
+    max-height: 300px;
+    border-radius: 20px;
+}
+
+.message-body p {
+    margin-bottom: 0 !important;
+    font-size: 15px !important;
+    line-height: 1.428571429 !important;
+}
+
+.dark .message-body p em {
+    color: rgb(138, 138, 138) !important;
+}
+
+.message-body p em {
+    color: rgb(110, 110, 110) !important;
+}

+ 14 - 0
css/html_readable_style.css

@@ -0,0 +1,14 @@
+.container {
+    max-width: 600px;
+    margin-left: auto;
+    margin-right: auto;
+    background-color: rgb(31, 41, 55);
+    padding:3em;
+}
+
+.container p {
+    font-size: 16px !important;
+    color: white !important;
+    margin-bottom: 22px;
+    line-height: 1.4 !important;
+}

+ 52 - 0
css/main.css

@@ -0,0 +1,52 @@
+.tabs.svelte-710i53 {
+    margin-top: 0
+}
+
+.py-6 {
+    padding-top: 2.5rem
+}
+
+.dark #refresh-button {
+    background-color: #ffffff1f;
+}
+
+#refresh-button {
+  flex: none;
+  margin: 0;
+  padding: 0;
+  min-width: 50px;
+  border: none;
+  box-shadow: none;
+  border-radius: 10px;
+  background-color: #0000000d;
+}
+
+#download-label, #upload-label {
+  min-height: 0
+}
+
+#accordion {
+}
+
+.dark svg {
+  fill: white;
+}
+
+.dark a {
+  color: white !important;
+  text-decoration: none !important;
+}
+
+svg {
+  display: unset !important;
+  vertical-align: middle !important;
+  margin: 5px;
+}
+
+ol li p, ul li p {
+    display: inline-block;
+}
+
+#main, #parameters, #chat-settings, #interface-mode, #lora {
+  border: 0;
+}

+ 18 - 0
css/main.js

@@ -0,0 +1,18 @@
+document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px";
+document.getElementById("main").parentNode.style = "padding: 0; margin: 0";
+document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0";
+
+// Get references to the elements
+let main = document.getElementById('main');
+let main_parent = main.parentNode;
+let extensions = document.getElementById('extensions');
+
+// Add an event listener to the main element
+main_parent.addEventListener('click', function(e) {
+    // Check if the main element is visible
+    if (main.offsetHeight > 0 && main.offsetWidth > 0) {
+        extensions.style.display = 'block';
+    } else {
+        extensions.style.display = 'none';
+    }
+});

+ 11 - 6
download-model.py

@@ -101,6 +101,7 @@ def get_download_links_from_huggingface(model, branch):
     classifications = []
     classifications = []
     has_pytorch = False
     has_pytorch = False
     has_safetensors = False
     has_safetensors = False
+    is_lora = False
     while True:
     while True:
         content = requests.get(f"{base}{page}{cursor.decode()}").content
         content = requests.get(f"{base}{page}{cursor.decode()}").content
 
 
@@ -110,8 +111,10 @@ def get_download_links_from_huggingface(model, branch):
 
 
         for i in range(len(dict)):
         for i in range(len(dict)):
             fname = dict[i]['path']
             fname = dict[i]['path']
+            if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
+                is_lora = True
 
 
-            is_pytorch = re.match("pytorch_model.*\.bin", fname)
+            is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
             is_safetensors = re.match("model.*\.safetensors", fname)
             is_safetensors = re.match("model.*\.safetensors", fname)
             is_tokenizer = re.match("tokenizer.*\.model", fname)
             is_tokenizer = re.match("tokenizer.*\.model", fname)
             is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
             is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
@@ -130,6 +133,7 @@ def get_download_links_from_huggingface(model, branch):
                         has_pytorch = True
                         has_pytorch = True
                         classifications.append('pytorch')
                         classifications.append('pytorch')
 
 
+
         cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
         cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
         cursor = base64.b64encode(cursor)
         cursor = base64.b64encode(cursor)
         cursor = cursor.replace(b'=', b'%3D')
         cursor = cursor.replace(b'=', b'%3D')
@@ -140,7 +144,7 @@ def get_download_links_from_huggingface(model, branch):
             if classifications[i] == 'pytorch':
             if classifications[i] == 'pytorch':
                 links.pop(i)
                 links.pop(i)
 
 
-    return links
+    return links, is_lora
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     model = args.MODEL
     model = args.MODEL
@@ -159,15 +163,16 @@ if __name__ == '__main__':
             except ValueError as err_branch:
             except ValueError as err_branch:
                 print(f"Error: {err_branch}")
                 print(f"Error: {err_branch}")
                 sys.exit()
                 sys.exit()
+
+    links, is_lora = get_download_links_from_huggingface(model, branch)
+    base_folder = 'models' if not is_lora else 'loras'
     if branch != 'main':
     if branch != 'main':
-        output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
+        output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}')
     else:
     else:
-        output_folder = Path("models") / model.split('/')[-1]
+        output_folder = Path(base_folder) / model.split('/')[-1]
     if not output_folder.exists():
     if not output_folder.exists():
         output_folder.mkdir()
         output_folder.mkdir()
 
 
-    links = get_download_links_from_huggingface(model, branch)
-
     # Downloading the files
     # Downloading the files
     print(f"Downloading the model to {output_folder}")
     print(f"Downloading the model to {output_folder}")
     pool = multiprocessing.Pool(processes=args.threads)
     pool = multiprocessing.Pool(processes=args.threads)

+ 1 - 0
extensions/api/requirements.txt

@@ -0,0 +1 @@
+flask_cloudflared==0.0.12

+ 90 - 0
extensions/api/script.py

@@ -0,0 +1,90 @@
+from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
+from threading import Thread
+from modules import shared
+from modules.text_generation import generate_reply, encode
+import json
+
+params = {
+    'port': 5000,
+}
+
+class Handler(BaseHTTPRequestHandler):
+    def do_GET(self):
+        if self.path == '/api/v1/model':
+            self.send_response(200)
+            self.end_headers()
+            response = json.dumps({
+                'result': shared.model_name
+            })
+
+            self.wfile.write(response.encode('utf-8'))
+        else:
+            self.send_error(404)
+
+    def do_POST(self):
+        content_length = int(self.headers['Content-Length'])
+        body = json.loads(self.rfile.read(content_length).decode('utf-8'))
+
+        if self.path == '/api/v1/generate':
+            self.send_response(200)
+            self.send_header('Content-Type', 'application/json')
+            self.end_headers()
+
+            prompt = body['prompt']
+            prompt_lines = [l.strip() for l in prompt.split('\n')]
+
+            max_context = body.get('max_context_length', 2048)
+
+            while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
+                prompt_lines.pop(0)
+
+            prompt = '\n'.join(prompt_lines)
+
+            generator = generate_reply(
+                question = prompt, 
+                max_new_tokens = body.get('max_length', 200), 
+                do_sample=True, 
+                temperature=body.get('temperature', 0.5), 
+                top_p=body.get('top_p', 1), 
+                typical_p=body.get('typical', 1), 
+                repetition_penalty=body.get('rep_pen', 1.1), 
+                encoder_repetition_penalty=1, 
+                top_k=body.get('top_k', 0), 
+                min_length=0, 
+                no_repeat_ngram_size=0, 
+                num_beams=1, 
+                penalty_alpha=0, 
+                length_penalty=1,
+                early_stopping=False,
+            )
+
+            answer = ''
+            for a in generator:
+                answer = a[0]
+
+            response = json.dumps({
+                'results': [{
+                    'text': answer[len(prompt):]
+                }]
+            })
+            self.wfile.write(response.encode('utf-8'))
+        else:
+            self.send_error(404)
+
+
+def run_server():
+    server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
+    server = ThreadingHTTPServer(server_addr, Handler)
+    if shared.args.share: 
+        try:
+            from flask_cloudflared import  _run_cloudflared
+            public_url = _run_cloudflared(params['port'], params['port'] + 1)
+            print(f'Starting KoboldAI compatible api at {public_url}/api')
+        except ImportError:
+            print('You should install flask_cloudflared manually')
+    else:
+        print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
+    server.serve_forever()
+
+def ui():
+    Thread(target=run_server, daemon=True).start()

+ 2 - 2
extensions/elevenlabs_tts/script.py

@@ -1,8 +1,8 @@
 from pathlib import Path
 from pathlib import Path
 
 
 import gradio as gr
 import gradio as gr
-from elevenlabslib import *
-from elevenlabslib.helpers import *
+from elevenlabslib import ElevenLabsUser
+from elevenlabslib.helpers import save_bytes_to_path
 
 
 params = {
 params = {
     'activate': True,
     'activate': True,

+ 1 - 1
extensions/gallery/script.py

@@ -76,7 +76,7 @@ def generate_html():
     return container_html
     return container_html
 
 
 def ui():
 def ui():
-    with gr.Accordion("Character gallery"):
+    with gr.Accordion("Character gallery", open=False):
         update = gr.Button("Refresh")
         update = gr.Button("Refresh")
         gallery = gr.HTML(value=generate_html())
         gallery = gr.HTML(value=generate_html())
     update.click(generate_html, [], gallery)
     update.click(generate_html, [], gallery)

+ 4 - 0
extensions/whisper_stt/requirements.txt

@@ -0,0 +1,4 @@
+git+https://github.com/Uberi/speech_recognition.git@010382b
+openai-whisper
+soundfile
+ffmpeg

+ 54 - 0
extensions/whisper_stt/script.py

@@ -0,0 +1,54 @@
+import gradio as gr
+import speech_recognition as sr
+
+input_hijack = {
+    'state': False,
+    'value': ["", ""]
+}
+
+
+def do_stt(audio, text_state=""):
+    transcription = ""
+    r = sr.Recognizer()
+
+    # Convert to AudioData
+    audio_data = sr.AudioData(sample_rate=audio[0], frame_data=audio[1], sample_width=4)
+
+    try:
+        transcription = r.recognize_whisper(audio_data, language="english", model="base.en")
+    except sr.UnknownValueError:
+        print("Whisper could not understand audio")
+    except sr.RequestError as e:
+        print("Could not request results from Whisper", e)
+
+    input_hijack.update({"state": True, "value": [transcription, transcription]})
+
+    text_state += transcription + " "
+    return text_state, text_state
+
+
+def update_hijack(val):
+    input_hijack.update({"state": True, "value": [val, val]})
+    return val
+
+
+def auto_transcribe(audio, audio_auto, text_state=""):
+    if audio is None:
+        return "", ""
+    if audio_auto:
+        return do_stt(audio, text_state)
+    return "", ""
+
+
+def ui():
+    tr_state = gr.State(value="")
+    output_transcription = gr.Textbox(label="STT-Input",
+                                      placeholder="Speech Preview. Click \"Generate\" to send",
+                                      interactive=True)
+    output_transcription.change(fn=update_hijack, inputs=[output_transcription], outputs=[tr_state])
+    audio_auto = gr.Checkbox(label="Auto-Transcribe", value=True)
+    with gr.Row():
+        audio = gr.Audio(source="microphone")
+        audio.change(fn=auto_transcribe, inputs=[audio, audio_auto, tr_state], outputs=[output_transcription, tr_state])
+        transcribe_button = gr.Button(value="Transcribe")
+        transcribe_button.click(do_stt, inputs=[audio, tr_state], outputs=[output_transcription, tr_state])

+ 0 - 0
loras/place-your-loras-here.txt


+ 1 - 1
modules/GPTQ_loader.py

@@ -61,7 +61,7 @@ def load_quantized(model_name):
             max_memory[i] = f"{shared.args.gpu_memory[i]}GiB"
             max_memory[i] = f"{shared.args.gpu_memory[i]}GiB"
         max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB"
         max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB"
 
 
-        device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LLaMADecoderLayer"])
+        device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
         model = accelerate.dispatch_model(model, device_map=device_map)
         model = accelerate.dispatch_model(model, device_map=device_map)
 
 
     # Single GPU
     # Single GPU

+ 22 - 0
modules/LoRA.py

@@ -0,0 +1,22 @@
+from pathlib import Path
+
+import modules.shared as shared
+from modules.models import load_model
+
+
+def add_lora_to_model(lora_name):
+
+    from peft import PeftModel
+
+    # Is there a more efficient way of returning to the base model?
+    if lora_name == "None":
+        print("Reloading the model to remove the LoRA...")
+        shared.model, shared.tokenizer = load_model(shared.model_name)
+    else:
+        # Why doesn't this work in 16-bit mode?
+        print(f"Adding the LoRA {lora_name} to the model...")
+
+        params = {}
+        params['device_map'] = {'': 0}
+        #params['dtype'] = shared.model.dtype
+        shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params)

+ 1 - 0
modules/callbacks.py

@@ -7,6 +7,7 @@ import transformers
 
 
 import modules.shared as shared
 import modules.shared as shared
 
 
+
 # Copied from https://github.com/PygmalionAI/gradio-ui/
 # Copied from https://github.com/PygmalionAI/gradio-ui/
 class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
 class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
 
 

+ 13 - 19
modules/chat.py

@@ -11,17 +11,11 @@ from PIL import Image
 import modules.extensions as extensions_module
 import modules.extensions as extensions_module
 import modules.shared as shared
 import modules.shared as shared
 from modules.extensions import apply_extensions
 from modules.extensions import apply_extensions
-from modules.html_generator import generate_chat_html
-from modules.text_generation import encode, generate_reply, get_max_prompt_length
+from modules.html_generator import fix_newlines, generate_chat_html
+from modules.text_generation import (encode, generate_reply,
+                                     get_max_prompt_length)
 
 
 
 
-# This gets the new line characters right.
-def clean_chat_message(text):
-    text = text.replace('\n', '\n\n')
-    text = re.sub(r"\n{3,}", "\n\n", text)
-    text = text.strip()
-    return text
-
 def generate_chat_output(history, name1, name2, character):
 def generate_chat_output(history, name1, name2, character):
     if shared.args.cai_chat:
     if shared.args.cai_chat:
         return generate_chat_html(history, name1, name2, character)
         return generate_chat_html(history, name1, name2, character)
@@ -29,7 +23,7 @@ def generate_chat_output(history, name1, name2, character):
         return history
         return history
 
 
 def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False):
 def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False):
-    user_input = clean_chat_message(user_input)
+    user_input = fix_newlines(user_input)
     rows = [f"{context.strip()}\n"]
     rows = [f"{context.strip()}\n"]
 
 
     if shared.soft_prompt:
     if shared.soft_prompt:
@@ -82,7 +76,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
         if idx != -1:
         if idx != -1:
             reply = reply[:idx]
             reply = reply[:idx]
             next_character_found = True
             next_character_found = True
-        reply = clean_chat_message(reply)
+        reply = fix_newlines(reply)
 
 
         # If something like "\nYo" is generated just before "\nYou:"
         # If something like "\nYo" is generated just before "\nYou:"
         # is completed, trim it
         # is completed, trim it
@@ -97,7 +91,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
 def stop_everything_event():
 def stop_everything_event():
     shared.stop_everything = True
     shared.stop_everything = True
 
 
-def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
+def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
     shared.stop_everything = False
     shared.stop_everything = False
     just_started = True
     just_started = True
     eos_token = '\n' if check else None
     eos_token = '\n' if check else None
@@ -133,7 +127,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
     # Generate
     # Generate
     reply = ''
     reply = ''
     for i in range(chat_generation_attempts):
     for i in range(chat_generation_attempts):
-        for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
+        for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
 
 
             # Extracting the reply
             # Extracting the reply
             reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check)
             reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check)
@@ -160,7 +154,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
 
 
     yield shared.history['visible']
     yield shared.history['visible']
 
 
-def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
+def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
     eos_token = '\n' if check else None
     eos_token = '\n' if check else None
 
 
     if 'pygmalion' in shared.model_name.lower():
     if 'pygmalion' in shared.model_name.lower():
@@ -172,18 +166,18 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
     # Yield *Is typing...*
     # Yield *Is typing...*
     yield shared.processing_message
     yield shared.processing_message
     for i in range(chat_generation_attempts):
     for i in range(chat_generation_attempts):
-        for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
+        for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
             reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
             reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
             yield reply
             yield reply
             if next_character_found:
             if next_character_found:
                 break
                 break
         yield reply
         yield reply
 
 
-def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
-    for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
+def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
+    for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
         yield generate_chat_html(_history, name1, name2, shared.character)
         yield generate_chat_html(_history, name1, name2, shared.character)
 
 
-def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
+def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
     if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
     if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
         yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
         yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
     else:
     else:
@@ -191,7 +185,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
         last_internal = shared.history['internal'].pop()
         last_internal = shared.history['internal'].pop()
         # Yield '*Is typing...*'
         # Yield '*Is typing...*'
         yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
         yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
-        for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
+        for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
             if shared.args.cai_chat:
             if shared.args.cai_chat:
                 shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
                 shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
             else:
             else:

+ 14 - 6
modules/extensions.py

@@ -1,3 +1,5 @@
+import gradio as gr
+
 import extensions
 import extensions
 import modules.shared as shared
 import modules.shared as shared
 
 
@@ -9,9 +11,12 @@ def load_extensions():
     for i, name in enumerate(shared.args.extensions):
     for i, name in enumerate(shared.args.extensions):
         if name in available_extensions:
         if name in available_extensions:
             print(f'Loading the extension "{name}"... ', end='')
             print(f'Loading the extension "{name}"... ', end='')
-            exec(f"import extensions.{name}.script")
-            state[name] = [True, i]
-            print('Ok.')
+            try:
+                exec(f"import extensions.{name}.script")
+                state[name] = [True, i]
+                print('Ok.')
+            except:
+                print('Fail.')
 
 
 # This iterator returns the extensions in the order specified in the command-line
 # This iterator returns the extensions in the order specified in the command-line
 def iterator():
 def iterator():
@@ -40,6 +45,9 @@ def create_extensions_block():
                     extension.params[param] = shared.settings[_id]
                     extension.params[param] = shared.settings[_id]
 
 
     # Creating the extension ui elements
     # Creating the extension ui elements
-    for extension, name in iterator():
-        if hasattr(extension, "ui"):
-            extension.ui()
+    if len(state) > 0:
+        with gr.Box(elem_id="extensions"):
+            gr.Markdown("Extensions")
+            for extension, name in iterator():
+                if hasattr(extension, "ui"):
+                    extension.ui()

+ 46 - 240
modules/html_generator.py

@@ -1,6 +1,6 @@
 '''
 '''
 
 
-This is a library for formatting GPT-4chan and chat outputs as nice HTML.
+This is a library for formatting text outputs as nice HTML.
 
 
 '''
 '''
 
 
@@ -8,30 +8,39 @@ import os
 import re
 import re
 from pathlib import Path
 from pathlib import Path
 
 
+import markdown
 from PIL import Image
 from PIL import Image
 
 
 # This is to store the paths to the thumbnails of the profile pictures
 # This is to store the paths to the thumbnails of the profile pictures
 image_cache = {}
 image_cache = {}
 
 
-def generate_basic_html(s):
-    css = """
-    .container {
-        max-width: 600px;
-        margin-left: auto;
-        margin-right: auto;
-        background-color: rgb(31, 41, 55);
-        padding:3em;
-    }
-    .container p {
-        font-size: 16px !important;
-        color: white !important;
-        margin-bottom: 22px;
-        line-height: 1.4 !important;
-    }
-    """
-    s = '\n'.join([f'<p>{line}</p>' for line in s.split('\n')])
-    s = f'<style>{css}</style><div class="container">{s}</div>'
-    return s
+with open(Path(__file__).resolve().parent / '../css/html_readable_style.css', 'r') as f:
+    readable_css = f.read()
+with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r') as css_f:
+    _4chan_css = css_f.read()
+with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f:
+    cai_css = f.read()
+
+def fix_newlines(string):
+    string = string.replace('\n', '\n\n')
+    string = re.sub(r"\n{3,}", "\n\n", string)
+    string = string.strip()
+    return string
+
+# This could probably be generalized and improved
+def convert_to_markdown(string):
+    string = string.replace('\\begin{code}', '```')
+    string = string.replace('\\end{code}', '```')
+    string = string.replace('\\begin{blockquote}', '> ')
+    string = string.replace('\\end{blockquote}', '')
+    string = re.sub(r"(.)```", r"\1\n```", string)
+#    string = fix_newlines(string)
+    return markdown.markdown(string, extensions=['fenced_code']) 
+
+def generate_basic_html(string):
+    string = convert_to_markdown(string)
+    string = f'<style>{readable_css}</style><div class="container">{string}</div>'
+    return string
 
 
 def process_post(post, c):
 def process_post(post, c):
     t = post.split('\n')
     t = post.split('\n')
@@ -48,113 +57,6 @@ def process_post(post, c):
     return src
     return src
 
 
 def generate_4chan_html(f):
 def generate_4chan_html(f):
-    css = """
-
-    #parent #container {
-        background-color: #eef2ff;
-        padding: 17px;
-    }
-    #parent #container .reply {
-        background-color: rgb(214, 218, 240);
-        border-bottom-color: rgb(183, 197, 217);
-        border-bottom-style: solid;
-        border-bottom-width: 1px;
-        border-image-outset: 0;
-        border-image-repeat: stretch;
-        border-image-slice: 100%;
-        border-image-source: none;
-        border-image-width: 1;
-        border-left-color: rgb(0, 0, 0);
-        border-left-style: none;
-        border-left-width: 0px;
-        border-right-color: rgb(183, 197, 217);
-        border-right-style: solid;
-        border-right-width: 1px;
-        border-top-color: rgb(0, 0, 0);
-        border-top-style: none;
-        border-top-width: 0px;
-        color: rgb(0, 0, 0);
-        display: table;
-        font-family: arial, helvetica, sans-serif;
-        font-size: 13.3333px;
-        margin-bottom: 4px;
-        margin-left: 0px;
-        margin-right: 0px;
-        margin-top: 4px;
-        overflow-x: hidden;
-        overflow-y: hidden;
-        padding-bottom: 4px;
-        padding-left: 2px;
-        padding-right: 2px;
-        padding-top: 4px;
-    }
-
-    #parent #container .number {
-        color: rgb(0, 0, 0);
-        font-family: arial, helvetica, sans-serif;
-        font-size: 13.3333px;
-        width: 342.65px;
-        margin-right: 7px;
-    }
-
-    #parent #container .op {
-        color: rgb(0, 0, 0);
-        font-family: arial, helvetica, sans-serif;
-        font-size: 13.3333px;
-        margin-bottom: 8px;
-        margin-left: 0px;
-        margin-right: 0px;
-        margin-top: 4px;
-        overflow-x: hidden;
-        overflow-y: hidden;
-    }
-
-    #parent #container .op blockquote {
-        margin-left: 0px !important;
-    }
-
-    #parent #container .name {
-        color: rgb(17, 119, 67);
-        font-family: arial, helvetica, sans-serif;
-        font-size: 13.3333px;
-        font-weight: 700;
-        margin-left: 7px;
-    }
-
-    #parent #container .quote {
-        color: rgb(221, 0, 0);
-        font-family: arial, helvetica, sans-serif;
-        font-size: 13.3333px;
-        text-decoration-color: rgb(221, 0, 0);
-        text-decoration-line: underline;
-        text-decoration-style: solid;
-        text-decoration-thickness: auto;
-    }
-
-    #parent #container .greentext {
-        color: rgb(120, 153, 34);
-        font-family: arial, helvetica, sans-serif;
-        font-size: 13.3333px;
-    }
-
-    #parent #container blockquote {
-        margin: 0px !important;
-        margin-block-start: 1em;
-        margin-block-end: 1em;
-        margin-inline-start: 40px;
-        margin-inline-end: 40px;
-        margin-top: 13.33px !important;
-        margin-bottom: 13.33px !important;
-        margin-left: 40px !important;
-        margin-right: 40px !important;
-    }
-
-    #parent #container .message {
-        color: black;
-        border: none;
-    }
-    """
-
     posts = []
     posts = []
     post = ''
     post = ''
     c = -2
     c = -2
@@ -181,7 +83,7 @@ def generate_4chan_html(f):
             posts[i] = f'<div class="reply">{posts[i]}</div>\n'
             posts[i] = f'<div class="reply">{posts[i]}</div>\n'
     
     
     output = ''
     output = ''
-    output += f'<style>{css}</style><div id="parent"><div id="container">'
+    output += f'<style>{_4chan_css}</style><div id="parent"><div id="container">'
     for post in posts:
     for post in posts:
         output += post
         output += post
     output += '</div></div>'
     output += '</div></div>'
@@ -208,135 +110,39 @@ def get_image_cache(path):
 
 
     return image_cache[path][1]
     return image_cache[path][1]
 
 
-def generate_chat_html(history, name1, name2, character):
-    css = """
-    .chat {
-      margin-left: auto;
-      margin-right: auto;
-      max-width: 800px;
-      height: 66.67vh;
-      overflow-y: auto;
-      padding-right: 20px;
-      display: flex;
-      flex-direction: column-reverse;
-    }       
-
-    .message {
-      display: grid;
-      grid-template-columns: 60px 1fr;
-      padding-bottom: 25px;
-      font-size: 15px;
-      font-family: Helvetica, Arial, sans-serif;
-      line-height: 1.428571429;
-    }   
-        
-    .circle-you {
-      width: 50px;
-      height: 50px;
-      background-color: rgb(238, 78, 59);
-      border-radius: 50%;
-    }
-          
-    .circle-bot {
-      width: 50px;
-      height: 50px;
-      background-color: rgb(59, 78, 244);
-      border-radius: 50%;
-    }
-
-    .circle-bot img, .circle-you img {
-      border-radius: 50%;
-      width: 100%;
-      height: 100%;
-      object-fit: cover;
-    }
-
-    .text {
-    }
+def load_html_image(paths):
+    for str_path in paths:
+          path = Path(str_path)
+          if path.exists():
+              return f'<img src="file/{get_image_cache(path)}">'
+    return ''
 
 
-    .text p {
-      margin-top: 5px;
-    }
-
-    .username {
-      font-weight: bold;
-    }
-
-    .message-body {
-    }
-
-    .message-body img {
-      max-width: 300px;
-      max-height: 300px;
-      border-radius: 20px;
-    }
-
-    .message-body p {
-      margin-bottom: 0 !important;
-      font-size: 15px !important;
-      line-height: 1.428571429 !important;
-    }
-
-    .dark .message-body p em {
-      color: rgb(138, 138, 138) !important;
-    }
-
-    .message-body p em {
-      color: rgb(110, 110, 110) !important;
-    }
-
-    """
-
-    output = ''
-    output += f'<style>{css}</style><div class="chat" id="chat">'
-    img = ''
-
-    for i in [
-            f"characters/{character}.png",
-            f"characters/{character}.jpg",
-            f"characters/{character}.jpeg",
-            "img_bot.png",
-            "img_bot.jpg",
-            "img_bot.jpeg"
-            ]:
-
-        path = Path(i)
-        if path.exists():
-            img = f'<img src="file/{get_image_cache(path)}">'
-            break
-
-    img_me = ''
-    for i in ["img_me.png", "img_me.jpg", "img_me.jpeg"]:
-        path = Path(i)
-        if path.exists():
-            img_me = f'<img src="file/{get_image_cache(path)}">'
-            break
+def generate_chat_html(history, name1, name2, character):
+    output = f'<style>{cai_css}</style><div class="chat" id="chat">'
+    
+    img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"])
+    img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"])
 
 
     for i,_row in enumerate(history[::-1]):
     for i,_row in enumerate(history[::-1]):
-        row = _row.copy()
-        row[0] = re.sub(r"(\*\*)([^\*\n]*)(\*\*)", r"<b>\2</b>", row[0])
-        row[1] = re.sub(r"(\*\*)([^\*\n]*)(\*\*)", r"<b>\2</b>", row[1])
-        row[0] = re.sub(r"(\*)([^\*\n]*)(\*)", r"<em>\2</em>", row[0])
-        row[1] = re.sub(r"(\*)([^\*\n]*)(\*)", r"<em>\2</em>", row[1])
-        p = '\n'.join([f"<p>{x}</p>" for x in row[1].split('\n')])
+        row = [convert_to_markdown(entry) for entry in _row]
+        
         output += f"""
         output += f"""
               <div class="message">
               <div class="message">
                 <div class="circle-bot">
                 <div class="circle-bot">
-                  {img}
+                  {img_bot}
                 </div>
                 </div>
                 <div class="text">
                 <div class="text">
                   <div class="username">
                   <div class="username">
                     {name2}
                     {name2}
                   </div>
                   </div>
                   <div class="message-body">
                   <div class="message-body">
-                    {p}
+                    {row[1]}
                   </div>
                   </div>
                 </div>
                 </div>
               </div>
               </div>
             """
             """
 
 
         if not (i == len(history)-1 and len(row[0]) == 0):
         if not (i == len(history)-1 and len(row[0]) == 0):
-            p = '\n'.join([f"<p>{x}</p>" for x in row[0].split('\n')])
             output += f"""
             output += f"""
                   <div class="message">
                   <div class="message">
                     <div class="circle-you">
                     <div class="circle-you">
@@ -347,7 +153,7 @@ def generate_chat_html(history, name1, name2, character):
                         {name1}
                         {name1}
                       </div>
                       </div>
                       <div class="message-body">
                       <div class="message-body">
-                        {p}
+                        {row[0]}
                       </div>
                       </div>
                     </div>
                     </div>
                   </div>
                   </div>

+ 53 - 26
modules/models.py

@@ -7,7 +7,9 @@ from pathlib import Path
 import numpy as np
 import numpy as np
 import torch
 import torch
 import transformers
 import transformers
-from transformers import AutoModelForCausalLM, AutoTokenizer
+from accelerate import infer_auto_device_map, init_empty_weights
+from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
+                          BitsAndBytesConfig)
 
 
 import modules.shared as shared
 import modules.shared as shared
 
 
@@ -16,8 +18,7 @@ transformers.logging.set_verbosity_error()
 local_rank = None
 local_rank = None
 
 
 if shared.args.flexgen:
 if shared.args.flexgen:
-    from flexgen.flex_opt import (CompressionConfig, ExecutionEnv, OptLM,
-                                  Policy, str2bool)
+    from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
 
 
 if shared.args.deepspeed:
 if shared.args.deepspeed:
     import deepspeed
     import deepspeed
@@ -46,7 +47,12 @@ def load_model(model_name):
         if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
         if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
             model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
             model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
         else:
         else:
-            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16).cuda()
+            model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
+            if torch.has_mps:
+                device = torch.device('mps')
+                model = model.to(device)
+            else:
+                model = model.cuda()
 
 
     # FlexGen
     # FlexGen
     elif shared.args.flexgen:
     elif shared.args.flexgen:
@@ -95,39 +101,60 @@ def load_model(model_name):
 
 
     # Custom
     # Custom
     else:
     else:
-        command = "AutoModelForCausalLM.from_pretrained"
-        params = ["low_cpu_mem_usage=True"]
-        if not shared.args.cpu and not torch.cuda.is_available():
-            print("Warning: no GPU has been detected.\nFalling back to CPU mode.\n")
+        params = {"low_cpu_mem_usage": True}
+        if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
+            print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
             shared.args.cpu = True
             shared.args.cpu = True
 
 
         if shared.args.cpu:
         if shared.args.cpu:
-            params.append("low_cpu_mem_usage=True")
-            params.append("torch_dtype=torch.float32")
+            params["torch_dtype"] = torch.float32
         else:
         else:
-            params.append("device_map='auto'")
-            params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16")
+            params["device_map"] = 'auto'
+            if shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)):
+                params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
+            elif shared.args.load_in_8bit:
+                params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
+            elif shared.args.bf16:
+                params["torch_dtype"] = torch.bfloat16
+            else:
+                params["torch_dtype"] = torch.float16
 
 
             if shared.args.gpu_memory:
             if shared.args.gpu_memory:
                 memory_map = shared.args.gpu_memory
                 memory_map = shared.args.gpu_memory
-                max_memory = f"max_memory={{0: '{memory_map[0]}GiB'"
-                for i in range(1, len(memory_map)):
-                    max_memory += (f", {i}: '{memory_map[i]}GiB'")
-                max_memory += (f", 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
-                params.append(max_memory)
-            elif not shared.args.load_in_8bit:
-                total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
-                suggestion = round((total_mem-1000)/1000)*1000
-                if total_mem-suggestion < 800:
+                max_memory = {}
+                for i in range(len(memory_map)):
+                    max_memory[i] = f'{memory_map[i]}GiB'
+                max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB'
+                params['max_memory'] = max_memory
+            elif shared.args.auto_devices:
+                total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024))
+                suggestion = round((total_mem-1000) / 1000) * 1000
+                if total_mem - suggestion < 800:
                     suggestion -= 1000
                     suggestion -= 1000
                 suggestion = int(round(suggestion/1000))
                 suggestion = int(round(suggestion/1000))
                 print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
                 print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
-                params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
-            if shared.args.disk:
-                params.append(f"offload_folder='{shared.args.disk_cache_dir}'")
+                
+                max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
+                params['max_memory'] = max_memory
 
 
-        command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})"
-        model = eval(command)
+            if shared.args.disk:
+                params["offload_folder"] = shared.args.disk_cache_dir
+
+        checkpoint = Path(f'models/{shared.model_name}')
+
+        if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
+            config = AutoConfig.from_pretrained(checkpoint)
+            with init_empty_weights():
+                model = AutoModelForCausalLM.from_config(config)
+            model.tie_weights()
+            params['device_map'] = infer_auto_device_map(
+                model, 
+                dtype=torch.int8, 
+                max_memory=params['max_memory'],
+                no_split_module_classes = model._no_split_modules
+            )
+
+        model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
 
 
     # Loading the tokenizer
     # Loading the tokenizer
     if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():
     if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():

+ 11 - 2
modules/shared.py

@@ -2,7 +2,8 @@ import argparse
 
 
 model = None
 model = None
 tokenizer = None
 tokenizer = None
-model_name = ""
+model_name = "None"
+lora_name = "None"
 soft_prompt_tensor = None
 soft_prompt_tensor = None
 soft_prompt = False
 soft_prompt = False
 is_RWKV = False
 is_RWKV = False
@@ -19,6 +20,9 @@ gradio = {}
 # Generation input parameters
 # Generation input parameters
 input_params = []
 input_params = []
 
 
+# For restarting the interface
+need_restart = False
+
 settings = {
 settings = {
     'max_new_tokens': 200,
     'max_new_tokens': 200,
     'max_new_tokens_min': 1,
     'max_new_tokens_min': 1,
@@ -26,7 +30,7 @@ settings = {
     'name1': 'Person 1',
     'name1': 'Person 1',
     'name2': 'Person 2',
     'name2': 'Person 2',
     'context': 'This is a conversation between two people.',
     'context': 'This is a conversation between two people.',
-    'stop_at_newline': True,
+    'stop_at_newline': False,
     'chat_prompt_size': 2048,
     'chat_prompt_size': 2048,
     'chat_prompt_size_min': 0,
     'chat_prompt_size_min': 0,
     'chat_prompt_size_max': 2048,
     'chat_prompt_size_max': 2048,
@@ -49,6 +53,10 @@ settings = {
         '^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
         '^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
         '(rosey|chip|joi)_.*_instruct.*': 'User: \n',
         '(rosey|chip|joi)_.*_instruct.*': 'User: \n',
         'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
         'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
+    },
+    'lora_prompts': {
+        'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
+        '(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
     }
     }
 }
 }
 
 
@@ -64,6 +72,7 @@ def str2bool(v):
 
 
 parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
 parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
 parser.add_argument('--model', type=str, help='Name of the model to load by default.')
 parser.add_argument('--model', type=str, help='Name of the model to load by default.')
+parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
 parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
 parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
 parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
 parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
 parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
 parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')

+ 10 - 4
modules/text_generation.py

@@ -33,12 +33,15 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
             return input_ids.numpy()
             return input_ids.numpy()
         elif shared.args.deepspeed:
         elif shared.args.deepspeed:
             return input_ids.to(device=local_rank)
             return input_ids.to(device=local_rank)
+        elif torch.has_mps:
+            device = torch.device('mps')
+            return input_ids.to(device)
         else:
         else:
             return input_ids.cuda()
             return input_ids.cuda()
 
 
 def decode(output_ids):
 def decode(output_ids):
     # Open Assistant relies on special tokens like <|endoftext|>
     # Open Assistant relies on special tokens like <|endoftext|>
-    if re.match('oasst-*', shared.model_name.lower()):
+    if re.match('(oasst|galactica)-*', shared.model_name.lower()):
         return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
         return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
     else:
     else:
         reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
         reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
@@ -89,7 +92,7 @@ def clear_torch_cache():
     if not shared.args.cpu:
     if not shared.args.cpu:
         torch.cuda.empty_cache()
         torch.cuda.empty_cache()
 
 
-def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
+def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
     clear_torch_cache()
     clear_torch_cache()
     t0 = time.time()
     t0 = time.time()
 
 
@@ -101,7 +104,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
                 reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
                 reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
                 yield formatted_outputs(reply, shared.model_name)
                 yield formatted_outputs(reply, shared.model_name)
             else:
             else:
-                yield formatted_outputs(question, shared.model_name)
+                if not (shared.args.chat or shared.args.cai_chat):
+                    yield formatted_outputs(question, shared.model_name)
                 # RWKV has proper streaming, which is very nice.
                 # RWKV has proper streaming, which is very nice.
                 # No need to generate 8 tokens at a time.
                 # No need to generate 8 tokens at a time.
                 for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
                 for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
@@ -143,6 +147,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
             "top_p": top_p,
             "top_p": top_p,
             "typical_p": typical_p,
             "typical_p": typical_p,
             "repetition_penalty": repetition_penalty,
             "repetition_penalty": repetition_penalty,
+            "encoder_repetition_penalty": encoder_repetition_penalty,
             "top_k": top_k,
             "top_k": top_k,
             "min_length": min_length if shared.args.no_stream else 0,
             "min_length": min_length if shared.args.no_stream else 0,
             "no_repeat_ngram_size": no_repeat_ngram_size,
             "no_repeat_ngram_size": no_repeat_ngram_size,
@@ -196,7 +201,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
             def generate_with_streaming(**kwargs):
             def generate_with_streaming(**kwargs):
                 return Iteratorize(generate_with_callback, kwargs, callback=None)
                 return Iteratorize(generate_with_callback, kwargs, callback=None)
 
 
-            yield formatted_outputs(original_question, shared.model_name)
+            if not (shared.args.chat or shared.args.cai_chat):
+                yield formatted_outputs(original_question, shared.model_name)
             with generate_with_streaming(**generate_params) as generator:
             with generate_with_streaming(**generate_params) as generator:
                 for output in generator:
                 for output in generator:
                     if shared.soft_prompt:
                     if shared.soft_prompt:

+ 10 - 61
modules/ui.py

@@ -1,68 +1,17 @@
+from pathlib import Path
+
 import gradio as gr
 import gradio as gr
 
 
 refresh_symbol = '\U0001f504'  # 🔄
 refresh_symbol = '\U0001f504'  # 🔄
 
 
-css = """
-.tabs.svelte-710i53 {
-    margin-top: 0
-}
-.py-6 {
-    padding-top: 2.5rem
-}
-.dark #refresh-button {
-    background-color: #ffffff1f;
-}
-#refresh-button {
-  flex: none;
-  margin: 0;
-  padding: 0;
-  min-width: 50px;
-  border: none;
-  box-shadow: none;
-  border-radius: 10px;
-  background-color: #0000000d;
-}
-#download-label, #upload-label {
-  min-height: 0
-}
-#accordion {
-}
-.dark svg {
-  fill: white;
-}
-svg {
-  display: unset !important;
-  vertical-align: middle !important;
-  margin: 5px;
-}
-ol li p, ul li p {
-    display: inline-block;
-}
-"""
-
-chat_css = """
-.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx {
-    height: 66.67vh
-}
-.gradio-container {
-    max-width: 800px !important;
-    margin-left: auto !important;
-    margin-right: auto !important;
-}
-.w-screen {
-    width: unset
-}
-div.svelte-362y77>*, div.svelte-362y77>.form>* {
-    flex-wrap: nowrap
-}
-/* fixes the API documentation in chat mode */
-.api-docs.svelte-1iguv9h.svelte-1iguv9h.svelte-1iguv9h {
-    display: grid;
-}
-.pending.svelte-1ed2p3z {
-    opacity: 1;
-}
-"""
+with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
+    css = f.read()
+with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
+    chat_css = f.read()
+with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
+    main_js = f.read()
+with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
+    chat_js = f.read()
 
 
 class ToolButton(gr.Button, gr.components.FormComponent):
 class ToolButton(gr.Button, gr.components.FormComponent):
     """Small button with single emoji as text, fits inside gradio forms"""
     """Small button with single emoji as text, fits inside gradio forms"""

+ 5 - 10
presets/Default.txt

@@ -1,12 +1,7 @@
 do_sample=True
 do_sample=True
-temperature=1
-top_p=1
-typical_p=1
-repetition_penalty=1
-top_k=50
-num_beams=1
-penalty_alpha=0
-min_length=0
-length_penalty=1
-no_repeat_ngram_size=0
+top_p=0.5
+top_k=40
+temperature=0.7
+repetition_penalty=1.2
+typical_p=1.0
 early_stopping=False
 early_stopping=False

+ 0 - 6
presets/Individual Today.txt

@@ -1,6 +0,0 @@
-do_sample=True
-top_p=0.9
-top_k=50
-temperature=1.39
-repetition_penalty=1.08
-typical_p=0.2

+ 4 - 2
requirements.txt

@@ -2,10 +2,12 @@ accelerate==0.17.1
 bitsandbytes==0.37.1
 bitsandbytes==0.37.1
 flexgen==0.1.7
 flexgen==0.1.7
 gradio==3.18.0
 gradio==3.18.0
+markdown
 numpy
 numpy
+peft==0.2.0
 requests
 requests
-rwkv==0.4.2
+rwkv==0.7.0
 safetensors==0.3.0
 safetensors==0.3.0
 sentencepiece
 sentencepiece
 tqdm
 tqdm
-git+https://github.com/zphang/transformers.git@68d640f7c368bcaaaecfc678f11908ebbd3d6176
+git+https://github.com/huggingface/transformers

+ 287 - 193
server.py

@@ -15,6 +15,7 @@ import modules.extensions as extensions_module
 import modules.shared as shared
 import modules.shared as shared
 import modules.ui as ui
 import modules.ui as ui
 from modules.html_generator import generate_chat_html
 from modules.html_generator import generate_chat_html
+from modules.LoRA import add_lora_to_model
 from modules.models import load_model, load_soft_prompt
 from modules.models import load_model, load_soft_prompt
 from modules.text_generation import generate_reply
 from modules.text_generation import generate_reply
 
 
@@ -34,7 +35,7 @@ def get_available_models():
     if shared.args.flexgen:
     if shared.args.flexgen:
         return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
         return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
     else:
     else:
-        return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt'))], key=str.lower)
+        return sorted([re.sub('.pth$', '', item.name) for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
 
 
 def get_available_presets():
 def get_available_presets():
     return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
     return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
@@ -48,6 +49,9 @@ def get_available_extensions():
 def get_available_softprompts():
 def get_available_softprompts():
     return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
     return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
 
 
+def get_available_loras():
+    return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
+
 def load_model_wrapper(selected_model):
 def load_model_wrapper(selected_model):
     if selected_model != shared.model_name:
     if selected_model != shared.model_name:
         shared.model_name = selected_model
         shared.model_name = selected_model
@@ -59,6 +63,17 @@ def load_model_wrapper(selected_model):
 
 
     return selected_model
     return selected_model
 
 
+def load_lora_wrapper(selected_lora):
+    shared.lora_name = selected_lora
+    default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
+
+    if not shared.args.cpu:
+        gc.collect()
+        torch.cuda.empty_cache()
+    add_lora_to_model(selected_lora)
+
+    return selected_lora, default_text
+
 def load_preset_values(preset_menu, return_dict=False):
 def load_preset_values(preset_menu, return_dict=False):
     generate_params = {
     generate_params = {
         'do_sample': True,
         'do_sample': True,
@@ -66,6 +81,7 @@ def load_preset_values(preset_menu, return_dict=False):
         'top_p': 1,
         'top_p': 1,
         'typical_p': 1,
         'typical_p': 1,
         'repetition_penalty': 1,
         'repetition_penalty': 1,
+        'encoder_repetition_penalty': 1,
         'top_k': 50,
         'top_k': 50,
         'num_beams': 1,
         'num_beams': 1,
         'penalty_alpha': 0,
         'penalty_alpha': 0,
@@ -86,7 +102,7 @@ def load_preset_values(preset_menu, return_dict=False):
     if return_dict:
     if return_dict:
         return generate_params
         return generate_params
     else:
     else:
-        return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
+        return preset_menu, generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
 
 
 def upload_soft_prompt(file):
 def upload_soft_prompt(file):
     with zipfile.ZipFile(io.BytesIO(file)) as zf:
     with zipfile.ZipFile(io.BytesIO(file)) as zf:
@@ -100,9 +116,7 @@ def upload_soft_prompt(file):
 
 
     return name
     return name
 
 
-def create_settings_menus(default_preset):
-    generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
-
+def create_model_and_preset_menus():
     with gr.Row():
     with gr.Row():
         with gr.Column():
         with gr.Column():
             with gr.Row():
             with gr.Row():
@@ -113,31 +127,48 @@ def create_settings_menus(default_preset):
                 shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
                 shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
                 ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
                 ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
 
 
-    with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'):
-        with gr.Row():
-            with gr.Column():
-                shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
-                shared.gradio['repetition_penalty'] = gr.Slider(1.0, 2.99, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
-                shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
-                shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
-            with gr.Column():
+def create_settings_menus(default_preset):
+    generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
+
+    with gr.Row():
+        shared.gradio['preset_menu_mirror'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
+        ui.create_refresh_button(shared.gradio['preset_menu_mirror'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
+
+    with gr.Row():
+        with gr.Column():
+            with gr.Box():
+                gr.Markdown('Custom generation parameters ([reference](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
+                with gr.Row():
+                    with gr.Column():
+                        shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
+                        shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
+                        shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
+                        shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
+                    with gr.Column():
+                        shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
+                        shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty')
+                        shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
+                        shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
                 shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
                 shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
-                shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
-                shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
-                shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
+        with gr.Column():
+            with gr.Box():
+                gr.Markdown('Contrastive search')
+                shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
 
 
-        gr.Markdown('Contrastive search:')
-        shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
+            with gr.Box():
+                gr.Markdown('Beam search (uses a lot of VRAM)')
+                with gr.Row():
+                    with gr.Column():
+                        shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
+                    with gr.Column():
+                        shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
+                shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
 
 
-        gr.Markdown('Beam search (uses a lot of VRAM):')
-        with gr.Row():
-            with gr.Column():
-                shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
-            with gr.Column():
-                shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
-        shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
+    with gr.Row():
+        shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
+        ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
 
 
-    with gr.Accordion('Soft prompt', open=False, elem_id='accordion'):
+    with gr.Accordion('Soft prompt', open=False):
         with gr.Row():
         with gr.Row():
             shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
             shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
             ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
             ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
@@ -147,14 +178,35 @@ def create_settings_menus(default_preset):
             shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
             shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
 
 
     shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
     shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
-    shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
+    shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
+    shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
+    shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)
     shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
     shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
     shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
     shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
 
 
+def set_interface_arguments(interface_mode, extensions, cmd_active):
+    modes = ["default", "notebook", "chat", "cai_chat"]
+    cmd_list = vars(shared.args)
+    cmd_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
+
+    shared.args.extensions = extensions
+    for k in modes[1:]:
+        exec(f"shared.args.{k} = False")
+    if interface_mode != "default":
+        exec(f"shared.args.{interface_mode} = True")
+
+    for k in cmd_list:
+        exec(f"shared.args.{k} = False")
+    for k in cmd_active:
+        exec(f"shared.args.{k} = True")
+
+    shared.need_restart = True
+
 available_models = get_available_models()
 available_models = get_available_models()
 available_presets = get_available_presets()
 available_presets = get_available_presets()
 available_characters = get_available_characters()
 available_characters = get_available_characters()
 available_softprompts = get_available_softprompts()
 available_softprompts = get_available_softprompts()
+available_loras = get_available_loras()
 
 
 # Default extensions
 # Default extensions
 extensions_module.available_extensions = get_available_extensions()
 extensions_module.available_extensions = get_available_extensions()
@@ -168,8 +220,6 @@ else:
         shared.args.extensions = shared.args.extensions or []
         shared.args.extensions = shared.args.extensions or []
         if extension not in shared.args.extensions:
         if extension not in shared.args.extensions:
             shared.args.extensions.append(extension)
             shared.args.extensions.append(extension)
-if shared.args.extensions is not None and len(shared.args.extensions) > 0:
-    extensions_module.load_extensions()
 
 
 # Default model
 # Default model
 if shared.args.model is not None:
 if shared.args.model is not None:
@@ -189,191 +239,235 @@ else:
         print()
         print()
     shared.model_name = available_models[i]
     shared.model_name = available_models[i]
 shared.model, shared.tokenizer = load_model(shared.model_name)
 shared.model, shared.tokenizer = load_model(shared.model_name)
+if shared.args.lora:
+    print(shared.args.lora)
+    shared.lora_name = shared.args.lora
+    add_lora_to_model(shared.lora_name)
 
 
 # Default UI settings
 # Default UI settings
-gen_events = []
 default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
 default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
-default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
+default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
+if default_text == '':
+    default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
 title ='Text generation web UI'
 title ='Text generation web UI'
 description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
 description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
 suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
 suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
 
 
-if shared.args.chat or shared.args.cai_chat:
-    with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
-        if shared.args.cai_chat:
-            shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
-        else:
-            shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
-        shared.gradio['textbox'] = gr.Textbox(label='Input')
-        with gr.Row():
-            shared.gradio['Stop'] = gr.Button('Stop')
-            shared.gradio['Generate'] = gr.Button('Generate')
-        with gr.Row():
-            shared.gradio['Impersonate'] = gr.Button('Impersonate')
-            shared.gradio['Regenerate'] = gr.Button('Regenerate')
-        with gr.Row():
-            shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
-            shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
-            shared.gradio['Remove last'] = gr.Button('Remove last')
-
-            shared.gradio['Clear history'] = gr.Button('Clear history')
-            shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
-            shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
-        with gr.Tab('Chat settings'):
-            shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
-            shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
-            shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
-            with gr.Row():
-                shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
-                ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
+def create_interface():
 
 
-            with gr.Row():
-                shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
-            with gr.Row():
-                with gr.Tab('Chat history'):
-                    with gr.Row():
-                        with gr.Column():
-                            gr.Markdown('Upload')
-                            shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'])
-                        with gr.Column():
-                            gr.Markdown('Download')
-                            shared.gradio['download'] = gr.File()
-                            shared.gradio['download_button'] = gr.Button(value='Click me')
-                with gr.Tab('Upload character'):
-                    with gr.Row():
-                        with gr.Column():
-                            gr.Markdown('1. Select the JSON file')
-                            shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'])
-                        with gr.Column():
-                            gr.Markdown('2. Select your character\'s profile picture (optional)')
-                            shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
-                    shared.gradio['Upload character'] = gr.Button(value='Submit')
-                with gr.Tab('Upload your profile picture'):
-                    shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
-                with gr.Tab('Upload TavernAI Character Card'):
-                    shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
-
-        with gr.Tab('Generation settings'):
-            with gr.Row():
-                with gr.Column():
-                    shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
-                with gr.Column():
-                    shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
-                shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
-            create_settings_menus(default_preset)
-
-        shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
-        if shared.args.extensions is not None:
-            with gr.Tab('Extensions'):
-                extensions_module.create_extensions_block()
-
-        function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
-
-        gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
-        gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
-        gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
-        gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
-        shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
-
-        shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
-        shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
-
-        # Clear history with confirmation
-        clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
-        shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
-        shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
-        shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
-        shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
-
-        shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
-        shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
-        shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
-
-        # Clearing stuff and saving the history
-        for i in ['Generate', 'Regenerate', 'Replace last reply']:
-            shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
-            shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
-        shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
-        shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
-        shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
-
-        shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']])
-        shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
-        shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
-        shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
-
-        reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
-        reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
-        shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
-        shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
-        shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
-
-        shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
-        shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
-
-elif shared.args.notebook:
-    with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
-        gr.Markdown(description)
-        with gr.Tab('Raw'):
-            shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=23)
-        with gr.Tab('Markdown'):
-            shared.gradio['markdown'] = gr.Markdown()
-        with gr.Tab('HTML'):
-            shared.gradio['html'] = gr.HTML()
-
-        shared.gradio['Generate'] = gr.Button('Generate')
-        shared.gradio['Stop'] = gr.Button('Stop')
-        shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
-
-        create_settings_menus(default_preset)
-        if shared.args.extensions is not None:
-            extensions_module.create_extensions_block()
+    gen_events = []
+    if shared.args.extensions is not None and len(shared.args.extensions) > 0:
+        extensions_module.load_extensions()
 
 
-        shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
-        output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
-        gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
-        gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
-        shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
+    with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
+        if shared.args.chat or shared.args.cai_chat:
+            with gr.Tab("Text generation", elem_id="main"):
+                if shared.args.cai_chat:
+                    shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
+                else:
+                    shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
+                shared.gradio['textbox'] = gr.Textbox(label='Input')
+                with gr.Row():
+                    shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
+                    shared.gradio['Generate'] = gr.Button('Generate')
+                with gr.Row():
+                    shared.gradio['Impersonate'] = gr.Button('Impersonate')
+                    shared.gradio['Regenerate'] = gr.Button('Regenerate')
+                with gr.Row():
+                    shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
+                    shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
+                    shared.gradio['Remove last'] = gr.Button('Remove last')
 
 
-else:
-    with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
-        gr.Markdown(description)
-        with gr.Row():
-            with gr.Column():
-                shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
-                shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
-                shared.gradio['Generate'] = gr.Button('Generate')
+                    shared.gradio['Clear history'] = gr.Button('Clear history')
+                    shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
+                    shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
+
+                create_model_and_preset_menus()
+
+            with gr.Tab("Character", elem_id="chat-settings"):
+                shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
+                shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
+                shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
                 with gr.Row():
                 with gr.Row():
-                    with gr.Column():
-                        shared.gradio['Continue'] = gr.Button('Continue')
-                    with gr.Column():
-                        shared.gradio['Stop'] = gr.Button('Stop')
+                    shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
+                    ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
+
+                with gr.Row():
+                    with gr.Tab('Chat history'):
+                        with gr.Row():
+                            with gr.Column():
+                                gr.Markdown('Upload')
+                                shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'])
+                            with gr.Column():
+                                gr.Markdown('Download')
+                                shared.gradio['download'] = gr.File()
+                                shared.gradio['download_button'] = gr.Button(value='Click me')
+                    with gr.Tab('Upload character'):
+                        with gr.Row():
+                            with gr.Column():
+                                gr.Markdown('1. Select the JSON file')
+                                shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'])
+                            with gr.Column():
+                                gr.Markdown('2. Select your character\'s profile picture (optional)')
+                                shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
+                        shared.gradio['Upload character'] = gr.Button(value='Submit')
+                    with gr.Tab('Upload your profile picture'):
+                        shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
+                    with gr.Tab('Upload TavernAI Character Card'):
+                        shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
+
+            with gr.Tab("Parameters", elem_id="parameters"):
+                with gr.Box():
+                    gr.Markdown("Chat parameters")
+                    with gr.Row():
+                        with gr.Column():
+                            shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
+                            shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
+                        with gr.Column():
+                            shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
+                            shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
 
 
                 create_settings_menus(default_preset)
                 create_settings_menus(default_preset)
-                if shared.args.extensions is not None:
-                    extensions_module.create_extensions_block()
 
 
-            with gr.Column():
+            function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
+            shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
+
+            gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
+            gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
+            gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
+            gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
+            shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
+
+            shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
+            shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
+
+            # Clear history with confirmation
+            clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
+            shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
+            shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
+            shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
+            shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
+
+            shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
+            shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
+            shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
+
+            # Clearing stuff and saving the history
+            for i in ['Generate', 'Regenerate', 'Replace last reply']:
+                shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
+                shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+            shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+            shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
+            shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+
+            shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']])
+            shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
+            shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
+            shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
+
+            reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
+            reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
+            shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
+            shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
+            shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
+
+            shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
+            shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
+            shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
+
+        elif shared.args.notebook:
+            with gr.Tab("Text generation", elem_id="main"):
                 with gr.Tab('Raw'):
                 with gr.Tab('Raw'):
-                    shared.gradio['output_textbox'] = gr.Textbox(lines=15, label='Output')
+                    shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
                 with gr.Tab('Markdown'):
                 with gr.Tab('Markdown'):
                     shared.gradio['markdown'] = gr.Markdown()
                     shared.gradio['markdown'] = gr.Markdown()
                 with gr.Tab('HTML'):
                 with gr.Tab('HTML'):
                     shared.gradio['html'] = gr.HTML()
                     shared.gradio['html'] = gr.HTML()
 
 
-        shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
-        output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
-        gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
-        gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
-        gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
-        shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
+                with gr.Row():
+                    shared.gradio['Stop'] = gr.Button('Stop')
+                    shared.gradio['Generate'] = gr.Button('Generate')
+                shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
 
 
-shared.gradio['interface'].queue()
-if shared.args.listen:
-    shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
-else:
-    shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
+                create_model_and_preset_menus()
+            with gr.Tab("Parameters", elem_id="parameters"):
+                create_settings_menus(default_preset)
+
+            shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
+            output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
+            gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
+            gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
+            shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
+            shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
+
+        else:
+            with gr.Tab("Text generation", elem_id="main"):
+                with gr.Row():
+                    with gr.Column():
+                        shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
+                        shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
+                        shared.gradio['Generate'] = gr.Button('Generate')
+                        with gr.Row():
+                            with gr.Column():
+                                shared.gradio['Continue'] = gr.Button('Continue')
+                            with gr.Column():
+                                shared.gradio['Stop'] = gr.Button('Stop')
+
+                        create_model_and_preset_menus()
+
+                    with gr.Column():
+                        with gr.Tab('Raw'):
+                            shared.gradio['output_textbox'] = gr.Textbox(lines=25, label='Output')
+                        with gr.Tab('Markdown'):
+                            shared.gradio['markdown'] = gr.Markdown()
+                        with gr.Tab('HTML'):
+                            shared.gradio['html'] = gr.HTML()
+            with gr.Tab("Parameters", elem_id="parameters"):
+                create_settings_menus(default_preset)
+
+            shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
+            output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
+            gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
+            gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
+            gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
+            shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
+            shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
+
+        with gr.Tab("Interface mode", elem_id="interface-mode"):
+            modes = ["default", "notebook", "chat", "cai_chat"]
+            current_mode = "default"
+            for mode in modes[1:]:
+                if eval(f"shared.args.{mode}"):
+                    current_mode = mode
+                    break
+            cmd_list = vars(shared.args)
+            cmd_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
+            active_cmd_list = [k for k in cmd_list if vars(shared.args)[k]]
+
+            gr.Markdown("*Experimental*")
+            shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
+            shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions")
+            shared.gradio['cmd_arguments_menu'] = gr.CheckboxGroup(choices=cmd_list, value=active_cmd_list, label="Boolean command-line flags")
+            shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
+
+            shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'cmd_arguments_menu']], None)
+            shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500)}')
+
+        if shared.args.extensions is not None:
+            extensions_module.create_extensions_block()
+
+    # Launch the interface
+    shared.gradio['interface'].queue()
+    if shared.args.listen:
+        shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
+    else:
+        shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
+
+create_interface()
 
 
-# I think that I will need this later
 while True:
 while True:
     time.sleep(0.5)
     time.sleep(0.5)
+    if shared.need_restart:
+        shared.need_restart = False
+        shared.gradio['interface'].close()
+        create_interface()

+ 6 - 3
settings-template.json

@@ -5,7 +5,7 @@
     "name1": "Person 1",
     "name1": "Person 1",
     "name2": "Person 2",
     "name2": "Person 2",
     "context": "This is a conversation between two people.",
     "context": "This is a conversation between two people.",
-    "stop_at_newline": true,
+    "stop_at_newline": false,
     "chat_prompt_size": 2048,
     "chat_prompt_size": 2048,
     "chat_prompt_size_min": 0,
     "chat_prompt_size_min": 0,
     "chat_prompt_size_max": 2048,
     "chat_prompt_size_max": 2048,
@@ -23,13 +23,16 @@
     "presets": {
     "presets": {
         "default": "NovelAI-Sphinx Moth",
         "default": "NovelAI-Sphinx Moth",
         "pygmalion-*": "Pygmalion",
         "pygmalion-*": "Pygmalion",
-        "RWKV-*": "Naive",
-        "(rosey|chip|joi)_.*_instruct.*": "Instruct Joi (Contrastive Search)"
+        "RWKV-*": "Naive"
     },
     },
     "prompts": {
     "prompts": {
         "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
         "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
         "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
         "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
         "(rosey|chip|joi)_.*_instruct.*": "User: \n",
         "(rosey|chip|joi)_.*_instruct.*": "User: \n",
         "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
         "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
+    },
+    "lora_prompts": {
+        "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
+        "alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
     }
     }
 }
 }