Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement transform precoding (DFT-s-OFDM) #459

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions doc/source/api/nr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ PUSCHTransmitter
:exclude-members: build, call
:members:

PUSCHTransformDeprecoder
------------------------
.. autoclass:: sionna.nr.PUSCHTransformDeprecoder
:exclude-members: call
:members:

PUSCHTransformPrecoder
----------------------
.. autoclass:: sionna.nr.PUSCHTransformPrecoder
:exclude-members: call
:members:


Transport Block
***************
Expand Down
9 changes: 9 additions & 0 deletions sionna/mimo/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class LinearDetector(Layer):
constellation point indices instead of soft-values.
Defaults to `False`.

post_equalizer_transformation: None or Layer
Optional layer that applies a transformation after the equalizer and
before the demapper. This can be used to apply transform precoding
when DFT-s-OFDM is enabled in NR PUSCH.

dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
The dtype of ``y``. Defaults to tf.complex64.
The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
Expand Down Expand Up @@ -96,11 +101,13 @@ def __init__(self,
num_bits_per_symbol=None,
constellation=None,
hard_out=False,
post_equalizer_transformation=None,
dtype=tf.complex64,
**kwargs):
super().__init__(dtype=dtype, **kwargs)
self._output = output
self._hard_out = hard_out
self._post_equalizer_transformation = post_equalizer_transformation

# Determine the equalizer to use
if isinstance(equalizer, str):
Expand Down Expand Up @@ -137,6 +144,8 @@ def __init__(self,

def call(self, inputs):
x_hat, no_eff = self._equalizer(*inputs)
if self._post_equalizer_transformation is not None:
x_hat, no_eff = self._post_equalizer_transformation(x_hat, no_eff)
z = self._demapper([x_hat, no_eff])

# Reshape to the expected output shape
Expand Down
3 changes: 2 additions & 1 deletion sionna/nr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from .pusch_dmrs_config import PUSCHDMRSConfig
from .pusch_pilot_pattern import PUSCHPilotPattern
from .pusch_precoder import PUSCHPrecoder
from .pusch_transform_precoder import PUSCHTransformPrecoder, PUSCHTransformDeprecoder
from .pusch_transmitter import PUSCHTransmitter
from .pusch_receiver import PUSCHReceiver
from .pusch_channel_estimation import PUSCHLSChannelEstimator
from .tb_config import TBConfig
from .utils import generate_prng_seq, select_mcs, calculate_tb_size
from .utils import generate_prng_seq, generate_low_papr_seq_type_1, select_mcs, calculate_tb_size
from .tb_encoder import TBEncoder
from .tb_decoder import TBDecoder
from .layer_mapping import LayerMapper, LayerDemapper
120 changes: 103 additions & 17 deletions sionna/nr/pusch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
"""
# pylint: disable=line-too-long

import functools
import numpy as np
from .utils import generate_prng_seq
from .utils import generate_prng_seq, generate_low_papr_seq_type_1
from .config import Config
from sionna import nr
from .utils import calculate_tb_size
Expand Down Expand Up @@ -233,7 +234,7 @@ def n_rnti(self, value):
assert value in range(65536), "n_rnti must be in [0, 65535]"
self._n_rnti = value

#---transform_precoding---#
#---precoding---#
@property
def precoding(self):
"""
Expand Down Expand Up @@ -427,9 +428,9 @@ def n(self):
used for DMRS generation
"""
if self.dmrs.config_type==1:
n_max = self.num_resource_blocks*12//4 -1
n_max = self.num_effective_subcarriers//4 -1
elif self.dmrs.config_type==2:
n_max = self.num_resource_blocks*12//6 -1
n_max = self.num_effective_subcarriers//6 -1
return list(range(n_max+1))

@property
Expand All @@ -450,6 +451,31 @@ def num_resource_blocks(self):
else:
return self.n_size_bwp

@property
def num_effective_resource_blocks(self):
"""
int, read-only : Number of allocated resource blocks for the
PUSCH transmissions, that are actually used (can differ from
num_subcarriers when transform precoding is enabled,
because of constraints on the largest prime factor of the
subcarrier count)
"""
@functools.lru_cache
def adjust_prbs_to_prime_factor_constraints(prbs):
# Decreases the number of PRBs until the largest prime factor is at most 5
for eff_prbs in range(prbs, 1, -1):
n = eff_prbs
for p in [2, 3, 5]:
while n % p == 0:
n /= p
if n == 1:
return eff_prbs

if self.transform_precoding:
return adjust_prbs_to_prime_factor_constraints(self.num_resource_blocks)
else:
return self.num_resource_blocks

@property
def num_subcarriers(self):
"""
Expand All @@ -458,6 +484,17 @@ def num_subcarriers(self):
"""
return 12*self.num_resource_blocks

@property
def num_effective_subcarriers(self):
"""
int, read-only : Number of allocated subcarriers for the
PUSCH transmissions, that are actually used (can differ from
num_subcarriers when transform precoding is enabled,
because of constraints on the largest prime factor of the
subcarrier count)
"""
return 12 * self.num_effective_resource_blocks

@property
def num_res_per_prb(self):
"""
Expand Down Expand Up @@ -488,7 +525,7 @@ def dmrs_mask(self):
resource elements in the resource grid. `True` corresponds to
resource elements on which no data is transmitted.
"""
mask = np.zeros([self.num_subcarriers,
mask = np.zeros([self.num_effective_subcarriers,
self.carrier.num_symbols_per_slot],
dtype=bool)

Expand All @@ -503,7 +540,7 @@ def dmrs_mask(self):
cdm_ind[:,i] = np.array([0,1, 6, 7])+2*i

for i in self.dmrs_symbol_indices:
for j in range(self.num_resource_blocks):
for j in range(self.num_effective_resource_blocks):
for k in range(num_cdm_groups):
mask[cdm_ind[:, k] + 12*j, i] = True
return mask
Expand All @@ -518,7 +555,7 @@ def dmrs_grid(self):
This property returns for each configured DMRS port an empty
resource grid filled with DMRS signals as defined in
Section 6.4.1.1 [3GPP38211]. Not all possible options are implemented,
e.g., frequency hopping and transform precoding are not available.
e.g., frequency hopping is not available.

This property provides the *unprecoded* DMRS for each configured DMRS port.
Precoding might be applied to map the DMRS to the antenna ports. However,
Expand All @@ -536,7 +573,7 @@ def dmrs_grid(self):

# Generate empty resource grid for each port
a_tilde = np.zeros([len(self.dmrs.dmrs_port_set),
self.num_subcarriers,
self.num_effective_subcarriers,
self.carrier.num_symbols_per_slot],
dtype=complex)

Expand All @@ -546,15 +583,23 @@ def dmrs_grid(self):
# For every l_prime
for l_prime in self.l_prime:

# Compute c_init
l = l_bar + l_prime
c_init = self.c_init(l)

# Generate RNG
c = generate_prng_seq(2*self.num_subcarriers, c_init=c_init)
if self.transform_precoding:
if self.dmrs.n_sid is None:
n_id = self.carrier.n_cell_id
else:
n_id = self.dmrs.n_sid
r = generate_low_papr_seq_type_1(self.num_effective_subcarriers // 2, n_id % 30, 0, 0)
else:
# Compute c_init
c_init = self.c_init(l)

# Generate RNG
c = generate_prng_seq(2*self.num_effective_subcarriers, c_init=c_init)

# Map to QAM
r = 1/np.sqrt(2)*((1-2*c[::2]) + 1j*(1-2*c[1::2]))
# Map to QAM
r = 1/np.sqrt(2)*((1-2*c[::2]) + 1j*(1-2*c[1::2]))

# For every port in the dmrs port set
for j_ind, _ in enumerate(self.dmrs.dmrs_port_set):
Expand Down Expand Up @@ -625,8 +670,38 @@ def precoding_matrix(self):

w /= np.sqrt(2)

# Table 6.3.1.5-2
elif self.transform_precoding and self.num_antenna_ports == 4:
w = np.zeros([28, 4, 1], complex)

# TPMI index 0-7
w[:8,0,0] = [ 1, 0, 0, 0, 1, 1, 1, 1]
w[:8,1,0] = [ 0, 1, 0, 0, 0, 0, 0, 0]
w[:8,2,0] = [ 0, 0, 1, 0, 1, -1, 1j,-1j]
w[:8,3,0] = [ 0, 0, 0, 1, 0, 0, 0, 0]

# TPMI index 8-15
w[8:16,0,0] = [ 0, 0, 0, 0, 1, 1, 1, 1]
w[8:16,1,0] = [ 1, 1, 1, 1, 1, 1, 1, 1]
w[8:16,2,0] = [ 0, 0, 0, 0, 1, 1j, -1,-1j]
w[8:16,3,0] = [ 1, -1, 1j,-1j, -1, 1j, 1,-1j]

# TPMI index 16-23
w[16:24,0,0] = [ 1, 1, 1, 1, 1, 1, 1, 1]
w[16:24,1,0] = [ 1j, 1j, 1j, 1j, -1, -1, -1, -1]
w[16:24,2,0] = [ 1, 1j, -1,-1j, 1, 1j, -1,-1j]
w[16:24,3,0] = [ 1j, 1,-1j, -1, 1,-1j, -1, 1j]

# TPMI index 24-27
w[24:28,0,0] = [ 1, 1, 1, 1]
w[24:28,1,0] = [-1j,-1j,-1j,-1j]
w[24:28,2,0] = [ 1, 1j, -1,-1j]
w[24:28,3,0] = [-1j, -1, 1j, 1]

w /= 2

# Table 6.3.1.5-3
elif self.num_antenna_ports==4:
elif not self.transform_precoding and self.num_antenna_ports==4:
w = np.zeros([28,4,1], complex)

# TPMI index 0-7
Expand Down Expand Up @@ -825,7 +900,7 @@ def num_coded_bits(self):
n_re_per_prb = self.num_res_per_prb - self.num_ov

# number of allocated REs
n_re = n_re_per_prb * self.num_resource_blocks
n_re = n_re_per_prb * self.num_effective_resource_blocks

# total number of bits per slot
num_coded_bits = int(self.tb.tb_scaling * self.tb.num_bits_per_symbol \
Expand All @@ -842,7 +917,7 @@ def tb_size(self):

# number of allocated REs
# the max. number of REs per PRB is limited to 156 in 38.214
n_re = min(156, n_re_per_prb) * self.num_resource_blocks
n_re = min(156, n_re_per_prb) * self.num_effective_resource_blocks

# include tb_scaling as defined in Tab. 5.1.3.2-2 38.214
target_tb_size = int(self.tb.target_coderate * self.tb.tb_scaling \
Expand Down Expand Up @@ -924,6 +999,14 @@ def check_config(self):
assert self.num_layers == self.num_antenna_ports,\
"num_layers must be == num_antenna_ports"

if self.transform_precoding:
assert self.num_layers == 1,\
"When transform precoding is used, only a single MIMO layer is supported"
assert self.dmrs.config_type == 1, \
"When transform precoding is used, DMRS config type must be 1"
assert self.dmrs.num_cdm_groups_without_data == 2, \
"When transform precoding is used, num_cdm_groups_without_data must be 2"

# Check Tables 6.4.1.1.3-3/4 are valid
if self.dmrs.length==1:
if self.mapping_type=="A":
Expand Down Expand Up @@ -1026,18 +1109,21 @@ def check_pusch_configs(pusch_configs):

# Create dictionary with extracted configuration parameters
pc = pusch_configs[0]
pc.tb.transform_precoding = pc.transform_precoding
carrier = pc.carrier

params = {
"num_bits_per_symbol" : pc.tb.num_bits_per_symbol,
"num_tx" : len(pusch_configs),
"num_layers" : pc.num_layers,
"num_subcarriers" : pc.num_subcarriers,
"num_effective_subcarriers": pc.num_effective_subcarriers,
"num_ofdm_symbols" : pc.symbol_allocation[1],
"subcarrier_spacing" : pc.carrier.subcarrier_spacing*1e3,
"num_antenna_ports" : pc.num_antenna_ports,
"precoding" : pc.precoding,
"precoding_matrices" : [],
"transform_precoding" : pc.transform_precoding,
"pusch_config" : pc,
"carrier_config" : pc.carrier,
"num_coded_bits" : pc.num_coded_bits,
Expand Down
19 changes: 17 additions & 2 deletions sionna/nr/pusch_dmrs_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,29 @@ def n_id(self, value):
if value is None:
self._n_id = None
elif isinstance(value, int):
assert value in list(range(65536)), "n_id must be in [0, 65535]"
assert value in range(65536), "n_id must be in [0, 65535]"
self._n_id = [value, value]
else:
assert len(value)==2, "n_id must be either [] or a two-tuple"
for e in value:
assert e in list(range(65536)), "Each element of n_id must be in [0, 65535]"
assert e in range(65536), "Each element of n_id must be in [0, 65535]"
self._n_id = value

#---n_sid---#
@property
def n_sid(self):
r"""
None (default), [0,...,1007] : DMRS scrambling identity for DFT-s-OFDM
:math:`n_\text{ID}^\text{PUSCH}`
"""
self._ifndef("n_sid", None)
return self._n_sid

@n_sid.setter
def n_sid(self, value):
assert value is None or (isinstance(value, int) and value in range(1008)), "n_sid must None or in [0, 1007]"
self._n_sid = value

#---n_scid---#
@property
def n_scid(self):
Expand Down
17 changes: 11 additions & 6 deletions sionna/nr/pusch_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sionna.ofdm import OFDMDemodulator, LinearDetector
from sionna.utils import insert_dims
from sionna.channel import time_to_ofdm_channel
from .pusch_transform_precoder import PUSCHTransformDeprecoder

class PUSCHReceiver(Layer):
# pylint: disable=line-too-long
Expand Down Expand Up @@ -197,14 +198,19 @@ def __init__(self,
# Use or create default MIMODetector
if mimo_detector is None:
# Default MIMO detector
transformation = PUSCHTransformDeprecoder(pusch_transmitter.resource_grid.num_effective_subcarriers,
dtype=dtype) if pusch_transmitter._transform_precoding else None
self._mimo_detector = LinearDetector("lmmse", "bit", "maxlog",
pusch_transmitter.resource_grid,
self._stream_management,
"qam",
pusch_transmitter._num_bits_per_symbol,
dtype=dtype)
pusch_transmitter.resource_grid,
self._stream_management,
"qam",
pusch_transmitter._num_bits_per_symbol,
post_equalizer_transformation=transformation,
dtype=dtype)
else:
# User-provided MIMO detector
if pusch_transmitter._transform_precoding:
print("WARNING: Using custom mimo detector which might not support transform precoding.")
self._mimo_detector = mimo_detector

# Create LayerDemapper
Expand Down Expand Up @@ -248,7 +254,6 @@ def call(self, inputs):
if self._input_domain=="time":
h = time_to_ofdm_channel(h, self.resource_grid, self._l_min)


if self._w is not None:
# Reshape h to put channel matrix dimensions last
# [batch size, num_rx, num_tx, num_ofdm_symbols,...
Expand Down
Loading