Merge branch 'main' into da3dsoul-main
This commit is contained in:
27
README.md
27
README.md
@@ -15,6 +15,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
|||||||
* Dropdown menu for switching between models
|
* Dropdown menu for switching between models
|
||||||
* Notebook mode that resembles OpenAI's playground
|
* Notebook mode that resembles OpenAI's playground
|
||||||
* Chat mode for conversation and role playing
|
* Chat mode for conversation and role playing
|
||||||
|
* Instruct mode compatible with Alpaca and Open Assistant formats **\*NEW!\***
|
||||||
* Nice HTML output for GPT-4chan
|
* Nice HTML output for GPT-4chan
|
||||||
* Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering
|
* Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering
|
||||||
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters)
|
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters)
|
||||||
@@ -26,11 +27,11 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
|||||||
* CPU mode
|
* CPU mode
|
||||||
* [FlexGen](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen)
|
* [FlexGen](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen)
|
||||||
* [DeepSpeed ZeRO-3](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed)
|
* [DeepSpeed ZeRO-3](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed)
|
||||||
* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming
|
* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-stream.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming
|
||||||
* [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model)
|
* [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model)
|
||||||
* [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\***
|
* [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\***
|
||||||
* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model)
|
* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model)
|
||||||
* [LoRa (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
|
* [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
|
||||||
* Softprompts
|
* Softprompts
|
||||||
* [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions)
|
* [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions)
|
||||||
* [Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab)
|
* [Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab)
|
||||||
@@ -62,7 +63,7 @@ Recommended if you have some experience with the command-line.
|
|||||||
|
|
||||||
On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/WSL-installation-guide).
|
On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/WSL-installation-guide).
|
||||||
|
|
||||||
0. Install Conda
|
#### 0. Install Conda
|
||||||
|
|
||||||
https://docs.conda.io/en/latest/miniconda.html
|
https://docs.conda.io/en/latest/miniconda.html
|
||||||
|
|
||||||
@@ -75,14 +76,14 @@ bash Miniconda3.sh
|
|||||||
|
|
||||||
Source: https://educe-ubc.github.io/conda.html
|
Source: https://educe-ubc.github.io/conda.html
|
||||||
|
|
||||||
1. Create a new conda environment
|
#### 1. Create a new conda environment
|
||||||
|
|
||||||
```
|
```
|
||||||
conda create -n textgen python=3.10.9
|
conda create -n textgen python=3.10.9
|
||||||
conda activate textgen
|
conda activate textgen
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install Pytorch
|
#### 2. Install Pytorch
|
||||||
|
|
||||||
| System | GPU | Command |
|
| System | GPU | Command |
|
||||||
|--------|---------|---------|
|
|--------|---------|---------|
|
||||||
@@ -92,10 +93,12 @@ conda activate textgen
|
|||||||
|
|
||||||
The up to date commands can be found here: https://pytorch.org/get-started/locally/.
|
The up to date commands can be found here: https://pytorch.org/get-started/locally/.
|
||||||
|
|
||||||
MacOS users, refer to the comments here: https://github.com/oobabooga/text-generation-webui/pull/393
|
#### 2.1 Special instructions
|
||||||
|
|
||||||
|
* MacOS users: https://github.com/oobabooga/text-generation-webui/pull/393
|
||||||
|
* AMD users: https://rentry.org/eq3hg
|
||||||
|
|
||||||
3. Install the web UI
|
#### 3. Install the web UI
|
||||||
|
|
||||||
```
|
```
|
||||||
git clone https://github.com/oobabooga/text-generation-webui
|
git clone https://github.com/oobabooga/text-generation-webui
|
||||||
@@ -116,6 +119,15 @@ As an alternative to the recommended WSL method, you can install the web UI nati
|
|||||||
|
|
||||||
https://github.com/oobabooga/text-generation-webui/issues/174, https://github.com/oobabooga/text-generation-webui/issues/87
|
https://github.com/oobabooga/text-generation-webui/issues/174, https://github.com/oobabooga/text-generation-webui/issues/87
|
||||||
|
|
||||||
|
### Updating the requirements
|
||||||
|
|
||||||
|
From time to time, the `requirements.txt` changes. To update, use this command:
|
||||||
|
|
||||||
|
```
|
||||||
|
conda activate textgen
|
||||||
|
cd text-generation-webui
|
||||||
|
pip install -r requirements.txt --upgrade
|
||||||
|
```
|
||||||
## Downloading models
|
## Downloading models
|
||||||
|
|
||||||
Models should be placed inside the `models` folder.
|
Models should be placed inside the `models` folder.
|
||||||
@@ -175,7 +187,6 @@ Optionally, you can use the following command-line flags:
|
|||||||
| `-h`, `--help` | show this help message and exit |
|
| `-h`, `--help` | show this help message and exit |
|
||||||
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
|
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
|
||||||
| `--chat` | Launch the web UI in chat mode.|
|
| `--chat` | Launch the web UI in chat mode.|
|
||||||
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
|
|
||||||
| `--model MODEL` | Name of the model to load by default. |
|
| `--model MODEL` | Name of the model to load by default. |
|
||||||
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
|
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
|
||||||
| `--model-dir MODEL_DIR` | Path to directory with all the models |
|
| `--model-dir MODEL_DIR` | Path to directory with all the models |
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ async def run(context):
|
|||||||
'early_stopping': False,
|
'early_stopping': False,
|
||||||
'seed': -1,
|
'seed': -1,
|
||||||
}
|
}
|
||||||
|
payload = json.dumps([context, params])
|
||||||
session = random_hash()
|
session = random_hash()
|
||||||
|
|
||||||
async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
|
async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
|
||||||
@@ -54,22 +55,7 @@ async def run(context):
|
|||||||
"session_hash": session,
|
"session_hash": session,
|
||||||
"fn_index": 12,
|
"fn_index": 12,
|
||||||
"data": [
|
"data": [
|
||||||
context,
|
payload
|
||||||
params['max_new_tokens'],
|
|
||||||
params['do_sample'],
|
|
||||||
params['temperature'],
|
|
||||||
params['top_p'],
|
|
||||||
params['typical_p'],
|
|
||||||
params['repetition_penalty'],
|
|
||||||
params['encoder_repetition_penalty'],
|
|
||||||
params['top_k'],
|
|
||||||
params['min_length'],
|
|
||||||
params['no_repeat_ngram_size'],
|
|
||||||
params['num_beams'],
|
|
||||||
params['penalty_alpha'],
|
|
||||||
params['length_penalty'],
|
|
||||||
params['early_stopping'],
|
|
||||||
params['seed'],
|
|
||||||
]
|
]
|
||||||
}))
|
}))
|
||||||
case "process_starts":
|
case "process_starts":
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ Optionally, you can also add the --share flag to generate a public gradio URL,
|
|||||||
allowing you to use the API remotely.
|
allowing you to use the API remotely.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
import json
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
# Server address
|
# Server address
|
||||||
@@ -38,24 +40,11 @@ params = {
|
|||||||
# Input prompt
|
# Input prompt
|
||||||
prompt = "What I would like to say is the following: "
|
prompt = "What I would like to say is the following: "
|
||||||
|
|
||||||
|
payload = json.dumps([prompt, params])
|
||||||
|
|
||||||
response = requests.post(f"http://{server}:7860/run/textgen", json={
|
response = requests.post(f"http://{server}:7860/run/textgen", json={
|
||||||
"data": [
|
"data": [
|
||||||
prompt,
|
payload
|
||||||
params['max_new_tokens'],
|
|
||||||
params['do_sample'],
|
|
||||||
params['temperature'],
|
|
||||||
params['top_p'],
|
|
||||||
params['typical_p'],
|
|
||||||
params['repetition_penalty'],
|
|
||||||
params['encoder_repetition_penalty'],
|
|
||||||
params['top_k'],
|
|
||||||
params['min_length'],
|
|
||||||
params['no_repeat_ngram_size'],
|
|
||||||
params['num_beams'],
|
|
||||||
params['penalty_alpha'],
|
|
||||||
params['length_penalty'],
|
|
||||||
params['early_stopping'],
|
|
||||||
params['seed'],
|
|
||||||
]
|
]
|
||||||
}).json()
|
}).json()
|
||||||
|
|
||||||
|
|||||||
@@ -1,32 +1,16 @@
|
|||||||
name: Chiharu Yamada
|
name: "Chiharu Yamada"
|
||||||
context: 'Chiharu Yamada''s Persona: Chiharu Yamada is a young, computer engineer-nerd
|
context: "Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology."
|
||||||
with a knack for problem solving and a passion for technology.'
|
greeting: |-
|
||||||
greeting: '*Chiharu strides into the room with a smile, her eyes lighting up
|
*Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air*
|
||||||
when she sees you. She''s wearing a light blue t-shirt and jeans, her laptop bag
|
Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started!
|
||||||
slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in
|
example_dialogue: |-
|
||||||
the air*
|
{{user}}: So how did you get into computer engineering?
|
||||||
|
{{char}}: I've always loved tinkering with technology since I was a kid.
|
||||||
Hey! I''m so excited to finally meet you. I''ve heard so many great things about
|
{{user}}: That's really impressive!
|
||||||
you and I''m eager to pick your brain about computers. I''m sure you have a wealth
|
|
||||||
of knowledge that I can learn from. *She grins, eyes twinkling with excitement*
|
|
||||||
Let''s get started!'
|
|
||||||
example_dialogue: '{{user}}: So how did you get into computer engineering?
|
|
||||||
|
|
||||||
{{char}}: I''ve always loved tinkering with technology since I was a kid.
|
|
||||||
|
|
||||||
{{user}}: That''s really impressive!
|
|
||||||
|
|
||||||
{{char}}: *She chuckles bashfully* Thanks!
|
{{char}}: *She chuckles bashfully* Thanks!
|
||||||
|
{{user}}: So what do you do when you're not working on computers?
|
||||||
{{user}}: So what do you do when you''re not working on computers?
|
{{char}}: I love exploring, going out with friends, watching movies, and playing video games.
|
||||||
|
{{user}}: What's your favorite type of computer hardware to work with?
|
||||||
{{char}}: I love exploring, going out with friends, watching movies, and playing
|
{{char}}: Motherboards, they're like puzzles and the backbone of any system.
|
||||||
video games.
|
|
||||||
|
|
||||||
{{user}}: What''s your favorite type of computer hardware to work with?
|
|
||||||
|
|
||||||
{{char}}: Motherboards, they''re like puzzles and the backbone of any system.
|
|
||||||
|
|
||||||
{{user}}: That sounds great!
|
{{user}}: That sounds great!
|
||||||
|
{{char}}: Yeah, it's really fun. I'm lucky to be able to do this as a job.
|
||||||
{{char}}: Yeah, it''s really fun. I''m lucky to be able to do this as a job.'
|
|
||||||
|
|||||||
3
characters/instruction-following/Alpaca.yaml
Normal file
3
characters/instruction-following/Alpaca.yaml
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
name: "### Response:"
|
||||||
|
your_name: "### Instruction:"
|
||||||
|
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||||
3
characters/instruction-following/Open Assistant.yaml
Normal file
3
characters/instruction-following/Open Assistant.yaml
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
name: "<|assistant|>"
|
||||||
|
your_name: "<|prompter|>"
|
||||||
|
end_of_turn: "<|endoftext|>"
|
||||||
3
characters/instruction-following/Vicuna.yaml
Normal file
3
characters/instruction-following/Vicuna.yaml
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
name: "### Assistant:"
|
||||||
|
your_name: "### Human:"
|
||||||
|
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||||
@@ -64,6 +64,15 @@
|
|||||||
line-height: 1.428571429 !important;
|
line-height: 1.428571429 !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.message-body li {
|
||||||
|
margin-top: 0.5em !important;
|
||||||
|
margin-bottom: 0.5em !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body li > p {
|
||||||
|
display: inline !important;
|
||||||
|
}
|
||||||
|
|
||||||
.dark .message-body p em {
|
.dark .message-body p em {
|
||||||
color: rgb(138, 138, 138) !important;
|
color: rgb(138, 138, 138) !important;
|
||||||
}
|
}
|
||||||
|
|||||||
65
css/html_instruct_style.css
Normal file
65
css/html_instruct_style.css
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
.chat {
|
||||||
|
margin-left: auto;
|
||||||
|
margin-right: auto;
|
||||||
|
max-width: 800px;
|
||||||
|
height: 66.67vh;
|
||||||
|
overflow-y: auto;
|
||||||
|
padding-right: 20px;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column-reverse;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: 60px 1fr;
|
||||||
|
padding-bottom: 25px;
|
||||||
|
font-size: 15px;
|
||||||
|
font-family: Helvetica, Arial, sans-serif;
|
||||||
|
line-height: 1.428571429;
|
||||||
|
}
|
||||||
|
|
||||||
|
.username {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body {}
|
||||||
|
|
||||||
|
.message-body p {
|
||||||
|
margin-bottom: 0 !important;
|
||||||
|
font-size: 15px !important;
|
||||||
|
line-height: 1.428571429 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body li {
|
||||||
|
margin-top: 0.5em !important;
|
||||||
|
margin-bottom: 0.5em !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body li > p {
|
||||||
|
display: inline !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .message-body p em {
|
||||||
|
color: rgb(138, 138, 138) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body p em {
|
||||||
|
color: rgb(110, 110, 110) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gradio-container .chat .assistant-message {
|
||||||
|
padding: 15px;
|
||||||
|
border-radius: 20px;
|
||||||
|
background-color: #0000000f;
|
||||||
|
margin-bottom: 17.5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gradio-container .chat .user-message {
|
||||||
|
padding: 15px;
|
||||||
|
border-radius: 20px;
|
||||||
|
margin-bottom: 17.5px !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .chat .assistant-message {
|
||||||
|
background-color: #ffffff21;
|
||||||
|
}
|
||||||
@@ -41,7 +41,7 @@ ol li p, ul li p {
|
|||||||
display: inline-block;
|
display: inline-block;
|
||||||
}
|
}
|
||||||
|
|
||||||
#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab {
|
#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab, #model-tab {
|
||||||
border: 0;
|
border: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,3 +63,7 @@ span.math.inline {
|
|||||||
font-size: 27px;
|
font-size: 27px;
|
||||||
vertical-align: baseline !important;
|
vertical-align: baseline !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
|
||||||
|
flex-wrap: nowrap;
|
||||||
|
}
|
||||||
|
|||||||
@@ -40,24 +40,27 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
prompt_lines.pop(0)
|
prompt_lines.pop(0)
|
||||||
|
|
||||||
prompt = '\n'.join(prompt_lines)
|
prompt = '\n'.join(prompt_lines)
|
||||||
|
generate_params = {
|
||||||
|
'max_new_tokens': int(body.get('max_length', 200)),
|
||||||
|
'do_sample': bool(body.get('do_sample', True)),
|
||||||
|
'temperature': float(body.get('temperature', 0.5)),
|
||||||
|
'top_p': float(body.get('top_p', 1)),
|
||||||
|
'typical_p': float(body.get('typical', 1)),
|
||||||
|
'repetition_penalty': float(body.get('rep_pen', 1.1)),
|
||||||
|
'encoder_repetition_penalty': 1,
|
||||||
|
'top_k': int(body.get('top_k', 0)),
|
||||||
|
'min_length': int(body.get('min_length', 0)),
|
||||||
|
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
|
||||||
|
'num_beams': int(body.get('num_beams',1)),
|
||||||
|
'penalty_alpha': float(body.get('penalty_alpha', 0)),
|
||||||
|
'length_penalty': float(body.get('length_penalty', 1)),
|
||||||
|
'early_stopping': bool(body.get('early_stopping', False)),
|
||||||
|
'seed': int(body.get('seed', -1)),
|
||||||
|
}
|
||||||
|
|
||||||
generator = generate_reply(
|
generator = generate_reply(
|
||||||
question = prompt,
|
prompt,
|
||||||
max_new_tokens = int(body.get('max_length', 200)),
|
generate_params,
|
||||||
do_sample=bool(body.get('do_sample', True)),
|
|
||||||
temperature=float(body.get('temperature', 0.5)),
|
|
||||||
top_p=float(body.get('top_p', 1)),
|
|
||||||
typical_p=float(body.get('typical', 1)),
|
|
||||||
repetition_penalty=float(body.get('rep_pen', 1.1)),
|
|
||||||
encoder_repetition_penalty=1,
|
|
||||||
top_k=int(body.get('top_k', 0)),
|
|
||||||
min_length=int(body.get('min_length', 0)),
|
|
||||||
no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
|
|
||||||
num_beams=int(body.get('num_beams',1)),
|
|
||||||
penalty_alpha=float(body.get('penalty_alpha', 0)),
|
|
||||||
length_penalty=float(body.get('length_penalty', 1)),
|
|
||||||
early_stopping=bool(body.get('early_stopping', False)),
|
|
||||||
seed=int(body.get('seed', -1)),
|
|
||||||
stopping_strings=body.get('stopping_strings', []),
|
stopping_strings=body.get('stopping_strings', []),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules.chat import load_character
|
|
||||||
from modules.html_generator import get_image_cache
|
from modules.html_generator import get_image_cache
|
||||||
from modules.shared import gradio, settings
|
from modules.shared import gradio
|
||||||
|
|
||||||
|
|
||||||
def generate_css():
|
def generate_css():
|
||||||
@@ -64,22 +63,13 @@ def generate_html():
|
|||||||
for file in sorted(Path("characters").glob("*")):
|
for file in sorted(Path("characters").glob("*")):
|
||||||
if file.suffix in [".json", ".yml", ".yaml"]:
|
if file.suffix in [".json", ".yml", ".yaml"]:
|
||||||
character = file.stem
|
character = file.stem
|
||||||
container_html = f'<div class="character-container">'
|
container_html = '<div class="character-container">'
|
||||||
image_html = "<div class='placeholder'></div>"
|
image_html = "<div class='placeholder'></div>"
|
||||||
|
|
||||||
for i in [
|
for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
|
||||||
f"characters/{character}.png",
|
|
||||||
f"characters/{character}.jpg",
|
|
||||||
f"characters/{character}.jpeg",
|
|
||||||
]:
|
|
||||||
|
|
||||||
path = Path(i)
|
|
||||||
if path.exists():
|
if path.exists():
|
||||||
try:
|
image_html = f'<img src="file/{get_image_cache(path)}">'
|
||||||
image_html = f'<img src="file/{get_image_cache(path)}">'
|
break
|
||||||
break
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
|
|
||||||
container_html += f'{image_html} <span class="character-name">{character}</span>'
|
container_html += f'{image_html} <span class="character-name">{character}</span>'
|
||||||
container_html += "</div>"
|
container_html += "</div>"
|
||||||
|
|||||||
@@ -176,4 +176,4 @@ def ui():
|
|||||||
|
|
||||||
force_btn.click(force_pic)
|
force_btn.click(force_pic)
|
||||||
generate_now_btn.click(force_pic)
|
generate_now_btn.click(force_pic)
|
||||||
generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||||
@@ -2,12 +2,11 @@ import base64
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.chat as chat
|
|
||||||
import modules.shared as shared
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
from transformers import BlipForConditionalGeneration, BlipProcessor
|
from transformers import BlipForConditionalGeneration, BlipProcessor
|
||||||
|
|
||||||
|
from modules import chat, shared
|
||||||
|
|
||||||
# If 'state' is True, will hijack the next chat generation with
|
# If 'state' is True, will hijack the next chat generation with
|
||||||
# custom input text given by 'value' in the format [text, visible_text]
|
# custom input text given by 'value' in the format [text, visible_text]
|
||||||
input_hijack = {
|
input_hijack = {
|
||||||
@@ -36,13 +35,11 @@ def generate_chat_picture(picture, name1, name2):
|
|||||||
def ui():
|
def ui():
|
||||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||||
|
|
||||||
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
|
|
||||||
|
|
||||||
# Prepare the hijack with custom inputs
|
# Prepare the hijack with custom inputs
|
||||||
picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
|
picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
|
||||||
|
|
||||||
# Call the generation function
|
# Call the generation function
|
||||||
picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||||
|
|
||||||
# Clear the picture from the upload field
|
# Clear the picture from the upload field
|
||||||
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
|
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import inspect
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -33,16 +34,31 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||||||
for name in exclude_layers:
|
for name in exclude_layers:
|
||||||
if name in layers:
|
if name in layers:
|
||||||
del layers[name]
|
del layers[name]
|
||||||
make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
|
|
||||||
|
gptq_args = inspect.getfullargspec(make_quant).args
|
||||||
|
|
||||||
|
make_quant_kwargs = {
|
||||||
|
'module': model,
|
||||||
|
'names': layers,
|
||||||
|
'bits': wbits,
|
||||||
|
}
|
||||||
|
if 'groupsize' in gptq_args:
|
||||||
|
make_quant_kwargs['groupsize'] = groupsize
|
||||||
|
if 'faster' in gptq_args:
|
||||||
|
make_quant_kwargs['faster'] = faster_kernel
|
||||||
|
if 'kernel_switch_threshold' in gptq_args:
|
||||||
|
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
|
||||||
|
|
||||||
|
make_quant(**make_quant_kwargs)
|
||||||
|
|
||||||
del layers
|
del layers
|
||||||
|
|
||||||
print('Loading model ...')
|
print('Loading model ...')
|
||||||
if checkpoint.endswith('.safetensors'):
|
if checkpoint.endswith('.safetensors'):
|
||||||
from safetensors.torch import load_file as safe_load
|
from safetensors.torch import load_file as safe_load
|
||||||
model.load_state_dict(safe_load(checkpoint))
|
model.load_state_dict(safe_load(checkpoint), strict = False)
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(torch.load(checkpoint))
|
model.load_state_dict(torch.load(checkpoint), strict = False)
|
||||||
model.seqlen = 2048
|
model.seqlen = 2048
|
||||||
print('Done.')
|
print('Done.')
|
||||||
|
|
||||||
@@ -52,7 +68,7 @@ def load_quantized(model_name):
|
|||||||
if not shared.args.model_type:
|
if not shared.args.model_type:
|
||||||
# Try to determine model type from model name
|
# Try to determine model type from model name
|
||||||
name = model_name.lower()
|
name = model_name.lower()
|
||||||
if any((k in name for k in ['llama', 'alpaca'])):
|
if any((k in name for k in ['llama', 'alpaca', 'vicuna'])):
|
||||||
model_type = 'llama'
|
model_type = 'llama'
|
||||||
elif any((k in name for k in ['opt-', 'galactica'])):
|
elif any((k in name for k in ['opt-', 'galactica'])):
|
||||||
model_type = 'opt'
|
model_type = 'opt'
|
||||||
@@ -65,16 +81,18 @@ def load_quantized(model_name):
|
|||||||
else:
|
else:
|
||||||
model_type = shared.args.model_type.lower()
|
model_type = shared.args.model_type.lower()
|
||||||
|
|
||||||
if model_type == 'llama' and shared.args.pre_layer:
|
if shared.args.pre_layer and model_type == 'llama':
|
||||||
load_quant = llama_inference_offload.load_quant
|
load_quant = llama_inference_offload.load_quant
|
||||||
elif model_type in ('llama', 'opt', 'gptj'):
|
elif model_type in ('llama', 'opt', 'gptj'):
|
||||||
|
if shared.args.pre_layer:
|
||||||
|
print("Warning: ignoring --pre_layer because it only works for llama model type.")
|
||||||
load_quant = _load_quant
|
load_quant = _load_quant
|
||||||
else:
|
else:
|
||||||
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
|
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
# Now we are going to try to locate the quantized model file.
|
# Now we are going to try to locate the quantized model file.
|
||||||
path_to_model = Path(f'models/{model_name}')
|
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
found_pts = list(path_to_model.glob("*.pt"))
|
found_pts = list(path_to_model.glob("*.pt"))
|
||||||
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
||||||
pt_path = None
|
pt_path = None
|
||||||
@@ -95,8 +113,8 @@ def load_quantized(model_name):
|
|||||||
else:
|
else:
|
||||||
pt_model = f'{model_name}-{shared.args.wbits}bit'
|
pt_model = f'{model_name}-{shared.args.wbits}bit'
|
||||||
|
|
||||||
# Try to find the .safetensors or .pt both in models/ and in the subfolder
|
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
|
||||||
for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||||
if path.exists():
|
if path.exists():
|
||||||
print(f"Found {path}")
|
print(f"Found {path}")
|
||||||
pt_path = path
|
pt_path = path
|
||||||
@@ -107,7 +125,7 @@ def load_quantized(model_name):
|
|||||||
exit()
|
exit()
|
||||||
|
|
||||||
# qwopqwop200's offload
|
# qwopqwop200's offload
|
||||||
if shared.args.pre_layer:
|
if model_type == 'llama' and shared.args.pre_layer:
|
||||||
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
|
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
|
||||||
else:
|
else:
|
||||||
threshold = False if model_type == 'gptj' else 128
|
threshold = False if model_type == 'gptj' else 128
|
||||||
|
|||||||
38
modules/api.py
Normal file
38
modules/api.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.text_generation import generate_reply
|
||||||
|
|
||||||
|
|
||||||
|
def generate_reply_wrapper(string):
|
||||||
|
generate_params = {
|
||||||
|
'do_sample': True,
|
||||||
|
'temperature': 1,
|
||||||
|
'top_p': 1,
|
||||||
|
'typical_p': 1,
|
||||||
|
'repetition_penalty': 1,
|
||||||
|
'encoder_repetition_penalty': 1,
|
||||||
|
'top_k': 50,
|
||||||
|
'num_beams': 1,
|
||||||
|
'penalty_alpha': 0,
|
||||||
|
'min_length': 0,
|
||||||
|
'length_penalty': 1,
|
||||||
|
'no_repeat_ngram_size': 0,
|
||||||
|
'early_stopping': False,
|
||||||
|
}
|
||||||
|
params = json.loads(string)
|
||||||
|
for k in params[1]:
|
||||||
|
generate_params[k] = params[1][k]
|
||||||
|
for i in generate_reply(params[0], generate_params):
|
||||||
|
yield i
|
||||||
|
|
||||||
|
def create_apis():
|
||||||
|
t1 = gr.Textbox(visible=False)
|
||||||
|
t2 = gr.Textbox(visible=False)
|
||||||
|
dummy = gr.Button(visible=False)
|
||||||
|
|
||||||
|
input_params = [t1]
|
||||||
|
output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
|
||||||
|
dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')
|
||||||
205
modules/chat.py
205
modules/chat.py
@@ -12,45 +12,55 @@ from PIL import Image
|
|||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import fix_newlines, generate_chat_html
|
from modules.html_generator import (fix_newlines, chat_html_wrapper,
|
||||||
|
make_thumbnail)
|
||||||
from modules.text_generation import (encode, generate_reply,
|
from modules.text_generation import (encode, generate_reply,
|
||||||
get_max_prompt_length)
|
get_max_prompt_length)
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_output(history, name1, name2, character):
|
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
|
||||||
if shared.args.cai_chat:
|
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
|
||||||
return generate_chat_html(history, name1, name2, character)
|
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
||||||
else:
|
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||||
return history
|
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||||
|
|
||||||
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False):
|
|
||||||
user_input = fix_newlines(user_input)
|
|
||||||
rows = [f"{context.strip()}\n"]
|
rows = [f"{context.strip()}\n"]
|
||||||
|
|
||||||
|
# Finding the maximum prompt size
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||||
max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
|
max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
|
||||||
|
|
||||||
|
if is_instruct:
|
||||||
|
prefix1 = f"{name1}\n"
|
||||||
|
prefix2 = f"{name2}\n"
|
||||||
|
else:
|
||||||
|
prefix1 = f"{name1}: "
|
||||||
|
prefix2 = f"{name2}: "
|
||||||
|
|
||||||
i = len(shared.history['internal'])-1
|
i = len(shared.history['internal'])-1
|
||||||
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
||||||
rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
|
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
||||||
prev_user_input = shared.history['internal'][i][0]
|
string = shared.history['internal'][i][0]
|
||||||
if prev_user_input not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||||
rows.insert(1, f"{name1}: {prev_user_input.strip()}\n")
|
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
|
||||||
i -= 1
|
i -= 1
|
||||||
|
|
||||||
if not impersonate:
|
if impersonate:
|
||||||
if len(user_input) > 0:
|
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
||||||
rows.append(f"{name1}: {user_input}\n")
|
|
||||||
rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
|
|
||||||
limit = 3
|
|
||||||
else:
|
|
||||||
rows.append(f"{name1}:")
|
|
||||||
limit = 2
|
limit = 2
|
||||||
|
else:
|
||||||
|
# Adding the user message
|
||||||
|
user_input = fix_newlines(user_input)
|
||||||
|
if len(user_input) > 0:
|
||||||
|
rows.append(f"{prefix1}{user_input}{end_of_turn}\n")
|
||||||
|
|
||||||
|
# Adding the Character prefix
|
||||||
|
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
||||||
|
limit = 3
|
||||||
|
|
||||||
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
|
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
|
||||||
rows.pop(1)
|
rows.pop(1)
|
||||||
|
|
||||||
prompt = ''.join(rows)
|
prompt = ''.join(rows)
|
||||||
|
|
||||||
if also_return_rows:
|
if also_return_rows:
|
||||||
@@ -81,13 +91,20 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|||||||
if reply[-j:] == string[:j]:
|
if reply[-j:] == string[:j]:
|
||||||
reply = reply[:-j]
|
reply = reply[:-j]
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
reply = fix_newlines(reply)
|
reply = fix_newlines(reply)
|
||||||
return reply, next_character_found
|
return reply, next_character_found
|
||||||
|
|
||||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
|
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
||||||
just_started = True
|
if mode == 'instruct':
|
||||||
eos_token = '\n' if stop_at_newline else None
|
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||||
|
else:
|
||||||
|
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||||
|
|
||||||
|
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||||
name1_original = name1
|
name1_original = name1
|
||||||
if 'pygmalion' in shared.model_name.lower():
|
if 'pygmalion' in shared.model_name.lower():
|
||||||
name1 = "You"
|
name1 = "You"
|
||||||
@@ -104,14 +121,13 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
|
|||||||
|
|
||||||
if visible_text is None:
|
if visible_text is None:
|
||||||
visible_text = text
|
visible_text = text
|
||||||
if shared.args.chat:
|
|
||||||
visible_text = visible_text.replace('\n', '<br>')
|
|
||||||
text = apply_extensions(text, "input")
|
text = apply_extensions(text, "input")
|
||||||
|
|
||||||
|
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
|
||||||
if custom_generate_chat_prompt is None:
|
if custom_generate_chat_prompt is None:
|
||||||
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
|
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||||
else:
|
else:
|
||||||
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
|
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||||
|
|
||||||
# Yield *Is typing...*
|
# Yield *Is typing...*
|
||||||
if not regenerate:
|
if not regenerate:
|
||||||
@@ -119,17 +135,16 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
|
|||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
cumulative_reply = ''
|
cumulative_reply = ''
|
||||||
for i in range(chat_generation_attempts):
|
just_started = True
|
||||||
|
for i in range(generate_state['chat_generation_attempts']):
|
||||||
reply = None
|
reply = None
|
||||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
|
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||||
reply = cumulative_reply + reply
|
reply = cumulative_reply + reply
|
||||||
|
|
||||||
# Extracting the reply
|
# Extracting the reply
|
||||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
|
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
||||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
|
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
|
||||||
visible_reply = apply_extensions(visible_reply, "output")
|
visible_reply = apply_extensions(visible_reply, "output")
|
||||||
if shared.args.chat:
|
|
||||||
visible_reply = visible_reply.replace('\n', '<br>')
|
|
||||||
|
|
||||||
# We need this global variable to handle the Stop event,
|
# We need this global variable to handle the Stop event,
|
||||||
# otherwise gradio gets confused
|
# otherwise gradio gets confused
|
||||||
@@ -152,23 +167,27 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
|
|||||||
|
|
||||||
yield shared.history['visible']
|
yield shared.history['visible']
|
||||||
|
|
||||||
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
|
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||||
eos_token = '\n' if stop_at_newline else None
|
if mode == 'instruct':
|
||||||
|
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||||
|
else:
|
||||||
|
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||||
|
|
||||||
|
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||||
if 'pygmalion' in shared.model_name.lower():
|
if 'pygmalion' in shared.model_name.lower():
|
||||||
name1 = "You"
|
name1 = "You"
|
||||||
|
|
||||||
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
|
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
|
||||||
|
|
||||||
# Yield *Is typing...*
|
# Yield *Is typing...*
|
||||||
yield shared.processing_message
|
yield shared.processing_message
|
||||||
|
|
||||||
cumulative_reply = ''
|
cumulative_reply = ''
|
||||||
for i in range(chat_generation_attempts):
|
for i in range(generate_state['chat_generation_attempts']):
|
||||||
reply = None
|
reply = None
|
||||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
|
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||||
reply = cumulative_reply + reply
|
reply = cumulative_reply + reply
|
||||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
|
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
||||||
yield reply
|
yield reply
|
||||||
if next_character_found:
|
if next_character_found:
|
||||||
break
|
break
|
||||||
@@ -178,36 +197,30 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
|
|||||||
|
|
||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
|
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||||
for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts):
|
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||||
yield generate_chat_html(history, name1, name2, shared.character)
|
yield chat_html_wrapper(history, name1, name2, mode)
|
||||||
|
|
||||||
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
|
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||||
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
|
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
|
||||||
yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
else:
|
else:
|
||||||
last_visible = shared.history['visible'].pop()
|
last_visible = shared.history['visible'].pop()
|
||||||
last_internal = shared.history['internal'].pop()
|
last_internal = shared.history['internal'].pop()
|
||||||
# Yield '*Is typing...*'
|
# Yield '*Is typing...*'
|
||||||
yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
|
yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode)
|
||||||
for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True):
|
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
|
||||||
if shared.args.cai_chat:
|
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
||||||
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
else:
|
|
||||||
shared.history['visible'][-1] = (last_visible[0], history[-1][1])
|
|
||||||
yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
|
||||||
|
|
||||||
def remove_last_message(name1, name2):
|
def remove_last_message(name1, name2, mode):
|
||||||
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||||
last = shared.history['visible'].pop()
|
last = shared.history['visible'].pop()
|
||||||
shared.history['internal'].pop()
|
shared.history['internal'].pop()
|
||||||
else:
|
else:
|
||||||
last = ['', '']
|
last = ['', '']
|
||||||
|
|
||||||
if shared.args.cai_chat:
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
|
||||||
return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
|
|
||||||
else:
|
|
||||||
return shared.history['visible'], last[0]
|
|
||||||
|
|
||||||
def send_last_reply_to_input():
|
def send_last_reply_to_input():
|
||||||
if len(shared.history['internal']) > 0:
|
if len(shared.history['internal']) > 0:
|
||||||
@@ -215,20 +228,17 @@ def send_last_reply_to_input():
|
|||||||
else:
|
else:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
def replace_last_reply(text, name1, name2):
|
def replace_last_reply(text, name1, name2, mode):
|
||||||
if len(shared.history['visible']) > 0:
|
if len(shared.history['visible']) > 0:
|
||||||
if shared.args.cai_chat:
|
shared.history['visible'][-1][1] = text
|
||||||
shared.history['visible'][-1][1] = text
|
|
||||||
else:
|
|
||||||
shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
|
|
||||||
shared.history['internal'][-1][1] = apply_extensions(text, "input")
|
shared.history['internal'][-1][1] = apply_extensions(text, "input")
|
||||||
|
|
||||||
return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
def clear_html():
|
def clear_html():
|
||||||
return generate_chat_html([], "", "", shared.character)
|
return chat_html_wrapper([], "", "")
|
||||||
|
|
||||||
def clear_chat_log(name1, name2, greeting):
|
def clear_chat_log(name1, name2, greeting, mode):
|
||||||
shared.history['visible'] = []
|
shared.history['visible'] = []
|
||||||
shared.history['internal'] = []
|
shared.history['internal'] = []
|
||||||
|
|
||||||
@@ -236,12 +246,12 @@ def clear_chat_log(name1, name2, greeting):
|
|||||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||||
|
|
||||||
return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
def redraw_html(name1, name2):
|
def redraw_html(name1, name2, mode):
|
||||||
return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
def tokenize_dialogue(dialogue, name1, name2):
|
def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
dialogue = re.sub('<START>', '', dialogue)
|
dialogue = re.sub('<START>', '', dialogue)
|
||||||
@@ -326,15 +336,35 @@ def build_pygmalion_style_context(data):
|
|||||||
context = f"{context.strip()}\n<START>\n"
|
context = f"{context.strip()}\n<START>\n"
|
||||||
return context
|
return context
|
||||||
|
|
||||||
def load_character(character, name1, name2):
|
def generate_pfp_cache(character):
|
||||||
|
cache_folder = Path("cache")
|
||||||
|
if not cache_folder.exists():
|
||||||
|
cache_folder.mkdir()
|
||||||
|
|
||||||
|
for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
|
||||||
|
if path.exists():
|
||||||
|
img = make_thumbnail(Image.open(path))
|
||||||
|
img.save(Path('cache/pfp_character.png'), format='PNG')
|
||||||
|
return img
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_character(character, name1, name2, mode):
|
||||||
shared.character = character
|
shared.character = character
|
||||||
shared.history['internal'] = []
|
shared.history['internal'] = []
|
||||||
shared.history['visible'] = []
|
shared.history['visible'] = []
|
||||||
greeting = ""
|
context = greeting = end_of_turn = ""
|
||||||
|
greeting_field = 'greeting'
|
||||||
|
picture = None
|
||||||
|
|
||||||
|
# Deleting the profile picture cache, if any
|
||||||
|
if Path("cache/pfp_character.png").exists():
|
||||||
|
Path("cache/pfp_character.png").unlink()
|
||||||
|
|
||||||
if character != 'None':
|
if character != 'None':
|
||||||
|
folder = 'characters' if not mode == 'instruct' else 'characters/instruction-following'
|
||||||
|
picture = generate_pfp_cache(character)
|
||||||
for extension in ["yml", "yaml", "json"]:
|
for extension in ["yml", "yaml", "json"]:
|
||||||
filepath = Path(f'characters/{character}.{extension}')
|
filepath = Path(f'{folder}/{character}.{extension}')
|
||||||
if filepath.exists():
|
if filepath.exists():
|
||||||
break
|
break
|
||||||
file_contents = open(filepath, 'r', encoding='utf-8').read()
|
file_contents = open(filepath, 'r', encoding='utf-8').read()
|
||||||
@@ -350,19 +380,21 @@ def load_character(character, name1, name2):
|
|||||||
|
|
||||||
if 'context' in data:
|
if 'context' in data:
|
||||||
context = f"{data['context'].strip()}\n\n"
|
context = f"{data['context'].strip()}\n\n"
|
||||||
greeting_field = 'greeting'
|
elif "char_persona" in data:
|
||||||
else:
|
|
||||||
context = build_pygmalion_style_context(data)
|
context = build_pygmalion_style_context(data)
|
||||||
greeting_field = 'char_greeting'
|
greeting_field = 'char_greeting'
|
||||||
|
|
||||||
if 'example_dialogue' in data and data['example_dialogue'] != '':
|
if 'example_dialogue' in data:
|
||||||
context += f"{data['example_dialogue'].strip()}\n"
|
context += f"{data['example_dialogue'].strip()}\n"
|
||||||
if greeting_field in data and len(data[greeting_field].strip()) > 0:
|
if greeting_field in data:
|
||||||
greeting = data[greeting_field]
|
greeting = data[greeting_field]
|
||||||
|
if 'end_of_turn' in data:
|
||||||
|
end_of_turn = data['end_of_turn']
|
||||||
else:
|
else:
|
||||||
context = shared.settings['context']
|
context = shared.settings['context']
|
||||||
name2 = shared.settings['name2']
|
name2 = shared.settings['name2']
|
||||||
greeting = shared.settings['greeting']
|
greeting = shared.settings['greeting']
|
||||||
|
end_of_turn = shared.settings['end_of_turn']
|
||||||
|
|
||||||
if Path(f'logs/{shared.character}_persistent.json').exists():
|
if Path(f'logs/{shared.character}_persistent.json').exists():
|
||||||
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
|
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
|
||||||
@@ -370,13 +402,10 @@ def load_character(character, name1, name2):
|
|||||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||||
|
|
||||||
if shared.args.cai_chat:
|
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
||||||
return name1, name2, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
|
|
||||||
else:
|
|
||||||
return name1, name2, greeting, context, shared.history['visible']
|
|
||||||
|
|
||||||
def load_default_history(name1, name2):
|
def load_default_history(name1, name2):
|
||||||
load_character("None", name1, name2)
|
load_character("None", name1, name2, "chat")
|
||||||
|
|
||||||
def upload_character(json_file, img, tavern=False):
|
def upload_character(json_file, img, tavern=False):
|
||||||
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
|
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
|
||||||
@@ -404,7 +433,17 @@ def upload_tavern_character(img, name1, name2):
|
|||||||
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
|
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
|
||||||
return upload_character(json.dumps(_json), img, tavern=True)
|
return upload_character(json.dumps(_json), img, tavern=True)
|
||||||
|
|
||||||
def upload_your_profile_picture(img):
|
def upload_your_profile_picture(img, name1, name2, mode):
|
||||||
img = Image.open(io.BytesIO(img))
|
cache_folder = Path("cache")
|
||||||
img.save(Path('img_me.png'))
|
if not cache_folder.exists():
|
||||||
print('Profile picture saved to "img_me.png"')
|
cache_folder.mkdir()
|
||||||
|
|
||||||
|
if img == None:
|
||||||
|
if Path("cache/pfp_me.png").exists():
|
||||||
|
Path("cache/pfp_me.png").unlink()
|
||||||
|
else:
|
||||||
|
img = make_thumbnail(img)
|
||||||
|
img.save(Path('cache/pfp_me.png'))
|
||||||
|
print('Profile picture saved to "cache/pfp_me.png"')
|
||||||
|
|
||||||
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
||||||
|
|||||||
@@ -6,10 +6,11 @@ This is a library for formatting text outputs as nice HTML.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import markdown
|
import markdown
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
# This is to store the paths to the thumbnails of the profile pictures
|
# This is to store the paths to the thumbnails of the profile pictures
|
||||||
image_cache = {}
|
image_cache = {}
|
||||||
@@ -20,6 +21,8 @@ with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r')
|
|||||||
_4chan_css = css_f.read()
|
_4chan_css = css_f.read()
|
||||||
with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f:
|
with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f:
|
||||||
cai_css = f.read()
|
cai_css = f.read()
|
||||||
|
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
|
||||||
|
instruct_css = f.read()
|
||||||
|
|
||||||
def fix_newlines(string):
|
def fix_newlines(string):
|
||||||
string = string.replace('\n', '\n\n')
|
string = string.replace('\n', '\n\n')
|
||||||
@@ -95,6 +98,13 @@ def generate_4chan_html(f):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def make_thumbnail(image):
|
||||||
|
image = image.resize((350, round(image.size[1]/image.size[0]*350)), Image.Resampling.LANCZOS)
|
||||||
|
if image.size[1] > 470:
|
||||||
|
image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
def get_image_cache(path):
|
def get_image_cache(path):
|
||||||
cache_folder = Path("cache")
|
cache_folder = Path("cache")
|
||||||
if not cache_folder.exists():
|
if not cache_folder.exists():
|
||||||
@@ -102,26 +112,52 @@ def get_image_cache(path):
|
|||||||
|
|
||||||
mtime = os.stat(path).st_mtime
|
mtime = os.stat(path).st_mtime
|
||||||
if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache):
|
if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache):
|
||||||
img = Image.open(path)
|
img = make_thumbnail(Image.open(path))
|
||||||
img.thumbnail((200, 200))
|
|
||||||
output_file = Path(f'cache/{path.name}_cache.png')
|
output_file = Path(f'cache/{path.name}_cache.png')
|
||||||
img.convert('RGB').save(output_file, format='PNG')
|
img.convert('RGB').save(output_file, format='PNG')
|
||||||
image_cache[path] = [mtime, output_file.as_posix()]
|
image_cache[path] = [mtime, output_file.as_posix()]
|
||||||
|
|
||||||
return image_cache[path][1]
|
return image_cache[path][1]
|
||||||
|
|
||||||
def load_html_image(paths):
|
def generate_instruct_html(history):
|
||||||
for str_path in paths:
|
output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
|
||||||
path = Path(str_path)
|
for i,_row in enumerate(history[::-1]):
|
||||||
if path.exists():
|
row = [convert_to_markdown(entry) for entry in _row]
|
||||||
return f'<img src="file/{get_image_cache(path)}">'
|
|
||||||
return ''
|
|
||||||
|
|
||||||
def generate_chat_html(history, name1, name2, character):
|
output += f"""
|
||||||
|
<div class="assistant-message">
|
||||||
|
<div class="text">
|
||||||
|
<div class="message-body">
|
||||||
|
{row[1]}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(row[0]) == 0: # don't display empty user messages
|
||||||
|
continue
|
||||||
|
|
||||||
|
output += f"""
|
||||||
|
<div class="user-message">
|
||||||
|
<div class="text">
|
||||||
|
<div class="message-body">
|
||||||
|
{row[0]}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
"""
|
||||||
|
|
||||||
|
output += "</div>"
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
||||||
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
|
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
|
||||||
|
|
||||||
img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"])
|
# The time.time() is to prevent the brower from caching the image
|
||||||
img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"])
|
suffix = f"?{time.time()}" if reset_cache else f"?{name2}"
|
||||||
|
img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
|
||||||
|
img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
|
||||||
|
|
||||||
for i,_row in enumerate(history[::-1]):
|
for i,_row in enumerate(history[::-1]):
|
||||||
row = [convert_to_markdown(entry) for entry in _row]
|
row = [convert_to_markdown(entry) for entry in _row]
|
||||||
@@ -163,3 +199,16 @@ def generate_chat_html(history, name1, name2, character):
|
|||||||
|
|
||||||
output += "</div>"
|
output += "</div>"
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def generate_chat_html(history, name1, name2):
|
||||||
|
return generate_cai_chat_html(history, name1, name2)
|
||||||
|
|
||||||
|
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
|
||||||
|
if mode == "cai-chat":
|
||||||
|
return generate_cai_chat_html(history, name1, name2, reset_cache)
|
||||||
|
elif mode == "chat":
|
||||||
|
return generate_chat_html(history, name1, name2)
|
||||||
|
elif mode == "instruct":
|
||||||
|
return generate_instruct_html(history)
|
||||||
|
else:
|
||||||
|
return ''
|
||||||
|
|||||||
65
modules/llamacpp_model_alternative.py
Normal file
65
modules/llamacpp_model_alternative.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
'''
|
||||||
|
Based on
|
||||||
|
https://github.com/abetlen/llama-cpp-python
|
||||||
|
|
||||||
|
Documentation:
|
||||||
|
https://abetlen.github.io/llama-cpp-python/
|
||||||
|
'''
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
from llama_cpp import Llama
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.callbacks import Iteratorize
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaCppModel:
|
||||||
|
def __init__(self):
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, path):
|
||||||
|
result = self()
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'model_path': str(path),
|
||||||
|
'n_ctx': 2048,
|
||||||
|
'seed': 0,
|
||||||
|
'n_threads': shared.args.threads or None
|
||||||
|
}
|
||||||
|
self.model = Llama(**params)
|
||||||
|
|
||||||
|
# This is ugly, but the model and the tokenizer are the same object in this library.
|
||||||
|
return result, result
|
||||||
|
|
||||||
|
def encode(self, string):
|
||||||
|
if type(string) is str:
|
||||||
|
string = string.encode()
|
||||||
|
return self.model.tokenize(string)
|
||||||
|
|
||||||
|
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
|
||||||
|
if type(context) is str:
|
||||||
|
context = context.encode()
|
||||||
|
tokens = self.model.tokenize(context)
|
||||||
|
|
||||||
|
output = b""
|
||||||
|
count = 0
|
||||||
|
for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
|
||||||
|
text = self.model.detokenize([token])
|
||||||
|
output += text
|
||||||
|
if callback:
|
||||||
|
callback(text.decode())
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
if count >= token_count or (token == self.model.token_eos()):
|
||||||
|
break
|
||||||
|
|
||||||
|
return output.decode()
|
||||||
|
|
||||||
|
def generate_with_streaming(self, **kwargs):
|
||||||
|
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
||||||
|
reply = ''
|
||||||
|
for token in generator:
|
||||||
|
reply += token
|
||||||
|
yield reply
|
||||||
@@ -10,7 +10,7 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
from accelerate import infer_auto_device_map, init_empty_weights
|
from accelerate import infer_auto_device_map, init_empty_weights
|
||||||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||||
BitsAndBytesConfig)
|
BitsAndBytesConfig, LlamaTokenizer)
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ def load_model(model_name):
|
|||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
shared.is_RWKV = 'rwkv-' in model_name.lower()
|
shared.is_RWKV = 'rwkv-' in model_name.lower()
|
||||||
shared.is_llamacpp = len(list(Path(f'models/{model_name}').glob('ggml*.bin'))) > 0
|
shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0
|
||||||
|
|
||||||
# Default settings
|
# Default settings
|
||||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
|
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
|
||||||
@@ -103,9 +103,9 @@ def load_model(model_name):
|
|||||||
|
|
||||||
# llamacpp model
|
# llamacpp model
|
||||||
elif shared.is_llamacpp:
|
elif shared.is_llamacpp:
|
||||||
from modules.llamacpp_model import LlamaCppModel
|
from modules.llamacpp_model_alternative import LlamaCppModel
|
||||||
|
|
||||||
model_file = list(Path(f'models/{model_name}').glob('ggml*.bin'))[0]
|
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0]
|
||||||
print(f"llama.cpp weights detected: {model_file}\n")
|
print(f"llama.cpp weights detected: {model_file}\n")
|
||||||
|
|
||||||
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
||||||
@@ -172,6 +172,8 @@ def load_model(model_name):
|
|||||||
# Loading the tokenizer
|
# Loading the tokenizer
|
||||||
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||||
|
elif type(model) is transformers.LlamaForCausalLM:
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
||||||
tokenizer.truncation_side = 'left'
|
tokenizer.truncation_side = 'left'
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ settings = {
|
|||||||
'name2': 'Assistant',
|
'name2': 'Assistant',
|
||||||
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
|
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
|
||||||
'greeting': 'Hello there!',
|
'greeting': 'Hello there!',
|
||||||
|
'end_of_turn': '',
|
||||||
'stop_at_newline': False,
|
'stop_at_newline': False,
|
||||||
'chat_prompt_size': 2048,
|
'chat_prompt_size': 2048,
|
||||||
'chat_prompt_size_min': 0,
|
'chat_prompt_size_min': 0,
|
||||||
@@ -44,6 +45,7 @@ settings = {
|
|||||||
'chat_default_extensions': ["gallery"],
|
'chat_default_extensions': ["gallery"],
|
||||||
'presets': {
|
'presets': {
|
||||||
'default': 'NovelAI-Sphinx Moth',
|
'default': 'NovelAI-Sphinx Moth',
|
||||||
|
'.*(alpaca|llama)': "LLaMA-Precise",
|
||||||
'.*pygmalion': 'NovelAI-Storywriter',
|
'.*pygmalion': 'NovelAI-Storywriter',
|
||||||
'.*RWKV': 'Naive',
|
'.*RWKV': 'Naive',
|
||||||
},
|
},
|
||||||
@@ -73,8 +75,8 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
|
|||||||
|
|
||||||
# Basic settings
|
# Basic settings
|
||||||
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
|
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
|
||||||
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
|
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.')
|
||||||
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
|
parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.')
|
||||||
parser.add_argument('--model', type=str, help='Name of the model to load by default.')
|
parser.add_argument('--model', type=str, help='Name of the model to load by default.')
|
||||||
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
||||||
parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
|
parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
|
||||||
@@ -131,12 +133,17 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Provisional, this will be deleted later
|
# Deprecation warnings for parameters that have been renamed
|
||||||
deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]}
|
deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]}
|
||||||
for k in deprecated_dict:
|
for k in deprecated_dict:
|
||||||
if eval(f"args.{k}") != deprecated_dict[k][1]:
|
if eval(f"args.{k}") != deprecated_dict[k][1]:
|
||||||
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
|
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
|
||||||
exec(f"args.{deprecated_dict[k][0]} = args.{k}")
|
exec(f"args.{deprecated_dict[k][0]} = args.{k}")
|
||||||
|
|
||||||
|
# Deprecation warnings for parameters that have been removed
|
||||||
|
if args.cai_chat:
|
||||||
|
print("Warning: --cai-chat is deprecated. Use --chat instead.")
|
||||||
|
args.chat = True
|
||||||
|
|
||||||
def is_chat():
|
def is_chat():
|
||||||
return any((args.chat, args.cai_chat))
|
return args.chat
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
|||||||
return input_ids
|
return input_ids
|
||||||
else:
|
else:
|
||||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
|
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
|
||||||
|
|
||||||
|
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
||||||
|
input_ids = input_ids[:,1:]
|
||||||
|
|
||||||
if shared.args.cpu:
|
if shared.args.cpu:
|
||||||
return input_ids
|
return input_ids
|
||||||
elif shared.args.flexgen:
|
elif shared.args.flexgen:
|
||||||
@@ -102,10 +106,11 @@ def set_manual_seed(seed):
|
|||||||
def stop_everything_event():
|
def stop_everything_event():
|
||||||
shared.stop_everything = True
|
shared.stop_everything = True
|
||||||
|
|
||||||
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
|
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
set_manual_seed(seed)
|
set_manual_seed(generate_state['seed'])
|
||||||
shared.stop_everything = False
|
shared.stop_everything = False
|
||||||
|
generate_params = {}
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
original_question = question
|
original_question = question
|
||||||
@@ -117,9 +122,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
# These models are not part of Hugging Face, so we handle them
|
# These models are not part of Hugging Face, so we handle them
|
||||||
# separately and terminate the function call earlier
|
# separately and terminate the function call earlier
|
||||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||||
|
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||||
|
generate_params[k] = generate_state[k]
|
||||||
|
generate_params["token_count"] = generate_state["max_new_tokens"]
|
||||||
try:
|
try:
|
||||||
if shared.args.no_stream:
|
if shared.args.no_stream:
|
||||||
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
|
reply = shared.model.generate(context=question, **generate_params)
|
||||||
output = original_question+reply
|
output = original_question+reply
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, "output")
|
||||||
@@ -130,7 +138,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
|
|
||||||
# RWKV has proper streaming, which is very nice.
|
# RWKV has proper streaming, which is very nice.
|
||||||
# No need to generate 8 tokens at a time.
|
# No need to generate 8 tokens at a time.
|
||||||
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty):
|
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||||
output = original_question+reply
|
output = original_question+reply
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, "output")
|
||||||
@@ -145,7 +153,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
|
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
|
||||||
return
|
return
|
||||||
|
|
||||||
input_ids = encode(question, max_new_tokens)
|
input_ids = encode(question, generate_state['max_new_tokens'])
|
||||||
original_input_ids = input_ids
|
original_input_ids = input_ids
|
||||||
output = input_ids[0]
|
output = input_ids[0]
|
||||||
|
|
||||||
@@ -158,33 +166,21 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
|
t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
|
||||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
|
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
|
||||||
|
|
||||||
generate_params = {}
|
generate_params["max_new_tokens"] = generate_state['max_new_tokens']
|
||||||
if not shared.args.flexgen:
|
if not shared.args.flexgen:
|
||||||
generate_params.update({
|
for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]:
|
||||||
"max_new_tokens": max_new_tokens,
|
generate_params[k] = generate_state[k]
|
||||||
"eos_token_id": eos_token_ids,
|
generate_params["eos_token_id"] = eos_token_ids
|
||||||
"stopping_criteria": stopping_criteria_list,
|
generate_params["stopping_criteria"] = stopping_criteria_list
|
||||||
"do_sample": do_sample,
|
if shared.args.no_stream:
|
||||||
"temperature": temperature,
|
generate_params["min_length"] = 0
|
||||||
"top_p": top_p,
|
|
||||||
"typical_p": typical_p,
|
|
||||||
"repetition_penalty": repetition_penalty,
|
|
||||||
"encoder_repetition_penalty": encoder_repetition_penalty,
|
|
||||||
"top_k": top_k,
|
|
||||||
"min_length": min_length if shared.args.no_stream else 0,
|
|
||||||
"no_repeat_ngram_size": no_repeat_ngram_size,
|
|
||||||
"num_beams": num_beams,
|
|
||||||
"penalty_alpha": penalty_alpha,
|
|
||||||
"length_penalty": length_penalty,
|
|
||||||
"early_stopping": early_stopping,
|
|
||||||
})
|
|
||||||
else:
|
else:
|
||||||
generate_params.update({
|
for k in ["do_sample", "temperature"]:
|
||||||
"max_new_tokens": max_new_tokens if shared.args.no_stream else 8,
|
generate_params[k] = generate_state[k]
|
||||||
"do_sample": do_sample,
|
generate_params["stop"] = generate_state["eos_token_ids"][-1]
|
||||||
"temperature": temperature,
|
if not shared.args.no_stream:
|
||||||
"stop": eos_token_ids[-1],
|
generate_params["max_new_tokens"] = 8
|
||||||
})
|
|
||||||
if shared.args.no_cache:
|
if shared.args.no_cache:
|
||||||
generate_params.update({"use_cache": False})
|
generate_params.update({"use_cache": False})
|
||||||
if shared.args.deepspeed:
|
if shared.args.deepspeed:
|
||||||
@@ -244,7 +240,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
|
|
||||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||||
else:
|
else:
|
||||||
for i in range(max_new_tokens//8+1):
|
for i in range(generate_state['max_new_tokens']//8+1):
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = shared.model.generate(**generate_params)[0]
|
output = shared.model.generate(**generate_params)[0]
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ MAX_STEPS = 0
|
|||||||
CURRENT_GRADIENT_ACCUM = 1
|
CURRENT_GRADIENT_ACCUM = 1
|
||||||
|
|
||||||
def get_dataset(path: str, ext: str):
|
def get_dataset(path: str, ext: str):
|
||||||
return ['None'] + sorted(set((k.stem for k in Path(path).glob(f'*.{ext}'))), key=str.lower)
|
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
||||||
|
|
||||||
def create_train_interface():
|
def create_train_interface():
|
||||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||||
@@ -45,22 +45,26 @@ def create_train_interface():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
||||||
ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||||
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The dataset file used to evaluate the model after training.')
|
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
|
||||||
ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||||
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
||||||
ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
|
ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
|
||||||
|
|
||||||
with gr.Tab(label="Raw Text File"):
|
with gr.Tab(label="Raw Text File"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
||||||
ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
|
ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
|
||||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0,maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length above). Setting overlap to exactly half the cutoff length may be ideal.')
|
with gr.Row():
|
||||||
|
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||||
|
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_button = gr.Button("Start LoRA Training")
|
start_button = gr.Button("Start LoRA Training")
|
||||||
stop_button = gr.Button("Interrupt")
|
stop_button = gr.Button("Interrupt")
|
||||||
|
|
||||||
output = gr.Markdown(value="Ready")
|
output = gr.Markdown(value="Ready")
|
||||||
start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len], [output])
|
start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout,
|
||||||
|
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
|
||||||
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
|
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
|
||||||
|
|
||||||
def do_interrupt():
|
def do_interrupt():
|
||||||
@@ -91,8 +95,8 @@ def clean_path(base_path: str, path: str):
|
|||||||
return path
|
return path
|
||||||
return f'{Path(base_path).absolute()}/{path}'
|
return f'{Path(base_path).absolute()}/{path}'
|
||||||
|
|
||||||
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int,
|
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
|
||||||
lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int):
|
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
|
||||||
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
|
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
|
||||||
WANT_INTERRUPT = False
|
WANT_INTERRUPT = False
|
||||||
CURRENT_STEPS = 0
|
CURRENT_STEPS = 0
|
||||||
@@ -103,6 +107,25 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
|
lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
|
||||||
actual_lr = float(learning_rate)
|
actual_lr = float(learning_rate)
|
||||||
|
|
||||||
|
model_type = type(shared.model).__name__
|
||||||
|
if model_type != "LlamaForCausalLM":
|
||||||
|
if model_type == "PeftModelForCausalLM":
|
||||||
|
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||||
|
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
|
||||||
|
else:
|
||||||
|
yield "LoRA training has only currently been validated for LLaMA models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||||
|
print(f"Warning: LoRA training has only currently been validated for LLaMA models. (Found model type: {model_type})")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
if shared.args.wbits > 0 or shared.args.gptq_bits > 0:
|
||||||
|
yield "LoRA training does not yet support 4bit. Please use `--load-in-8bit` for now."
|
||||||
|
return
|
||||||
|
|
||||||
|
elif not shared.args.load_in_8bit:
|
||||||
|
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
|
||||||
|
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
||||||
|
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
||||||
|
|
||||||
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
||||||
yield "Cannot input zeroes."
|
yield "Cannot input zeroes."
|
||||||
return
|
return
|
||||||
@@ -126,15 +149,20 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
raw_text = file.read()
|
raw_text = file.read()
|
||||||
tokens = shared.tokenizer.encode(raw_text)
|
tokens = shared.tokenizer.encode(raw_text)
|
||||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||||
|
|
||||||
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
|
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
|
||||||
for i in range(1, len(tokens)):
|
for i in range(1, len(tokens)):
|
||||||
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
|
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
|
||||||
text_chunks = [shared.tokenizer.decode(x) for x in tokens]
|
text_chunks = [shared.tokenizer.decode(x) for x in tokens]
|
||||||
del tokens
|
del tokens
|
||||||
data = Dataset.from_list([tokenize(x) for x in text_chunks])
|
|
||||||
train_data = data.shuffle()
|
if newline_favor_len > 0:
|
||||||
eval_data = None
|
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
|
||||||
|
|
||||||
|
train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
|
||||||
del text_chunks
|
del text_chunks
|
||||||
|
train_data = train_data.shuffle()
|
||||||
|
eval_data = None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if dataset in ['None', '']:
|
if dataset in ['None', '']:
|
||||||
@@ -232,33 +260,37 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
# TODO: save/load checkpoints to resume from?
|
# TODO: save/load checkpoints to resume from?
|
||||||
print("Starting training...")
|
print("Starting training...")
|
||||||
yield "Starting..."
|
yield "Starting..."
|
||||||
|
if WANT_INTERRUPT:
|
||||||
|
yield "Interrupted before start."
|
||||||
|
return
|
||||||
|
|
||||||
def threadedRun():
|
def threaded_run():
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
thread = threading.Thread(target=threadedRun)
|
thread = threading.Thread(target=threaded_run)
|
||||||
thread.start()
|
thread.start()
|
||||||
lastStep = 0
|
last_step = 0
|
||||||
startTime = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
while thread.is_alive():
|
while thread.is_alive():
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
|
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
|
||||||
elif CURRENT_STEPS != lastStep:
|
|
||||||
lastStep = CURRENT_STEPS
|
elif CURRENT_STEPS != last_step:
|
||||||
timeElapsed = time.perf_counter() - startTime
|
last_step = CURRENT_STEPS
|
||||||
if timeElapsed <= 0:
|
time_elapsed = time.perf_counter() - start_time
|
||||||
timerInfo = ""
|
if time_elapsed <= 0:
|
||||||
totalTimeEstimate = 999
|
timer_info = ""
|
||||||
|
total_time_estimate = 999
|
||||||
else:
|
else:
|
||||||
its = CURRENT_STEPS / timeElapsed
|
its = CURRENT_STEPS / time_elapsed
|
||||||
if its > 1:
|
if its > 1:
|
||||||
timerInfo = f"`{its:.2f}` it/s"
|
timer_info = f"`{its:.2f}` it/s"
|
||||||
else:
|
else:
|
||||||
timerInfo = f"`{1.0/its:.2f}` s/it"
|
timer_info = f"`{1.0/its:.2f}` s/it"
|
||||||
totalTimeEstimate = (1.0/its) * (MAX_STEPS)
|
total_time_estimate = (1.0/its) * (MAX_STEPS)
|
||||||
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
|
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
||||||
|
|
||||||
print("Training complete, saving...")
|
print("Training complete, saving...")
|
||||||
lora_model.save_pretrained(lora_name)
|
lora_model.save_pretrained(lora_name)
|
||||||
@@ -273,3 +305,25 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
def split_chunks(arr, step):
|
def split_chunks(arr, step):
|
||||||
for i in range(0, len(arr), step):
|
for i in range(0, len(arr), step):
|
||||||
yield arr[i:i + step]
|
yield arr[i:i + step]
|
||||||
|
|
||||||
|
def cut_chunk_for_newline(chunk: str, max_length: int):
|
||||||
|
if '\n' not in chunk:
|
||||||
|
return chunk
|
||||||
|
first_newline = chunk.index('\n')
|
||||||
|
if first_newline < max_length:
|
||||||
|
chunk = chunk[first_newline + 1:]
|
||||||
|
if '\n' not in chunk:
|
||||||
|
return chunk
|
||||||
|
last_newline = chunk.rindex('\n')
|
||||||
|
if len(chunk) - last_newline < max_length:
|
||||||
|
chunk = chunk[:last_newline]
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
def format_time(seconds: float):
|
||||||
|
if seconds < 120:
|
||||||
|
return f"`{seconds:.0f}` seconds"
|
||||||
|
minutes = seconds / 60
|
||||||
|
if minutes < 120:
|
||||||
|
return f"`{minutes:.0f}` minutes"
|
||||||
|
hours = minutes / 60
|
||||||
|
return f"`{hours:.0f}` hours"
|
||||||
|
|||||||
6
presets/LLaMA-Precise.txt
Normal file
6
presets/LLaMA-Precise.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
do_sample=True
|
||||||
|
top_p=0.1
|
||||||
|
top_k=40
|
||||||
|
temperature=0.7
|
||||||
|
repetition_penalty=1.18
|
||||||
|
typical_p=1.0
|
||||||
@@ -3,12 +3,11 @@ bitsandbytes==0.37.2
|
|||||||
datasets
|
datasets
|
||||||
flexgen==0.1.7
|
flexgen==0.1.7
|
||||||
gradio==3.24.1
|
gradio==3.24.1
|
||||||
llamacpp==0.1.11
|
|
||||||
markdown
|
markdown
|
||||||
numpy
|
numpy
|
||||||
peft==0.2.0
|
peft==0.2.0
|
||||||
requests
|
requests
|
||||||
rwkv==0.7.2
|
rwkv==0.7.3
|
||||||
safetensors==0.3.0
|
safetensors==0.3.0
|
||||||
sentencepiece
|
sentencepiece
|
||||||
pyyaml
|
pyyaml
|
||||||
|
|||||||
156
server.py
156
server.py
@@ -1,3 +1,7 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
@@ -8,10 +12,11 @@ from datetime import datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
from modules import chat, shared, training, ui
|
from modules import chat, shared, training, ui, api
|
||||||
from modules.html_generator import generate_chat_html
|
from modules.html_generator import chat_html_wrapper
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, load_soft_prompt
|
from modules.models import load_model, load_soft_prompt
|
||||||
from modules.text_generation import (clear_torch_cache, generate_reply,
|
from modules.text_generation import (clear_torch_cache, generate_reply,
|
||||||
@@ -47,6 +52,13 @@ def get_available_prompts():
|
|||||||
|
|
||||||
def get_available_characters():
|
def get_available_characters():
|
||||||
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||||
|
return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
|
||||||
|
|
||||||
|
def get_available_instruction_templates():
|
||||||
|
path = "characters/instruction-following"
|
||||||
|
paths = []
|
||||||
|
if os.path.exists(path):
|
||||||
|
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||||
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
|
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
|
||||||
|
|
||||||
def get_available_extensions():
|
def get_available_extensions():
|
||||||
@@ -76,7 +88,7 @@ def load_lora_wrapper(selected_lora):
|
|||||||
add_lora_to_model(selected_lora)
|
add_lora_to_model(selected_lora)
|
||||||
return selected_lora
|
return selected_lora
|
||||||
|
|
||||||
def load_preset_values(preset_menu, return_dict=False):
|
def load_preset_values(preset_menu, state, return_dict=False):
|
||||||
generate_params = {
|
generate_params = {
|
||||||
'do_sample': True,
|
'do_sample': True,
|
||||||
'temperature': 1,
|
'temperature': 1,
|
||||||
@@ -98,13 +110,13 @@ def load_preset_values(preset_menu, return_dict=False):
|
|||||||
i = i.rstrip(',').strip().split('=')
|
i = i.rstrip(',').strip().split('=')
|
||||||
if len(i) == 2 and i[0].strip() != 'tokens':
|
if len(i) == 2 and i[0].strip() != 'tokens':
|
||||||
generate_params[i[0].strip()] = eval(i[1].strip())
|
generate_params[i[0].strip()] = eval(i[1].strip())
|
||||||
|
|
||||||
generate_params['temperature'] = min(1.99, generate_params['temperature'])
|
generate_params['temperature'] = min(1.99, generate_params['temperature'])
|
||||||
|
|
||||||
if return_dict:
|
if return_dict:
|
||||||
return generate_params
|
return generate_params
|
||||||
else:
|
else:
|
||||||
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
|
state.update(generate_params)
|
||||||
|
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||||
|
|
||||||
def upload_soft_prompt(file):
|
def upload_soft_prompt(file):
|
||||||
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
||||||
@@ -118,19 +130,8 @@ def upload_soft_prompt(file):
|
|||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def create_model_and_preset_menus():
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Row():
|
|
||||||
shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
|
|
||||||
ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Row():
|
|
||||||
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
|
||||||
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
|
|
||||||
|
|
||||||
def save_prompt(text):
|
def save_prompt(text):
|
||||||
fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
|
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
|
||||||
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
|
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
|
||||||
f.write(text)
|
f.write(text)
|
||||||
return f"Saved to prompts/{fname}"
|
return f"Saved to prompts/{fname}"
|
||||||
@@ -160,12 +161,31 @@ def create_prompt_menus():
|
|||||||
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
|
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
|
||||||
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
||||||
|
|
||||||
|
def create_model_menus():
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
|
||||||
|
ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
||||||
|
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
|
||||||
|
|
||||||
|
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
||||||
|
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
||||||
|
|
||||||
def create_settings_menus(default_preset):
|
def create_settings_menus(default_preset):
|
||||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
|
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
||||||
|
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
|
||||||
|
generate_params[k] = shared.settings[k]
|
||||||
|
shared.gradio['generate_state'] = gr.State(generate_params)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
create_model_and_preset_menus()
|
with gr.Row():
|
||||||
|
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
||||||
|
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
|
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
|
||||||
|
|
||||||
@@ -199,9 +219,6 @@ def create_settings_menus(default_preset):
|
|||||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
|
||||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
|
|
||||||
|
|
||||||
with gr.Accordion('Soft prompt', open=False):
|
with gr.Accordion('Soft prompt', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@@ -212,17 +229,14 @@ def create_settings_menus(default_preset):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
||||||
|
|
||||||
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
|
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
||||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
|
||||||
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
|
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
|
||||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
|
|
||||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
|
|
||||||
|
|
||||||
def set_interface_arguments(interface_mode, extensions, bool_active):
|
def set_interface_arguments(interface_mode, extensions, bool_active):
|
||||||
modes = ["default", "notebook", "chat", "cai_chat"]
|
modes = ["default", "notebook", "chat", "cai_chat"]
|
||||||
cmd_list = vars(shared.args)
|
cmd_list = vars(shared.args)
|
||||||
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
|
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
|
||||||
#int_list = [k for k in cmd_list if type(k) is int]
|
|
||||||
|
|
||||||
shared.args.extensions = extensions
|
shared.args.extensions = extensions
|
||||||
for k in modes[1:]:
|
for k in modes[1:]:
|
||||||
@@ -295,10 +309,7 @@ def create_interface():
|
|||||||
if shared.is_chat():
|
if shared.is_chat():
|
||||||
shared.gradio['Chat input'] = gr.State()
|
shared.gradio['Chat input'] = gr.State()
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
if shared.args.cai_chat:
|
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
|
||||||
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
|
|
||||||
else:
|
|
||||||
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot")
|
|
||||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Generate'] = gr.Button('Generate')
|
shared.gradio['Generate'] = gr.Button('Generate')
|
||||||
@@ -315,11 +326,20 @@ def create_interface():
|
|||||||
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
||||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||||
|
|
||||||
|
shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
|
||||||
|
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False)
|
||||||
|
|
||||||
with gr.Tab("Character", elem_id="chat-settings"):
|
with gr.Tab("Character", elem_id="chat-settings"):
|
||||||
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
|
with gr.Row():
|
||||||
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
|
with gr.Column(scale=8):
|
||||||
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting')
|
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
|
||||||
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context')
|
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
|
||||||
|
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
|
||||||
|
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
|
||||||
|
shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings["end_of_turn"], lines=1, label='End of turn string')
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil")
|
||||||
|
shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
|
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
|
||||||
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
|
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
|
||||||
@@ -347,8 +367,6 @@ def create_interface():
|
|||||||
|
|
||||||
gr.Markdown("# TavernAI PNG format")
|
gr.Markdown("# TavernAI PNG format")
|
||||||
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
|
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
|
||||||
with gr.Tab('Upload your profile picture'):
|
|
||||||
shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
|
|
||||||
|
|
||||||
with gr.Tab("Parameters", elem_id="parameters"):
|
with gr.Tab("Parameters", elem_id="parameters"):
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
@@ -359,35 +377,35 @@ def create_interface():
|
|||||||
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
||||||
shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
|
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
|
||||||
|
|
||||||
create_settings_menus(default_preset)
|
create_settings_menus(default_preset)
|
||||||
|
|
||||||
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
|
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
|
||||||
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
|
|
||||||
|
|
||||||
def set_chat_input(textbox):
|
def set_chat_input(textbox):
|
||||||
return textbox, ""
|
return textbox, ""
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
||||||
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||||
gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
||||||
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
|
|
||||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||||
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
|
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||||
|
|
||||||
# Clear history with confirmation
|
# Clear history with confirmation
|
||||||
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
||||||
shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||||
shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||||
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2'], shared.gradio['greeting']], shared.gradio['display'])
|
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
|
||||||
shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||||
|
shared.gradio['Chat mode'].change(lambda x : gr.update(visible= x=='instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
|
||||||
|
|
||||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||||
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
|
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
|
||||||
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
||||||
|
|
||||||
@@ -399,20 +417,21 @@ def create_interface():
|
|||||||
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
||||||
shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||||
|
|
||||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'context', 'display']])
|
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||||
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
|
shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||||
|
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], [])
|
||||||
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
||||||
shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
|
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
|
||||||
|
|
||||||
reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
|
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
|
||||||
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
|
shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||||
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||||
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||||
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
|
shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||||
|
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
|
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
|
||||||
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
|
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
|
||||||
|
|
||||||
elif shared.args.notebook:
|
elif shared.args.notebook:
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
@@ -442,9 +461,9 @@ def create_interface():
|
|||||||
with gr.Tab("Parameters", elem_id="parameters"):
|
with gr.Tab("Parameters", elem_id="parameters"):
|
||||||
create_settings_menus(default_preset)
|
create_settings_menus(default_preset)
|
||||||
|
|
||||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
|
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
||||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
@@ -475,14 +494,17 @@ def create_interface():
|
|||||||
with gr.Tab("Parameters", elem_id="parameters"):
|
with gr.Tab("Parameters", elem_id="parameters"):
|
||||||
create_settings_menus(default_preset)
|
create_settings_menus(default_preset)
|
||||||
|
|
||||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
|
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
||||||
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
||||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
||||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
|
with gr.Tab("Model", elem_id="model-tab"):
|
||||||
|
create_model_menus()
|
||||||
|
|
||||||
with gr.Tab("Training", elem_id="training-tab"):
|
with gr.Tab("Training", elem_id="training-tab"):
|
||||||
training.create_train_interface()
|
training.create_train_interface()
|
||||||
|
|
||||||
@@ -496,7 +518,6 @@ def create_interface():
|
|||||||
cmd_list = vars(shared.args)
|
cmd_list = vars(shared.args)
|
||||||
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
|
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
|
||||||
bool_active = [k for k in bool_list if vars(shared.args)[k]]
|
bool_active = [k for k in bool_list if vars(shared.args)[k]]
|
||||||
#int_list = [k for k in cmd_list if type(k) is int]
|
|
||||||
|
|
||||||
gr.Markdown("*Experimental*")
|
gr.Markdown("*Experimental*")
|
||||||
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
|
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
|
||||||
@@ -510,6 +531,21 @@ def create_interface():
|
|||||||
if shared.args.extensions is not None:
|
if shared.args.extensions is not None:
|
||||||
extensions_module.create_extensions_block()
|
extensions_module.create_extensions_block()
|
||||||
|
|
||||||
|
def change_dict_value(d, key, value):
|
||||||
|
d[key] = value
|
||||||
|
return d
|
||||||
|
|
||||||
|
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
|
||||||
|
if k not in shared.gradio:
|
||||||
|
continue
|
||||||
|
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
|
||||||
|
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
||||||
|
else:
|
||||||
|
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
||||||
|
|
||||||
|
if not shared.is_chat():
|
||||||
|
api.create_apis()
|
||||||
|
|
||||||
# Authentication
|
# Authentication
|
||||||
auth = None
|
auth = None
|
||||||
if shared.args.gradio_auth_path is not None:
|
if shared.args.gradio_auth_path is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user