Просмотр исходного кода

Merge branch 'main' into pt-path-changes

oobabooga 2 лет назад
Родитель
Сommit
e9dbdafb14
3 измененных файлов с 32 добавлено и 9 удалено
  1. 1 1
      README.md
  2. 14 6
      download-model.py
  3. 17 2
      modules/models.py

+ 1 - 1
README.md

@@ -54,7 +54,7 @@ The third line assumes that you have an NVIDIA GPU.
 pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
 ```
   	  
-* If you are running in CPU mode, replace the third command with this one:
+* If you are running it in CPU mode, replace the third command with this one:
 
 ```
 conda install pytorch torchvision torchaudio git -c pytorch

+ 14 - 6
download-model.py

@@ -5,7 +5,9 @@ Example:
 python download-model.py facebook/opt-1.3b
 
 '''
+
 import argparse
+import base64
 import json
 import multiprocessing
 import re
@@ -93,23 +95,28 @@ facebook/opt-1.3b
 def get_download_links_from_huggingface(model, branch):
     base = "https://huggingface.co"
     page = f"/api/models/{model}/tree/{branch}?cursor="
+    cursor = b""
 
     links = []
     classifications = []
     has_pytorch = False
     has_safetensors = False
-    while page is not None:
-        content = requests.get(f"{base}{page}").content
+    while True:
+        content = requests.get(f"{base}{page}{cursor.decode()}").content
+
         dict = json.loads(content)
+        if len(dict) == 0:
+            break
 
         for i in range(len(dict)):
             fname = dict[i]['path']
 
             is_pytorch = re.match("pytorch_model.*\.bin", fname)
             is_safetensors = re.match("model.*\.safetensors", fname)
-            is_text = re.match(".*\.(txt|json)", fname)
+            is_tokenizer = re.match("tokenizer.*\.model", fname)
+            is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
 
-            if is_text or is_safetensors or is_pytorch:
+            if any((is_pytorch, is_safetensors, is_text, is_tokenizer)):
                 if is_text:
                     links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
                     classifications.append('text')
@@ -123,8 +130,9 @@ def get_download_links_from_huggingface(model, branch):
                         has_pytorch = True
                         classifications.append('pytorch')
 
-        #page = dict['nextUrl']
-        page = None
+        cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
+        cursor = base64.b64encode(cursor)
+        cursor = cursor.replace(b'=', b'%3D')
 
     # If both pytorch and safetensors are available, download safetensors only
     if has_pytorch and has_safetensors:

+ 17 - 2
modules/models.py

@@ -116,8 +116,23 @@ def load_model(model_name):
             print(f"Could not find {pt_model}, exiting...")
             exit()
 
-        model = load_quant(path_to_model, pt_path, 4)
-        model = model.to(torch.device('cuda:0'))
+        model = load_quant(path_to_model, Path(f"models/{pt_model}"), 4)
+
+        # Multi-GPU setup
+        if shared.args.gpu_memory:
+            import accelerate
+
+            max_memory = {}
+            for i in range(len(shared.args.gpu_memory)):
+                max_memory[i] = f"{shared.args.gpu_memory[i]}GiB"
+            max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB"
+
+            device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LLaMADecoderLayer"])
+            model = accelerate.dispatch_model(model, device_map=device_map)
+
+        # Single GPU
+        else:
+            model = model.to(torch.device('cuda:0'))
 
     # Custom
     else: