Sfoglia il codice sorgente

Improve usage of stopping_criteria

oobabooga 2 anni fa
parent
commit
59b5f7a4b7
1 ha cambiato i file con 6 aggiunte e 13 eliminazioni
  1. 6 13
      modules/text_generation.py

+ 6 - 13
modules/text_generation.py

@@ -119,18 +119,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
     output = input_ids[0]
     cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
     n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
+    stopping_criteria_list = transformers.StoppingCriteriaList()
     if stopping_string is not None:
-        # The stopping_criteria code below was copied from
-        # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
+        # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
         t = encode(stopping_string, 0, add_special_tokens=False)
-        stopping_criteria_list = transformers.StoppingCriteriaList([
-            _SentinelTokenStoppingCriteria(
-                sentinel_token_ids=t,
-                starting_idx=len(input_ids[0])
-            )
-        ])
-    else:
-        stopping_criteria_list = []
+        stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
 
     if not shared.args.flexgen:
         generate_params = [
@@ -184,17 +177,17 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
     elif not shared.args.flexgen:
 
         def generate_with_callback(callback=None, **kwargs):
-            if 'stopping_criteria' not in kwargs:
-                kwargs['stopping_criteria'] = []
             kwargs['stopping_criteria'].append(Stream(callback_func=callback))
             clear_torch_cache()
-            shared.model.generate(**kwargs)
+            with torch.no_grad():
+                shared.model.generate(**kwargs)
 
         def generate_with_streaming(**kwargs):
             return Iteratorize(generate_with_callback, kwargs, callback=None)
 
         yield formatted_outputs(original_question, shared.model_name)
         for output in eval(f"generate_with_streaming({', '.join(generate_params)})"):
+            print(print('Used vram in gib:', torch.cuda.memory_allocated() / 1024**3))
             if shared.soft_prompt:
                 output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
             reply = decode(output)