import asyncio
import os
import yaml
import io
import tarfile
import logging
from gssa.family import Family
from gssa.docker import Submitter
from gssa.parameters import convert_parameter
import gssa.error
[docs]class DockerFamily(Family):
_retrievable_files = ['logs/job.err', 'logs/job.out']
def __init__(self, files_required):
self._needles = {}
self._needle_order = {}
self._files_required = files_required
self._submitter = Submitter()
# Needle index can be either needle index (as given in XML input) or an
# integer n indicating the nth needle in the order of the needles XML block
[docs] def get_needle_parameter(self, needle_index, key, try_json=True):
if needle_index not in self._needles and needle_index in self._needle_order:
needle_index = self._needle_order[needle_index]
value = self.get_parameter(key, try_json, self._needles[needle_index]["parameters"])
return value
[docs] def get_parameter(self, key, try_json=True, parameters=None):
if parameters is None:
parameters = self._parameters
if key not in parameters:
return None
parameter, typ = parameters[key]
return convert_parameter(parameter, typ, try_json)
@asyncio.coroutine
[docs] def prepare_simulation(self, working_directory):
return True
[docs] def get_percentage_socket_location(self, working_directory):
return os.path.join(working_directory, 'update.sock')
@asyncio.coroutine
[docs] def simulate(self, working_directory):
proceed = yield from self.prepare_simulation(working_directory)
if not proceed:
logging.warn("Prepare simulation told us not to proceed")
return False
update_socket = self.get_percentage_socket_location(working_directory)
os.chmod(update_socket, 0o777)
self._submitter.set_update_socket(update_socket)
logging.debug("Set update socket")
regions_yaml = os.path.join(working_directory, "input", "regions.yml")
regions = self._regions
with open(regions_yaml, "w") as f:
yaml.dump(regions, f, default_flow_style=False)
self._submitter.add_input(regions_yaml)
logging.debug("Wrote regions.yml")
parameters_yaml = os.path.join(working_directory, "input", "parameters.yml")
parameters = self._parameters
for k, v in parameters.items():
parameters[k] = [v[1], v[0]]
with open(parameters_yaml, "w") as f:
yaml.dump(parameters, f, default_flow_style=False)
self._submitter.add_input(parameters_yaml)
logging.debug("Wrote parameters.yml")
needle_parameters_yaml = os.path.join(working_directory, "input", "needle_parameters.yml")
for j, w in self._needles.items():
needle_parameters = w['parameters']
for k, v in needle_parameters.items():
needle_parameters[k] = [v[1], v[0]]
self._needles[j]['index'] = j
self._needles[j]['parameters'] = needle_parameters
with open(needle_parameters_yaml, "w") as f:
yaml.dump_all(self._needles.values(), f, default_flow_style=False)
self._submitter.add_input(needle_parameters_yaml)
logging.debug("Wrote needle_parameters.yml")
definition_tar = os.path.join("input", "start.tar.gz")
self._submitter.add_input(os.path.join(working_directory, definition_tar))
magic_script = None
if self._definition is not None:
try:
declared_parameters, python_script = self._definition.split("\n==========ENDPARAMETERS========\n")
except ValueError:
files = (('start.py', self._definition),)
else:
files = (('start.py', python_script), ('parameters.yml', declared_parameters))
tar = tarfile.open(os.path.join(working_directory, definition_tar), "w:gz")
for name, content in files:
encoded_definition = content.encode('utf-8')
stringio = io.BytesIO(encoded_definition)
info = tarfile.TarInfo(name=name)
info.size = len(encoded_definition)
tar.addfile(tarinfo=info, fileobj=stringio)
tar.close()
logging.debug("Created definition tarball")
# Need to make sure this is last uploaded
if definition_tar in self._files_required:
del self._files_required[definition_tar]
logging.debug("Removing definition of tar from files required")
loop = asyncio.get_event_loop()
try:
logging.debug("Submitting")
outcome = yield from self._submitter.run_script(
loop,
working_directory,
self._docker_image,
self._files_required.keys(),
magic_script
)
logging.debug("DONE")
except gssa.error.ErrorException as e:
outcome = e.get_error()
logging.debug("Failed [%s]" % str(e))
return outcome
@asyncio.coroutine
[docs] def clean(self):
yield from self._submitter.destroy()
self._submitter.finalize()
[docs] def load_definition(self, xml, parameters, algorithms):
self.load_core_definition(xml, parameters, algorithms)
[docs] def retrieve_files(self, destination):
for f in self._retrievable_files:
logging.debug("{fm} -> {to}".format(fm=f, to=destination))
logging.debug(self._submitter.copy_output(f, destination))
@asyncio.coroutine
[docs] def logs(self, only=None):
logs = yield from self._submitter.logs(only)
if only is not None:
return {only: logs[only]} if only in logs else {}
return logs
@asyncio.coroutine
[docs] def cancel(self):
success = yield from self._submitter.cancel()
return success