|
|
@@ -136,14 +136,17 @@ def decode(output_ids):
|
|
|
return reply
|
|
|
|
|
|
def formatted_outputs(reply, model_name):
|
|
|
- if model_name.lower().startswith('galactica'):
|
|
|
- reply = fix_galactica(reply)
|
|
|
- return reply, reply, generate_basic_html(reply)
|
|
|
- elif model_name.lower().startswith('gpt4chan'):
|
|
|
- reply = fix_gpt4chan(reply)
|
|
|
- return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
|
|
+ if not (args.chat or args.cai_chat):
|
|
|
+ if model_name.lower().startswith('galactica'):
|
|
|
+ reply = fix_galactica(reply)
|
|
|
+ return reply, reply, generate_basic_html(reply)
|
|
|
+ elif model_name.lower().startswith('gpt4chan'):
|
|
|
+ reply = fix_gpt4chan(reply)
|
|
|
+ return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
|
|
+ else:
|
|
|
+ return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
|
|
else:
|
|
|
- return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
|
|
+ return reply
|
|
|
|
|
|
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
|
|
|
global model, tokenizer, model_name, loaded_preset, preset
|
|
|
@@ -245,16 +248,17 @@ if args.chat or args.cai_chat:
|
|
|
question = generate_chat_prompt(text, tokens, name1, name2, context)
|
|
|
history.append(['', ''])
|
|
|
eos_token = '\n' if check else None
|
|
|
- for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
|
|
|
- reply = i[0]
|
|
|
+ for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
|
|
|
next_character_found = False
|
|
|
|
|
|
+ previous_idx = [m.start() for m in re.finditer(f"\n{name2}:", question)]
|
|
|
+ idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", reply)]
|
|
|
+ idx = idx[len(previous_idx)-1]
|
|
|
+ reply = reply[idx + len(f"\n{name2}:"):]
|
|
|
+
|
|
|
if check:
|
|
|
- idx = reply.rfind(question[-1024:])
|
|
|
- reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip()
|
|
|
+ reply = reply.split('\n')[0].strip()
|
|
|
else:
|
|
|
- idx = reply.rfind(question[-1024:])
|
|
|
- reply = reply[idx+min(1024, len(question)):]
|
|
|
idx = reply.find(f"\n{name1}:")
|
|
|
if idx != -1:
|
|
|
reply = reply[:idx]
|