浏览代码

Use 'with' statement to better handle streaming memory

oobabooga 2 年之前
父节点
当前提交
0bd5430988
共有 3 个文件被更改,包括 39 次插入19 次删除
  1. 5 5
      modules/RWKV.py
  2. 23 4
      modules/callbacks.py
  3. 11 10
      modules/text_generation.py

+ 5 - 5
modules/RWKV.py

@@ -50,11 +50,11 @@ class RWKVModel:
         return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
 
     def generate_with_streaming(self, **kwargs):
-        iterable = Iteratorize(self.generate, kwargs, callback=None)
-        reply = kwargs['context']
-        for token in iterable:
-            reply += token
-            yield reply
+        with Iteratorize(self.generate, kwargs, callback=None) as generator:
+            reply = kwargs['context']
+            for token in generator:
+                reply += token
+                yield reply
 
 class RWKVTokenizer:
     def __init__(self):

+ 23 - 4
modules/callbacks.py

@@ -1,3 +1,4 @@
+import gc
 from queue import Queue
 from threading import Thread
 
@@ -6,7 +7,6 @@ import transformers
 
 import modules.shared as shared
 
-
 # Copied from https://github.com/PygmalionAI/gradio-ui/
 class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
 
@@ -52,17 +52,24 @@ class Iteratorize:
         self.q = Queue()
         self.sentinel = object()
         self.kwargs = kwargs
+        self.stop_now = False
 
         def _callback(val):
+            if self.stop_now:
+                raise ValueError
             self.q.put(val)
 
         def gentask():
-            ret = self.mfunc(callback=_callback, **self.kwargs)
+            try:
+                ret = self.mfunc(callback=_callback, **self.kwargs)
+            except ValueError:
+                pass
             self.q.put(self.sentinel)
             if self.c_callback:
                 self.c_callback(ret)
 
-        Thread(target=gentask).start()
+        self.thread = Thread(target=gentask)
+        self.thread.start()
 
     def __iter__(self):
         return self
@@ -75,4 +82,16 @@ class Iteratorize:
             return obj
 
     def __del__(self):
-        pass
+        clear_torch_cache()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.stop_now = True
+        clear_torch_cache()
+
+def clear_torch_cache():
+    gc.collect()
+    if not shared.args.cpu:
+        torch.cuda.empty_cache()

+ 11 - 10
modules/text_generation.py

@@ -186,17 +186,18 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
             return Iteratorize(generate_with_callback, kwargs, callback=None)
 
         yield formatted_outputs(original_question, shared.model_name)
-        for output in eval(f"generate_with_streaming({', '.join(generate_params)})"):
-            if shared.soft_prompt:
-                output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
-            reply = decode(output)
-
-            if not (shared.args.chat or shared.args.cai_chat):
-                reply = original_question + apply_extensions(reply[len(question):], "output")
-            yield formatted_outputs(reply, shared.model_name)
+        with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator:
+            for output in generator:
+                if shared.soft_prompt:
+                    output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+                reply = decode(output)
+
+                if not (shared.args.chat or shared.args.cai_chat):
+                    reply = original_question + apply_extensions(reply[len(question):], "output")
+                yield formatted_outputs(reply, shared.model_name)
 
-            if output[-1] == n:
-                break
+                if output[-1] == n:
+                    break
 
     # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
     else: