|
|
@@ -10,7 +10,7 @@ import torch
|
|
|
import transformers
|
|
|
from accelerate import infer_auto_device_map, init_empty_weights
|
|
|
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
|
|
- BitsAndBytesConfig)
|
|
|
+ BitsAndBytesConfig, LlamaTokenizer)
|
|
|
|
|
|
import modules.shared as shared
|
|
|
|
|
|
@@ -172,6 +172,8 @@ def load_model(model_name):
|
|
|
# Loading the tokenizer
|
|
|
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
|
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
|
|
+ elif type(model) is transformers.LlamaForCausalLM:
|
|
|
+ tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
|
|
else:
|
|
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
|
|
tokenizer.truncation_side = 'left'
|