Просмотр исходного кода

Empty the cuda cache at model.generate()

oobabooga 2 лет назад
Родитель
Сommit
700311ce40
1 измененных файлов с 2 добавлено и 0 удалено
  1. 2 0
      modules/text_generation.py

+ 2 - 0
modules/text_generation.py

@@ -73,6 +73,8 @@ def formatted_outputs(reply, model_name):
         return reply
 
 def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
+    torch.cuda.empty_cache()
+
     original_question = question
     if not (shared.args.chat or shared.args.cai_chat):
         question = apply_extensions(question, "input")