فهرست منبع

Fix training dataset loading #636

oobabooga 2 سال پیش
والد
کامیت
a6d0373063
1فایلهای تغییر یافته به همراه7 افزوده شده و 6 حذف شده
  1. 7 6
      modules/training.py

+ 7 - 6
modules/training.py

@@ -119,7 +119,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
         }
 
     # == Prep the dataset, format, etc ==
-    if raw_text_file is not None:
+    if raw_text_file not in ['None', '']:
         print("Loading raw text file dataset...")
         with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
             raw_text = file.read()
@@ -136,16 +136,17 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
         del text_chunks
 
     else:
-        with open(clean_path('training/formats', f'{format}.json'), 'r') as formatFile:
-            format_data: dict[str, str] = json.load(formatFile)
-
-        if dataset is None:
+        if dataset in ['None', '']:
             yield "**Missing dataset choice input, cannot continue.**"
             return
-        if format is None:
+
+        if format in ['None', '']:
             yield "**Missing format choice input, cannot continue.**"
             return
 
+        with open(clean_path('training/formats', f'{format}.json'), 'r') as formatFile:
+            format_data: dict[str, str] = json.load(formatFile)
+
         def generate_prompt(data_point: dict[str, str]):
             for options, data in format_data.items():
                 if set(options.split(',')) == set(x[0] for x in data_point.items() if len(x[1].strip()) > 0):