diff --git a/source/pysph/solver/controller.py b/source/pysph/solver/controller.py index 7a507d9..3997ad9 100644 --- a/source/pysph/solver/controller.py +++ b/source/pysph/solver/controller.py @@ -171,6 +171,9 @@ def set_blocking(self, block): def get_blocking(self): ''' get the blocking mode ( True/False ) ''' return self.block + + def ping(self): + return True def on_root_proc(f): ''' run the decorated function only on the root proc ''' @@ -212,6 +215,7 @@ def __init__(self, solver, comm=None): logger.info('CommandManager: using comm: %s'%self.comm) self.solver = solver self.interfaces = [] + self.func_dict = {} self.rlock = threading.RLock() self.res_lock = threading.Lock() self.plock = threading.Condition() @@ -238,6 +242,11 @@ def add_interface(self, callable, block=True): thr.daemon = True thr.start() return thr + + def add_function(self, callable, interval=1): + ''' add a function to to be called every `interval` iterations ''' + l = self.func_dict[interval] = self.func_dict.get(interval, []) + l.append(callable) def execute_commands(self, solver): ''' called by the solver after each timestep ''' @@ -247,6 +256,11 @@ def execute_commands(self, solver): self.run_queued_commands() logger.info('control handler: count=%d'%solver.count) + for interval in self.func_dict: + if solver.count%interval == 0: + for func in self.func_dict[interval]: + func(solver) + self.wait_for_cmd() def wait_for_cmd(self): diff --git a/source/pysph/solver/solver_interfaces.py b/source/pysph/solver/solver_interfaces.py index f10cca3..6be1eb4 100644 --- a/source/pysph/solver/solver_interfaces.py +++ b/source/pysph/solver/solver_interfaces.py @@ -1,23 +1,24 @@ import threading import os - +import socket from SimpleXMLRPCServer import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler -from multiprocessing.managers import BaseManager from SimpleHTTPServer import SimpleHTTPRequestHandler +from multiprocessing.managers import BaseManager, BaseProxy + class MultiprocessingInterface(BaseManager): - """ A multiprocessing interface to the solver command_manager + """ A multiprocessing interface to the solver controller - This object exports a command_manager instance proxy over the multiprocessing + This object exports a controller instance proxy over the multiprocessing interface. Control actions can be performed by connecting to the interface - and calling methods on the command_manager proxy instance """ + and calling methods on the controller proxy instance """ def get_controller(self): - return self.command_manager + return self.controller def start(self, controller): - self.command_manager = controller + self.controller = controller self.register('get_controller', self.get_controller) self.get_server().serve_forever() @@ -32,13 +33,28 @@ def __init__(self, address=None, authkey=None, serializer='pickle', start=True): if start: self.start() - def start(self): + def start(self, connect=True): self.interfaces = [] + + # to work around a python caching bug + # http://stackoverflow.com/questions/3649458/broken-pipe-when-using-python-multiprocessing-managers-basemanager-syncmanager + if self.address in BaseProxy._address_to_local: + del BaseProxy._address_to_local[self.address][0].connection + self.register('get_controller') - self.connect() - self.controller = self.get_controller() + if connect: + self.connect() + self.controller = self.get_controller() self.run(self.controller) + @staticmethod + def is_available(address): + try: + socket.create_connection(address, 1).close() + return True + except socket.error: + return False + def run(self, controller): pass @@ -83,7 +99,7 @@ def end_headers(self): SimpleXMLRPCRequestHandler.end_headers(self) class XMLRPCInterface(SimpleXMLRPCServer): - """ An XML-RPC interface to the solver command_manager + """ An XML-RPC interface to the solver controller Currently cannot work with objects which cannot be marshalled (which is basically most custom classes, most importantly @@ -101,7 +117,7 @@ def start(self, controller): class CommandlineInterface(object): - """ command-line interface to the solver command_manager """ + """ command-line interface to the solver controller """ def start(self, controller): while True: try: diff --git a/source/pysph/tools/mayavi_viewer.py b/source/pysph/tools/mayavi_viewer.py index 3295df1..ef8be1f 100644 --- a/source/pysph/tools/mayavi_viewer.py +++ b/source/pysph/tools/mayavi_viewer.py @@ -1,4 +1,4 @@ -"""A particle viewer using Mayavi. +a"""A particle viewer using Mayavi. This code uses the :py:class:`MultiprocessingClient` solver interface to communicate with a running solver and displays the particles using @@ -8,9 +8,10 @@ import sys import math import numpy +import socket from enthought.traits.api import (HasTraits, Instance, on_trait_change, - List, Str, Int, Range, Float, Bool, Password) + List, Str, Int, Range, Float, Bool, Password, Property) from enthought.traits.ui.api import (View, Item, Group, HSplit, ListEditor, EnumEditor, TitleEditor) from enthought.mayavi.core.api import PipelineBase @@ -23,6 +24,8 @@ from pysph.base.api import ParticleArray from pysph.solver.solver_interfaces import MultiprocessingClient +import logging +logger = logging.getLogger() def set_arrays(dataset, particle_array): """ Code to add all the arrays to a dataset given a particle array.""" @@ -97,7 +100,7 @@ class ParticleArrayHelper(HasTraits): def _particle_array_changed(self, pa): self.name = pa.name # Setup the scalars. - self.scalar_list = pa.properties.keys() + self.scalar_list = sorted(pa.properties.keys()) # Update the plot. x, y, z, u, v, w = pa.x, pa.y, pa.z, pa.u, pa.v, pa.w @@ -141,22 +144,24 @@ class MayaviViewer(HasTraits): are queried from a running solver. """ - particle_arrays = List(Instance(ParticleArrayHelper)) - pa_names = List(Str) + particle_arrays = List(Instance(ParticleArrayHelper), []) + pa_names = List(Str, []) client = Instance(MultiprocessingClient) host = Str('localhost', desc='machine to connect to') port = Int(8800, desc='port to use to connect to solver') authkey = Password('pysph', desc='authorization key') + host_changed = Bool(False) scene = Instance(MlabSceneModel, ()) + controller = Property() ######################################## # Timer traits. timer = Instance(Timer) - interval = Range(0.5, 20.0, 5.0, + interval = Range(2, 20.0, 5.0, desc='frequency in seconds with which plot is updated') ######################################## @@ -169,12 +174,12 @@ class MayaviViewer(HasTraits): # The layout of the dialog created view = View(HSplit( Group( - #Group( - # Item(name='host'), - # Item(name='port'), - # Item(name='authkey'), - # label='Connection', - # ), + Group( + Item(name='host'), + Item(name='port'), + Item(name='authkey'), + label='Connection', + ), Group( Item(name='current_time'), Item(name='iteration'), @@ -202,7 +207,7 @@ class MayaviViewer(HasTraits): ###################################################################### # `MayaviViewer` interface. - ###################################################################### + ###################################################################### @on_trait_change('scene.activated') def start_timer(self): # Just accessing the timer will start it. @@ -212,31 +217,85 @@ def start_timer(self): @on_trait_change('scene.activated') def update_plot(self): - c = self.client.controller - self.iteration = c.get_count() - self.current_time = c.get_t() + # do not update if solver is paused + if self.pause_solver: + return + controller = self.controller + if controller is None: + return + + self.current_time = controller.get_t() for idx, name in enumerate(self.pa_names): - pa = c.get_named_particle_array(name) + pa = controller.get_named_particle_array(name) self.particle_arrays[idx].particle_array = pa ###################################################################### # Private interface. ###################################################################### - def _client_default(self): - return MultiprocessingClient(address=(self.host, self.port), - authkey=self.authkey) + @on_trait_change('host,port,authkey') + def _mark_reconnect(self): + self.host_changed = True - def _pa_names_default(self): - c = self.client.controller - return c.get_particle_array_names() + def _client_default(self): + try: + if: + MultiprocessingClient.is_available((self.host, self.port)) + return MultiprocessingClient(address=(self.host, self.port), + authkey=self.authkey) + except socket.error, e: + logger.info('Could not connect: check if solver is running') + return None + + def _get_controller(self): + ''' get the controller, also sets the iteration count ''' + reconnect = self.host_changed + + if not reconnect: + try: + c = self.client.controller + self.iteration = c.get_count() + except Exception as e: + logger.info('Error: no connection or connection closed: reconnecting') + reconnect = True + self.client = None + + if reconnect: + self.host_changed = False + try: + if MultiprocessingClient.is_available((self.host, self.port)): + self.client = MultiprocessingClient(address=(self.host, self.port), + authkey=self.authkey) + else: + return None + except Exception as e: + logger.info('Could not connect: check if solver is running') + return None + c = self.client.controller + self.iteration = c.get_count() + + return self.client.controller + + def _client_changed(self, old, new): + if self.client is None: + return + else: + self.pa_names = self.client.controller.get_particle_array_names() - def _particle_arrays_default(self): - r = [ParticleArrayHelper(scene=self.scene, name=x) for x in - self.pa_names] + for pa in self.particle_arrays: + if pa.plot is not None: + pa.plot.remove() + self.particle_arrays = [ParticleArrayHelper(scene=self.scene, name=x) for x in + self.pa_names] # Turn on the legend for the first particle array. - if len(r) > 0: - r[0].show_legend = True - return r + if len(self.particle_arrays) > 0: + self.particle_arrays[0].show_legend = True + + def _timer_event(self): + # catch all Exceptions else timer will stop + try: + self.update_plot() + except Exception: + pass def _interval_changed(self, value): t = self.timer @@ -247,7 +306,7 @@ def _interval_changed(self, value): t.Start(int(value*1000)) def _timer_default(self): - return Timer(int(self.interval*1000), self.update_plot) + return Timer(int(self.interval*1000), self._timer_event) def _pause_solver_changed(self, value): c = self.client.controller