Bläddra i källkod

Better variable names

oobabooga 3 år sedan
förälder
incheckning
6be571cff7
1 ändrade filer med 27 tillägg och 28 borttagningar
  1. 27 28
      server.py

+ 27 - 28
server.py

@@ -115,17 +115,17 @@ def load_model(model_name):
     # Custom
     else:
         command = "AutoModelForCausalLM.from_pretrained"
-        settings = ["low_cpu_mem_usage=True"]
+        params = ["low_cpu_mem_usage=True"]
 
         if args.cpu:
-            settings.append("low_cpu_mem_usage=True")
-            settings.append("torch_dtype=torch.float32")
+            params.append("low_cpu_mem_usage=True")
+            params.append("torch_dtype=torch.float32")
         else:
-            settings.append("device_map='auto'")
-            settings.append("load_in_8bit=True" if args.load_in_8bit else "torch_dtype=torch.float16")
+            params.append("device_map='auto'")
+            params.append("load_in_8bit=True" if args.load_in_8bit else "torch_dtype=torch.float16")
 
             if args.gpu_memory:
-                settings.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
+                params.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
             elif (args.gpu_memory or args.cpu_memory) and not args.load_in_8bit:
                 total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
                 suggestion = round((total_mem-1000)/1000)*1000
@@ -133,11 +133,11 @@ def load_model(model_name):
                     suggestion -= 1000
                 suggestion = int(round(suggestion/1000))
                 print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
-                settings.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
+                params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
             if args.disk:
-                settings.append(f"offload_folder='{args.disk_cache_dir or 'cache'}'")
+                params.append(f"offload_folder='{args.disk_cache_dir or 'cache'}'")
 
-        command = f"{command}(Path(f'models/{model_name}'), {','.join(set(settings))})"
+        command = f"{command}(Path(f'models/{model_name}'), {','.join(set(params))})"
         model = eval(command)
 
     # Loading the tokenizer
@@ -162,7 +162,7 @@ def load_model_wrapper(selected_model):
         model, tokenizer = load_model(model_name)
 
 def load_preset_values(preset_menu, return_dict=False):
-    settings = {
+    generate_params = {
         'do_sample': True,
         'temperature': 1,
         'top_p': 1,
@@ -180,14 +180,14 @@ def load_preset_values(preset_menu, return_dict=False):
     for i in preset.split(','):
         i = i.strip().split('=')
         if len(i) == 2 and i[0].strip() != 'tokens':
-            settings[i[0].strip()] = eval(i[1].strip())
+            generate_params[i[0].strip()] = eval(i[1].strip())
 
-    settings['temperature'] = min(1.99, settings['temperature'])
+    generate_params['temperature'] = min(1.99, generate_params['temperature'])
 
     if return_dict:
-        return settings
+        return generate_params
     else:
-        return settings['do_sample'], settings['temperature'], settings['top_p'], settings['typical_p'], settings['repetition_penalty'], settings['top_k'], settings['min_length'], settings['no_repeat_ngram_size'], settings['num_beams'], settings['length_penalty'], settings['early_stopping']
+        return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['length_penalty'], generate_params['early_stopping']
 
 # Removes empty replies from gpt4chan outputs
 def fix_gpt4chan(s):
@@ -365,7 +365,7 @@ def create_extensions_block():
     btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])
 
 def create_settings_menus():
-    defaults = load_preset_values(settings[f'preset{suffix}'], return_dict=True)
+    generate_params = load_preset_values(settings[f'preset{suffix}'], return_dict=True)
 
     with gr.Row():
         with gr.Column():
@@ -380,23 +380,23 @@ def create_settings_menus():
     with gr.Accordion("Custom generation parameters", open=False):
         with gr.Row():
             with gr.Column():
-                do_sample = gr.Checkbox(value=defaults['do_sample'], label="do_sample")
-                temperature = gr.Slider(0.01, 1.99, value=defaults['temperature'], step=0.01, label="temperature")
-                top_p = gr.Slider(0.0,1.0,value=defaults['top_p'],step=0.01,label="top_p")
-                typical_p = gr.Slider(0.0,1.0,value=defaults['typical_p'],step=0.01,label="typical_p")
+                do_sample = gr.Checkbox(value=generate_params['do_sample'], label="do_sample")
+                temperature = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label="temperature")
+                top_p = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label="top_p")
+                typical_p = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label="typical_p")
             with gr.Column():
-                repetition_penalty = gr.Slider(1.0,4.99,value=defaults['repetition_penalty'],step=0.01,label="repetition_penalty")
-                top_k = gr.Slider(0,200,value=defaults['top_k'],step=1,label="top_k")
-                no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=defaults["no_repeat_ngram_size"], label="no_repeat_ngram_size")
+                repetition_penalty = gr.Slider(1.0,4.99,value=generate_params['repetition_penalty'],step=0.01,label="repetition_penalty")
+                top_k = gr.Slider(0,200,value=generate_params['top_k'],step=1,label="top_k")
+                no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=generate_params["no_repeat_ngram_size"], label="no_repeat_ngram_size")
 
         gr.Markdown("Special parameters (only use them if you really need them):")
         with gr.Row():
             with gr.Column():
-                num_beams = gr.Slider(0, 20, step=1, value=defaults["num_beams"], label="num_beams")
-                length_penalty = gr.Slider(-5, 5, value=defaults["length_penalty"], label="length_penalty")
+                num_beams = gr.Slider(0, 20, step=1, value=generate_params["num_beams"], label="num_beams")
+                length_penalty = gr.Slider(-5, 5, value=generate_params["length_penalty"], label="length_penalty")
             with gr.Column():
-                min_length = gr.Slider(0, 2000, step=1, value=defaults["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
-                early_stopping = gr.Checkbox(value=defaults["early_stopping"], label="early_stopping")
+                min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
+                early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
 
     model_menu.change(load_model_wrapper, [model_menu], [])
     preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping])
@@ -737,10 +737,9 @@ loaded_preset = None
 default_text = settings['prompt_gpt4chan'] if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) else settings['prompt']
 description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
 css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
+suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
 buttons = {}
 gen_events = []
-
-suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
 history = {'internal': [], 'visible': []}
 character = None