Remove variables
This commit is contained in:
@@ -103,12 +103,8 @@ def stop_everything_event():
|
|||||||
shared.stop_everything = True
|
shared.stop_everything = True
|
||||||
|
|
||||||
def generate_reply(question, generate_params, eos_token=None, stopping_strings=[]):
|
def generate_reply(question, generate_params, eos_token=None, stopping_strings=[]):
|
||||||
max_new_tokens = generate_params['max_new_tokens']
|
|
||||||
seed = generate_params['seed']
|
|
||||||
print(generate_params)
|
|
||||||
print('---------------')
|
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
set_manual_seed(seed)
|
set_manual_seed(generate_params['seed'])
|
||||||
shared.stop_everything = False
|
shared.stop_everything = False
|
||||||
updated_params = {}
|
updated_params = {}
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
@@ -155,7 +151,7 @@ def generate_reply(question, generate_params, eos_token=None, stopping_strings=[
|
|||||||
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
|
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
|
||||||
return
|
return
|
||||||
|
|
||||||
input_ids = encode(question, max_new_tokens)
|
input_ids = encode(question, generate_params['max_new_tokens'])
|
||||||
original_input_ids = input_ids
|
original_input_ids = input_ids
|
||||||
output = input_ids[0]
|
output = input_ids[0]
|
||||||
|
|
||||||
@@ -168,7 +164,7 @@ def generate_reply(question, generate_params, eos_token=None, stopping_strings=[
|
|||||||
t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
|
t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
|
||||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
|
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
|
||||||
|
|
||||||
updated_params["max_new_tokens"] = max_new_tokens
|
updated_params["max_new_tokens"] = generate_params['max_new_tokens']
|
||||||
if not shared.args.flexgen:
|
if not shared.args.flexgen:
|
||||||
updated_params["eos_token_id"] = eos_token_ids
|
updated_params["eos_token_id"] = eos_token_ids
|
||||||
updated_params["stopping_criteria"] = stopping_criteria_list
|
updated_params["stopping_criteria"] = stopping_criteria_list
|
||||||
@@ -244,7 +240,7 @@ def generate_reply(question, generate_params, eos_token=None, stopping_strings=[
|
|||||||
|
|
||||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||||
else:
|
else:
|
||||||
for i in range(max_new_tokens//8+1):
|
for i in range(generate_params['max_new_tokens']//8+1):
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = shared.model.generate(**updated_params)[0]
|
output = shared.model.generate(**updated_params)[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user