class Linear(nn.Linear, LoraLayer):
"""Lora implemented in a dense layer with DoRA support."""
def __init__(self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False,
merge_weights: bool = True,
Wdecompose: bool = False,
dora_simple: bool = True,
**kwargs):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha,
lora_dropout=lora_dropout, merge_weights=merge_weights)
self.weight_m_wdecomp = nn.Linear(1, out_features, bias=False)
self.fan_in_fan_out = fan_in_fan_out
self.Wdecompose = Wdecompose
self.dora_simple = dora_simple
if self.Wdecompose == False:
if r > 0:
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
self.scaling = self.lora_alpha / self.r
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, "lora_A"):
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def train(self, mode: bool = True):
nn.Linear.train(self, mode)
if self.Wdecompose == False:
self.lora_A.train(mode)
self.lora_B.train(mode)
self.weight_m_wdecomp.train(mode)
def forward(self, x: torch.Tensor):
previous_dtype = self.weight.dtype
if self.disable_adapters:
raise NotImplementedError
elif self.Wdecompose and not self.merged:
norm_scale = self.weight_m_wdecomp.weight.view(-1) / (torch.linalg.norm(self.weight, dim=1))
org_result = F.linear(x, transpose(self.weight, self.fan_in_fan_out))
result = org_result + (norm_scale - 1) * (F.linear(self.lora_dropout(x), transpose(self.weight, self.fan_in_fan_out)))
if not self.bias is None:
result += self.bias.view(1, -1).expand_as(result)
elif self.r > 0 and not self.merged:
new_weight_v = self.weight + (self.lora_B.weight @ self.lora_A.weight) * self.scaling
if self.dora_simple:
norm_scale = self.weight_m_wdecomp.weight.view(-1) / (torch.linalg.norm(new_weight_v, dim=1)).detach()
else:
norm_scale = self.weight_m_wdecomp.weight.view(-1) / (torch.linalg.norm(new_weight_v, dim=1))
org_result = F.linear(x, transpose(self.weight, self.fan_in_fan_out))
dropout_x = self.lora_dropout(x)
result = org_result + (norm_scale - 1) * (F.linear(dropout_x, transpose(self.weight, self.fan_in_fan_out)))
if not self.bias is None:
result += self.bias.view(1, -1).expand_as(result)
result += (norm_scale * (self.lora_B(self.lora_A(dropout_x.to(self.lora_A.weight.dtype)))) * self.scaling)
else:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if result.dtype != previous_dtype:
result = result.to(previous_dtype)
return result