Просмотр исходного кода

Print the softprompt metadata when it is loaded

oobabooga 3 лет назад
Родитель
Сommit
8c9dd95d55
1 измененных файлов с 13 добавлено и 0 удалено
  1. 13 0
      server.py

+ 13 - 0
server.py

@@ -173,7 +173,19 @@ def load_soft_prompt(name):
     else:
         with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
             zf.extract('tensor.npy')
+            zf.extract('meta.json')
+            j = json.loads(open('meta.json', 'r').read())
+            print(f"\nLoading the softprompt \"{name}\".")
+            for field in j:
+                if field != 'name':
+                    if type(j[field]) is list:
+                        print(f"{field}: {', '.join(j[field])}")
+                    else:
+                        print(f"{field}: {j[field]}")
+            print()
             tensor = np.load('tensor.npy')
+            Path('tensor.npy').unlink()
+            Path('meta.json').unlink()
         tensor = torch.Tensor(tensor).to(device=model.device, dtype=model.dtype)
         tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
 
@@ -187,6 +199,7 @@ def upload_soft_prompt(file):
         zf.extract('meta.json')
         j = json.loads(open('meta.json', 'r').read())
         name = j['name']
+        Path('meta.json').unlink()
 
     with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
         f.write(file)