Публічний доступ
1
0

Make model loading more transparent

Цей коміт міститься в:
oobabooga
2023-01-06 01:41:52 -03:00
джерело c65bad40dc
коміт 285032da36
2 змінених файлів з 9 додано та 11 видалено

Переглянути файл

@@ -46,15 +46,11 @@ The files that you need to download and put under `models/model-name` (for insta
## Converting to pytorch
This webui allows you to switch between different models on the fly, so it must be fast to load the models from disk.
One way to make this process about 10x faster is to convert the models to pytorch format using the script `convert-to-torch.py`. Create a folder called `torch-dumps` and then make the conversion with:
The script `convert-to-torch.py` allows you to convert models to .pt format, which is about 10x faster to load:
python convert-to-torch.py models/model-name/
The output model will be saved to `torch-dumps/model-name.pt`. This is the default way to load all models except for `gpt-neox-20b`, `opt-13b`, `OPT-13B-Erebus`, `gpt-j-6B`, and `flan-t5`. I don't remember why these models are exceptions.
If I get enough ⭐s on this repository, I will make the process of loading models saner and more customizable.
The output model will be saved to `torch-dumps/model-name.pt`. When you load a new model from the webui, it will first look for this .pt file; if it is not found, it will load the model as usual from `models/model-name/`.
## Starting the webui

Переглянути файл

@@ -1,3 +1,4 @@
import os
import re
import time
import glob
@@ -20,17 +21,18 @@ model_name = 'galactica-6.7b'
settings_name = "Default"
def load_model(model_name):
print(f"Loading {model_name}")
print(f"Loading {model_name}...")
t0 = time.time()
if model_name in ['gpt-neox-20b', 'opt-13b', 'OPT-13B-Erebus']:
if os.path.exists(f"torch-dumps/{model_name}.pt"):
print("Loading in .pt format...")
model = torch.load(f"torch-dumps/{model_name}.pt").cuda()
elif model_name in ['gpt-neox-20b', 'opt-13b', 'OPT-13B-Erebus']:
model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", device_map='auto', load_in_8bit=True)
elif model_name in ['gpt-j-6B']:
model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
elif model_name in ['flan-t5']:
model = T5ForConditionalGeneration.from_pretrained(f"models/{model_name}").cuda()
else:
model = torch.load(f"torch-dumps/{model_name}.pt").cuda()
if model_name in ['gpt4chan_model_float16']:
tokenizer = AutoTokenizer.from_pretrained("models/gpt-j-6B/")