Просмотр исходного кода

Exception handling (#454)

* Update text_generation.py
* Update extensions.py
oobabooga 2 лет назад
Родитель
Сommit
75a7a84ef2
2 измененных файлов с 8 добавлено и 0 удалено
  1. 3 0
      modules/extensions.py
  2. 5 0
      modules/text_generation.py

+ 3 - 0
modules/extensions.py

@@ -1,3 +1,5 @@
+import traceback
+
 import gradio as gr
 
 import extensions
@@ -17,6 +19,7 @@ def load_extensions():
                 print('Ok.')
             except:
                 print('Fail.')
+                traceback.print_exc()
 
 # This iterator returns the extensions in the order specified in the command-line
 def iterator():

+ 5 - 0
modules/text_generation.py

@@ -1,6 +1,7 @@
 import gc
 import re
 import time
+import traceback
 
 import numpy as np
 import torch
@@ -110,6 +111,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
                 # No need to generate 8 tokens at a time.
                 for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
                     yield formatted_outputs(reply, shared.model_name)
+        except:
+            traceback.print_exc()
         finally:
             t1 = time.time()
             output = encode(reply)[0]
@@ -243,6 +246,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
 
             yield formatted_outputs(reply, shared.model_name)
 
+    except:
+        traceback.print_exc()
     finally:
         t1 = time.time()
         print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")