Prechádzať zdrojové kódy

Add nice HTML output for all models

oobabooga 3 rokov pred
rodič
commit
d5e01c80e3
3 zmenil súbory, kde vykonal 17 pridanie a 10 odobranie
  1. 1 1
      README.md
  2. 1 1
      html_generator.py
  3. 15 8
      server.py

+ 1 - 1
README.md

@@ -94,7 +94,7 @@ Optionally, you can use the following command-line flags:
 --cpu           Use the CPU to generate text.
 --cpu           Use the CPU to generate text.
 --auto-devices  Automatically split the model across the available GPU(s) and CPU.
 --auto-devices  Automatically split the model across the available GPU(s) and CPU.
 --load-in-8bit  Load the model with 8-bit precision.
 --load-in-8bit  Load the model with 8-bit precision.
---listen        Make the webui reachable from your local network.
+--no-listen     Make the webui unreachable from your local network.
 ```
 ```
 
 
 ## Presets
 ## Presets

+ 1 - 1
html_generator.py

@@ -20,7 +20,7 @@ def process_post(post, c):
     src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
     src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
     return src
     return src
 
 
-def generate_html(f):
+def generate_4chan_html(f):
     css = """
     css = """
     #container {
     #container {
         background-color: #eef2ff;
         background-color: #eef2ff;

+ 15 - 8
server.py

@@ -18,7 +18,7 @@ parser.add_argument('--chat', action='store_true', help='Launch the webui in cha
 parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
 parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
 parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
 parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
 parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
 parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
-parser.add_argument('--listen', action='store_true', help='Make the webui reachable from your local network.')
+parser.add_argument('--no-listen', action='store_true', help='Make the webui unreachable from your local network.')
 args = parser.parse_args()
 args = parser.parse_args()
 loaded_preset = None
 loaded_preset = None
 available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*')))))
 available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*')))))
@@ -63,7 +63,7 @@ def load_model(model_name):
         model = eval(command)
         model = eval(command)
 
 
     # Loading the tokenizer
     # Loading the tokenizer
-    if model_name.lower().startswith('gpt4chan') and Path(f"models/gpt-j-6B/").exists():
+    if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"models/gpt-j-6B/").exists():
         tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
         tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
     else:
     else:
         tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/"))
         tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/"))
@@ -79,6 +79,7 @@ def fix_gpt4chan(s):
         s = re.sub("--- [0-9]*\n\n\n---", "---", s)
         s = re.sub("--- [0-9]*\n\n\n---", "---", s)
     return s
     return s
 
 
+# Fix the LaTeX equations in GALACTICA
 def fix_galactica(s):
 def fix_galactica(s):
     s = s.replace(r'\[', r'$')
     s = s.replace(r'\[', r'$')
     s = s.replace(r'\]', r'$')
     s = s.replace(r'\]', r'$')
@@ -87,6 +88,11 @@ def fix_galactica(s):
     s = s.replace(r'$$', r'$')
     s = s.replace(r'$$', r'$')
     return s
     return s
 
 
+def generate_html(s):
+    s = '\n'.join([f'<p style="margin-bottom: 20px">{line}</p>' for line in s.split('\n')])
+    s = f'<div style="max-width: 600px; margin-left: auto; margin-right: auto; background-color:#eef2ff; color:#0b0f19; padding:3em; font-size:1.2em;">{s}</div>'
+    return s
+
 def generate_reply(question, temperature, max_length, inference_settings, selected_model, eos_token=None):
 def generate_reply(question, temperature, max_length, inference_settings, selected_model, eos_token=None):
     global model, tokenizer, model_name, loaded_preset, preset
     global model, tokenizer, model_name, loaded_preset, preset
 
 
@@ -117,14 +123,15 @@ def generate_reply(question, temperature, max_length, inference_settings, select
         output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
         output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
 
 
     reply = tokenizer.decode(output[0], skip_special_tokens=True)
     reply = tokenizer.decode(output[0], skip_special_tokens=True)
+    reply = reply.replace(r'<|endoftext|>', '')
     if model_name.lower().startswith('galactica'):
     if model_name.lower().startswith('galactica'):
         reply = fix_galactica(reply)
         reply = fix_galactica(reply)
-        return reply, reply, 'Only applicable for gpt4chan.'
+        return reply, reply, generate_html(reply)
     elif model_name.lower().startswith('gpt4chan'):
     elif model_name.lower().startswith('gpt4chan'):
         reply = fix_gpt4chan(reply)
         reply = fix_gpt4chan(reply)
-        return reply, 'Only applicable for galactica models.', generate_html(reply)
+        return reply, 'Only applicable for galactica models.', generate_4chan_html(reply)
     else:
     else:
-        return reply, 'Only applicable for galactica models.', 'Only applicable for gpt4chan.'
+        return reply, 'Only applicable for galactica models.', generate_html(reply)
 
 
 # Choosing the default model
 # Choosing the default model
 if args.model is not None:
 if args.model is not None:
@@ -248,7 +255,7 @@ else:
         btn.click(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True)
         btn.click(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True)
         textbox.submit(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True)
         textbox.submit(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True)
 
 
-if args.listen:
-    interface.launch(share=False, server_name="0.0.0.0")
-else:
+if args.no_listen:
     interface.launch(share=False)
     interface.launch(share=False)
+else:
+    interface.launch(share=False, server_name="0.0.0.0")