mc_conv_tasnet
Adopted from https://github.com/kaituoxu/Conv-TasNet
- class clarity.enhancer.dnn.mc_conv_tasnet.ChannelwiseLayerNorm(channel_size)[source]
Bases:
Module
Channel-wise Layer Normalization (cLN)
- forward(y)[source]
- Parameters:
y – [M, N, K], M is batch size, N is channel size, K is length
- Returns:
[M, N, K]
- Return type:
cLN_y
- reset_parameters()[source]
- class clarity.enhancer.dnn.mc_conv_tasnet.Chomp1d(chomp_size)[source]
Bases:
Module
To ensure the output length is the same as the input.
- forward(x)[source]
- Parameters:
x – [M, H, Kpad]
- Returns:
[M, H, K]
- class clarity.enhancer.dnn.mc_conv_tasnet.ConvTasNet(N_spec, N_spat, L, B, H, P, X, R, C, num_channels, norm_type='cLN', causal=False, mask_nonlinear='relu', device=None)[source]
Bases:
Module
- forward(mixture)[source]
- Parameters:
mixture – [M, T], M is batch size, T is #samples
- Returns:
[M, C, T]
- Return type:
est_source
- class clarity.enhancer.dnn.mc_conv_tasnet.Decoder(N, L, device: str | None = None)[source]
Bases:
Module
- forward(mixture_w, est_mask)[source]
- Parameters:
mixture_w – [M, N, K]
est_mask – [M, C, N, K]
- Returns:
[M, C, T]
- Return type:
est_source
- class clarity.enhancer.dnn.mc_conv_tasnet.DepthwiseSeparableConv(in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type='gLN', causal=False)[source]
Bases:
Module
- forward(x)[source]
- Parameters:
x – [M, H, K]
- Returns:
[M, B, K]
- Return type:
result
- class clarity.enhancer.dnn.mc_conv_tasnet.GlobalLayerNorm(channel_size)[source]
Bases:
Module
Global Layer Normalization (gLN)
- forward(y)[source]
- Parameters:
y – [M, N, K], M is batch size, N is channel size, K is length
- Returns:
[M, N, K]
- Return type:
gLN_y
- reset_parameters()[source]
- class clarity.enhancer.dnn.mc_conv_tasnet.SpatialEncoder(L, N, num_channels)[source]
Bases:
Module
Estimation of the nonnegative mixture weight by a 1-D conv layer.
- forward(mixture)[source]
- Parameters:
mixture – [M, num_channels, T], M is batch size, T is #samples
- Returns:
[M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
- Return type:
mixture_w
- class clarity.enhancer.dnn.mc_conv_tasnet.SpectralEncoder(L, N)[source]
Bases:
Module
Estimation of the nonnegative mixture weight by a 1-D conv layer.
- forward(mixture)[source]
- Parameters:
mixture – [M, T], M is batch size, T is #samples
- Returns:
[M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
- Return type:
mixture_w
- class clarity.enhancer.dnn.mc_conv_tasnet.TemporalBlock(in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type='gLN', causal=False)[source]
Bases:
Module
- forward(x)[source]
- Parameters:
x – [M, B, K]
- Returns:
[M, B, K]
- class clarity.enhancer.dnn.mc_conv_tasnet.TemporalConvNet(N_spec, N_spat, B, H, P, X, R, C, norm_type='gLN', causal=False, mask_nonlinear='relu')[source]
Bases:
Module
- forward(mixture_w)[source]
Keep this API same with TasNet
- Parameters:
mixture_w – [M, N, K], M is batch size
- Returns:
[M, C, N, K]
- Return type:
est_mask
- clarity.enhancer.dnn.mc_conv_tasnet.chose_norm(norm_type, channel_size)[source]
The input of normlization will be (M, C, K), where M is batch size, C is channel size and K is sequence length.
- Parameters:
() (channel_size)
()
Returns:
- clarity.enhancer.dnn.mc_conv_tasnet.overlap_and_add(signal, frame_step, device)[source]
Reconstructs a signal from a framed representation.
Adds potentially overlapping frames of a signal with shape […, frames, frame_length], offsetting subsequent frames by frame_step. The resulting tensor has shape […, output_size] where
output_size = (frames - 1) * frame_step + frame_length
- Parameters:
signal – A […, frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
frame_step – An integer denoting overlap offsets. Must be less than or equal to frame_length.
device – Whether to use ‘cpu’ or ‘cuda’ for processing.
- Returns:
- A Tensor with shape […, output_size] containing the overlap-added frames of
signal’s inner-most two dimensions.
output_size = (frames - 1) * frame_step + frame_length
- Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/
contrib/signal/python/ops/reconstruction_ops.py