|
@@ -116,6 +116,14 @@ def fix_galactica(s):
|
|
|
s = s.replace(r'$$', r'$')
|
|
s = s.replace(r'$$', r'$')
|
|
|
return s
|
|
return s
|
|
|
|
|
|
|
|
|
|
+def encode(prompt, tokens):
|
|
|
|
|
+ if not args.cpu:
|
|
|
|
|
+ torch.cuda.empty_cache()
|
|
|
|
|
+ input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
|
|
|
|
|
+ else:
|
|
|
|
|
+ input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens)
|
|
|
|
|
+ return input_ids
|
|
|
|
|
+
|
|
|
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
|
|
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
|
|
|
global model, tokenizer, model_name, loaded_preset, preset
|
|
global model, tokenizer, model_name, loaded_preset, preset
|
|
|
|
|
|
|
@@ -131,14 +139,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
|
|
preset = infile.read()
|
|
preset = infile.read()
|
|
|
loaded_preset = inference_settings
|
|
loaded_preset = inference_settings
|
|
|
|
|
|
|
|
- if not args.cpu:
|
|
|
|
|
- torch.cuda.empty_cache()
|
|
|
|
|
- input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
|
|
|
|
|
- cuda = ".cuda()"
|
|
|
|
|
- else:
|
|
|
|
|
- input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens)
|
|
|
|
|
- cuda = ""
|
|
|
|
|
|
|
+ input_ids = encode(question, tokens)
|
|
|
|
|
|
|
|
|
|
+ cuda = ".cuda()" if args.cpu else ""
|
|
|
if eos_token is None:
|
|
if eos_token is None:
|
|
|
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
|
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
|
|
else:
|
|
else:
|
|
@@ -217,16 +220,20 @@ elif args.chat or args.cai_chat:
|
|
|
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
|
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
|
|
text = chat_response_cleaner(text)
|
|
text = chat_response_cleaner(text)
|
|
|
|
|
|
|
|
- question = f"{context}\n\n"
|
|
|
|
|
- for i in range(len(history)):
|
|
|
|
|
- if args.cai_chat:
|
|
|
|
|
- question += f"{name1}: {history[i][0].strip()}\n"
|
|
|
|
|
- question += f"{name2}: {history[i][1].strip()}\n"
|
|
|
|
|
- else:
|
|
|
|
|
- question += f"{name1}: {history[i][0][3:-5].strip()}\n"
|
|
|
|
|
- question += f"{name2}: {history[i][1][3:-5].strip()}\n"
|
|
|
|
|
- question += f"{name1}: {text}\n"
|
|
|
|
|
- question += f"{name2}:"
|
|
|
|
|
|
|
+ rows = [f"{context}\n\n"]
|
|
|
|
|
+ i = len(history)-1
|
|
|
|
|
+ while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens:
|
|
|
|
|
+ rows.insert(1, f"{name2}: {history[i][1].strip()}\n")
|
|
|
|
|
+ rows.insert(1, f"{name1}: {history[i][0].strip()}\n")
|
|
|
|
|
+ i -= 1
|
|
|
|
|
+ rows.append(f"{name1}: {text}\n")
|
|
|
|
|
+ rows.append(f"{name2}:")
|
|
|
|
|
+
|
|
|
|
|
+ while len(rows) > 3 and len(encode(''.join(rows), tokens)[0]) >= 2048-tokens:
|
|
|
|
|
+ rows.pop(1)
|
|
|
|
|
+ rows.pop(1)
|
|
|
|
|
+
|
|
|
|
|
+ question = ''.join(rows)
|
|
|
|
|
|
|
|
if check:
|
|
if check:
|
|
|
reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]
|
|
reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]
|