@@ -0,0 +1,2 @@
+do_sample=False,
+max_new_tokens=tokens,
@@ -136,11 +136,11 @@ def fix_galactica(s):
return s
def encode(prompt, tokens):
- if not args.cpu:
+ if args.cpu:
+ input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens)
+ else:
torch.cuda.empty_cache()
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
- else:
- input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens)
return input_ids
def decode(output_ids):