llama_attn_hijack.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import math
  2. import sys
  3. import torch
  4. import torch.nn as nn
  5. import transformers.models.llama.modeling_llama
  6. from typing import Optional
  7. from typing import Tuple
  8. import modules.shared as shared
  9. if shared.args.xformers:
  10. try:
  11. import xformers.ops
  12. except Exception:
  13. print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
  14. def hijack_llama_attention():
  15. if shared.args.xformers:
  16. transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
  17. print("Replaced attention with xformers_attention")
  18. elif shared.args.sdp_attention:
  19. transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
  20. print("Replaced attention with sdp_attention")
  21. def xformers_forward(
  22. self,
  23. hidden_states: torch.Tensor,
  24. attention_mask: Optional[torch.Tensor] = None,
  25. position_ids: Optional[torch.LongTensor] = None,
  26. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  27. output_attentions: bool = False,
  28. use_cache: bool = False,
  29. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  30. bsz, q_len, _ = hidden_states.size()
  31. query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  32. key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  33. value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  34. kv_seq_len = key_states.shape[-2]
  35. if past_key_value is not None:
  36. kv_seq_len += past_key_value[0].shape[-2]
  37. cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  38. query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  39. # [bsz, nh, t, hd]
  40. if past_key_value is not None:
  41. # reuse k, v, self_attention
  42. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  43. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  44. past_key_value = (key_states, value_states) if use_cache else None
  45. #We only apply xformers optimizations if we don't need to output the whole attention matrix
  46. if not output_attentions:
  47. dtype = query_states.dtype
  48. query_states = query_states.transpose(1, 2)
  49. key_states = key_states.transpose(1, 2)
  50. value_states = value_states.transpose(1, 2)
  51. #This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
  52. #We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
  53. if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
  54. # input and output should be of form (bsz, q_len, num_heads, head_dim)
  55. attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
  56. else:
  57. # input and output should be of form (bsz, q_len, num_heads, head_dim)
  58. attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
  59. attn_weights = None
  60. else:
  61. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  62. if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
  63. raise ValueError(
  64. f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
  65. f" {attn_weights.size()}"
  66. )
  67. if attention_mask is not None:
  68. if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
  69. raise ValueError(
  70. f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
  71. )
  72. attn_weights = attn_weights + attention_mask
  73. attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
  74. # upcast attention to fp32
  75. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  76. attn_output = torch.matmul(attn_weights, value_states)
  77. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  78. raise ValueError(
  79. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  80. f" {attn_output.size()}"
  81. )
  82. attn_output = attn_output.transpose(1, 2)
  83. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  84. attn_output = self.o_proj(attn_output)
  85. return attn_output, attn_weights, past_key_value
  86. def sdp_attention_forward(
  87. self,
  88. hidden_states: torch.Tensor,
  89. attention_mask: Optional[torch.Tensor] = None,
  90. position_ids: Optional[torch.LongTensor] = None,
  91. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  92. output_attentions: bool = False,
  93. use_cache: bool = False,
  94. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  95. bsz, q_len, _ = hidden_states.size()
  96. query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  97. key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  98. value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  99. kv_seq_len = key_states.shape[-2]
  100. if past_key_value is not None:
  101. kv_seq_len += past_key_value[0].shape[-2]
  102. cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  103. query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  104. # [bsz, nh, t, hd]
  105. if past_key_value is not None:
  106. # reuse k, v, self_attention
  107. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  108. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  109. past_key_value = (key_states, value_states) if use_cache else None
  110. #We only apply sdp attention if we don't need to output the whole attention matrix
  111. if not output_attentions:
  112. attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
  113. attn_weights = None
  114. else:
  115. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  116. if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
  117. raise ValueError(
  118. f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
  119. f" {attn_weights.size()}"
  120. )
  121. if attention_mask is not None:
  122. if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
  123. raise ValueError(
  124. f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
  125. )
  126. attn_weights = attn_weights + attention_mask
  127. attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
  128. # upcast attention to fp32
  129. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  130. attn_output = torch.matmul(attn_weights, value_states)
  131. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  132. raise ValueError(
  133. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  134. f" {attn_output.size()}"
  135. )
  136. attn_output = attn_output.transpose(1, 2)
  137. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  138. attn_output = self.o_proj(attn_output)
  139. return attn_output, attn_weights, past_key_value