Forráskód Böngészése

Better way to generate custom prompts

oobabooga 2 éve
szülő
commit
13f2688134
2 módosított fájl, 13 hozzáadás és 30 törlés
  1. 1 6
      extensions/send_pictures/script.py
  2. 12 24
      modules/chat.py

+ 1 - 6
extensions/send_pictures/script.py

@@ -14,12 +14,7 @@ params = {
 # custom output text
 input_hijack = {
     'state': 'off',
-    'value': []
-}
-
-prompt_hijack = {
-    'state': 'off',
-    'value': ""
+    'value': ["", ""]
 }
 
 def generate_chat_picture(picture, name1, name2):

+ 12 - 24
modules/chat.py

@@ -86,42 +86,30 @@ def stop_everything_event():
 
 def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
     shared.stop_everything = False
-
-    # Check if any extension wants to hijack this function call
-    visible_text = None
-    prompt = None
-    for extension, _ in extensions_module.iterator():
-        if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] in ['temporary', 'permanent']:
-            if extension.input_hijack['state'] == 'temporary':
-                extension.input_hijack['state'] = 'off'
-            values = extension.input_hijack['value']
-            if len(values) == 2:
-                text, visible_text = values
-            elif len(values) == 4:
-                text, visible_text, reply, visible_reply = valueso
-                if not shared.stop_everything:
-                    shared.history['internal'].append([text, reply])
-                    shared.history['visible'].append([visible_text, visible_reply])
-                return shared.history['visible']
-        if hasattr(extension, 'prompt_hijack') and extension.prompt_hijack['state'] in ['temporary', 'permanent']:
-            if extension.prompt_hijack['state'] == 'temporary':
-                extension.prompt_hijack['state'] = 'off'
-            prompt = extension.prompt_hijack['value']
-                
     just_started = True
     eos_token = '\n' if check else None
-
     if 'pygmalion' in shared.model_name.lower():
         name1 = "You"
 
+    # Check if any extension wants to hijack this function call
+    visible_text = None
+    custom_prompt_generator = None
+    for extension, _ in extensions_module.iterator():
+        if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
+            text, visible_text = extension.input_hijack['value']
+        if custom_prompt_generator is None and hasattr(extension, 'custom_prompt_generator'):
+            custom_prompt_generator = extension.custom_prompt_generator
+
     if visible_text is None:
         visible_text = text
         if shared.args.chat:
             visible_text = visible_text.replace('\n', '<br>')
         text = apply_extensions(text, "input")
 
-    if prompt is None:
+    if custom_prompt_generator is None:
         prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
+    else:
+        prompt = custom_prompt_generator(text, max_new_tokens, name1, name2, context, chat_prompt_size)
 
     # Generate
     for reply in generate_reply(prompt, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):