|
@@ -3,6 +3,7 @@
|
|
|
Converts a transformers model to a format compatible with flexgen.
|
|
Converts a transformers model to a format compatible with flexgen.
|
|
|
|
|
|
|
|
'''
|
|
'''
|
|
|
|
|
+
|
|
|
import argparse
|
|
import argparse
|
|
|
import os
|
|
import os
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
@@ -10,9 +11,8 @@ from pathlib import Path
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import torch
|
|
import torch
|
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
|
-from transformers import AutoModelForCausalLM
|
|
|
|
|
-from transformers import AutoTokenizer
|
|
|
|
|
-
|
|
|
|
|
|
|
+from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
+
|
|
|
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
|
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
|
|
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
|
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
@@ -31,7 +31,6 @@ def disable_torch_init():
|
|
|
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
|
|
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
|
|
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def restore_torch_init():
|
|
def restore_torch_init():
|
|
|
"""Rollback the change made by disable_torch_init."""
|
|
"""Rollback the change made by disable_torch_init."""
|
|
|
import torch
|
|
import torch
|