Преглед изворни кода

Add --cai-chat option that mimics Character.AI's interface

oobabooga пре 3 година
родитељ
комит
6136da419c
3 измењених фајлова са 136 додато и 9 уклоњено
  1. 3 0
      README.md
  2. 104 0
      html_generator.py
  3. 29 9
      server.py

+ 3 - 0
README.md

@@ -125,6 +125,9 @@ Optionally, you can use the following command-line flags:
 --notebook      Launch the webui in notebook mode, where the output is written to the same text
 --notebook      Launch the webui in notebook mode, where the output is written to the same text
                 box as the input.
                 box as the input.
 --chat          Launch the webui in chat mode.
 --chat          Launch the webui in chat mode.
+--cai-chat      Launch the webui in chat mode with a style similar to Character.AI's. If the
+                file profile.png exists in the same folder as server.py, this image will be used
+                as the bot's profile picture.
 --cpu           Use the CPU to generate text.
 --cpu           Use the CPU to generate text.
 --auto-devices  Automatically split the model across the available GPU(s) and CPU.
 --auto-devices  Automatically split the model across the available GPU(s) and CPU.
 --load-in-8bit  Load the model with 8-bit precision.
 --load-in-8bit  Load the model with 8-bit precision.

+ 104 - 0
html_generator.py

@@ -5,6 +5,7 @@ This is a library for formatting gpt4chan outputs as nice HTML.
 '''
 '''
 
 
 import re
 import re
+from pathlib import Path
 
 
 def process_post(post, c):
 def process_post(post, c):
     t = post.split('\n')
     t = post.split('\n')
@@ -153,3 +154,106 @@ def generate_4chan_html(f):
     output = '\n'.join(output)
     output = '\n'.join(output)
 
 
     return output
     return output
+
+def generate_chat_html(history, name1, name2):
+    css = """
+    .chat {
+      margin-left: auto;
+      margin-right: auto;
+      max-width: 800px;
+      height: 50vh;
+      overflow-y: auto;
+      padding-right: 20px;
+      display: flex;
+      flex-direction: column-reverse;
+    }       
+
+    .message {
+      display: grid;
+      grid-template-columns: 50px 1fr;
+      padding-bottom: 20px;
+      font-size: 15px;
+      font-family: helvetica;
+    }   
+        
+    .circle-you {
+      width: 45px;
+      height: 45px;
+      background-color: rgb(244, 78, 59);
+      border-radius: 50%;
+    }
+          
+    .circle-bot {
+      width: 45px;
+      height: 45px;
+      background-color: rgb(59, 78, 244);
+      border-radius: 50%;
+    }
+
+    .circle-bot img {
+      border-radius: 50%;
+      width: 100%;
+      height: 100%;
+      object-fit: cover;
+    }
+
+    .text {
+    }
+
+    .text p {
+      margin-top: 5px;
+    }
+
+    .username {
+      font-weight: bold;
+    }
+
+    .body {
+    }
+    """
+
+    output = ''
+    output += f'<style>{css}</style><div class="chat" id="chat">'
+    if Path("profile.png").exists():
+        img = '<img src="file/profile.png">'
+    else:
+        img = ''
+
+    for row in history[::-1]:
+        p = '\n'.join([f"<p>{x}</p>" for x in row[1].split('\n')])
+        output += f"""
+              <div class="message">
+                <div class="circle-bot">
+                  {img}
+                </div>
+                <div class="text">
+                  <div class="username">
+                    {name2}
+                  </div>
+                  <div class="body">
+                    {p}
+                  </div>
+                </div>
+              </div>
+            """
+
+        p = '\n'.join([f"<p>{x}</p>" for x in row[0].split('\n')])
+        output += f"""
+              <div class="message">
+                <div class="circle-you">
+                </div>
+                <div class="text">
+                  <div class="username">
+                    {name1}
+                  </div>
+                  <div class="body">
+                    {p}
+                  </div>
+                </div>
+              </div>
+            """
+
+    output += '<script>document.getElementById("chat").scrollTo(0, document.getElementById("chat").scrollHeight);</script>'
+    output += "</div>"
+
+    return output

+ 29 - 9
server.py

@@ -16,6 +16,7 @@ parser = argparse.ArgumentParser()
 parser.add_argument('--model', type=str, help='Name of the model to load by default.')
 parser.add_argument('--model', type=str, help='Name of the model to load by default.')
 parser.add_argument('--notebook', action='store_true', help='Launch the webui in notebook mode, where the output is written to the same text box as the input.')
 parser.add_argument('--notebook', action='store_true', help='Launch the webui in notebook mode, where the output is written to the same text box as the input.')
 parser.add_argument('--chat', action='store_true', help='Launch the webui in chat mode.')
 parser.add_argument('--chat', action='store_true', help='Launch the webui in chat mode.')
+parser.add_argument('--cai-chat', action='store_true', help='Launch the webui in chat mode with a style similar to Character.AI\'s. If the file profile.png exists in the same folder as server.py, this image will be used as the bot\'s profile picture.')
 parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
 parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
 parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
 parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
 parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
 parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
@@ -189,7 +190,7 @@ if args.notebook:
 
 
         btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True, api_name="textgen")
         btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True, api_name="textgen")
         textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True)
         textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True)
-elif args.chat:
+elif args.chat or args.cai_chat:
     history = []
     history = []
 
 
     # This gets the new line characters right.
     # This gets the new line characters right.
@@ -218,19 +219,29 @@ elif args.chat:
             idx = reply.find(f"\n{name1}:")
             idx = reply.find(f"\n{name1}:")
             if idx != -1:
             if idx != -1:
                 reply = reply[:idx]
                 reply = reply[:idx]
-            reply = chat_response_cleaner(response)
+            reply = chat_response_cleaner(reply)
 
 
         history.append((text, reply))
         history.append((text, reply))
         return history
         return history
 
 
-    def remove_last_message():
+    def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
+        history = chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check)
+        return generate_chat_html(history, name1, name2)
+
+    def remove_last_message(name1, name2):
         history.pop()
         history.pop()
-        return history
+        if args.cai_chat:
+            return generate_chat_html(history, name1, name2)
+        else:
+            return history
 
 
     def clear():
     def clear():
         global history
         global history
         history = []
         history = []
 
 
+    def clear_html():
+        return generate_chat_html([], "", "")
+
     if 'pygmalion' in model_name.lower():
     if 'pygmalion' in model_name.lower():
         context_str = "This is a conversation between two people.\n<START>"
         context_str = "This is a conversation between two people.\n<START>"
         name1_str = "You"
         name1_str = "You"
@@ -258,7 +269,10 @@ elif args.chat:
                     check = gr.Checkbox(value=True, label='Stop generating at new line character?')
                     check = gr.Checkbox(value=True, label='Stop generating at new line character?')
 
 
             with gr.Column():
             with gr.Column():
-                display1 = gr.Chatbot()
+                if args.cai_chat:
+                    display1 = gr.HTML(value=generate_chat_html([], "", ""))
+                else:
+                    display1 = gr.Chatbot()
                 textbox = gr.Textbox(lines=2, label='Input')
                 textbox = gr.Textbox(lines=2, label='Input')
                 btn = gr.Button("Generate")
                 btn = gr.Button("Generate")
                 with gr.Row():
                 with gr.Row():
@@ -267,13 +281,19 @@ elif args.chat:
                     with gr.Column():
                     with gr.Column():
                         btn2 = gr.Button("Clear history")
                         btn2 = gr.Button("Clear history")
 
 
-        btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen")
-        textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True)
-        btn3.click(remove_last_message, [], display1, show_progress=False)
+        if args.cai_chat:
+            btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen")
+            textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True)
+            btn2.click(clear_html, [], display1, show_progress=False)
+        else:
+            btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen")
+            textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True)
+            btn2.click(lambda x: "", display1, display1)
+
         btn2.click(clear)
         btn2.click(clear)
+        btn3.click(remove_last_message, [name1, name2], display1, show_progress=False)
         btn.click(lambda x: "", textbox, textbox, show_progress=False)
         btn.click(lambda x: "", textbox, textbox, show_progress=False)
         textbox.submit(lambda x: "", textbox, textbox, show_progress=False)
         textbox.submit(lambda x: "", textbox, textbox, show_progress=False)
-        btn2.click(lambda x: "", display1, display1)
 else:
 else:
 
 
     def continue_wrapper(question, tokens, inference_settings, selected_model):
     def continue_wrapper(question, tokens, inference_settings, selected_model):