|
@@ -76,7 +76,7 @@ def fix_galactica(s):
|
|
|
return s
|
|
return s
|
|
|
|
|
|
|
|
def formatted_outputs(reply, model_name):
|
|
def formatted_outputs(reply, model_name):
|
|
|
- if not (shared.args.chat or shared.args.cai_chat):
|
|
|
|
|
|
|
+ if not shared.is_chat():
|
|
|
if 'galactica' in model_name.lower():
|
|
if 'galactica' in model_name.lower():
|
|
|
reply = fix_galactica(reply)
|
|
reply = fix_galactica(reply)
|
|
|
return reply, reply, generate_basic_html(reply)
|
|
return reply, reply, generate_basic_html(reply)
|
|
@@ -109,7 +109,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|
|
t0 = time.time()
|
|
t0 = time.time()
|
|
|
|
|
|
|
|
original_question = question
|
|
original_question = question
|
|
|
- if not (shared.args.chat or shared.args.cai_chat):
|
|
|
|
|
|
|
+ if not shared.is_chat():
|
|
|
question = apply_extensions(question, "input")
|
|
question = apply_extensions(question, "input")
|
|
|
if shared.args.verbose:
|
|
if shared.args.verbose:
|
|
|
print(f"\n\n{question}\n--------------------\n")
|
|
print(f"\n\n{question}\n--------------------\n")
|
|
@@ -121,18 +121,18 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|
|
if shared.args.no_stream:
|
|
if shared.args.no_stream:
|
|
|
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
|
|
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
|
|
|
output = original_question+reply
|
|
output = original_question+reply
|
|
|
- if not (shared.args.chat or shared.args.cai_chat):
|
|
|
|
|
|
|
+ if not shared.is_chat():
|
|
|
reply = original_question + apply_extensions(reply, "output")
|
|
reply = original_question + apply_extensions(reply, "output")
|
|
|
yield formatted_outputs(reply, shared.model_name)
|
|
yield formatted_outputs(reply, shared.model_name)
|
|
|
else:
|
|
else:
|
|
|
- if not (shared.args.chat or shared.args.cai_chat):
|
|
|
|
|
|
|
+ if not shared.is_chat():
|
|
|
yield formatted_outputs(question, shared.model_name)
|
|
yield formatted_outputs(question, shared.model_name)
|
|
|
|
|
|
|
|
# RWKV has proper streaming, which is very nice.
|
|
# RWKV has proper streaming, which is very nice.
|
|
|
# No need to generate 8 tokens at a time.
|
|
# No need to generate 8 tokens at a time.
|
|
|
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty):
|
|
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty):
|
|
|
output = original_question+reply
|
|
output = original_question+reply
|
|
|
- if not (shared.args.chat or shared.args.cai_chat):
|
|
|
|
|
|
|
+ if not shared.is_chat():
|
|
|
reply = original_question + apply_extensions(reply, "output")
|
|
reply = original_question + apply_extensions(reply, "output")
|
|
|
yield formatted_outputs(reply, shared.model_name)
|
|
yield formatted_outputs(reply, shared.model_name)
|
|
|
|
|
|
|
@@ -208,7 +208,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|
|
|
|
|
|
|
new_tokens = len(output) - len(input_ids[0])
|
|
new_tokens = len(output) - len(input_ids[0])
|
|
|
reply = decode(output[-new_tokens:])
|
|
reply = decode(output[-new_tokens:])
|
|
|
- if not (shared.args.chat or shared.args.cai_chat):
|
|
|
|
|
|
|
+ if not shared.is_chat():
|
|
|
reply = original_question + apply_extensions(reply, "output")
|
|
reply = original_question + apply_extensions(reply, "output")
|
|
|
|
|
|
|
|
yield formatted_outputs(reply, shared.model_name)
|
|
yield formatted_outputs(reply, shared.model_name)
|
|
@@ -226,7 +226,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|
|
def generate_with_streaming(**kwargs):
|
|
def generate_with_streaming(**kwargs):
|
|
|
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
|
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
|
|
|
|
|
|
|
- if not (shared.args.chat or shared.args.cai_chat):
|
|
|
|
|
|
|
+ if not shared.is_chat():
|
|
|
yield formatted_outputs(original_question, shared.model_name)
|
|
yield formatted_outputs(original_question, shared.model_name)
|
|
|
with generate_with_streaming(**generate_params) as generator:
|
|
with generate_with_streaming(**generate_params) as generator:
|
|
|
for output in generator:
|
|
for output in generator:
|
|
@@ -235,7 +235,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|
|
|
|
|
|
|
new_tokens = len(output) - len(input_ids[0])
|
|
new_tokens = len(output) - len(input_ids[0])
|
|
|
reply = decode(output[-new_tokens:])
|
|
reply = decode(output[-new_tokens:])
|
|
|
- if not (shared.args.chat or shared.args.cai_chat):
|
|
|
|
|
|
|
+ if not shared.is_chat():
|
|
|
reply = original_question + apply_extensions(reply, "output")
|
|
reply = original_question + apply_extensions(reply, "output")
|
|
|
|
|
|
|
|
if output[-1] in eos_token_ids:
|
|
if output[-1] in eos_token_ids:
|
|
@@ -253,7 +253,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|
|
|
|
|
|
|
new_tokens = len(output) - len(original_input_ids[0])
|
|
new_tokens = len(output) - len(original_input_ids[0])
|
|
|
reply = decode(output[-new_tokens:])
|
|
reply = decode(output[-new_tokens:])
|
|
|
- if not (shared.args.chat or shared.args.cai_chat):
|
|
|
|
|
|
|
+ if not shared.is_chat():
|
|
|
reply = original_question + apply_extensions(reply, "output")
|
|
reply = original_question + apply_extensions(reply, "output")
|
|
|
|
|
|
|
|
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
|
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|