Преглед изворни кода

Use Path.stem for simplicity

oobabooga пре 2 година
родитељ
комит
2a267011dc
2 измењених фајлова са 6 додато и 6 уклоњено
  1. 2 2
      modules/training.py
  2. 4 4
      server.py

+ 2 - 2
modules/training.py

@@ -20,7 +20,7 @@ MAX_STEPS = 0
 CURRENT_GRADIENT_ACCUM = 1
 CURRENT_GRADIENT_ACCUM = 1
 
 
 def get_dataset(path: str, ext: str):
 def get_dataset(path: str, ext: str):
-    return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path(path).glob(f'*.{ext}'))), key=str.lower)
+    return ['None'] + sorted(set((k.stem for k in Path(path).glob(f'*.{ext}'))), key=str.lower)
 
 
 def create_train_interface():
 def create_train_interface():
     with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
     with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
@@ -104,7 +104,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
     actual_lr = float(learning_rate)
     actual_lr = float(learning_rate)
 
 
     if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
     if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
-        yield f"Cannot input zeroes."
+        yield "Cannot input zeroes."
         return
         return
 
 
     gradient_accumulation_steps = batch_size // micro_batch_size
     gradient_accumulation_steps = batch_size // micro_batch_size

+ 4 - 4
server.py

@@ -36,12 +36,12 @@ def get_available_models():
         return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
         return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
 
 
 def get_available_presets():
 def get_available_presets():
-    return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
+    return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
 
 
 def get_available_prompts():
 def get_available_prompts():
     prompts = []
     prompts = []
-    prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
-    prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('*.txt'))), key=str.lower)
+    prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
+    prompts += sorted(set((k.stem for k in Path('prompts').glob('*.txt'))), key=str.lower)
     prompts += ['None']
     prompts += ['None']
     return prompts
     return prompts
 
 
@@ -53,7 +53,7 @@ def get_available_extensions():
     return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
     return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
 
 
 def get_available_softprompts():
 def get_available_softprompts():
-    return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
+    return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
 
 
 def get_available_loras():
 def get_available_loras():
     return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
     return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)