소스 검색

Simplify encode() function

oobabooga 3 년 전
부모
커밋
3f05cf5ddd
1개의 변경된 파일4개의 추가작업 그리고 7개의 파일을 삭제
  1. 4 7
      server.py

+ 4 - 7
server.py

@@ -168,16 +168,13 @@ def fix_galactica(s):
     return s
 
 def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
+    input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens)
     if args.cpu:
-        input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens)
-    else:
-        torch.cuda.empty_cache()
-        input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens).cuda()
-
-    if not args.deepspeed:
         return input_ids
-    else:
+    elif args.deepspeed:
         return input_ids.to(device=local_rank)
+    else:
+        return input_ids.cuda()
 
 def decode(output_ids):
     reply = tokenizer.decode(output_ids, skip_special_tokens=True)