当前位置:   article > 正文

Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation_conv-tasnet: surpassing ideal time-frequencymagnit

conv-tasnet: surpassing ideal time-frequencymagnitude masking for speech sep

一、模型架构

卷积时域音频分离网络(convt - tasnet)包括三个处理阶段,如(A)所示:编码器、分离和解码器。首先,使用编码器模块将混合波形的短段转换为其在中间特征空间中的相应表示。然后使用这种表示来估计每个时间步长的每个源的乘法函数(掩码)。然后通过使用解码器模块转换掩码编码器特征来重建源波形。

二、代码

1、TCN

  1. import numpy as np
  2. import os
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torch.autograd import Variable
  7. class cLN(nn.Module):
  8. def __init__(self, dimension, eps = 1e-8, trainable=True):
  9. super(cLN, self).__init__()
  10. self.eps = eps
  11. if trainable:
  12. self.gain = nn.Parameter(torch.ones(1, dimension, 1))
  13. self.bias = nn.Parameter(torch.zeros(1, dimension, 1))
  14. else:
  15. self.gain = Variable(torch.ones(1, dimension, 1), requires_grad=False)
  16. self.bias = Variable(torch.zeros(1, dimension, 1), requires_grad=False)
  17. def forward(self, input):
  18. # input size: (Batch, Freq, Time)
  19. # cumulative mean for each time step
  20. batch_size = input.size(0)
  21. channel = input.size(1)
  22. time_step = input.size(2)
  23. step_sum = input.sum(1) # B, T
  24. step_pow_sum = input.pow(2).sum(1) # B, T
  25. cum_sum = torch.cumsum(step_sum, dim=1) # B, T
  26. cum_pow_sum = torch.cumsum(step_pow_sum, dim=1) # B, T
  27. entry_cnt = np.arange(channel, channel*(time_step+1), channel)
  28. entry_cnt = torch.from_numpy(entry_cnt).type(input.type())
  29. entry_cnt = entry_cnt.view(1, -1).expand_as(cum_sum)
  30. cum_mean = cum_sum / entry_cnt # B, T
  31. cum_var = (cum_pow_sum - 2*cum_mean*cum_sum) / entry_cnt + cum_mean.pow(2) # B, T
  32. cum_std = (cum_var + self.eps).sqrt() # B, T
  33. cum_mean = cum_mean.unsqueeze(1)
  34. cum_std = cum_std.unsqueeze(1)
  35. x = (input - cum_mean.expand_as(input)) / cum_std.expand_as(input)
  36. return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type())
  37. def repackage_hidden(h):
  38. """
  39. Wraps hidden states in new Variables, to detach them from their history.
  40. """
  41. if type(h) == Variable:
  42. return Variable(h.data)
  43. else:
  44. return tuple(repackage_hidden(v) for v in h)
  45. class MultiRNN(nn.Module):
  46. """
  47. Container module for multiple stacked RNN layers.
  48. args:
  49. rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'.
  50. input_size: int, dimension of the input feature. The input should have shape
  51. (batch, seq_len, input_size).
  52. hidden_size: int, dimension of the hidden state. The corresponding output should
  53. have shape (batch, seq_len, hidden_size).
  54. num_layers: int, number of stacked RNN layers. Default is 1.
  55. bidirectional: bool, whether the RNN layers are bidirectional. Default is False.
  56. """
  57. def __init__(self, rnn_type, input_size, hidden_size, dropout=0, num_layers=1, bidirectional=False):
  58. super(MultiRNN, self).__init__()
  59. self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, num_layers, dropout=dropout,
  60. batch_first=True, bidirectional=bidirectional)
  61. self.rnn_type = rnn_type
  62. self.hidden_size = hidden_size
  63. self.num_layers = num_layers
  64. self.num_direction = int(bidirectional) + 1
  65. def forward(self, input):
  66. hidden = self.init_hidden(input.size(0))
  67. self.rnn.flatten_parameters()
  68. return self.rnn(input, hidden)
  69. def init_hidden(self, batch_size):
  70. weight = next(self.parameters()).data
  71. if self.rnn_type == 'LSTM':
  72. return (Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_()),
  73. Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_()))
  74. else:
  75. return Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_())
  76. class FCLayer(nn.Module):
  77. """
  78. Container module for a fully-connected layer.
  79. args:
  80. input_size: int, dimension of the input feature. The input should have shape
  81. (batch, input_size).
  82. hidden_size: int, dimension of the output. The corresponding output should
  83. have shape (batch, hidden_size).
  84. nonlinearity: string, the nonlinearity applied to the transformation. Default is None.
  85. """
  86. def __init__(self, input_size, hidden_size, bias=True, nonlinearity=None):
  87. super(FCLayer, self).__init__()
  88. self.input_size = input_size
  89. self.hidden_size = hidden_size
  90. self.bias = bias
  91. self.FC = nn.Linear(self.input_size, self.hidden_size, bias=bias)
  92. if nonlinearity:
  93. self.nonlinearity = getattr(F, nonlinearity)
  94. else:
  95. self.nonlinearity = None
  96. self.init_hidden()
  97. def forward(self, input):
  98. if self.nonlinearity is not None:
  99. return self.nonlinearity(self.FC(input))
  100. else:
  101. return self.FC(input)
  102. def init_hidden(self):
  103. initrange = 1. / np.sqrt(self.input_size * self.hidden_size)
  104. self.FC.weight.data.uniform_(-initrange, initrange)
  105. if self.bias:
  106. self.FC.bias.data.fill_(0)
  107. # 1 × 1 convD 模块
  108. class DepthConv1d(nn.Module):
  109. def __init__(self, input_channel, hidden_channel, kernel, padding, dilation=1, skip=True, causal=False):
  110. super(DepthConv1d, self).__init__()
  111. self.causal = causal
  112. self.skip = skip
  113. self.conv1d = nn.Conv1d(input_channel, hidden_channel, 1)
  114. if self.causal:
  115. self.padding = (kernel - 1) * dilation
  116. else:
  117. self.padding = padding
  118. self.dconv1d = nn.Conv1d(hidden_channel, hidden_channel, kernel, dilation=dilation,
  119. groups=hidden_channel,
  120. padding=self.padding)
  121. self.res_out = nn.Conv1d(hidden_channel, input_channel, 1)
  122. self.nonlinearity1 = nn.PReLU()
  123. self.nonlinearity2 = nn.PReLU()
  124. if self.causal:
  125. self.reg1 = cLN(hidden_channel, eps=1e-08)
  126. self.reg2 = cLN(hidden_channel, eps=1e-08)
  127. else:
  128. self.reg1 = nn.GroupNorm(1, hidden_channel, eps=1e-08)
  129. self.reg2 = nn.GroupNorm(1, hidden_channel, eps=1e-08)
  130. if self.skip:
  131. self.skip_out = nn.Conv1d(hidden_channel, input_channel, 1)
  132. def forward(self, input):
  133. output = self.reg1(self.nonlinearity1(self.conv1d(input)))
  134. if self.causal:
  135. output = self.reg2(self.nonlinearity2(self.dconv1d(output)[:,:,:-self.padding]))
  136. else:
  137. output = self.reg2(self.nonlinearity2(self.dconv1d(output)))
  138. residual = self.res_out(output)
  139. if self.skip:
  140. skip = self.skip_out(output)
  141. return residual, skip
  142. else:
  143. return residual
  144. class TCN(nn.Module):
  145. def __init__(self, input_dim, output_dim, BN_dim, hidden_dim,
  146. layer, stack, kernel=3, skip=True,
  147. causal=False, dilated=True):
  148. super(TCN, self).__init__()
  149. # input is a sequence of features of shape (B, N, L)
  150. # normalization
  151. if not causal:
  152. self.LN = nn.GroupNorm(1, input_dim, eps=1e-8)
  153. else:
  154. self.LN = cLN(input_dim, eps=1e-8)
  155. self.BN = nn.Conv1d(input_dim, BN_dim, 1)
  156. # TCN for feature extraction
  157. self.receptive_field = 0
  158. self.dilated = dilated
  159. self.TCN = nn.ModuleList([])
  160. for s in range(stack):
  161. for i in range(layer):
  162. if self.dilated:
  163. self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=2**i, padding=2**i, skip=skip, causal=causal))
  164. else:
  165. self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=1, padding=1, skip=skip, causal=causal))
  166. if i == 0 and s == 0:
  167. self.receptive_field += kernel
  168. else:
  169. if self.dilated:
  170. self.receptive_field += (kernel - 1) * 2**i
  171. else:
  172. self.receptive_field += (kernel - 1)
  173. #print("Receptive field: {:3d} frames.".format(self.receptive_field))
  174. # output layer
  175. self.output = nn.Sequential(nn.PReLU(),
  176. nn.Conv1d(BN_dim, output_dim, 1)
  177. )
  178. self.skip = skip
  179. def forward(self, input):
  180. # input shape: (B, N, L)
  181. # normalization
  182. output = self.BN(self.LN(input))
  183. # pass to TCN
  184. if self.skip:
  185. skip_connection = 0.
  186. for i in range(len(self.TCN)):
  187. residual, skip = self.TCN[i](output)
  188. output = output + residual
  189. skip_connection = skip_connection + skip
  190. else:
  191. for i in range(len(self.TCN)):
  192. residual = self.TCN[i](output)
  193. output = output + residual
  194. # output layer
  195. if self.skip:
  196. output = self.output(skip_connection)
  197. else:
  198. output = self.output(output)
  199. return output

2、Conv-TasNet

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.autograd import Variable
  5. # Conv-TasNet
  6. class TasNet(nn.Module):
  7. def __init__(self, enc_dim=512, feature_dim=128, sr=16000, win=2, layer=8, stack=3,
  8. kernel=3, num_spk=2, causal=False):
  9. super(TasNet, self).__init__()
  10. # hyper parameters
  11. self.num_spk = num_spk
  12. self.enc_dim = enc_dim
  13. self.feature_dim = feature_dim
  14. self.win = int(sr*win/1000)
  15. self.stride = self.win // 2
  16. self.layer = layer
  17. self.stack = stack
  18. self.kernel = kernel
  19. self.causal = causal
  20. # input encoder
  21. self.encoder = nn.Conv1d(1, self.enc_dim, self.win, bias=False, stride=self.stride)
  22. # TCN separator
  23. self.TCN = TCN(self.enc_dim, self.enc_dim*self.num_spk, self.feature_dim, self.feature_dim*4,
  24. self.layer, self.stack, self.kernel, causal=self.causal)
  25. self.receptive_field = self.TCN.receptive_field
  26. # output decoder
  27. self.decoder = nn.ConvTranspose1d(self.enc_dim, 1, self.win, bias=False, stride=self.stride)
  28. def pad_signal(self, input):
  29. # input is the waveforms: (B, T) or (B, 1, T)
  30. # reshape and padding
  31. if input.dim() not in [2, 3]:
  32. raise RuntimeError("Input can only be 2 or 3 dimensional.")
  33. if input.dim() == 2:
  34. input = input.unsqueeze(1)
  35. batch_size = input.size(0)
  36. nsample = input.size(2)
  37. rest = self.win - (self.stride + nsample % self.win) % self.win
  38. if rest > 0:
  39. pad = Variable(torch.zeros(batch_size, 1, rest)).type(input.type())
  40. input = torch.cat([input, pad], 2)
  41. pad_aux = Variable(torch.zeros(batch_size, 1, self.stride)).type(input.type())
  42. input = torch.cat([pad_aux, input, pad_aux], 2)
  43. return input, rest
  44. def forward(self, input):
  45. # padding
  46. output, rest = self.pad_signal(input)
  47. batch_size = output.size(0)
  48. # waveform encoder
  49. enc_output = self.encoder(output) # B, N, L
  50. # generate masks
  51. masks = torch.sigmoid(self.TCN(enc_output)).view(batch_size, self.num_spk, self.enc_dim, -1) # B, C, N, L
  52. masked_output = enc_output.unsqueeze(1) * masks # B, C, N, L
  53. # waveform decoder
  54. output = self.decoder(masked_output.view(batch_size*self.num_spk, self.enc_dim, -1)) # B*C, 1, L
  55. output = output[:,:,self.stride:-(rest+self.stride)].contiguous() # B*C, 1, L
  56. output = output.view(batch_size, self.num_spk, -1) # B, C, T
  57. return output
  58. def test_conv_tasnet():
  59. x = torch.rand(2, 32000)
  60. nnet = TasNet()
  61. x = nnet(x)
  62. s1 = x[0]
  63. print(s1.shape)
  64. for name,param in nnet.named_parameters():
  65. print(name,param.shape)
  66. if __name__ == "__main__":
  67. test_conv_tasnet()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/372362
推荐阅读
相关标签
  

闽ICP备14008679号