Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] First try at TNG writing #18

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 226 additions & 4 deletions pytng/pytng.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ np.import_array()
ctypedef enum tng_function_status: TNG_SUCCESS, TNG_FAILURE, TNG_CRITICAL
ctypedef enum tng_data_type: TNG_CHAR_DATA, TNG_INT_DATA, TNG_FLOAT_DATA, TNG_DOUBLE_DATA
ctypedef enum tng_hash_mode: TNG_SKIP_HASH, TNG_USE_HASH
ctypedef enum tng_block_type: TNG_NON_TRAJECTORY_BLOCK, TNG_TRAJECTORY_BLOCK
ctypedef enum tng_compression: TNG_UNCOMPRESSED, TNG_XTC_COMPRESSION, TNG_TNG_COMPRESSION, TNG_GZIP_COMPRESSION
ctypedef enum tng_particle_dependency: TNG_NON_PARTICLE_BLOCK_DATA, TNG_PARTICLE_BLOCK_DATA

cdef long long TNG_TRAJ_BOX_SHAPE = 0x0000000010000000LL
cdef long long TNG_TRAJ_POSITIONS = 0x0000000010000001LL

status_error_message = ['OK', 'Failure', 'Critical']

Expand All @@ -28,6 +32,7 @@ cdef extern from "tng/tng_io.h":
const char *filename,
const char mode,
tng_trajectory_t *tng_data_p)

tng_function_status tng_util_trajectory_close(
tng_trajectory_t *tng_data_p)

Expand Down Expand Up @@ -66,6 +71,114 @@ cdef extern from "tng/tng_io.h":
int64_t *n_values_per_frame,
char *type)

tng_function_status tng_util_box_shape_write_interval_set(
const tng_trajectory_t tng_data,
const int64_t interval)

tng_function_status tng_util_pos_write_interval_set(
const tng_trajectory_t tng_data,
const int64_t interval)

tng_function_status tng_util_vel_write_interval_set(
const tng_trajectory_t tng_data,
const int64_t interval)

tng_function_status tng_util_force_write_interval_set(
const tng_trajectory_t tng_data,
const int64_t interval)

tng_function_status tng_util_box_shape_write(
const tng_trajectory_t tng_data,
const int64_t frame_nr,
const float *box_shape)

tng_function_status tng_util_box_shape_with_time_write(
const tng_trajectory_t tng_data,
const int64_t frame_nr,
const double time,
const float *box_shape)

tng_function_status tng_util_pos_write(
const tng_trajectory_t tng_data,
const int64_t frame_nr,
const float *positions)

tng_function_status tng_util_pos_with_time_write(
const tng_trajectory_t tng_data,
const int64_t frame_nr,
const double time,
const float *positions)

tng_function_status tng_util_vel_write(
const tng_trajectory_t tng_data,
const int64_t frame_nr,
const float *velocities)

tng_function_status tng_util_vel_with_time_write(
const tng_trajectory_t tng_data,
const int64_t frame_nr,
const double time,
const float *velocities)

tng_function_status tng_util_force_write(
const tng_trajectory_t tng_data,
const int64_t frame_nr,
const float *forces)

tng_function_status tng_util_force_with_time_write(
const tng_trajectory_t tng_data,
const int64_t frame_nr,
const double time,
const float *forces)

tng_function_status tng_implicit_num_particles_set(
const tng_trajectory_t tng_data,
const int64_t n)

tng_function_status tng_num_frames_per_frame_set_set(
const tng_trajectory_t tng_data,
const int64_t n)

tng_function_status tng_util_generic_with_time_write(
const tng_trajectory_t tng_data,
const int64_t frame_nr,
const double time,
const float *values,
const int64_t n_values_per_frame,
const int64_t block_id,
const char *block_name,
const char particle_dependency,
const char compression)

tng_function_status tng_time_per_frame_set(
const tng_trajectory_t tng_data,
const double time)

tng_function_status tng_util_generic_with_time_write(
const tng_trajectory_t tng_data,
const int64_t frame_nr,
const double time,
const float *values,
const int64_t n_values_per_frame,
const int64_t block_id,
const char *block_name,
const char particle_dependency,
const char compression)

tng_function_status tng_util_generic_write_interval_set(
const tng_trajectory_t tng_data,
const int64_t i,
const int64_t n_values_per_frame,
const int64_t block_id,
const char *block_name,
const char particle_dependency,
const char compression)


tng_function_status tng_frame_set_write(
const tng_trajectory_t tng_data,
const char hash_mode)

TNGFrame = namedtuple("TNGFrame", "positions time step box")

cdef class TNGFile:
Expand Down Expand Up @@ -108,7 +221,6 @@ cdef class TNGFile:
_mode = 'r'
elif self.mode == 'w':
_mode = 'w'
raise NotImplementedError('Writing is not implemented yet.')
else:
raise IOError('mode must be one of "r" or "w", you '
'supplied {}'.format(mode))
Expand Down Expand Up @@ -138,6 +250,9 @@ cdef class TNGFile:
if ok != TNG_SUCCESS:
raise IOError("An error ocurred reading distance unit exponent. {}".format(status_error_message[ok]))
self.distance_scale = 10.0**(exponent+9)
elif self.mode == 'w':
self._n_frames = 0 # No frame were written yet
# self._n_atoms ?

self.is_open = True
self.step = 0
Expand All @@ -146,9 +261,11 @@ cdef class TNGFile:
def close(self):
"""Make sure the file handle is closed"""
if self.is_open:
tng_util_trajectory_close(&self._traj)
ok = tng_util_trajectory_close(&self._traj)
if_not_ok(ok, "couldn't close")
self.is_open = False
self._n_frames = -1
print("closed file")

def __enter__(self):
# Support context manager
Expand Down Expand Up @@ -203,8 +320,8 @@ cdef class TNGFile:
if not self.is_open:
raise IOError('No file opened')
if self.mode != 'r':
raise IOError('File opened in mode: {}. Reading only allow '
'in mode "r"'.format('self.mode'))
raise IOError('File opened in mode: {}. Reading only allowed '
'in mode "r"'.format(self.mode))
if self.step >= self.n_frames:
self.reached_eof = True
raise StopIteration("Reached EOF in read")
Expand Down Expand Up @@ -260,10 +377,110 @@ cdef class TNGFile:
finally:
if box_shape != NULL:
free(box_shape)
# DO NOT FREE float_box or double_box here. They point to the same
# memory as box_shape

self.step += 1
return TNGFrame(xyz, time, self.step - 1, box)

def write(self,
np.ndarray[np.float32_t, ndim=2, mode='c'] positions,
np.ndarray[np.float32_t, ndim=2, mode='c'] box,
time=None):
if self.mode != 'w':
raise IOError('File opened in mode: {}. Writing only allowed '
'in mode "w"'.format(self.mode))
if not self.is_open:
raise IOError('No file currently opened')

cdef int64_t ok
cdef np.ndarray[float, ndim=2, mode='c'] xyz
cdef np.ndarray[float, ndim=2, mode='c'] box_contiguous
cdef double dt

if self._n_frames == 0:
# TODO: The number of frames per frame set should be tunable.
ok = tng_num_frames_per_frame_set_set(self._traj, 1)
if_not_ok(ok, 'Could not set the number of frames per frame set')
# The number of atoms must be set either with a full description
# of the system content (topology), or with just the number of
# particles. We should fall back on the latter, but being
# able to write the topology would be a nice addition
# in the future.
self._n_atoms = positions.shape[0]
ok = tng_implicit_num_particles_set(self._traj, self.n_atoms)
if_not_ok(ok, 'Could not set the number of particles')
# Set the writing interval to 1 for all blocks.
ok = tng_util_pos_write_interval_set(self._traj, 1)
if_not_ok(ok, 'Could not set the writing interval for positions')
# When we use the "tn_util_box_shape_*" functions to write the box
# shape, gromacs tools fail to uncompress the data block. Instead of
# using the default gzip to compress the box, we do not compress it.
# ok = tng_util_box_shape_write_interval_set(self._traj, 1)
ok = tng_util_generic_write_interval_set(
self._traj, 1, 9,
TNG_TRAJ_BOX_SHAPE,
"BOX SHAPE",
TNG_NON_PARTICLE_BLOCK_DATA,
TNG_UNCOMPRESSED
)
if_not_ok(ok, 'Could not set the writing interval for the box shape')
elif self.n_atoms != positions.shape[0]:
message = ('Only fixed number of particles is currently supported. '
'Cannot write {} particles instead of {}.'
.format(positions.shape[0], self.n_atoms))
raise NotImplementedError(message)

if time is not None:
try:
time = float(time) # Make sure time is a real
# Time is provided to this function in picoseconds,
# but functions from tng_io expect seconds.
time *= 1e-12
except ValueError:
raise ValueError('time must be a real number or None')
# The time per frame has to be set for the time to be written in
# the frames.
# To be able to set an arbitrary time, we need to set the time per
# frame to 0 and to use one frame per frame set. Using the actual
# time difference between consecutive frames can cause issues if
# the difference is negative, or if the difference is 0 and the
# frame is not the first of the frame set.
ok = tng_time_per_frame_set(self._traj, 0)
if_not_ok(ok, 'Could not set the time per frame')

box_contiguous = np.ascontiguousarray(box, dtype=np.float32)
if time is None:
ok = tng_util_box_shape_write(self._traj, self.step,
&box_contiguous[0, 0])
else:
#ok = tng_util_box_shape_with_time_write(self._traj,
# self.step,
# time,
# &box_contiguous[0, 0])
ok = tng_util_generic_with_time_write(
self._traj, self.step, time,
&box[0, 0],
9, TNG_TRAJ_BOX_SHAPE, "BOX SHAPE",
TNG_NON_PARTICLE_BLOCK_DATA,
TNG_UNCOMPRESSED
)
if_not_ok(ok, 'Could not write box shape')

xyz = np.ascontiguousarray(positions, dtype=np.float32)
if time is None:
ok = tng_util_pos_write(self._traj, self.step, &xyz[0, 0])
else:
ok = tng_util_pos_with_time_write(self._traj, self.step,
time, &xyz[0, 0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why write the time twice? That seems strange. Maybe we need to ask in the gromacs devel list

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know. But I am also not sure it is the way to do. I am still trying to figure things out.

if_not_ok(ok, 'Could not write positions')

# finish frame set to write step, hashing should be configurable
tng_frame_set_write(self._traj, TNG_USE_HASH)

self.step += 1
self._n_frames += 1

def seek(self, step):
"""Move the file handle to a particular frame number

Expand Down Expand Up @@ -319,3 +536,8 @@ cdef class TNGFile:
else:
raise TypeError("Trajectories must be an indexed using an integer,"
" slice or list of indices")


def if_not_ok(ok, message, exception=IOError):
if ok != TNG_SUCCESS:
raise exception(message)
22 changes: 13 additions & 9 deletions tests/test_tng.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ def test_open_missing_file_mode_r(MISSING_FILEPATH):
assert 'does not exist' in str(excinfo.value)


def test_open_mode_w(MISSING_FILEPATH):
with pytest.raises(NotImplementedError):
pytng.TNGFile(MISSING_FILEPATH, mode='w')


def test_open_invalide_mode(GMX_REF_FILEPATH):
with pytest.raises(IOError) as excinfo:
pytng.TNGFile(GMX_REF_FILEPATH, mode='invalid')
Expand Down Expand Up @@ -150,7 +145,6 @@ def test_seek_IndexError(idx, GMX_REF_FILEPATH):
tng[idx]


@pytest.mark.skip(reason="Write mode not implemented yet.")
def test_seek_write(MISSING_FILEPATH):
with pytng.TNGFile(MISSING_FILEPATH, mode='w') as tng:
with pytest.raises(IOError) as excinfo:
Expand Down Expand Up @@ -210,9 +204,19 @@ def test_read_not_open(GMX_REF_FILEPATH):
assert 'No file opened' in str(excinfo.value)


@pytest.mark.skip(reason="Write mode not implemented yet.")
def test_read_not_mode_r(MISSING_FILEPATH):
with pytest.raises(IOError) as excinfo:
with pytest.raises(IOError, match='Reading only allow'):
with pytng.TNGFile(MISSING_FILEPATH, mode='w') as tng:
tng.read()
assert 'Reading only allow in mode "r"' in str(excinfo.value)


def test_writting(GMX_REF_FILEPATH, tmpdir):
outfile = str(tmpdir.join('foo.tng'))
with pytng.TNGFile(GMX_REF_FILEPATH) as ref, pytng.TNGFile(outfile,
'w') as out:
for ts in ref:
out.write(ts.positions, ts.box, ts.time)

with pytng.TNGFile(GMX_REF_FILEPATH) as ref, pytng.TNGFile(outfile) as out:
for r, o in zip(ref, out):
np.testing.assert_almost_equal(r.positions, o.positions, decimal=4)