浏览代码

Bump transformers (16-bit llama must be reconverted/redownloaded)

oobabooga 2 年之前
父节点
当前提交
113f94b61e
共有 3 个文件被更改,包括 8 次插入2 次删除
  1. 3 1
      modules/models.py
  2. 4 0
      modules/text_generation.py
  3. 1 1
      requirements.txt

+ 3 - 1
modules/models.py

@@ -10,7 +10,7 @@ import torch
 import transformers
 from accelerate import infer_auto_device_map, init_empty_weights
 from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
-                          BitsAndBytesConfig)
+                          BitsAndBytesConfig, LlamaTokenizer)
 
 import modules.shared as shared
 
@@ -172,6 +172,8 @@ def load_model(model_name):
     # Loading the tokenizer
     if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
         tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
+    elif type(model) is transformers.LlamaForCausalLM:
+        tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
     else:
         tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
     tokenizer.truncation_side = 'left'

+ 4 - 0
modules/text_generation.py

@@ -28,6 +28,10 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
         return input_ids
     else:
         input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
+
+        if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
+            input_ids = input_ids[:,1:]
+
         if shared.args.cpu:
             return input_ids
         elif shared.args.flexgen:

+ 1 - 1
requirements.txt

@@ -13,4 +13,4 @@ safetensors==0.3.0
 sentencepiece
 pyyaml
 tqdm
-git+https://github.com/huggingface/transformers@9eae4aa57650c1dbe1becd4e0979f6ad1e572ac0
+git+https://github.com/huggingface/transformers