deepspeed_parameters.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
  2. '''
  3. DeepSpeed configration
  4. https://huggingface.co/docs/transformers/main_classes/deepspeed
  5. '''
  6. if nvme_offload_dir:
  7. ds_config = {
  8. "fp16": {
  9. "enabled": not ds_bf16,
  10. },
  11. "bf16": {
  12. "enabled": ds_bf16,
  13. },
  14. "zero_optimization": {
  15. "stage": 3,
  16. "offload_param": {
  17. "device": "nvme",
  18. "nvme_path": nvme_offload_dir,
  19. "pin_memory": True,
  20. "buffer_count": 5,
  21. "buffer_size": 1e9,
  22. "max_in_cpu": 1e9
  23. },
  24. "overlap_comm": True,
  25. "reduce_bucket_size": "auto",
  26. "contiguous_gradients": True,
  27. "sub_group_size": 1e8,
  28. "stage3_prefetch_bucket_size": "auto",
  29. "stage3_param_persistence_threshold": "auto",
  30. "stage3_max_live_parameters": "auto",
  31. "stage3_max_reuse_distance": "auto",
  32. },
  33. "aio": {
  34. "block_size": 262144,
  35. "queue_depth": 32,
  36. "thread_count": 1,
  37. "single_submit": False,
  38. "overlap_events": True
  39. },
  40. "steps_per_print": 2000,
  41. "train_batch_size": train_batch_size,
  42. "train_micro_batch_size_per_gpu": 1,
  43. "wall_clock_breakdown": False
  44. }
  45. else:
  46. ds_config = {
  47. "fp16": {
  48. "enabled": not ds_bf16,
  49. },
  50. "bf16": {
  51. "enabled": ds_bf16,
  52. },
  53. "zero_optimization": {
  54. "stage": 3,
  55. "offload_param": {
  56. "device": "cpu",
  57. "pin_memory": True
  58. },
  59. "overlap_comm": True,
  60. "contiguous_gradients": True,
  61. "reduce_bucket_size": "auto",
  62. "stage3_prefetch_bucket_size": "auto",
  63. "stage3_param_persistence_threshold": "auto",
  64. "stage3_max_live_parameters": "auto",
  65. "stage3_max_reuse_distance": "auto",
  66. },
  67. "steps_per_print": 2000,
  68. "train_batch_size": train_batch_size,
  69. "train_micro_batch_size_per_gpu": 1,
  70. "wall_clock_breakdown": False
  71. }
  72. return ds_config